# LiteMatting Model Training in Google Colab

This notebook provides a complete training pipeline for the LiteMatting model using the Adobe Composition-1K dataset. The model is designed for high-quality image matting (background removal) with efficient mobile-friendly architecture.

## Features:
- ✅ Mobile-optimized matting network
- ✅ Adobe Composition-1K dataset support
- ✅ GPU accelerated training
- ✅ Comprehensive loss functions (Alpha, Composition, Gradient, Laplacian)
- ✅ Real-time visualization
- ✅ Model checkpointing
- ✅ Evaluation metrics (SAD, MSE, Gradient Error)

## Requirements:
- Google Colab with GPU enabled
- Adobe Composition-1K dataset (download link provided below)
- Student account for academic dataset access

---

## 1. Install Required Libraries

First, let's install all the necessary dependencies for training the LiteMatting model.

In [None]:
# Install required packages
!pip install torch torchvision torchaudio
!pip install opencv-python-headless
!pip install albumentations
!pip install imgaug
!pip install scikit-image
!pip install tqdm
!pip install tensorboard
!pip install matplotlib
!pip install pillow
!pip install scipy

# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("⚠️ GPU not available. Please enable GPU in Runtime -> Change runtime type -> Hardware accelerator -> GPU")

## 2. Download and Prepare the Adobe Composition-1K Dataset

The Adobe Composition-1K dataset is the standard benchmark for image matting. You'll need to download it manually due to academic licensing requirements.

In [None]:
# Dataset download instructions
print("📖 DATASET DOWNLOAD INSTRUCTIONS:")
print("1. Visit: https://sites.google.com/view/deepimagematting")
print("2. Request access to Adobe Composition-1K dataset with your student email")
print("3. Download and extract to Google Drive")
print("4. Upload to Colab or mount Google Drive")
print()
print("Expected directory structure:")
print("Composition-1K/")
print("├── Training_set/")
print("│   ├── fg/           # Foreground images")
print("│   ├── alpha/        # Alpha mattes (ground truth)")
print("│   ├── bg/           # Background images") 
print("│   ├── merged/       # Composite images")
print("│   └── trimaps/      # Trimap annotations")
print("└── Test_set/")
print("    ├── fg/")
print("    ├── alpha/")
print("    ├── bg/")
print("    ├── merged/")
print("    └── trimaps/")
print()

# Mount Google Drive (if dataset is stored there)
from google.colab import drive
drive.mount('/content/drive')

# Set dataset path (update this to your dataset location)
DATASET_PATH = "/content/drive/MyDrive/Composition-1K"  # Update this path!

import os
if os.path.exists(DATASET_PATH):
    print(f"✅ Dataset found at: {DATASET_PATH}")
    
    # Check dataset structure
    train_path = os.path.join(DATASET_PATH, "Training_set")
    test_path = os.path.join(DATASET_PATH, "Test_set")
    
    if os.path.exists(train_path):
        print(f"✅ Training set found")
        for subfolder in ['fg', 'alpha', 'bg', 'merged', 'trimaps']:
            subfolder_path = os.path.join(train_path, subfolder)
            if os.path.exists(subfolder_path):
                count = len([f for f in os.listdir(subfolder_path) if f.endswith(('.jpg', '.png', '.jpeg'))])
                print(f"   - {subfolder}: {count} files")
    
    if os.path.exists(test_path):
        print(f"✅ Test set found")
        for subfolder in ['fg', 'alpha', 'bg', 'merged', 'trimaps']:
            subfolder_path = os.path.join(test_path, subfolder)
            if os.path.exists(subfolder_path):
                count = len([f for f in os.listdir(subfolder_path) if f.endswith(('.jpg', '.png', '.jpeg'))])
                print(f"   - {subfolder}: {count} files")
else:
    print(f"❌ Dataset not found at: {DATASET_PATH}")
    print("Please update DATASET_PATH variable to point to your dataset location")

## 3. Data Preprocessing and Augmentation

Implement PyTorch Dataset and DataLoader classes for efficient data loading with proper augmentations.

In [None]:
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import random
import matplotlib.pyplot as plt

class Composition1KDataset(Dataset):
    def __init__(self, data_root, mode='train', transform=None, input_size=512):
        """
        Adobe Composition-1K Dataset
        
        Args:
            data_root: Path to Composition-1K dataset
            mode: 'train' or 'test'
            transform: Data augmentation transforms
            input_size: Input image size for training
        """
        self.data_root = data_root
        self.mode = mode
        self.input_size = input_size
        self.transform = transform
        
        if mode == 'train':
            self.composite_path = os.path.join(data_root, 'Training_set', 'merged')
            self.alpha_path = os.path.join(data_root, 'Training_set', 'alpha')
            self.trimap_path = os.path.join(data_root, 'Training_set', 'trimaps')
        else:
            self.composite_path = os.path.join(data_root, 'Test_set', 'merged')
            self.alpha_path = os.path.join(data_root, 'Test_set', 'alpha')
            self.trimap_path = os.path.join(data_root, 'Test_set', 'trimaps')
        
        # Get list of images
        self.image_list = []
        if os.path.exists(self.composite_path):
            self.image_list = sorted([f for f in os.listdir(self.composite_path) 
                                    if f.endswith(('.jpg', '.png', '.jpeg'))])
        
        print(f"Found {len(self.image_list)} images in {mode} set")
    
    def __len__(self):
        return len(self.image_list)
    
    def load_image(self, path):
        """Load image and convert to RGB"""
        image = cv2.imread(path)
        if image is None:
            raise ValueError(f"Could not load image: {path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return image
    
    def load_alpha(self, path):
        """Load alpha matte"""
        alpha = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if alpha is None:
            raise ValueError(f"Could not load alpha: {path}")
        return alpha.astype(np.float32) / 255.0
    
    def generate_trimap(self, alpha, k_size=10):
        """Generate trimap from alpha matte"""
        alpha_uint8 = (alpha * 255).astype(np.uint8)
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size))
        
        eroded = cv2.erode(alpha_uint8, kernel, iterations=1)
        dilated = cv2.dilate(alpha_uint8, kernel, iterations=1)
        
        trimap = np.zeros_like(alpha_uint8)
        trimap[eroded > 128] = 255  # Foreground
        trimap[(dilated > 128) & (eroded <= 128)] = 128  # Unknown
        return trimap
    
    def trimap_to_clicks(self, trimap):
        """Convert trimap to click guidance maps"""
        h, w = trimap.shape
        clicks = np.zeros((h, w, 2), dtype=np.float32)
        
        fg_mask = (trimap == 255).astype(np.uint8)
        bg_mask = (trimap == 0).astype(np.uint8)
        
        if np.sum(fg_mask) > 0:
            dt_fg = cv2.distanceTransform(1 - fg_mask, cv2.DIST_L2, 0)
            clicks[:, :, 0] = np.exp(-dt_fg**2 / (2 * (0.05 * 320)**2))
        
        if np.sum(bg_mask) > 0:
            dt_bg = cv2.distanceTransform(1 - bg_mask, cv2.DIST_L2, 0)
            clicks[:, :, 1] = np.exp(-dt_bg**2 / (2 * (0.05 * 320)**2))
        
        return clicks
    
    def __getitem__(self, idx):
        image_name = self.image_list[idx]
        base_name = os.path.splitext(image_name)[0]
        
        # Load composite image
        composite_path = os.path.join(self.composite_path, image_name)
        composite = self.load_image(composite_path)
        
        # Load alpha matte
        alpha_path = os.path.join(self.alpha_path, image_name)
        if not os.path.exists(alpha_path):
            alpha_path = os.path.join(self.alpha_path, base_name + '.png')
        alpha = self.load_alpha(alpha_path)
        
        # Load or generate trimap
        trimap_file = os.path.join(self.trimap_path, base_name + '.png')
        if os.path.exists(trimap_file):
            trimap = cv2.imread(trimap_file, cv2.IMREAD_GRAYSCALE)
        else:
            trimap = self.generate_trimap(alpha)
        
        # Convert trimap to one-hot encoding
        trimap_onehot = np.zeros((trimap.shape[0], trimap.shape[1], 3), dtype=np.float32)
        trimap_onehot[:, :, 0] = (trimap == 0).astype(np.float32)    # Background
        trimap_onehot[:, :, 1] = (trimap == 128).astype(np.float32)  # Unknown
        trimap_onehot[:, :, 2] = (trimap == 255).astype(np.float32)  # Foreground
        
        # Generate click guidance
        clicks = self.trimap_to_clicks(trimap)
        
        # Apply transforms
        if self.transform:
            transformed = self.transform(
                image=composite,
                mask=alpha,
                masks=[trimap_onehot[:,:,0], trimap_onehot[:,:,1], trimap_onehot[:,:,2], 
                       clicks[:,:,0], clicks[:,:,1]]
            )
            composite = transformed['image']
            alpha = transformed['mask']
            trimap_bg, trimap_unk, trimap_fg, click_fg, click_bg = transformed['masks']
            
            trimap_onehot = np.stack([trimap_bg, trimap_unk, trimap_fg], axis=2)
            clicks = np.stack([click_fg, click_bg], axis=2)
        
        # Convert to tensors
        if not torch.is_tensor(composite):
            composite = torch.from_numpy(composite.transpose(2, 0, 1)).float() / 255.0
        if not torch.is_tensor(alpha):
            alpha = torch.from_numpy(alpha).unsqueeze(0).float()
        
        trimap_tensor = torch.from_numpy(trimap_onehot.transpose(2, 0, 1)).float()
        clicks_tensor = torch.from_numpy(clicks.transpose(2, 0, 1)).float()
        
        return {
            'image': composite,
            'alpha': alpha,
            'trimap': trimap_tensor,
            'clicks': clicks_tensor,
            'name': image_name
        }

def get_train_transforms(input_size=512):
    """Get training data augmentation transforms"""
    return A.Compose([
        A.RandomResizedCrop(input_size, input_size, scale=(0.8, 1.0)),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05, p=0.5),
        A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
        A.GaussianBlur(blur_limit=(3, 7), p=0.2),
    ], additional_targets={
        'mask': 'mask',
        'masks': 'masks'
    })

def get_val_transforms(input_size=512):
    """Get validation transforms"""
    return A.Compose([
        A.Resize(input_size, input_size),
    ], additional_targets={
        'mask': 'mask',
        'masks': 'masks'
    })

def collate_fn(batch):
    """Custom collate function for DataLoader"""
    images = torch.stack([item['image'] for item in batch])
    alphas = torch.stack([item['alpha'] for item in batch])
    trimaps = torch.stack([item['trimap'] for item in batch])
    clicks = torch.stack([item['clicks'] for item in batch])
    names = [item['name'] for item in batch]
    
    return {
        'image': images,
        'alpha': alphas,
        'trimap': trimaps,
        'clicks': clicks,
        'name': names
    }

print("✅ Dataset classes defined successfully!")

## 4. Define the MobileMatting Model

Copy the complete MobileMatting model architecture with all required components.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def GN(x):
    if x % 32 == 0:
        return nn.GroupNorm(32, x)
    if x % 16 == 0:
        return nn.GroupNorm(16, x)
    if x % 8 == 0:
        return nn.GroupNorm(8, x)
    return nn.GroupNorm(4, x)

class MSLPPM(nn.Module):
    def __init__(self, in_channels):
        super(MSLPPM, self).__init__()
        mid_channels = int(in_channels / 2)
        inter_channels = int(mid_channels / 4)
        self.trans = nn.Sequential(nn.Conv2d(in_channels, mid_channels, 1, 1, 0, bias=False),
                                   GN(mid_channels),
                                   nn.ReLU(True),
                                   nn.Conv2d(mid_channels, inter_channels, 1, 1, 0, bias=False),
                                   GN(inter_channels),
                                   nn.ReLU(True)
                                   )
        self.pool1 = nn.AdaptiveAvgPool2d((5, 5))
        self.pool2 = nn.AdaptiveAvgPool2d((13, 13))
        self.pool3 = nn.AdaptiveAvgPool2d((1, None))
        self.pool4 = nn.AdaptiveAvgPool2d((None, 1))
        self.pool5 = nn.AdaptiveAvgPool2d((15, 7))
        self.pool6 = nn.AdaptiveAvgPool2d((7, 15))
        self.pool7 = nn.AdaptiveAvgPool2d((23, 11))
        self.pool8 = nn.AdaptiveAvgPool2d((11, 23))

        self.conv1 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 3), 1, (1, 1), bias=False),
                                   GN(inter_channels),
                                   nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 3), 1, (1, 1), bias=False),
                                   GN(inter_channels),
                                   nn.ReLU(True))
        self.conv3 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (1, 3), 1, (0, 1), bias=False),
                                   GN(inter_channels),
                                   nn.ReLU(True))
        self.conv4 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 1), 1, (1, 0), bias=False),
                                   GN(inter_channels),
                                   nn.ReLU(True))
        self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (5, 3), 1, (2, 1), bias=False),
                                   GN(inter_channels),
                                   nn.ReLU(True))
        self.conv6 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 5), 1, (1, 2), bias=False),
                                   GN(inter_channels),
                                   nn.ReLU(True))
        self.conv7 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (7, 3), 1, (3, 1), bias=False),
                                   GN(inter_channels),
                                   nn.ReLU(True))
        self.conv8 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 7), 1, (1, 3), bias=False),
                                   GN(inter_channels),
                                   nn.ReLU(True))
        self.conv = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False),
                                  GN(inter_channels),
                                  nn.ReLU(True))

    def forward(self, x):
        _, _, h, w = x.size()
        x = self.trans(x)
        x1 = F.interpolate(self.conv1(self.pool1(x)), size=(h, w), mode='bilinear', align_corners=False)
        x2 = F.interpolate(self.conv2(self.pool2(x)), size=(h, w), mode='bilinear', align_corners=False)
        x3 = F.interpolate(self.conv3(self.pool3(x)), size=(h, w), mode='bilinear', align_corners=False)
        x4 = F.interpolate(self.conv4(self.pool4(x)), size=(h, w), mode='bilinear', align_corners=False)
        x5 = F.interpolate(self.conv5(self.pool5(x)), size=(h, w), mode='bilinear', align_corners=False)
        x6 = F.interpolate(self.conv6(self.pool6(x)), size=(h, w), mode='bilinear', align_corners=False)
        x7 = F.interpolate(self.conv7(self.pool7(x)), size=(h, w), mode='bilinear', align_corners=False)
        x8 = F.interpolate(self.conv8(self.pool8(x)), size=(h, w), mode='bilinear', align_corners=False)
        s = self.conv(F.relu_(x1 + x2))
        l = self.conv(F.relu_(x3 + x4))
        m = self.conv(F.relu_(x5 + x6))
        n = self.conv(F.relu_(x7 + x8))
        out = torch.cat([s, m, n, l], dim=1)
        return out

class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, dilation, expand_ratio, batch_norm):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        BatchNorm2d = batch_norm
        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup
        self.kernel_size = 3
        self.dilation = dilation

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
                BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
                BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                BatchNorm2d(oup),
            )

    def fixed_padding(self, inputs, kernel_size, dilation):
        kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
        pad_total = kernel_size_effective - 1
        pad_beg = pad_total // 2
        pad_end = pad_total - pad_beg
        padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
        return padded_inputs

    def forward(self, x):
        x_pad = self.fixed_padding(x, self.kernel_size, dilation=self.dilation)
        if self.use_res_connect:
            return x + self.conv(x_pad)
        else:
            return self.conv(x_pad)

class InvertedResidualLeaky(nn.Module):
    def __init__(self, inp, oup, stride, dilation, expand_ratio, batch_norm):
        super(InvertedResidualLeaky, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        BatchNorm2d = batch_norm
        hidden_dim = round(oup * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup
        self.kernel_size = 3
        self.dilation = dilation

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
                BatchNorm2d(hidden_dim),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                BatchNorm2d(hidden_dim),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
                BatchNorm2d(hidden_dim),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                BatchNorm2d(oup),
            )

    def fixed_padding(self, inputs, kernel_size, dilation):
        kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
        pad_total = kernel_size_effective - 1
        pad_beg = pad_total // 2
        pad_end = pad_total - pad_beg
        padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
        return padded_inputs

    def forward(self, x):
        x_pad = self.fixed_padding(x, self.kernel_size, dilation=self.dilation)
        if self.use_res_connect:
            return x + self.conv(x_pad)
        else:
            return self.conv(x_pad)

class PSPModule(nn.Module):
    def __init__(self, sizes=(1, 2, 3, 6)):
        super(PSPModule, self).__init__()
        self.stages = nn.ModuleList([self._make_stage(size) for size in sizes])

    def _make_stage(self, size):
        prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
        return prior

    def forward(self, feats):
        n, c, _, _ = feats.size()
        priors = [stage(feats).view(n, c, -1) for stage in self.stages]
        center = torch.cat(priors, -1)
        return center

class GFNB(nn.Module):
    def __init__(self, low_in_channels, high_in_channels, out_channels, key_channels, value_channels, dropout=0):
        super(GFNB, self).__init__()
        self.in_channels = low_in_channels
        self.out_channels = out_channels
        self.key_channels = key_channels
        self.value_channels = value_channels

        self.f_key = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0),
            GN(self.key_channels),
            nn.LeakyReLU(inplace=True)
        )
        self.f_query = nn.Sequential(
            nn.Conv2d(in_channels=high_in_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0),
            GN(self.key_channels),
            nn.LeakyReLU(inplace=True)
        )
        self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels, kernel_size=1,
                                 stride=1, padding=0)
        self.W = nn.Conv2d(in_channels=self.value_channels, out_channels=self.out_channels, kernel_size=1, stride=1,
                           padding=0)
        nn.init.constant_(self.W.weight, 0)
        nn.init.constant_(self.W.bias, 0)
        self.psp = PSPModule(sizes=(1, 3, 6, 8))

        self.conv_bn_dropout = nn.Sequential(
            nn.Conv2d(out_channels + high_in_channels, out_channels, kernel_size=1, padding=0),
            GN(out_channels),
            nn.Dropout2d(dropout)
        )

    def forward(self, low_feats, high_feats):
        batch_size, h, w = high_feats.size(0), high_feats.size(2), high_feats.size(3)
        value = self.f_value(low_feats)
        value = self.psp(value)
        value = value.permute(0, 2, 1)

        key = self.f_key(low_feats)
        key = self.psp(key)

        query = self.f_query(high_feats).view(batch_size, self.key_channels, -1)
        query = query.permute(0, 2, 1)

        sim_map = torch.matmul(query, key)
        sim_map = (self.key_channels ** -.5) * sim_map
        sim_map = F.softmax(sim_map, dim=-1)

        context = torch.matmul(sim_map, value)
        context = context.permute(0, 2, 1).contiguous()
        context = context.view(batch_size, self.value_channels, *high_feats.size()[2:])
        context = self.W(context)
        output = self.conv_bn_dropout(torch.cat([context, high_feats], 1))
        return output

def conv_bn(inp, oup, k=3, s=1, BatchNorm2d=nn.BatchNorm2d):
    return nn.Sequential(
        nn.Conv2d(inp, oup, k, s, padding=k // 2, bias=False),
        BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

class MobileMatting(nn.Module):
    def __init__(self):
        super(MobileMatting, self).__init__()
        output_stride = 32
        BatchNorm2d = GN
        width_mult = 1.
        self.width_mult = width_mult
        block = InvertedResidual
        blockleaky = InvertedResidualLeaky

        initial_channel = 32
        current_stride = 1
        rate = 1

        inverted_residual_setting = [
            [1, 32, 32, 1, 1, 1],
            [6, 32, 24, 2, 2, 1],
            [6, 24, 32, 3, 2, 1],
            [6, 32, 64, 4, 2, 1],
            [6, 64, 96, 3, 1, 1],
            [6, 96, 160, 3, 2, 1],
            [6, 240, 320, 1, 1, 1], ]

        initial_channel = int(initial_channel * width_mult)
        self.layerx = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            GN(32),
            nn.ReLU(True),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            GN(64),
            self._build_layer(block, [4, 64, 64, 4, 1, 1], BatchNorm2d)
        )
        self.layer0 = conv_bn(3 + 3 + 2, initial_channel, 3, 2, BatchNorm2d)

        current_stride *= 2

        for i, setting in enumerate(inverted_residual_setting):
            s = setting[4]
            if current_stride == output_stride:
                inverted_residual_setting[i][4] = 1
                rate *= s
                inverted_residual_setting[i][5] = rate
            else:
                current_stride *= s

        self.layer1 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1), BatchNorm2d(32), nn.ReLU(inplace=True))
        self.layer2 = self._build_layer(block, inverted_residual_setting[1], BatchNorm2d, downsample=True)
        self.layer3 = self._build_layer(block, inverted_residual_setting[2], BatchNorm2d, downsample=True)
        self.layer4 = self._build_layer(block, inverted_residual_setting[3], BatchNorm2d, downsample=True)
        self.layer5 = self._build_layer(block, inverted_residual_setting[4], BatchNorm2d)
        self.layer6 = self._build_layer(block, inverted_residual_setting[5], BatchNorm2d, downsample=True)
        self.layer7 = self._build_layer(block, inverted_residual_setting[6], BatchNorm2d)
        self.ppm = MSLPPM(320)
        self.gfnb = GFNB(64, 160, 80, 80, 80)
        self.dfpool = nn.Sequential(nn.Conv2d(320 + 160, 128, 1, 1, 0), nn.GroupNorm(16, 128), nn.LeakyReLU(True))
        self.uper1 = self._build_layer(blockleaky, [4, 160 + 128, 160, 3, 1, 1], BatchNorm2d)
        self.uper2 = self._build_layer(blockleaky, [4, 256, 128, 3, 1, 1], BatchNorm2d)
        self.uper3 = self._build_layer(blockleaky, [4, 160, 96, 3, 1, 1], BatchNorm2d)
        self.uper4 = self._build_layer(blockleaky, [4, 120, 64, 3, 1, 1], BatchNorm2d)
        self.uper5 = nn.Sequential(nn.Conv2d(96, 48, 3, 1, 1), BatchNorm2d(48), nn.PReLU(48),
                                   nn.Conv2d(48, 32, 3, 1, 1), BatchNorm2d(32), nn.PReLU(32))
        self.out = nn.Sequential(nn.Conv2d(32 + 6, 24, 3, 1, 1), nn.PReLU(24), nn.Conv2d(24, 12, 3, 1, 1), nn.PReLU(12),
                                 nn.Conv2d(12, 1, 3, 1, 1))
        self.up = nn.Upsample(scale_factor=2, mode='nearest')

    def _build_layer(self, block, layer_setting, batch_norm, downsample=False):
        t, p, c, n, s, d = layer_setting
        input_channel = int(p * self.width_mult)
        output_channel = int(c * self.width_mult)
        layers = []
        for i in range(n):
            if i == 0:
                d0 = d
                if downsample:
                    d0 = d // 2 if d > 1 else 1
                layers.append(block(input_channel, output_channel, s, d0, expand_ratio=t, batch_norm=batch_norm))
            else:
                layers.append(block(input_channel, output_channel, 1, d, expand_ratio=t, batch_norm=batch_norm))
            input_channel = output_channel
        return nn.Sequential(*layers)

    def forward(self, img, tri, sixc):
        input = torch.cat((img * 2. - 1., tri, sixc), dim=1)

        l0 = self.layer0(input)
        l1 = self.layer1(l0)
        l2 = self.layer2(l1)
        l3 = self.layer3(l2)
        l4 = self.layer4(l3)
        l5 = self.layer5(l4)
        l6 = self.layer6(l5)

        _, _, h, w = l6.shape
        unkown = torch.nn.functional.interpolate(tri[:, 1:2] + tri[:, 2:3], size=(h, w), mode='nearest')
        lx = self.layerx(img)
        l6_g = self.gfnb(lx, l6 * unkown)
        l6_c = torch.cat((l6, l6_g), dim=1)

        l7 = self.layer7(l6_c)
        feats = self.ppm(l7)
        l7 = torch.cat([l7, feats], 1)
        l7 = self.dfpool(l7)

        lmid = torch.cat((l6, l7), dim=1)
        lmid = self.uper1(lmid)
        lmid = torch.cat((self.up(lmid), l5), dim=1)
        lmid = self.uper2(lmid)
        lmid = torch.cat((self.up(lmid), l3), dim=1)
        lmid = self.uper3(lmid)
        lmid = torch.cat((self.up(lmid), l2), dim=1)
        lmid = self.uper4(lmid)
        lmid = torch.cat((self.up(lmid), l1), dim=1)
        lmid = self.uper5(lmid)
        lmid = self.up(lmid)
        lmid = torch.cat((lmid, img, tri), dim=1)
        lmid = self.out(lmid)
        lmid = torch.clamp(lmid, 0, 1)
        return lmid

# Test model creation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MobileMatting().to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("✅ MobileMatting model created successfully!")
print(f"📊 Total parameters: {total_params:,}")
print(f"📊 Trainable parameters: {trainable_params:,}")
print(f"💾 Model size: ~{total_params * 4 / 1024 / 1024:.1f} MB")

## 5. Define Loss Functions and Metrics

Implement comprehensive loss functions for matting including alpha loss, composition loss, gradient loss, and evaluation metrics.

In [None]:
def alpha_loss(pred_alpha, gt_alpha, trimap):
    """Alpha prediction loss (L1 loss in unknown regions)"""
    unknown_mask = (trimap[:, 1:2, :, :] > 0.5).float()
    diff = torch.abs(pred_alpha - gt_alpha)
    loss = torch.sum(diff * unknown_mask) / (torch.sum(unknown_mask) + 1e-8)
    return loss

def composition_loss(pred_alpha, gt_alpha, image, trimap):
    """Composition loss - compares composited images"""
    unknown_mask = (trimap[:, 1:2, :, :] > 0.5).float()
    pred_comp = pred_alpha * image
    gt_comp = gt_alpha * image
    diff = torch.abs(pred_comp - gt_comp)
    loss = torch.sum(diff * unknown_mask) / (torch.sum(unknown_mask) + 1e-8)
    return loss

def gradient_loss(pred_alpha, gt_alpha, trimap):
    """Gradient loss to preserve fine details"""
    unknown_mask = (trimap[:, 1:2, :, :] > 0.5).float()
    
    # Sobel filters for gradient computation
    sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], 
                          dtype=torch.float32, device=pred_alpha.device).view(1, 1, 3, 3)
    sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], 
                          dtype=torch.float32, device=pred_alpha.device).view(1, 1, 3, 3)
    
    # Compute gradients
    pred_grad_x = F.conv2d(pred_alpha, sobel_x, padding=1)
    pred_grad_y = F.conv2d(pred_alpha, sobel_y, padding=1)
    gt_grad_x = F.conv2d(gt_alpha, sobel_x, padding=1)
    gt_grad_y = F.conv2d(gt_alpha, sobel_y, padding=1)
    
    # Gradient magnitude
    pred_grad = torch.sqrt(pred_grad_x**2 + pred_grad_y**2 + 1e-8)
    gt_grad = torch.sqrt(gt_grad_x**2 + gt_grad_y**2 + 1e-8)
    
    # L1 loss on gradients in unknown regions
    diff = torch.abs(pred_grad - gt_grad)
    loss = torch.sum(diff * unknown_mask) / (torch.sum(unknown_mask) + 1e-8)
    return loss

def laplacian_loss(pred_alpha, gt_alpha, trimap):
    """Laplacian loss for smoothness"""
    unknown_mask = (trimap[:, 1:2, :, :] > 0.5).float()
    
    # Laplacian kernel
    laplacian_kernel = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], 
                                   dtype=torch.float32, device=pred_alpha.device).view(1, 1, 3, 3)
    
    # Apply Laplacian
    pred_lap = F.conv2d(pred_alpha, laplacian_kernel, padding=1)
    gt_lap = F.conv2d(gt_alpha, laplacian_kernel, padding=1)
    
    # L1 loss on Laplacian in unknown regions
    diff = torch.abs(pred_lap - gt_lap)
    loss = torch.sum(diff * unknown_mask) / (torch.sum(unknown_mask) + 1e-8)
    return loss

class MattingLoss(nn.Module):
    """Combined matting loss function"""
    def __init__(self, alpha_weight=1.0, comp_weight=1.0, grad_weight=1.0, lap_weight=1.0):
        super(MattingLoss, self).__init__()
        self.alpha_weight = alpha_weight
        self.comp_weight = comp_weight
        self.grad_weight = grad_weight
        self.lap_weight = lap_weight
    
    def forward(self, pred_alpha, gt_alpha, image, trimap):
        losses = {}
        
        # Alpha prediction loss
        losses['alpha'] = alpha_loss(pred_alpha, gt_alpha, trimap)
        
        # Composition loss
        losses['composition'] = composition_loss(pred_alpha, gt_alpha, image, trimap)
        
        # Gradient loss
        losses['gradient'] = gradient_loss(pred_alpha, gt_alpha, trimap)
        
        # Laplacian loss
        losses['laplacian'] = laplacian_loss(pred_alpha, gt_alpha, trimap)
        
        # Total loss
        total_loss = (self.alpha_weight * losses['alpha'] + 
                     self.comp_weight * losses['composition'] + 
                     self.grad_weight * losses['gradient'] + 
                     self.lap_weight * losses['laplacian'])
        
        losses['total'] = total_loss
        return losses

# Evaluation metrics
def compute_sad(pred_alpha, gt_alpha, trimap):
    """Sum of Absolute Differences"""
    unknown_mask = (trimap[:, 1:2, :, :] > 0.5).float()
    diff = torch.abs(pred_alpha - gt_alpha) * unknown_mask
    sad = torch.sum(diff, dim=[1, 2, 3]) / 1000.0  # Scale to thousands
    return sad.mean()

def compute_mse(pred_alpha, gt_alpha, trimap):
    """Mean Squared Error in unknown regions"""
    unknown_mask = (trimap[:, 1:2, :, :] > 0.5).float()
    diff = (pred_alpha - gt_alpha) ** 2 * unknown_mask
    mse = torch.sum(diff, dim=[1, 2, 3]) / torch.sum(unknown_mask, dim=[1, 2, 3])
    return mse.mean()

def compute_gradient_error(pred_alpha, gt_alpha, trimap):
    """Gradient error in unknown regions"""
    unknown_mask = (trimap[:, 1:2, :, :] > 0.5).float()
    
    # Sobel filters
    sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], 
                          dtype=torch.float32, device=pred_alpha.device).view(1, 1, 3, 3)
    sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], 
                          dtype=torch.float32, device=pred_alpha.device).view(1, 1, 3, 3)
    
    # Compute gradients
    pred_grad_x = F.conv2d(pred_alpha, sobel_x, padding=1)
    pred_grad_y = F.conv2d(pred_alpha, sobel_y, padding=1)
    gt_grad_x = F.conv2d(gt_alpha, sobel_x, padding=1)
    gt_grad_y = F.conv2d(gt_alpha, sobel_y, padding=1)
    
    # Gradient error
    grad_error = torch.sqrt((pred_grad_x - gt_grad_x)**2 + (pred_grad_y - gt_grad_y)**2 + 1e-8)
    grad_error = grad_error * unknown_mask
    
    error = torch.sum(grad_error, dim=[1, 2, 3]) / (torch.sum(unknown_mask, dim=[1, 2, 3]) + 1e-8)
    return error.mean() / 1000.0  # Scale to thousands

print("✅ Loss functions and metrics defined successfully!")

## 6. Set Up Training Configuration and Data Loaders

Configure training hyperparameters and create data loaders for training and validation.

In [None]:
# Training Configuration
class Config:
    # Data
    input_size = 512
    
    # Training
    batch_size = 4  # Reduced for Colab memory limits
    val_batch_size = 2
    epochs = 50
    learning_rate = 1e-4
    weight_decay = 1e-4
    num_workers = 2
    
    # Loss weights
    alpha_weight = 1.0
    comp_weight = 1.0
    grad_weight = 1.0
    lap_weight = 1.0
    
    # Logging and saving
    log_interval = 50
    save_interval = 5
    checkpoint_dir = "/content/checkpoints"
    log_dir = "/content/logs"

config = Config()

# Create directories
import os
os.makedirs(config.checkpoint_dir, exist_ok=True)
os.makedirs(config.log_dir, exist_ok=True)

# Create datasets (only if dataset exists)
if os.path.exists(DATASET_PATH):
    # Create transforms
    train_transform = get_train_transforms(config.input_size)
    val_transform = get_val_transforms(config.input_size)
    
    # Create datasets
    train_dataset = Composition1KDataset(
        DATASET_PATH,
        mode='train',
        transform=train_transform,
        input_size=config.input_size
    )
    
    val_dataset = Composition1KDataset(
        DATASET_PATH,
        mode='test',
        transform=val_transform,
        input_size=config.input_size
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.val_batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    print(f"✅ Data loaders created successfully!")
    print(f"📊 Training samples: {len(train_dataset)}")
    print(f"📊 Validation samples: {len(val_dataset)}")
    print(f"📊 Training batches: {len(train_loader)}")
    print(f"📊 Validation batches: {len(val_loader)}")
    
    # Test data loading
    sample_batch = next(iter(train_loader))
    print(f"\n🔍 Sample batch shapes:")
    print(f"   Image: {sample_batch['image'].shape}")
    print(f"   Alpha: {sample_batch['alpha'].shape}")
    print(f"   Trimap: {sample_batch['trimap'].shape}")
    print(f"   Clicks: {sample_batch['clicks'].shape}")
    
else:
    print("❌ Dataset not found. Please update DATASET_PATH and re-run this cell.")
    train_loader = None
    val_loader = None

## 7. Training Loop Implementation

Implement the main training loop with optimizer, scheduler, and logging.

In [None]:
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import time

# Initialize model, loss, optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MobileMatting().to(device)

# Loss function
criterion = MattingLoss(
    alpha_weight=config.alpha_weight,
    comp_weight=config.comp_weight,
    grad_weight=config.grad_weight,
    lap_weight=config.lap_weight
)

# Optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config.epochs,
    eta_min=config.learning_rate * 0.1
)

# Tensorboard writer
writer = SummaryWriter(config.log_dir)

# Training state
best_loss = float('inf')
global_step = 0

print("✅ Training setup completed!")
print(f"🎯 Device: {device}")
print(f"📈 Total epochs: {config.epochs}")
print(f"📚 Batch size: {config.batch_size}")
print(f"🧠 Learning rate: {config.learning_rate}")

# Memory optimization for Colab
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [None]:
def train_epoch(epoch):
    """Train for one epoch"""
    model.train()
    epoch_losses = {
        'total': 0.0,
        'alpha': 0.0,
        'composition': 0.0,
        'gradient': 0.0,
        'laplacian': 0.0
    }
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
    
    for batch_idx, batch in enumerate(progress_bar):
        global global_step
        
        # Move to device
        image = batch['image'].to(device)
        alpha_gt = batch['alpha'].to(device)
        trimap = batch['trimap'].to(device)
        clicks = batch['clicks'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        alpha_pred = model(image, trimap, clicks)
        
        # Compute loss
        losses = criterion(alpha_pred, alpha_gt, image, trimap)
        
        # Backward pass
        losses['total'].backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Update metrics
        for key in epoch_losses:
            epoch_losses[key] += losses[key].item()
        
        # Log to tensorboard
        if global_step % config.log_interval == 0:
            for key, value in losses.items():
                writer.add_scalar(f'Train/{key}_loss', value.item(), global_step)
            writer.add_scalar('Train/lr', optimizer.param_groups[0]['lr'], global_step)
        
        # Update progress bar
        progress_bar.set_postfix({
            'Loss': f"{losses['total'].item():.4f}",
            'Alpha': f"{losses['alpha'].item():.4f}",
            'LR': f"{optimizer.param_groups[0]['lr']:.6f}"
        })
        
        global_step += 1
        
        # Memory cleanup for Colab
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Average losses
    num_batches = len(train_loader)
    for key in epoch_losses:
        epoch_losses[key] /= num_batches
    
    return epoch_losses

def validate_epoch(epoch):
    """Validate the model"""
    model.eval()
    val_losses = {
        'total': 0.0,
        'alpha': 0.0,
        'composition': 0.0,
        'gradient': 0.0,
        'laplacian': 0.0
    }
    val_metrics = {
        'sad': 0.0,
        'mse': 0.0,
        'grad_error': 0.0
    }
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc="Validation")
        
        for batch in progress_bar:
            # Move to device
            image = batch['image'].to(device)
            alpha_gt = batch['alpha'].to(device)
            trimap = batch['trimap'].to(device)
            clicks = batch['clicks'].to(device)
            
            # Forward pass
            alpha_pred = model(image, trimap, clicks)
            
            # Compute loss
            losses = criterion(alpha_pred, alpha_gt, image, trimap)
            
            # Compute metrics
            sad = compute_sad(alpha_pred, alpha_gt, trimap)
            mse = compute_mse(alpha_pred, alpha_gt, trimap)
            grad_error = compute_gradient_error(alpha_pred, alpha_gt, trimap)
            
            # Update metrics
            for key in val_losses:
                val_losses[key] += losses[key].item()
            
            val_metrics['sad'] += sad.item()
            val_metrics['mse'] += mse.item()
            val_metrics['grad_error'] += grad_error.item()
            
            progress_bar.set_postfix({
                'Loss': f"{losses['total'].item():.4f}",
                'SAD': f"{sad.item():.2f}",
                'MSE': f"{mse.item():.4f}"
            })
    
    # Average metrics
    num_batches = len(val_loader)
    for key in val_losses:
        val_losses[key] /= num_batches
    for key in val_metrics:
        val_metrics[key] /= num_batches
    
    return val_losses, val_metrics

def save_checkpoint(epoch, is_best=False):
    """Save model checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_loss': best_loss,
        'global_step': global_step,
        'config': config
    }
    
    # Save latest checkpoint
    checkpoint_path = os.path.join(config.checkpoint_dir, 'latest_checkpoint.pth')
    torch.save(checkpoint, checkpoint_path)
    
    # Save best checkpoint
    if is_best:
        best_path = os.path.join(config.checkpoint_dir, 'best_checkpoint.pth')
        torch.save(checkpoint, best_path)
        print(f"💾 New best model saved with loss: {best_loss:.4f}")

print("✅ Training functions defined successfully!")

## 8. Start Training

Execute the main training loop. This will take several hours depending on your dataset size and hardware.

In [None]:
# Only run training if data loaders are available
if train_loader is not None and val_loader is not None:
    print("🚀 Starting training...")
    start_time = time.time()
    
    for epoch in range(config.epochs):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch+1}/{config.epochs}")
        print(f"{'='*60}")
        
        # Train
        train_losses = train_epoch(epoch)
        
        # Validate
        val_losses, val_metrics = validate_epoch(epoch)
        
        # Update learning rate
        scheduler.step()
        
        # Log epoch results
        print(f"\n📊 Epoch {epoch+1} Results:")
        print(f"   Train Loss: {train_losses['total']:.4f}")
        print(f"   Val Loss: {val_losses['total']:.4f}")
        print(f"   Val SAD: {val_metrics['sad']:.2f}")
        print(f"   Val MSE: {val_metrics['mse']:.4f}")
        print(f"   Val Grad Error: {val_metrics['grad_error']:.4f}")
        print(f"   Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Log to tensorboard
        for key, value in val_losses.items():
            writer.add_scalar(f'Val/{key}_loss', value, epoch)
        for key, value in val_metrics.items():
            writer.add_scalar(f'Val/{key}', value, epoch)
        
        # Save checkpoint
        is_best = val_losses['total'] < best_loss
        if is_best:
            best_loss = val_losses['total']
        
        if (epoch + 1) % config.save_interval == 0 or is_best:
            save_checkpoint(epoch, is_best)
        
        # Early stopping check (optional)
        # You can add early stopping logic here
        
    total_time = time.time() - start_time
    print(f"\n🎉 Training completed!")
    print(f"⏱️  Total time: {total_time/3600:.2f} hours")
    print(f"🏆 Best validation loss: {best_loss:.4f}")
    
    writer.close()
    
else:
    print("⚠️ Cannot start training - dataset not loaded.")
    print("Please ensure the dataset path is correct and re-run the data loading cells.")

## 9. Model Inference and Visualization

Test the trained model and visualize results on sample images.

In [None]:
def load_checkpoint(checkpoint_path):
    """Load model checkpoint"""
    if not os.path.exists(checkpoint_path):
        print(f"❌ Checkpoint not found: {checkpoint_path}")
        return False
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✅ Model loaded from checkpoint: {checkpoint_path}")
    print(f"📊 Epoch: {checkpoint['epoch']}")
    print(f"📊 Best loss: {checkpoint['best_loss']:.4f}")
    return True

def visualize_predictions(num_samples=4):
    """Visualize model predictions"""
    if val_loader is None:
        print("❌ Validation loader not available")
        return
    
    model.eval()
    
    # Get a batch of validation data
    sample_batch = next(iter(val_loader))
    
    with torch.no_grad():
        # Move to device
        images = sample_batch['image'][:num_samples].to(device)
        alphas_gt = sample_batch['alpha'][:num_samples].to(device)
        trimaps = sample_batch['trimap'][:num_samples].to(device)
        clicks = sample_batch['clicks'][:num_samples].to(device)
        names = sample_batch['name'][:num_samples]
        
        # Predict
        alphas_pred = model(images, trimaps, clicks)
        
        # Move back to CPU for visualization
        images = images.cpu()
        alphas_gt = alphas_gt.cpu()
        alphas_pred = alphas_pred.cpu()
        trimaps = trimaps.cpu()
        
        # Create visualization
        fig, axes = plt.subplots(num_samples, 5, figsize=(20, 4*num_samples))
        if num_samples == 1:
            axes = axes.reshape(1, -1)
        
        for i in range(num_samples):
            # Input image
            img = images[i].permute(1, 2, 0).numpy()
            axes[i, 0].imshow(img)
            axes[i, 0].set_title(f'Input Image\\n{names[i]}')
            axes[i, 0].axis('off')
            
            # Trimap (unknown region)
            trimap_vis = trimaps[i, 1].numpy()  # Unknown region
            axes[i, 1].imshow(trimap_vis, cmap='gray')
            axes[i, 1].set_title('Trimap (Unknown)')
            axes[i, 1].axis('off')
            
            # Ground truth alpha
            alpha_gt = alphas_gt[i, 0].numpy()
            axes[i, 2].imshow(alpha_gt, cmap='gray')
            axes[i, 2].set_title('Ground Truth Alpha')
            axes[i, 2].axis('off')
            
            # Predicted alpha
            alpha_pred = alphas_pred[i, 0].numpy()
            axes[i, 3].imshow(alpha_pred, cmap='gray')
            axes[i, 3].set_title('Predicted Alpha')
            axes[i, 3].axis('off')
            
            # Composited result (with new background)
            # Create a simple colored background for demo
            bg_color = np.array([0.2, 0.8, 0.2])  # Green background
            composite = alpha_pred[..., np.newaxis] * img + (1 - alpha_pred[..., np.newaxis]) * bg_color
            composite = np.clip(composite, 0, 1)
            axes[i, 4].imshow(composite)
            axes[i, 4].set_title('Composite Result')
            axes[i, 4].axis('off')
            
            # Calculate metrics for this sample
            unknown_mask = trimap_vis > 0.5
            if np.sum(unknown_mask) > 0:
                sad = np.sum(np.abs(alpha_pred - alpha_gt) * unknown_mask) / 1000.0
                mse = np.mean((alpha_pred - alpha_gt)**2 * unknown_mask)
                print(f"Sample {i+1} - SAD: {sad:.2f}, MSE: {mse:.4f}")
        
        plt.tight_layout()
        plt.show()

def infer_on_custom_image(image_path, trimap_path=None):
    """Run inference on a custom image"""
    model.eval()
    
    # Load image
    image = cv2.imread(image_path)
    if image is None:
        print(f"❌ Could not load image: {image_path}")
        return
    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    h, w = image.shape[:2]
    
    # Resize to model input size
    image_resized = cv2.resize(image, (config.input_size, config.input_size))
    image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float() / 255.0
    image_tensor = image_tensor.unsqueeze(0).to(device)
    
    # Create trimap (you can load from file or create automatically)
    if trimap_path and os.path.exists(trimap_path):
        trimap = cv2.imread(trimap_path, cv2.IMREAD_GRAYSCALE)
        trimap = cv2.resize(trimap, (config.input_size, config.input_size))
    else:
        # Simple automatic trimap generation (for demo)
        print("🔄 Generating automatic trimap...")
        # This is a very basic trimap - in practice you'd want better trimap generation
        trimap = np.ones((config.input_size, config.input_size), dtype=np.uint8) * 128  # All unknown
        # Create some foreground and background regions
        center = config.input_size // 2
        trimap[:50, :] = 0  # Top background
        trimap[-50:, :] = 0  # Bottom background
        trimap[:, :50] = 0  # Left background
        trimap[:, -50:] = 0  # Right background
        trimap[center-100:center+100, center-100:center+100] = 255  # Center foreground
    
    # Convert trimap to one-hot
    trimap_onehot = np.zeros((3, config.input_size, config.input_size), dtype=np.float32)
    trimap_onehot[0] = (trimap == 0).astype(np.float32)    # Background
    trimap_onehot[1] = (trimap == 128).astype(np.float32)  # Unknown
    trimap_onehot[2] = (trimap == 255).astype(np.float32)  # Foreground
    trimap_tensor = torch.from_numpy(trimap_onehot).unsqueeze(0).to(device)
    
    # Generate click guidance (simplified)
    clicks = np.zeros((2, config.input_size, config.input_size), dtype=np.float32)
    clicks_tensor = torch.from_numpy(clicks).unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        alpha_pred = model(image_tensor, trimap_tensor, clicks_tensor)
    
    # Resize alpha back to original size
    alpha_pred = alpha_pred.squeeze().cpu().numpy()
    alpha_pred = cv2.resize(alpha_pred, (w, h))
    
    # Visualize result
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(cv2.resize(trimap, (w, h)), cmap='gray')
    axes[1].set_title('Trimap')
    axes[1].axis('off')
    
    axes[2].imshow(alpha_pred, cmap='gray')
    axes[2].set_title('Predicted Alpha')
    axes[2].axis('off')
    
    # Composite with new background
    bg_color = np.array([0.2, 0.2, 0.8])  # Blue background
    composite = alpha_pred[..., np.newaxis] * (image/255.0) + (1 - alpha_pred[..., np.newaxis]) * bg_color
    composite = np.clip(composite, 0, 1)
    axes[3].imshow(composite)
    axes[3].set_title('Composite Result')
    axes[3].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return alpha_pred

print("✅ Inference functions defined!")

In [None]:
# Load the best trained model
best_checkpoint_path = os.path.join(config.checkpoint_dir, 'best_checkpoint.pth')
if os.path.exists(best_checkpoint_path):
    load_checkpoint(best_checkpoint_path)
    
    # Visualize predictions on validation set
    print("🎨 Visualizing predictions on validation samples...")
    visualize_predictions(num_samples=4)
    
else:
    print("⚠️ No trained model found. Please run training first or load a pre-trained model.")

# Example: Run inference on custom image (uncomment and modify path)
# custom_image_path = "/path/to/your/image.jpg"
# if os.path.exists(custom_image_path):
#     print("🔍 Running inference on custom image...")
#     alpha_result = infer_on_custom_image(custom_image_path)
# else:
#     print("📝 To test on custom image, update custom_image_path variable")

## 10. Model Export and Deployment

Export the trained model for deployment and create final summary.

In [None]:
# Export trained model for deployment
def export_model():
    """Export the trained model in different formats"""
    
    if not os.path.exists(best_checkpoint_path):
        print("❌ No trained model found to export")
        return
    
    # Create export directory
    export_dir = "/content/exported_models"
    os.makedirs(export_dir, exist_ok=True)
    
    # Load the best model
    model.eval()
    
    # 1. Export as TorchScript (for production deployment)
    print("📦 Exporting TorchScript model...")
    dummy_input = (
        torch.randn(1, 3, config.input_size, config.input_size).to(device),
        torch.randn(1, 3, config.input_size, config.input_size).to(device),
        torch.randn(1, 2, config.input_size, config.input_size).to(device)
    )
    
    try:
        traced_model = torch.jit.trace(model, dummy_input)
        torchscript_path = os.path.join(export_dir, "lite_matting_torchscript.pt")
        traced_model.save(torchscript_path)
        print(f"✅ TorchScript model saved: {torchscript_path}")
    except Exception as e:
        print(f"❌ TorchScript export failed: {e}")
    
    # 2. Export model weights only
    weights_path = os.path.join(export_dir, "lite_matting_weights.pth")
    torch.save(model.state_dict(), weights_path)
    print(f"✅ Model weights saved: {weights_path}")
    
    # 3. Export complete model (architecture + weights)
    full_model_path = os.path.join(export_dir, "lite_matting_full.pth")
    torch.save(model, full_model_path)
    print(f"✅ Full model saved: {full_model_path}")
    
    # 4. Export ONNX model (for cross-platform deployment)
    print("📦 Exporting ONNX model...")
    try:
        onnx_path = os.path.join(export_dir, "lite_matting.onnx")
        torch.onnx.export(
            model,
            dummy_input,
            onnx_path,
            export_params=True,
            opset_version=11,
            do_constant_folding=True,
            input_names=['image', 'trimap', 'clicks'],
            output_names=['alpha'],
            dynamic_axes={
                'image': {0: 'batch_size'},
                'trimap': {0: 'batch_size'},
                'clicks': {0: 'batch_size'},
                'alpha': {0: 'batch_size'}
            }
        )
        print(f"✅ ONNX model saved: {onnx_path}")
    except Exception as e:
        print(f"❌ ONNX export failed: {e}")
    
    print(f"\n📁 All models exported to: {export_dir}")
    print("📋 Export summary:")
    for file in os.listdir(export_dir):
        file_path = os.path.join(export_dir, file)
        size_mb = os.path.getsize(file_path) / (1024 * 1024)
        print(f"   - {file}: {size_mb:.1f} MB")

# Export the model
export_model()

# Download models to local machine
print("\n💾 To download the trained models to your local machine:")
print("1. Navigate to the exported_models folder in Colab files")
print("2. Right-click and download each model file")
print("3. Or use the following code to zip and download:")

zip_code = '''
import zipfile
import shutil

# Create a zip file with all models
zip_path = "/content/lite_matting_models.zip"
with zipfile.ZipFile(zip_path, 'w') as zipf:
    for root, dirs, files in os.walk("/content/exported_models"):
        for file in files:
            file_path = os.path.join(root, file)
            zipf.write(file_path, file)

# Download the zip file
from google.colab import files
files.download(zip_path)
'''

print("📝 Copy and run this code to download all models as a zip file:")
print(zip_code)

## 🎉 Training Complete!

### Summary

You have successfully implemented and trained a LiteMatting model for high-quality image matting! Here's what we accomplished:

✅ **Model Architecture**: Implemented the complete MobileMatting network with:
- Multi-Scale Local Pyramid Pooling Module (MSLPPM)
- Global Feature Network Block (GFNB) 
- Efficient MobileNet-based backbone
- Skip connections for feature fusion

✅ **Training Pipeline**: Built a comprehensive training system with:
- Adobe Composition-1K dataset loading
- Data augmentation and preprocessing
- Combined loss functions (Alpha + Composition + Gradient + Laplacian)
- Learning rate scheduling and optimization
- Model checkpointing and tensorboard logging

✅ **Evaluation**: Implemented standard matting metrics:
- Sum of Absolute Differences (SAD)
- Mean Squared Error (MSE)
- Gradient Error

✅ **Deployment**: Exported models in multiple formats:
- PyTorch weights and full model
- TorchScript for production
- ONNX for cross-platform deployment

### Next Steps

1. **Dataset Access**: Download the Adobe Composition-1K dataset using your student account
2. **Training**: Run the training loop with your dataset
3. **Hyperparameter Tuning**: Experiment with different learning rates, loss weights, and architectures
4. **Production Deployment**: Use the exported models in your applications
5. **Mobile Deployment**: The model is optimized for mobile devices

### Performance Tips

- **Memory**: Reduce batch size if you encounter CUDA out of memory errors
- **Speed**: Use mixed precision training with `torch.cuda.amp` for faster training
- **Quality**: Experiment with different loss weights for your specific use case
- **Data**: Add more diverse backgrounds and subjects to improve generalization

### Resources

- [Adobe Image Matting Dataset](https://sites.google.com/view/deepimagematting)
- [Papers With Code - Image Matting](https://paperswithcode.com/task/image-matting)
- [LFPNet Paper](https://arxiv.org/abs/2109.12252) (inspiration for this work)

Happy matting! 🎨✨