# TinyEfficientViT Segmentation Training

This notebook trains a TinyEfficientViT model for eye/pupil segmentation on the OpenEDS dataset.

**Features:**
- Free GPU access via Google Colab
- Interactive experimentation and visualization
- No external infrastructure required (Modal, MLflow, etc.)

**Model Constraints:**
- Target: <60k parameters for edge deployment
- Input: 640x400 grayscale images
- Output: 2-class segmentation (background, pupil)

## 1. Setup and GPU Check

In [None]:
# Check GPU availability
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"GPU Memory: {gpu_memory:.1f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("WARNING: No GPU detected!")
    print("Go to Runtime -> Change runtime type -> GPU")
    raise RuntimeError("GPU required for training")

In [None]:
# Install dependencies
!pip install -q torch torchvision opencv-python-headless Pillow scikit-learn tqdm matplotlib datasets huggingface_hub

In [None]:
# Core imports
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import cv2
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from datasets import load_dataset

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Dataset Loading

In [None]:
# Load OpenEDS dataset from HuggingFace
HF_DATASET_REPO = "Conner/openeds-precomputed"

print(f"Loading dataset from HuggingFace: {HF_DATASET_REPO}")
print("This may take a few minutes on first run...")

hf_dataset = load_dataset(HF_DATASET_REPO)

print(f"\nDataset loaded!")
print(f"Train samples: {len(hf_dataset['train'])}")
print(f"Validation samples: {len(hf_dataset['validation'])}")

In [None]:
# Display sample images from dataset
IMAGE_HEIGHT = 400
IMAGE_WIDTH = 640

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for i in range(2):
    sample = hf_dataset['train'][i * 100]
    image = np.array(sample['image'], dtype=np.uint8).reshape(IMAGE_HEIGHT, IMAGE_WIDTH)
    label = np.array(sample['label'], dtype=np.uint8).reshape(IMAGE_HEIGHT, IMAGE_WIDTH)
    
    axes[i, 0].imshow(image, cmap='gray')
    axes[i, 0].set_title(f"Image {i+1}")
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(label, cmap='jet', vmin=0, vmax=1)
    axes[i, 1].set_title(f"Label {i+1}")
    axes[i, 1].axis('off')
    
    # Overlay
    overlay = image.copy()
    overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2RGB)
    overlay[label == 1] = [255, 0, 0]  # Pupil in red
    axes[i, 2].imshow(overlay)
    axes[i, 2].set_title(f"Overlay {i+1}")
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

## 3. Model Architecture: TinyEfficientViT

In [None]:
class TinyConvNorm(nn.Module):
    """Convolution + BatchNorm layer (parameter-efficient)."""

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, groups=1, bias=False):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        return self.bn(self.conv(x))


class TinyPatchEmbedding(nn.Module):
    """Lightweight patch embedding with 2 conv layers and stride 4."""

    def __init__(self, in_channels=1, embed_dim=8):
        super().__init__()
        mid_dim = embed_dim // 2 if embed_dim >= 4 else 2
        self.conv1 = TinyConvNorm(in_channels, mid_dim, kernel_size=3, stride=2, padding=1)
        self.act1 = nn.GELU()
        self.conv2 = TinyConvNorm(mid_dim, embed_dim, kernel_size=3, stride=2, padding=1)
        self.act2 = nn.GELU()

    def forward(self, x):
        x = self.act1(self.conv1(x))
        x = self.act2(self.conv2(x))
        return x


class TinyCascadedGroupAttention(nn.Module):
    """Tiny version of Cascaded Group Attention."""

    def __init__(self, dim, num_heads=1, key_dim=4, attn_ratio=2):
        super().__init__()
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.scale = key_dim**-0.5
        self.d = int(attn_ratio * key_dim)
        self.attn_ratio = attn_ratio

        qkv_dim = (num_heads * key_dim * 2) + (num_heads * self.d)
        self.qkv = nn.Linear(dim, qkv_dim)
        self.proj = nn.Linear(num_heads * self.d, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)

        q_total = self.num_heads * self.key_dim
        k_total = self.num_heads * self.key_dim

        q = qkv[:, :, :q_total].reshape(B, N, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
        k = qkv[:, :, q_total:q_total + k_total].reshape(B, N, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
        v = qkv[:, :, q_total + k_total:].reshape(B, N, self.num_heads, self.d).permute(0, 2, 1, 3)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, self.num_heads * self.d)
        x = self.proj(x)
        return x


class TinyLocalWindowAttention(nn.Module):
    """Local window attention wrapper."""

    def __init__(self, dim, num_heads=1, key_dim=4, attn_ratio=2, window_size=7):
        super().__init__()
        self.window_size = window_size
        self.attn = TinyCascadedGroupAttention(dim=dim, num_heads=num_heads, key_dim=key_dim, attn_ratio=attn_ratio)

    def forward(self, x):
        B, C, H, W = x.shape
        ws = self.window_size

        pad_h = (ws - H % ws) % ws
        pad_w = (ws - W % ws) % ws
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, pad_w, 0, pad_h))
        _, _, Hp, Wp = x.shape

        x = x.view(B, C, Hp // ws, ws, Wp // ws, ws)
        x = x.permute(0, 2, 4, 3, 5, 1).contiguous()
        x = x.view(B * (Hp // ws) * (Wp // ws), ws * ws, C)

        x = self.attn(x)

        x = x.view(B, Hp // ws, Wp // ws, ws, ws, C)
        x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
        x = x.view(B, C, Hp, Wp)

        if pad_h > 0 or pad_w > 0:
            x = x[:, :, :H, :W]

        return x


class TinyMLP(nn.Module):
    """Tiny MLP with expansion ratio."""

    def __init__(self, dim, expansion_ratio=2):
        super().__init__()
        hidden_dim = int(dim * expansion_ratio)
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))


class TinyEfficientVitBlock(nn.Module):
    """Single EfficientViT block."""

    def __init__(self, dim, num_heads=1, key_dim=4, attn_ratio=2, window_size=7, mlp_ratio=2):
        super().__init__()
        self.norm1 = nn.BatchNorm2d(dim)
        self.dw_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
        self.norm2 = nn.BatchNorm2d(dim)
        self.attn = TinyLocalWindowAttention(dim=dim, num_heads=num_heads, key_dim=key_dim, attn_ratio=attn_ratio, window_size=window_size)
        self.norm3 = nn.LayerNorm(dim)
        self.mlp = TinyMLP(dim, expansion_ratio=mlp_ratio)

    def forward(self, x):
        x = x + self.dw_conv(self.norm1(x))
        x = x + self.attn(self.norm2(x))

        B, C, H, W = x.shape
        x_flat = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        x_flat = x_flat + self.mlp(self.norm3(x_flat))
        x = x_flat.reshape(B, H, W, C).permute(0, 3, 1, 2)

        return x


class TinyEfficientVitStage(nn.Module):
    """Single stage of TinyEfficientViT."""

    def __init__(self, in_dim, out_dim, depth=1, num_heads=1, key_dim=4, attn_ratio=2, window_size=7, mlp_ratio=2, downsample=True):
        super().__init__()
        self.downsample = None
        if downsample:
            self.downsample = nn.Sequential(
                TinyConvNorm(in_dim, out_dim, kernel_size=3, stride=2, padding=1),
                nn.GELU(),
            )
        elif in_dim != out_dim:
            self.downsample = nn.Sequential(
                TinyConvNorm(in_dim, out_dim, kernel_size=1, stride=1, padding=0),
                nn.GELU(),
            )

        self.blocks = nn.ModuleList([
            TinyEfficientVitBlock(dim=out_dim, num_heads=num_heads, key_dim=key_dim, attn_ratio=attn_ratio, window_size=window_size, mlp_ratio=mlp_ratio)
            for _ in range(depth)
        ])

    def forward(self, x):
        if self.downsample is not None:
            x = self.downsample(x)
        for block in self.blocks:
            x = block(x)
        return x


class TinyEfficientVitEncoder(nn.Module):
    """Complete TinyEfficientViT encoder with 3 stages."""

    def __init__(self, in_channels=1, embed_dims=(8, 16, 24), depths=(1, 1, 1), num_heads=(1, 1, 2), key_dims=(4, 4, 4), attn_ratios=(2, 2, 2), window_sizes=(7, 7, 7), mlp_ratios=(2, 2, 2)):
        super().__init__()
        self.patch_embed = TinyPatchEmbedding(in_channels=in_channels, embed_dim=embed_dims[0])

        self.stage1 = TinyEfficientVitStage(in_dim=embed_dims[0], out_dim=embed_dims[0], depth=depths[0], num_heads=num_heads[0], key_dim=key_dims[0], attn_ratio=attn_ratios[0], window_size=window_sizes[0], mlp_ratio=mlp_ratios[0], downsample=False)
        self.stage2 = TinyEfficientVitStage(in_dim=embed_dims[0], out_dim=embed_dims[1], depth=depths[1], num_heads=num_heads[1], key_dim=key_dims[1], attn_ratio=attn_ratios[1], window_size=window_sizes[1], mlp_ratio=mlp_ratios[1], downsample=True)
        self.stage3 = TinyEfficientVitStage(in_dim=embed_dims[1], out_dim=embed_dims[2], depth=depths[2], num_heads=num_heads[2], key_dim=key_dims[2], attn_ratio=attn_ratios[2], window_size=window_sizes[2], mlp_ratio=mlp_ratios[2], downsample=True)

    def forward(self, x):
        x = self.patch_embed(x)
        f1 = self.stage1(x)
        f2 = self.stage2(f1)
        f3 = self.stage3(f2)
        return f1, f2, f3


class TinySegmentationDecoder(nn.Module):
    """Lightweight FPN-style decoder with skip connections."""

    def __init__(self, encoder_dims=(8, 16, 24), decoder_dim=16, num_classes=2):
        super().__init__()
        self.lateral3 = nn.Conv2d(encoder_dims[2], decoder_dim, kernel_size=1)
        self.lateral2 = nn.Conv2d(encoder_dims[1], decoder_dim, kernel_size=1)
        self.lateral1 = nn.Conv2d(encoder_dims[0], decoder_dim, kernel_size=1)

        self.smooth3 = nn.Sequential(
            nn.Conv2d(decoder_dim, decoder_dim, kernel_size=3, padding=1, groups=decoder_dim),
            nn.BatchNorm2d(decoder_dim),
            nn.GELU(),
        )
        self.smooth2 = nn.Sequential(
            nn.Conv2d(decoder_dim, decoder_dim, kernel_size=3, padding=1, groups=decoder_dim),
            nn.BatchNorm2d(decoder_dim),
            nn.GELU(),
        )
        self.smooth1 = nn.Sequential(
            nn.Conv2d(decoder_dim, decoder_dim, kernel_size=3, padding=1, groups=decoder_dim),
            nn.BatchNorm2d(decoder_dim),
            nn.GELU(),
        )

        self.head = nn.Conv2d(decoder_dim, num_classes, kernel_size=1)

    def forward(self, f1, f2, f3, target_size):
        p3 = self.lateral3(f3)
        p3 = self.smooth3(p3)

        p2 = self.lateral2(f2) + F.interpolate(p3, size=f2.shape[2:], mode="bilinear", align_corners=False)
        p2 = self.smooth2(p2)

        p1 = self.lateral1(f1) + F.interpolate(p2, size=f1.shape[2:], mode="bilinear", align_corners=False)
        p1 = self.smooth1(p1)

        out = self.head(p1)
        out = F.interpolate(out, size=target_size, mode="bilinear", align_corners=False)
        return out


class TinyEfficientViTSeg(nn.Module):
    """Complete TinyEfficientViT for semantic segmentation (<60k parameters)."""

    def __init__(self, in_channels=1, num_classes=2, embed_dims=(8, 16, 24), depths=(1, 1, 1), num_heads=(1, 1, 2), key_dims=(4, 4, 4), attn_ratios=(2, 2, 2), window_sizes=(7, 7, 7), mlp_ratios=(2, 2, 2), decoder_dim=16):
        super().__init__()
        self.encoder = TinyEfficientVitEncoder(in_channels=in_channels, embed_dims=embed_dims, depths=depths, num_heads=num_heads, key_dims=key_dims, attn_ratios=attn_ratios, window_sizes=window_sizes, mlp_ratios=mlp_ratios)
        self.decoder = TinySegmentationDecoder(encoder_dims=embed_dims, decoder_dim=decoder_dim, num_classes=num_classes)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        target_size = (x.shape[2], x.shape[3])
        f1, f2, f3 = self.encoder(x)
        out = self.decoder(f1, f2, f3, target_size)
        return out

## 4. Loss Functions and Metrics

In [None]:
class CombinedLoss(nn.Module):
    """Combined Cross-Entropy + Dice + Surface Loss."""

    def __init__(self, epsilon=1e-5):
        super(CombinedLoss, self).__init__()
        self.epsilon = epsilon
        self.nll = nn.NLLLoss(reduction="none")

    def forward(self, logits, target, spatial_weights, dist_map, alpha):
        probs = F.softmax(logits, dim=1)
        log_probs = F.log_softmax(logits, dim=1)
        ce_loss = self.nll(log_probs, target)
        weighted_ce = (ce_loss * (1.0 + spatial_weights)).mean()

        target_onehot = F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float()
        probs_flat = probs.flatten(start_dim=2)
        target_flat = target_onehot.flatten(start_dim=2)
        intersection = (probs_flat * target_flat).sum(dim=2)
        cardinality = (probs_flat + target_flat).sum(dim=2)
        class_weights = 1.0 / (target_flat.sum(dim=2) ** 2).clamp(min=self.epsilon)
        dice = 2.0 * (class_weights * intersection).sum(dim=1) / (class_weights * cardinality).sum(dim=1)
        dice_loss = (1.0 - dice.clamp(min=self.epsilon)).mean()

        surface_loss = (probs.flatten(start_dim=2) * dist_map.flatten(start_dim=2)).mean(dim=2).mean(dim=1).mean()

        total_loss = weighted_ce + alpha * dice_loss + (1.0 - alpha) * surface_loss
        return total_loss, weighted_ce, dice_loss, surface_loss


def compute_iou_tensors(predictions, targets, num_classes=2):
    """Compute IoU for each class."""
    intersection = torch.zeros(num_classes, device=predictions.device)
    union = torch.zeros(num_classes, device=predictions.device)
    for c in range(num_classes):
        pred_c = predictions == c
        target_c = targets == c
        intersection[c] = torch.logical_and(pred_c, target_c).sum().float()
        union[c] = torch.logical_or(pred_c, target_c).sum().float()
    return intersection, union


def finalize_iou(total_intersection, total_union):
    """Compute final mIoU from accumulated intersection/union."""
    iou_per_class = (total_intersection / total_union.clamp(min=1)).cpu().numpy()
    return float(np.mean(iou_per_class)), iou_per_class.tolist()


def get_predictions(output):
    """Get class predictions from model output."""
    bs, _, h, w = output.size()
    _, indices = output.max(1)
    indices = indices.view(bs, h, w)
    return indices


def get_nparams(model):
    """Count trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## 5. Data Augmentation and Dataset

In [None]:
class RandomHorizontalFlip:
    def __call__(self, img, label):
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT), label.transpose(Image.FLIP_LEFT_RIGHT)
        return img, label


class Gaussian_blur:
    def __call__(self, img):
        sigma_value = np.random.randint(2, 7)
        return Image.fromarray(cv2.GaussianBlur(img, (7, 7), sigma_value))


class Line_augment:
    def __call__(self, base):
        yc, xc = (0.3 + 0.4 * np.random.rand(1)) * np.array(base.shape)
        aug_base = np.copy(base)
        num_lines = np.random.randint(1, 10)
        for _ in np.arange(0, num_lines):
            theta = np.pi * np.random.rand(1)
            x1 = xc - 50 * np.random.rand(1) * (1 if np.random.rand(1) < 0.5 else -1)
            y1 = (x1 - xc) * np.tan(theta) + yc
            x2 = xc - (150 * np.random.rand(1) + 50) * (1 if np.random.rand(1) < 0.5 else -1)
            y2 = (x2 - xc) * np.tan(theta) + yc
            aug_base = cv2.line(aug_base, (int(x1), int(y1)), (int(x2), int(y2)), (255, 255, 255), 4)
        aug_base = aug_base.astype(np.uint8)
        return Image.fromarray(aug_base)


class MaskToTensor:
    def __call__(self, img):
        return torch.from_numpy(np.array(img, dtype=np.int64)).long()


class IrisDataset(Dataset):
    """Dataset for OpenEDS iris/pupil segmentation."""

    def __init__(self, hf_dataset, split="train", transform=None):
        self.transform = transform
        self.split = split
        self.clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
        self.gamma_table = 255.0 * (np.linspace(0, 1, 256) ** 0.8)
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image = np.array(sample["image"], dtype=np.uint8).reshape(IMAGE_HEIGHT, IMAGE_WIDTH)
        label = np.array(sample["label"], dtype=np.uint8).reshape(IMAGE_HEIGHT, IMAGE_WIDTH)
        spatial_weights = np.array(sample["spatial_weights"], dtype=np.float32).reshape(IMAGE_HEIGHT, IMAGE_WIDTH)
        dist_map = np.array(sample["dist_map"], dtype=np.float32).reshape(2, IMAGE_HEIGHT, IMAGE_WIDTH)
        filename = sample["filename"]

        pilimg = cv2.LUT(image, self.gamma_table)
        if self.transform is not None and self.split == "train":
            if random.random() < 0.2:
                pilimg = Line_augment()(np.array(pilimg))
            if random.random() < 0.2:
                pilimg = Gaussian_blur()(np.array(pilimg))
        img = self.clahe.apply(np.array(np.uint8(pilimg)))
        img = Image.fromarray(img)
        label_pil = Image.fromarray(label)

        if self.transform is not None:
            if self.split == "train":
                img, label_pil = RandomHorizontalFlip()(img, label_pil)
                if np.array(label_pil)[0, 0] != label[0, 0]:
                    spatial_weights = np.fliplr(spatial_weights).copy()
                    dist_map = np.flip(dist_map, axis=2).copy()
            img = self.transform(img)

        label_tensor = MaskToTensor()(label_pil)
        return img, label_tensor, filename, spatial_weights, dist_map

## 6. Visualization Utilities

In [None]:
def plot_training_curves(train_metrics, valid_metrics):
    """Plot training and validation curves."""
    epochs = range(1, len(train_metrics["loss"]) + 1)

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # Loss
    axes[0, 0].plot(epochs, train_metrics["loss"], "b-", label="Train", linewidth=2)
    axes[0, 0].plot(epochs, valid_metrics["loss"], "r-", label="Valid", linewidth=2)
    axes[0, 0].set_xlabel("Epoch")
    axes[0, 0].set_ylabel("Loss")
    axes[0, 0].set_title("Training and Validation Loss", fontweight="bold")
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # mIoU
    axes[0, 1].plot(epochs, train_metrics["iou"], "b-", label="Train", linewidth=2)
    axes[0, 1].plot(epochs, valid_metrics["iou"], "r-", label="Valid", linewidth=2)
    axes[0, 1].set_xlabel("Epoch")
    axes[0, 1].set_ylabel("mIoU")
    axes[0, 1].set_title("Training and Validation mIoU", fontweight="bold")
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Learning Rate
    if "lr" in train_metrics:
        axes[1, 0].plot(epochs, train_metrics["lr"], "g-", linewidth=2)
        axes[1, 0].set_xlabel("Epoch")
        axes[1, 0].set_ylabel("Learning Rate")
        axes[1, 0].set_title("Learning Rate Schedule", fontweight="bold")
        axes[1, 0].set_yscale("log")
        axes[1, 0].grid(True, alpha=0.3)

    # Per-class IoU
    if "background_iou" in valid_metrics:
        axes[1, 1].plot(epochs, valid_metrics["background_iou"], "b-", label="Background", linewidth=2)
        axes[1, 1].plot(epochs, valid_metrics["pupil_iou"], "r-", label="Pupil", linewidth=2)
        axes[1, 1].set_xlabel("Epoch")
        axes[1, 1].set_ylabel("IoU")
        axes[1, 1].set_title("Per-Class IoU (Validation)", fontweight="bold")
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()


def visualize_predictions(model, dataloader, device, num_samples=4):
    """Visualize model predictions."""
    model.eval()
    samples_collected = 0
    images_to_plot = []
    labels_to_plot = []
    preds_to_plot = []

    with torch.no_grad():
        for img, labels, _, _, _ in dataloader:
            if samples_collected >= num_samples:
                break
            single_img = img[0:1].to(device, memory_format=torch.channels_last)
            single_target = labels[0:1].to(device).long()
            output = model(single_img)
            predictions = get_predictions(output)
            images_to_plot.append(img[0].cpu().squeeze().numpy())
            labels_to_plot.append(single_target[0].cpu().numpy())
            preds_to_plot.append(predictions[0].cpu().numpy())
            samples_collected += 1

    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(num_samples):
        axes[i, 0].imshow(images_to_plot[i], cmap="gray")
        axes[i, 0].set_title("Input Image", fontweight="bold")
        axes[i, 0].axis("off")
        axes[i, 1].imshow(labels_to_plot[i], cmap="jet", vmin=0, vmax=1)
        axes[i, 1].set_title("Ground Truth", fontweight="bold")
        axes[i, 1].axis("off")
        axes[i, 2].imshow(preds_to_plot[i], cmap="jet", vmin=0, vmax=1)
        axes[i, 2].set_title("Prediction", fontweight="bold")
        axes[i, 2].axis("off")

    plt.tight_layout()
    plt.show()

## 7. Training Configuration

In [None]:
# Training hyperparameters
BATCH_SIZE = 32  # Reduced for Colab memory constraints
EPOCHS = 50
LEARNING_RATE = 1e-2
NUM_WORKERS = 2  # Colab has limited workers

# Model configuration
model = TinyEfficientViTSeg(
    in_channels=1,
    num_classes=2,
    embed_dims=(16, 32, 64),
    depths=(1, 1, 1),
    num_heads=(1, 1, 2),
    key_dims=(4, 4, 4),
    attn_ratios=(2, 2, 2),
    window_sizes=(7, 7, 7),
    mlp_ratios=(2, 2, 2),
    decoder_dim=32,
).to(device)

nparams = get_nparams(model)
print(f"Model Parameters: {nparams:,}")

if nparams >= 60000:
    print(f"WARNING: Model has {nparams} parameters, exceeds 60k limit by {nparams - 60000}")
else:
    print(f"Model is within 60k parameter budget: {nparams} < 60000")

# Convert to channels_last for better GPU performance
if torch.cuda.is_available():
    model = model.to(memory_format=torch.channels_last)

# Mixed precision training
use_amp = torch.cuda.is_available()
print(f"Mixed Precision (AMP): {use_amp}")

In [None]:
# Verify forward pass
print(f"Verifying forward pass with batch_size={BATCH_SIZE}...")
with torch.no_grad():
    test_input = torch.randn(BATCH_SIZE, 1, IMAGE_HEIGHT, IMAGE_WIDTH).to(device, memory_format=torch.channels_last)
    with torch.amp.autocast("cuda", enabled=use_amp):
        test_output = model(test_input)
    print(f"Input shape: {test_input.shape}")
    print(f"Output shape: {test_output.shape}")
    assert test_output.shape == (BATCH_SIZE, 2, IMAGE_HEIGHT, IMAGE_WIDTH), f"Output shape mismatch!"
    print("Forward pass verification: PASSED")

In [None]:
# Setup optimizer, scheduler, loss
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=5)
criterion = CombinedLoss()
scaler = torch.amp.GradScaler("cuda") if use_amp else None

# Data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

# Datasets
train_dataset = IrisDataset(hf_dataset["train"], split="train", transform=transform)
valid_dataset = IrisDataset(hf_dataset["validation"], split="validation", transform=transform)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(valid_dataset)}")

# DataLoaders
trainloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=True,
)
validloader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

# Alpha schedule for loss weighting
alpha = np.zeros(EPOCHS)
alpha[0:min(125, EPOCHS)] = 1 - np.arange(1, min(125, EPOCHS) + 1) / min(125, EPOCHS)
if EPOCHS > 125:
    alpha[125:] = 1

print(f"\n{'='*60}")
print("Training Configuration:")
print(f"{'='*60}")
print(f"  Model: TinyEfficientViT")
print(f"  Parameters: {nparams:,}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Mixed Precision (AMP): {use_amp}")
print(f"{'='*60}")

## 8. Training Loop

In [None]:
# Initialize metrics tracking
train_metrics = {
    "loss": [], "iou": [], "ce_loss": [], "dice_loss": [], "surface_loss": [],
    "alpha": [], "lr": [], "background_iou": [], "pupil_iou": [],
}
valid_metrics = {
    "loss": [], "iou": [], "ce_loss": [], "dice_loss": [], "surface_loss": [],
    "background_iou": [], "pupil_iou": [],
}
best_valid_iou = 0.0
best_epoch = 0

print("\n" + "=" * 60)
print("Starting training")
print("=" * 60)

for epoch in range(EPOCHS):
    # Training phase
    model.train()
    train_loss_sum = torch.tensor(0.0, device=device)
    train_ce_sum = torch.tensor(0.0, device=device)
    train_dice_sum = torch.tensor(0.0, device=device)
    train_surface_sum = torch.tensor(0.0, device=device)
    train_batch_count = 0
    train_intersection = torch.zeros(2, device=device)
    train_union = torch.zeros(2, device=device)

    pbar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batchdata in pbar:
        img, labels, _, spatialWeights, maxDist = batchdata
        data = img.to(device, non_blocking=True, memory_format=torch.channels_last)
        target = labels.to(device, non_blocking=True).long()
        spatial_weights_gpu = spatialWeights.to(device, non_blocking=True).float()
        dist_map_gpu = maxDist.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        if use_amp:
            with torch.amp.autocast("cuda"):
                output = model(data)
                total_loss, ce_loss, dice_loss, surface_loss = criterion(
                    output, target, spatial_weights_gpu, dist_map_gpu, alpha[epoch]
                )
            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(data)
            total_loss, ce_loss, dice_loss, surface_loss = criterion(
                output, target, spatial_weights_gpu, dist_map_gpu, alpha[epoch]
            )
            total_loss.backward()
            optimizer.step()

        train_loss_sum += total_loss.detach()
        train_ce_sum += ce_loss.detach()
        train_dice_sum += dice_loss.detach()
        train_surface_sum += surface_loss.detach()
        train_batch_count += 1

        predict = get_predictions(output)
        batch_intersection, batch_union = compute_iou_tensors(predict, target)
        train_intersection += batch_intersection
        train_union += batch_union

    miou_train, per_class_ious_train = finalize_iou(train_intersection, train_union)
    bg_iou_train, pupil_iou_train = per_class_ious_train[0], per_class_ious_train[1]
    loss_train = (train_loss_sum / train_batch_count).item()
    ce_loss_train = (train_ce_sum / train_batch_count).item()
    dice_loss_train = (train_dice_sum / train_batch_count).item()
    surface_loss_train = (train_surface_sum / train_batch_count).item()

    train_metrics["loss"].append(loss_train)
    train_metrics["iou"].append(miou_train)
    train_metrics["ce_loss"].append(ce_loss_train)
    train_metrics["dice_loss"].append(dice_loss_train)
    train_metrics["surface_loss"].append(surface_loss_train)
    train_metrics["alpha"].append(alpha[epoch])
    train_metrics["lr"].append(optimizer.param_groups[0]["lr"])
    train_metrics["background_iou"].append(bg_iou_train)
    train_metrics["pupil_iou"].append(pupil_iou_train)

    # Validation phase
    model.eval()
    valid_loss_sum = torch.tensor(0.0, device=device)
    valid_ce_sum = torch.tensor(0.0, device=device)
    valid_dice_sum = torch.tensor(0.0, device=device)
    valid_surface_sum = torch.tensor(0.0, device=device)
    valid_batch_count = 0
    valid_intersection = torch.zeros(2, device=device)
    valid_union = torch.zeros(2, device=device)

    with torch.no_grad():
        for batchdata in validloader:
            img, labels, _, spatialWeights, maxDist = batchdata
            data = img.to(device, non_blocking=True, memory_format=torch.channels_last)
            target = labels.to(device, non_blocking=True).long()
            spatial_weights_gpu = spatialWeights.to(device, non_blocking=True).float()
            dist_map_gpu = maxDist.to(device, non_blocking=True)

            if use_amp:
                with torch.amp.autocast("cuda"):
                    output = model(data)
                    total_loss, ce_loss, dice_loss, surface_loss = criterion(
                        output, target, spatial_weights_gpu, dist_map_gpu, alpha[epoch]
                    )
            else:
                output = model(data)
                total_loss, ce_loss, dice_loss, surface_loss = criterion(
                    output, target, spatial_weights_gpu, dist_map_gpu, alpha[epoch]
                )

            valid_loss_sum += total_loss.detach()
            valid_ce_sum += ce_loss.detach()
            valid_dice_sum += dice_loss.detach()
            valid_surface_sum += surface_loss.detach()
            valid_batch_count += 1

            predict = get_predictions(output)
            batch_intersection, batch_union = compute_iou_tensors(predict, target)
            valid_intersection += batch_intersection
            valid_union += batch_union

    miou_valid, per_class_ious_valid = finalize_iou(valid_intersection, valid_union)
    bg_iou_valid, pupil_iou_valid = per_class_ious_valid[0], per_class_ious_valid[1]
    loss_valid = (valid_loss_sum / valid_batch_count).item()
    ce_loss_valid = (valid_ce_sum / valid_batch_count).item()
    dice_loss_valid = (valid_dice_sum / valid_batch_count).item()
    surface_loss_valid = (valid_surface_sum / valid_batch_count).item()

    valid_metrics["loss"].append(loss_valid)
    valid_metrics["iou"].append(miou_valid)
    valid_metrics["ce_loss"].append(ce_loss_valid)
    valid_metrics["dice_loss"].append(dice_loss_valid)
    valid_metrics["surface_loss"].append(surface_loss_valid)
    valid_metrics["background_iou"].append(bg_iou_valid)
    valid_metrics["pupil_iou"].append(pupil_iou_valid)

    scheduler.step(loss_valid)

    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print(f"Train Loss: {loss_train:.4f} | Valid Loss: {loss_valid:.4f}")
    print(f"Train mIoU: {miou_train:.4f} | Valid mIoU: {miou_valid:.4f}")
    print(f"Valid BG IoU: {bg_iou_valid:.4f} | Valid Pupil IoU: {pupil_iou_valid:.4f}")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

    if miou_valid > best_valid_iou:
        best_valid_iou = miou_valid
        best_epoch = epoch + 1
        torch.save(model.state_dict(), "best_efficientvit_model.pt")
        print(f"New best model! Valid mIoU: {best_valid_iou:.4f}")

print("\n" + "=" * 60)
print("Training completed!")
print(f"Final validation mIoU: {miou_valid:.4f}")
print(f"Best validation mIoU: {best_valid_iou:.4f} (epoch {best_epoch})")
print("=" * 60)

## 9. Results Visualization

In [None]:
# Plot training curves
plot_training_curves(train_metrics, valid_metrics)

In [None]:
# Visualize predictions
model.load_state_dict(torch.load("best_efficientvit_model.pt"))
visualize_predictions(model, validloader, device, num_samples=4)

In [None]:
# Print final metrics summary
print("\n" + "=" * 60)
print("FINAL METRICS SUMMARY")
print("=" * 60)
print(f"Model Parameters: {nparams:,}")
print(f"Parameter Budget: {'PASSED' if nparams < 60000 else 'EXCEEDED'} (<60k)")
print(f"")
print(f"Best Validation mIoU: {best_valid_iou:.4f} (epoch {best_epoch})")
print(f"Final Validation mIoU: {valid_metrics['iou'][-1]:.4f}")
print(f"Final Train mIoU: {train_metrics['iou'][-1]:.4f}")
print(f"")
print(f"Final Validation Loss: {valid_metrics['loss'][-1]:.4f}")
print(f"Final Train Loss: {train_metrics['loss'][-1]:.4f}")
print("=" * 60)

## 10. Export Model

In [None]:
# Save final checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': EPOCHS,
    'best_valid_iou': best_valid_iou,
    'train_metrics': train_metrics,
    'valid_metrics': valid_metrics,
    'nparams': nparams,
}, 'efficientvit_checkpoint.pt')

print("Checkpoint saved: efficientvit_checkpoint.pt")

In [None]:
# Export to ONNX
model.eval()
model_cpu = model.to("cpu").to(memory_format=torch.contiguous_format)

dummy_input = torch.randn(1, 1, IMAGE_HEIGHT, IMAGE_WIDTH)

torch.onnx.export(
    model_cpu,
    dummy_input,
    "efficientvit_model.onnx",
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)

print("ONNX model exported: efficientvit_model.onnx")
print(f"ONNX file size: {os.path.getsize('efficientvit_model.onnx') / 1024:.2f} KB")

In [None]:
# Download files (run this cell to get download links in Colab)
from google.colab import files

print("Downloading model files...")
files.download('best_efficientvit_model.pt')
files.download('efficientvit_checkpoint.pt')
files.download('efficientvit_model.onnx')

## Optional: Save to Google Drive

In [None]:
# Mount Google Drive and save models (optional)
from google.colab import drive
drive.mount('/content/drive')

import shutil
save_dir = '/content/drive/MyDrive/efficientvit_models'
os.makedirs(save_dir, exist_ok=True)

shutil.copy('best_efficientvit_model.pt', save_dir)
shutil.copy('efficientvit_checkpoint.pt', save_dir)
shutil.copy('efficientvit_model.onnx', save_dir)

print(f"Models saved to Google Drive: {save_dir}")