# Ellipse Regression Training - Google Colab

This notebook trains an ellipse regression model for pupil detection on the OpenEDS dataset.
The model predicts ellipse parameters (center x, center y, radius x, radius y) instead of full semantic segmentation.

**Requirements:**
- GPU runtime (T4 or better recommended)
- ~8GB GPU memory

**Dataset:** [Conner/openeds-precomputed](https://huggingface.co/datasets/Conner/openeds-precomputed)

## 1. Setup and GPU Check

In [None]:
# Check GPU availability
import subprocess

try:
    result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
    print(result.stdout)
except FileNotFoundError:
    print("WARNING: No GPU detected!")
    print("Go to Runtime -> Change runtime type -> Hardware accelerator -> GPU")

import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Install dependencies (run once)
!pip install -q torch torchvision opencv-python-headless datasets huggingface_hub tqdm matplotlib pillow scikit-learn

In [None]:
# Imports
import os
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import cv2
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from datasets import load_dataset, load_from_disk

# Set random seeds for reproducibility
torch.manual_seed(12)
np.random.seed(12)
random.seed(12)

print("Imports complete!")

## 2. Constants and Configuration

In [None]:
# Dataset and image configuration
HF_DATASET_REPO = "Conner/openeds-precomputed"
IMAGE_HEIGHT = 400
IMAGE_WIDTH = 640

# Normalization factors for ellipse parameters
MAX_RADIUS = math.sqrt(IMAGE_WIDTH**2 + IMAGE_HEIGHT**2) / 2

# Training hyperparameters
BATCH_SIZE = 16  # Reduced for Colab free tier (increase to 32 if using Colab Pro)
EPOCHS = 15
LEARNING_RATE = 1e-3
NUM_WORKERS = 2  # Colab has limited CPU cores

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Performance optimizations
USE_AMP = torch.cuda.is_available()  # Automatic mixed precision
USE_CHANNELS_LAST = True

if torch.cuda.is_available():
    torch.cuda.manual_seed(12)
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

## 3. Load Dataset from HuggingFace

In [None]:
# Cache path for Colab persistent storage
CACHE_DIR = "/content/dataset_cache"
CACHE_MARKER = os.path.join(CACHE_DIR, ".cache_complete")

# Check if running in Colab with Google Drive mounted
DRIVE_CACHE = "/content/drive/MyDrive/openeds_cache"
USE_DRIVE_CACHE = os.path.exists("/content/drive")

if USE_DRIVE_CACHE:
    print("Google Drive detected! Using Drive for persistent cache.")
    CACHE_DIR = DRIVE_CACHE
    CACHE_MARKER = os.path.join(CACHE_DIR, ".cache_complete")
else:
    print("No Google Drive mounted. Cache will be lost on runtime disconnect.")
    print("To enable persistent cache, run: from google.colab import drive; drive.mount('/content/drive')")

In [None]:
# Load dataset (with caching)
if os.path.exists(CACHE_MARKER):
    print(f"Loading cached dataset from: {CACHE_DIR}")
    try:
        hf_dataset = load_from_disk(CACHE_DIR)
        print("Loaded from cache!")
    except Exception as e:
        print(f"Cache corrupted, re-downloading: {e}")
        import shutil
        shutil.rmtree(CACHE_DIR, ignore_errors=True)
        hf_dataset = None
else:
    hf_dataset = None

if hf_dataset is None:
    print(f"Downloading from HuggingFace: {HF_DATASET_REPO}")
    print("This may take 10-20 minutes on first run...")
    hf_dataset = load_dataset(HF_DATASET_REPO)
    
    # Save to cache
    os.makedirs(CACHE_DIR, exist_ok=True)
    hf_dataset.save_to_disk(CACHE_DIR)
    with open(CACHE_MARKER, "w") as f:
        f.write(f"Cached from {HF_DATASET_REPO}\n")
    print("Dataset cached!")

print(f"\nTrain samples: {len(hf_dataset['train'])}")
print(f"Validation samples: {len(hf_dataset['validation'])}")

In [None]:
# Display sample images from dataset
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(4):
    sample = hf_dataset['train'][i]
    image = np.array(sample['image'], dtype=np.uint8).reshape(IMAGE_HEIGHT, IMAGE_WIDTH)
    label = np.array(sample['label'], dtype=np.uint8).reshape(IMAGE_HEIGHT, IMAGE_WIDTH)
    
    axes[0, i].imshow(image, cmap='gray')
    axes[0, i].set_title(f"Image {i}")
    axes[0, i].axis('off')
    
    axes[1, i].imshow(label, cmap='jet', vmin=0, vmax=1)
    axes[1, i].set_title(f"Label {i}")
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()
print("Sample images from the training set")

## 4. Ellipse Parameter Utilities

In [None]:
def extract_ellipse_params(mask):
    """
    Extract ellipse parameters from a binary mask.
    Returns: (cx, cy, rx, ry, angle) where:
        - cx, cy: center coordinates (pixels)
        - rx, ry: semi-axes lengths (pixels)
        - angle: rotation angle in degrees
    If no valid contour found, returns zeros.
    """
    contours, _ = cv2.findContours(
        mask.astype(np.uint8),
        cv2.RETR_EXTERNAL,
        cv2.CHAIN_APPROX_SIMPLE,
    )

    if len(contours) == 0:
        return 0.0, 0.0, 0.0, 0.0, 0.0

    largest_contour = max(contours, key=cv2.contourArea)

    if len(largest_contour) < 5:
        M = cv2.moments(largest_contour)
        if M["m00"] > 0:
            cx = M["m10"] / M["m00"]
            cy = M["m01"] / M["m00"]
            area = cv2.contourArea(largest_contour)
            radius = math.sqrt(area / math.pi)
            return cx, cy, radius, radius, 0.0
        return 0.0, 0.0, 0.0, 0.0, 0.0

    try:
        ellipse = cv2.fitEllipse(largest_contour)
        (cx, cy), (width, height), angle = ellipse
        rx = width / 2.0
        ry = height / 2.0
        return cx, cy, rx, ry, angle
    except cv2.error:
        M = cv2.moments(largest_contour)
        if M["m00"] > 0:
            cx = M["m10"] / M["m00"]
            cy = M["m01"] / M["m00"]
            area = cv2.contourArea(largest_contour)
            radius = math.sqrt(area / math.pi)
            return cx, cy, radius, radius, 0.0
        return 0.0, 0.0, 0.0, 0.0, 0.0


def normalize_ellipse_params(cx, cy, rx, ry):
    """Normalize ellipse parameters to [0, 1] range."""
    cx_norm = cx / IMAGE_WIDTH
    cy_norm = cy / IMAGE_HEIGHT
    rx_norm = rx / MAX_RADIUS
    ry_norm = ry / MAX_RADIUS
    return cx_norm, cy_norm, rx_norm, ry_norm


def denormalize_ellipse_params(cx_norm, cy_norm, rx_norm, ry_norm):
    """Denormalize ellipse parameters back to pixel values."""
    cx = cx_norm * IMAGE_WIDTH
    cy = cy_norm * IMAGE_HEIGHT
    rx = rx_norm * MAX_RADIUS
    ry = ry_norm * MAX_RADIUS
    return cx, cy, rx, ry


def render_ellipse_mask(cx, cy, rx, ry, height=IMAGE_HEIGHT, width=IMAGE_WIDTH):
    """Render an ellipse mask from parameters."""
    mask = np.zeros((height, width), dtype=np.uint8)
    if rx > 0 and ry > 0:
        cv2.ellipse(
            mask,
            center=(int(round(cx)), int(round(cy))),
            axes=(int(round(rx)), int(round(ry))),
            angle=0,
            startAngle=0,
            endAngle=360,
            color=1,
            thickness=-1,
        )
    return mask

print("Ellipse utilities defined!")

## 5. Loss Function and Metrics

In [None]:
class EllipseRegressionLoss(nn.Module):
    """
    Combined loss for ellipse regression:
    - Smooth L1 for center prediction
    - Smooth L1 for radii prediction
    - Optional IoU loss computed by rendering ellipses
    """

    def __init__(self, center_weight=1.0, radius_weight=1.0, iou_weight=0.5):
        super(EllipseRegressionLoss, self).__init__()
        self.center_weight = center_weight
        self.radius_weight = radius_weight
        self.iou_weight = iou_weight
        self.smooth_l1 = nn.SmoothL1Loss(reduction="mean")

    def forward(self, pred, target, compute_iou=True):
        """
        pred: (B, 4) - cx, cy, rx, ry (normalized)
        target: (B, 4) - cx, cy, rx, ry (normalized)
        """
        center_loss = self.smooth_l1(pred[:, :2], target[:, :2])
        radius_loss = self.smooth_l1(pred[:, 2:], target[:, 2:])

        total_loss = (
            self.center_weight * center_loss + self.radius_weight * radius_loss
        )

        if self.iou_weight > 0 and compute_iou:
            param_dist = torch.mean((pred - target) ** 2, dim=1)
            iou_proxy_loss = torch.mean(param_dist)
            total_loss = total_loss + self.iou_weight * iou_proxy_loss

        return total_loss, center_loss, radius_loss


def compute_center_error(pred, target):
    """Compute mean center error in pixels."""
    pred_cx = pred[:, 0] * IMAGE_WIDTH
    pred_cy = pred[:, 1] * IMAGE_HEIGHT
    target_cx = target[:, 0] * IMAGE_WIDTH
    target_cy = target[:, 1] * IMAGE_HEIGHT
    dist = torch.sqrt((pred_cx - target_cx) ** 2 + (pred_cy - target_cy) ** 2)
    return dist.mean().item()


def compute_radius_error(pred, target):
    """Compute mean radius error in pixels."""
    pred_rx = pred[:, 2] * MAX_RADIUS
    pred_ry = pred[:, 3] * MAX_RADIUS
    target_rx = target[:, 2] * MAX_RADIUS
    target_ry = target[:, 3] * MAX_RADIUS
    rx_error = torch.abs(pred_rx - target_rx)
    ry_error = torch.abs(pred_ry - target_ry)
    return ((rx_error + ry_error) / 2).mean().item()


def compute_iou_with_gt_mask(pred, gt_masks, device):
    """
    Compute IoU between predicted ellipses and ground truth masks.
    gt_masks: (B, H, W) ground truth binary masks
    """
    pred_np = pred.detach().cpu().numpy()
    gt_masks_np = gt_masks.cpu().numpy()

    batch_size = pred_np.shape[0]
    ious_bg = []
    ious_pupil = []

    for i in range(batch_size):
        pred_cx, pred_cy, pred_rx, pred_ry = denormalize_ellipse_params(
            pred_np[i, 0], pred_np[i, 1], pred_np[i, 2], pred_np[i, 3]
        )

        pred_mask = render_ellipse_mask(pred_cx, pred_cy, pred_rx, pred_ry)
        target_mask = gt_masks_np[i]

        # Pupil IoU (class 1)
        pred_pupil = pred_mask == 1
        target_pupil = target_mask == 1
        intersection_pupil = np.logical_and(pred_pupil, target_pupil).sum()
        union_pupil = np.logical_or(pred_pupil, target_pupil).sum()
        iou_pupil = intersection_pupil / max(union_pupil, 1)
        ious_pupil.append(iou_pupil)

        # Background IoU (class 0)
        pred_bg = pred_mask == 0
        target_bg = target_mask == 0
        intersection_bg = np.logical_and(pred_bg, target_bg).sum()
        union_bg = np.logical_or(pred_bg, target_bg).sum()
        iou_bg = intersection_bg / max(union_bg, 1)
        ious_bg.append(iou_bg)

    mean_bg_iou = np.mean(ious_bg)
    mean_pupil_iou = np.mean(ious_pupil)
    mean_iou = (mean_bg_iou + mean_pupil_iou) / 2

    return mean_iou, mean_bg_iou, mean_pupil_iou


print("Loss and metrics defined!")

## 6. Model Architecture

In [None]:
class DownBlock(nn.Module):
    """Encoder block with depthwise separable convolutions."""

    def __init__(self, input_channels, output_channels, down_size, dropout=False, prob=0):
        super(DownBlock, self).__init__()
        self.depthwise_conv1 = nn.Conv2d(
            input_channels, input_channels, kernel_size=3, padding=1, groups=input_channels
        )
        self.pointwise_conv1 = nn.Conv2d(input_channels, output_channels, kernel_size=1)
        self.conv21 = nn.Conv2d(
            input_channels + output_channels, output_channels, kernel_size=1, padding=0
        )
        self.depthwise_conv22 = nn.Conv2d(
            output_channels, output_channels, kernel_size=3, padding=1, groups=output_channels
        )
        self.pointwise_conv22 = nn.Conv2d(output_channels, output_channels, kernel_size=1)
        self.conv31 = nn.Conv2d(
            input_channels + 2 * output_channels, output_channels, kernel_size=1, padding=0
        )
        self.depthwise_conv32 = nn.Conv2d(
            output_channels, output_channels, kernel_size=3, padding=1, groups=output_channels
        )
        self.pointwise_conv32 = nn.Conv2d(output_channels, output_channels, kernel_size=1)
        self.max_pool = nn.AvgPool2d(kernel_size=down_size) if down_size else None
        self.relu = nn.LeakyReLU()
        self.down_size = down_size
        self.dropout = dropout
        self.dropout1 = nn.Dropout(p=prob)
        self.dropout2 = nn.Dropout(p=prob)
        self.dropout3 = nn.Dropout(p=prob)
        self.bn = torch.nn.BatchNorm2d(num_features=output_channels)

    def forward(self, x):
        if self.max_pool is not None:
            x = self.max_pool(x)

        if self.dropout:
            x1 = self.relu(self.dropout1(self.pointwise_conv1(self.depthwise_conv1(x))))
            x21 = torch.cat((x, x1), dim=1)
            x22 = self.relu(
                self.dropout2(self.pointwise_conv22(self.depthwise_conv22(self.conv21(x21))))
            )
            x31 = torch.cat((x21, x22), dim=1)
            out = self.relu(
                self.dropout3(self.pointwise_conv32(self.depthwise_conv32(self.conv31(x31))))
            )
        else:
            x1 = self.relu(self.pointwise_conv1(self.depthwise_conv1(x)))
            x21 = torch.cat((x, x1), dim=1)
            x22 = self.relu(self.pointwise_conv22(self.depthwise_conv22(self.conv21(x21))))
            x31 = torch.cat((x21, x22), dim=1)
            out = self.relu(self.pointwise_conv32(self.depthwise_conv32(self.conv31(x31))))

        return self.bn(out)


class EllipseRegressionNet(nn.Module):
    """
    Lightweight CNN for ellipse parameter regression.
    Predicts: (cx, cy, rx, ry) - center and semi-axes of pupil ellipse.
    """

    def __init__(self, in_channels=1, channel_size=32, dropout=False, prob=0):
        super(EllipseRegressionNet, self).__init__()

        # Encoder blocks
        self.down_block1 = DownBlock(
            input_channels=in_channels, output_channels=channel_size,
            down_size=None, dropout=dropout, prob=prob,
        )
        self.down_block2 = DownBlock(
            input_channels=channel_size, output_channels=channel_size,
            down_size=(2, 2), dropout=dropout, prob=prob,
        )
        self.down_block3 = DownBlock(
            input_channels=channel_size, output_channels=channel_size * 2,
            down_size=(2, 2), dropout=dropout, prob=prob,
        )
        self.down_block4 = DownBlock(
            input_channels=channel_size * 2, output_channels=channel_size * 2,
            down_size=(2, 2), dropout=dropout, prob=prob,
        )

        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        # Regression head
        fc_input_size = channel_size * 2
        self.fc = nn.Sequential(
            nn.Linear(fc_input_size, 128),
            nn.LeakyReLU(),
            nn.Dropout(p=prob) if dropout else nn.Identity(),
            nn.Linear(128, 64),
            nn.LeakyReLU(),
            nn.Dropout(p=prob) if dropout else nn.Identity(),
            nn.Linear(64, 4),  # cx, cy, rx, ry
            nn.Sigmoid(),  # Output in [0, 1] range
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.groups == m.in_channels and m.in_channels == m.out_channels:
                    n = m.kernel_size[0] * m.kernel_size[1]
                    m.weight.data.normal_(0, math.sqrt(2.0 / n))
                elif m.kernel_size == (1, 1):
                    n = m.in_channels
                    m.weight.data.normal_(0, math.sqrt(2.0 / n))
                else:
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2.0 / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x):
        x = self.down_block1(x)
        x = self.down_block2(x)
        x = self.down_block3(x)
        x = self.down_block4(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        params = self.fc(x)
        return params


def get_nparams(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Create and inspect model
model = EllipseRegressionNet(in_channels=1, channel_size=32, dropout=True, prob=0.2).to(device)
nparams = get_nparams(model)
print(f"Model parameters: {nparams:,}")
print(f"Model size: {nparams * 4 / 1024 / 1024:.2f} MB")

if USE_CHANNELS_LAST:
    model = model.to(memory_format=torch.channels_last)
    print("Model converted to channels_last memory format")

In [None]:
# Verify forward pass
print(f"Verifying forward pass with batch_size={BATCH_SIZE}...")
with torch.no_grad():
    test_input = torch.randn(BATCH_SIZE, 1, IMAGE_HEIGHT, IMAGE_WIDTH).to(device)
    if USE_CHANNELS_LAST:
        test_input = test_input.to(memory_format=torch.channels_last)
    with torch.amp.autocast("cuda", enabled=USE_AMP):
        test_output = model(test_input)
    print(f"Input shape: {test_input.shape}")
    print(f"Output shape: {test_output.shape}")
    assert test_output.shape == (BATCH_SIZE, 4), f"Expected ({BATCH_SIZE}, 4), got {test_output.shape}"
    print("Forward pass verification: PASSED")

## 7. Data Augmentation and Dataset

In [None]:
class RandomHorizontalFlip:
    def __call__(self, img, label):
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT), label.transpose(Image.FLIP_LEFT_RIGHT)
        return img, label


class Gaussian_blur:
    def __call__(self, img):
        sigma_value = np.random.randint(2, 7)
        return Image.fromarray(cv2.GaussianBlur(np.array(img), (7, 7), sigma_value))


class Line_augment:
    def __call__(self, base):
        yc, xc = (0.3 + 0.4 * np.random.rand(1)) * np.array(base.shape)
        aug_base = np.copy(base)
        num_lines = np.random.randint(1, 10)
        for _ in np.arange(0, num_lines):
            theta = np.pi * np.random.rand(1)
            x1 = xc - 50 * np.random.rand(1) * (1 if np.random.rand(1) < 0.5 else -1)
            y1 = (x1 - xc) * np.tan(theta) + yc
            x2 = xc - (150 * np.random.rand(1) + 50) * (1 if np.random.rand(1) < 0.5 else -1)
            y2 = (x2 - xc) * np.tan(theta) + yc
            aug_base = cv2.line(
                aug_base, (int(x1), int(y1)), (int(x2), int(y2)), (255, 255, 255), 4
            )
        aug_base = aug_base.astype(np.uint8)
        return Image.fromarray(aug_base)


class MaskToTensor:
    def __call__(self, img):
        return torch.from_numpy(np.array(img, dtype=np.int64)).long()


class IrisDataset(Dataset):
    def __init__(self, hf_dataset, split="train", transform=None):
        self.transform = transform
        self.split = split
        self.clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
        self.gamma_table = 255.0 * (np.linspace(0, 1, 256) ** 0.8)
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]

        image = np.array(sample["image"], dtype=np.uint8).reshape(IMAGE_HEIGHT, IMAGE_WIDTH)
        label = np.array(sample["label"], dtype=np.uint8).reshape(IMAGE_HEIGHT, IMAGE_WIDTH)
        spatial_weights = np.array(sample["spatial_weights"], dtype=np.float32).reshape(
            IMAGE_HEIGHT, IMAGE_WIDTH
        )
        dist_map = np.array(sample["dist_map"], dtype=np.float32).reshape(
            2, IMAGE_HEIGHT, IMAGE_WIDTH
        )
        filename = sample["filename"]

        # Extract ellipse parameters from mask
        cx, cy, rx, ry, _ = extract_ellipse_params(label)
        cx_norm, cy_norm, rx_norm, ry_norm = normalize_ellipse_params(cx, cy, rx, ry)
        ellipse_params = torch.tensor([cx_norm, cy_norm, rx_norm, ry_norm], dtype=torch.float32)

        # Image preprocessing
        pilimg = cv2.LUT(image, self.gamma_table)

        if self.transform is not None and self.split == "train":
            if random.random() < 0.2:
                pilimg = Line_augment()(np.array(pilimg))
            if random.random() < 0.2:
                pilimg = Gaussian_blur()(np.array(pilimg))

        img = self.clahe.apply(np.array(np.uint8(pilimg)))
        img = Image.fromarray(img)
        label_pil = Image.fromarray(label)

        if self.transform is not None:
            if self.split == "train":
                img, label_pil = RandomHorizontalFlip()(img, label_pil)
                # Check if flipped
                if np.array(label_pil)[0, 0] != label[0, 0]:
                    spatial_weights = np.fliplr(spatial_weights).copy()
                    dist_map = np.flip(dist_map, axis=2).copy()
                    # Flip ellipse center x coordinate
                    cx_norm = 1.0 - cx_norm
                    ellipse_params = torch.tensor(
                        [cx_norm, cy_norm, rx_norm, ry_norm], dtype=torch.float32
                    )

            img = self.transform(img)

        label_tensor = MaskToTensor()(label_pil)

        return (img, label_tensor, ellipse_params, filename, spatial_weights, dist_map)


print("Dataset class defined!")

In [None]:
# Create datasets and dataloaders
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

train_dataset = IrisDataset(hf_dataset["train"], split="train", transform=transform)
valid_dataset = IrisDataset(hf_dataset["validation"], split="validation", transform=transform)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(valid_dataset)}")

trainloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=torch.cuda.is_available(),
    drop_last=True,
)

validloader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=torch.cuda.is_available(),
)

print(f"\nTraining batches: {len(trainloader)}")
print(f"Validation batches: {len(validloader)}")

## 8. Training Loop

In [None]:
# Initialize optimizer, scheduler, and loss
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=5)
criterion = EllipseRegressionLoss(center_weight=1.0, radius_weight=1.0, iou_weight=0.5)
scaler = torch.amp.GradScaler("cuda") if USE_AMP else None

# Metrics storage
train_metrics = {
    "loss": [], "iou": [], "center_loss": [], "radius_loss": [],
    "center_error": [], "radius_error": [], "lr": [],
}
valid_metrics = {
    "loss": [], "iou": [], "center_loss": [], "radius_loss": [],
    "center_error": [], "radius_error": [],
}

best_valid_iou = 0.0
best_epoch = 0

print("\n" + "=" * 60)
print("Training Configuration:")
print("=" * 60)
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Mixed Precision (AMP): {USE_AMP}")
print(f"  Channels Last: {USE_CHANNELS_LAST}")
print("=" * 60)

In [None]:
# Training loop
print("\nStarting training...")

for epoch in range(EPOCHS):
    # Training phase
    model.train()
    train_loss_sum = 0.0
    train_center_loss_sum = 0.0
    train_radius_loss_sum = 0.0
    train_batch_count = 0
    train_center_errors = []
    train_radius_errors = []
    train_ious = []

    pbar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batchdata in pbar:
        img, labels, ellipse_params, _, _, _ = batchdata

        data = img.to(device, non_blocking=True)
        if USE_CHANNELS_LAST:
            data = data.to(memory_format=torch.channels_last)
        target_params = ellipse_params.to(device, non_blocking=True)
        target_labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        if USE_AMP:
            with torch.amp.autocast("cuda"):
                output = model(data)
                total_loss, center_loss, radius_loss = criterion(output, target_params)
            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(data)
            total_loss, center_loss, radius_loss = criterion(output, target_params)
            total_loss.backward()
            optimizer.step()

        train_loss_sum += total_loss.item()
        train_center_loss_sum += center_loss.item()
        train_radius_loss_sum += radius_loss.item()
        train_batch_count += 1

        train_center_errors.append(compute_center_error(output, target_params))
        train_radius_errors.append(compute_radius_error(output, target_params))
        miou, _, _ = compute_iou_with_gt_mask(output, target_labels, device)
        train_ious.append(miou)

        pbar.set_postfix({"loss": f"{total_loss.item():.4f}", "iou": f"{miou:.4f}"})

    # Training metrics
    loss_train = train_loss_sum / train_batch_count
    miou_train = np.mean(train_ious)
    train_metrics["loss"].append(loss_train)
    train_metrics["iou"].append(miou_train)
    train_metrics["center_loss"].append(train_center_loss_sum / train_batch_count)
    train_metrics["radius_loss"].append(train_radius_loss_sum / train_batch_count)
    train_metrics["center_error"].append(np.mean(train_center_errors))
    train_metrics["radius_error"].append(np.mean(train_radius_errors))
    train_metrics["lr"].append(optimizer.param_groups[0]["lr"])

    # Validation phase
    model.eval()
    valid_loss_sum = 0.0
    valid_center_loss_sum = 0.0
    valid_radius_loss_sum = 0.0
    valid_batch_count = 0
    valid_center_errors = []
    valid_radius_errors = []
    valid_ious = []

    with torch.no_grad():
        for batchdata in validloader:
            img, labels, ellipse_params, _, _, _ = batchdata

            data = img.to(device, non_blocking=True)
            if USE_CHANNELS_LAST:
                data = data.to(memory_format=torch.channels_last)
            target_params = ellipse_params.to(device, non_blocking=True)
            target_labels = labels.to(device, non_blocking=True)

            if USE_AMP:
                with torch.amp.autocast("cuda"):
                    output = model(data)
                    total_loss, center_loss, radius_loss = criterion(
                        output, target_params, compute_iou=False
                    )
            else:
                output = model(data)
                total_loss, center_loss, radius_loss = criterion(
                    output, target_params, compute_iou=False
                )

            valid_loss_sum += total_loss.item()
            valid_center_loss_sum += center_loss.item()
            valid_radius_loss_sum += radius_loss.item()
            valid_batch_count += 1

            valid_center_errors.append(compute_center_error(output, target_params))
            valid_radius_errors.append(compute_radius_error(output, target_params))
            miou, _, _ = compute_iou_with_gt_mask(output, target_labels, device)
            valid_ious.append(miou)

    # Validation metrics
    loss_valid = valid_loss_sum / valid_batch_count
    miou_valid = np.mean(valid_ious)
    valid_metrics["loss"].append(loss_valid)
    valid_metrics["iou"].append(miou_valid)
    valid_metrics["center_loss"].append(valid_center_loss_sum / valid_batch_count)
    valid_metrics["radius_loss"].append(valid_radius_loss_sum / valid_batch_count)
    valid_metrics["center_error"].append(np.mean(valid_center_errors))
    valid_metrics["radius_error"].append(np.mean(valid_radius_errors))

    scheduler.step(loss_valid)

    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print(f"  Train Loss: {loss_train:.4f} | Valid Loss: {loss_valid:.4f}")
    print(f"  Train mIoU: {miou_train:.4f} | Valid mIoU: {miou_valid:.4f}")
    print(f"  Center Error: {np.mean(valid_center_errors):.2f}px | Radius Error: {np.mean(valid_radius_errors):.2f}px")
    print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

    # Save best model
    if miou_valid > best_valid_iou:
        best_valid_iou = miou_valid
        best_epoch = epoch + 1
        torch.save(model.state_dict(), "best_ellipse_model.pt")
        print(f"  New best model saved! Valid mIoU: {best_valid_iou:.4f}")

print("\n" + "=" * 60)
print("Training Complete!")
print(f"Best Valid mIoU: {best_valid_iou:.4f} (Epoch {best_epoch})")
print("=" * 60)

## 9. Training Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

epochs_range = range(1, len(train_metrics["loss"]) + 1)

# Loss curves
axes[0, 0].plot(epochs_range, train_metrics["loss"], "b-", label="Train Loss", linewidth=2)
axes[0, 0].plot(epochs_range, valid_metrics["loss"], "r-", label="Valid Loss", linewidth=2)
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Training and Validation Loss")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# IoU curves
axes[0, 1].plot(epochs_range, train_metrics["iou"], "b-", label="Train mIoU", linewidth=2)
axes[0, 1].plot(epochs_range, valid_metrics["iou"], "r-", label="Valid mIoU", linewidth=2)
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("mIoU")
axes[0, 1].set_title("Training and Validation mIoU")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Center error
axes[1, 0].plot(epochs_range, train_metrics["center_error"], "b-", label="Train", linewidth=2)
axes[1, 0].plot(epochs_range, valid_metrics["center_error"], "r-", label="Valid", linewidth=2)
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].set_ylabel("Error (pixels)")
axes[1, 0].set_title("Center Error")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Radius error
axes[1, 1].plot(epochs_range, train_metrics["radius_error"], "b-", label="Train", linewidth=2)
axes[1, 1].plot(epochs_range, valid_metrics["radius_error"], "r-", label="Valid", linewidth=2)
axes[1, 1].set_xlabel("Epoch")
axes[1, 1].set_ylabel("Error (pixels)")
axes[1, 1].set_title("Radius Error")
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("training_curves.png", dpi=150, bbox_inches="tight")
plt.show()

print("Training curves saved to training_curves.png")

In [None]:
# Visualize predictions on validation set
model.eval()
num_samples = 4

fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))

with torch.no_grad():
    for i, batch in enumerate(validloader):
        if i >= num_samples:
            break
        
        img, labels, ellipse_params, _, _, _ = batch
        single_img = img[0:1].to(device)
        if USE_CHANNELS_LAST:
            single_img = single_img.to(memory_format=torch.channels_last)
        
        with torch.amp.autocast("cuda", enabled=USE_AMP):
            output = model(single_img)
        
        pred_params = output[0].cpu().numpy()
        pred_cx, pred_cy, pred_rx, pred_ry = denormalize_ellipse_params(
            pred_params[0], pred_params[1], pred_params[2], pred_params[3]
        )
        pred_mask = render_ellipse_mask(pred_cx, pred_cy, pred_rx, pred_ry)
        
        axes[i, 0].imshow(img[0].squeeze().numpy(), cmap="gray")
        axes[i, 0].set_title("Input Image")
        axes[i, 0].axis("off")
        
        axes[i, 1].imshow(labels[0].numpy(), cmap="jet", vmin=0, vmax=1)
        axes[i, 1].set_title("Ground Truth")
        axes[i, 1].axis("off")
        
        axes[i, 2].imshow(pred_mask, cmap="jet", vmin=0, vmax=1)
        axes[i, 2].set_title(f"Prediction\ncx={pred_cx:.1f}, cy={pred_cy:.1f}\nrx={pred_rx:.1f}, ry={pred_ry:.1f}")
        axes[i, 2].axis("off")

plt.tight_layout()
plt.savefig("predictions.png", dpi=150, bbox_inches="tight")
plt.show()

print("Predictions saved to predictions.png")

## 10. Display Final Metrics

In [None]:
# Final metrics summary
print("=" * 60)
print("FINAL TRAINING SUMMARY")
print("=" * 60)
print(f"\nBest Model Performance (Epoch {best_epoch}):")
print(f"  Validation mIoU: {best_valid_iou:.4f}")
print(f"\nFinal Epoch Metrics:")
print(f"  Train Loss: {train_metrics['loss'][-1]:.4f}")
print(f"  Valid Loss: {valid_metrics['loss'][-1]:.4f}")
print(f"  Train mIoU: {train_metrics['iou'][-1]:.4f}")
print(f"  Valid mIoU: {valid_metrics['iou'][-1]:.4f}")
print(f"  Center Error: {valid_metrics['center_error'][-1]:.2f} pixels")
print(f"  Radius Error: {valid_metrics['radius_error'][-1]:.2f} pixels")
print(f"\nModel saved as: best_ellipse_model.pt")
print("=" * 60)

## 11. Export to ONNX

In [None]:
# Load best model and export to ONNX
model.load_state_dict(torch.load("best_ellipse_model.pt"))
model.eval()

# Convert to contiguous format for export
model_export = model.to(memory_format=torch.contiguous_format)

# Create dummy input
dummy_input = torch.randn(1, 1, IMAGE_HEIGHT, IMAGE_WIDTH).to(device)

# Export to ONNX
onnx_path = "ellipse_regression_model.onnx"
torch.onnx.export(
    model_export,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
)

print(f"ONNX model exported to: {onnx_path}")
print(f"File size: {os.path.getsize(onnx_path) / 1024 / 1024:.2f} MB")

In [None]:
# Verify ONNX model
try:
    import onnx
    import onnxruntime as ort
    
    # Load and check ONNX model
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print("ONNX model validation: PASSED")
    
    # Test inference with ONNX Runtime
    ort_session = ort.InferenceSession(onnx_path)
    test_input = np.random.randn(1, 1, IMAGE_HEIGHT, IMAGE_WIDTH).astype(np.float32)
    ort_output = ort_session.run(None, {"input": test_input})
    print(f"ONNX Runtime inference test: PASSED")
    print(f"Output shape: {ort_output[0].shape}")
except ImportError:
    print("Install onnx and onnxruntime to verify: pip install onnx onnxruntime")

## 12. Download Trained Model

In [None]:
# Download files (Colab only)
try:
    from google.colab import files
    print("Downloading trained models...")
    files.download("best_ellipse_model.pt")
    files.download("ellipse_regression_model.onnx")
    files.download("training_curves.png")
    files.download("predictions.png")
except ImportError:
    print("Not running in Google Colab.")
    print("\nModel files saved locally:")
    print("  - best_ellipse_model.pt")
    print("  - ellipse_regression_model.onnx")
    print("  - training_curves.png")
    print("  - predictions.png")

---

## Usage Notes

### Running on Google Colab
1. Go to [Google Colab](https://colab.research.google.com/)
2. Upload this notebook or open from GitHub
3. Set runtime to GPU: `Runtime -> Change runtime type -> T4 GPU`
4. Run all cells in order

### Persistent Storage (Optional)
To save the dataset cache across sessions:
```python
from google.colab import drive
drive.mount('/content/drive')
```

### Adjusting Hyperparameters
- **Colab Free (T4)**: `BATCH_SIZE=16`, `NUM_WORKERS=2`
- **Colab Pro (A100)**: `BATCH_SIZE=32-64`, `NUM_WORKERS=4`

### Model Output
The model predicts 4 normalized values: `[cx, cy, rx, ry]`
- `cx, cy`: Ellipse center (normalized by image dimensions)
- `rx, ry`: Semi-axes radii (normalized by max radius)