# Quick Experimentation Framework

Fast prototyping and testing of model improvements for sensing area detection.

**Goal**: Test 15+ approaches in hours instead of days
**Strategy**: Reduced epochs + subset data + early stopping + quick metrics

In [1]:
import os, math, copy, glob, random, time
from collections import defaultdict
from pathlib import Path
import pandas as pd
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch.amp import autocast, GradScaler

from sklearn.decomposition import PCA
from sklearn.metrics import r2_score

import albumentations as A
import cv2
import numpy as np
import timm

def set_seed(seed=42):
    """Reproducibility setup"""
    # Python random
    random.seed(seed)
    
    # NumPy random
    np.random.seed(seed)
    
    # PyTorch random
    torch.manual_seed(seed)

# Apply seeding
set_seed(42)

  from .autonotebook import tqdm as notebook_tqdm


## Quick Experiment Configuration

In [None]:
# EXPERIMENT CONFIGURATION
QUICK_CONFIG = {
    # Training schedule
    'epochs': {
        'A': 10,  # Axis pretraining
        'B': 15,  # Intersection training
        'C': 20   # End-to-end fine-tuning
    },
    'batch_size': 16,
    'early_stop_patience': 3,  # Note: Early stopping currently DISABLED
    
    # Learning rates per stage
    'lr': {
        'A': {'backbone': 1e-4, 'axis': 3e-4, 'inter': 0.0, 'spatial': 0.0},
        'B': {'backbone': 1e-4, 'axis': 0.0, 'inter': 3e-4, 'spatial': 3e-4},
        'C': {'backbone': 1e-5, 'axis': 1e-5, 'inter': 5e-5, 'spatial': 5e-5}
    },
    
    # Loss weights per stage (for training configuration)
    'stage_loss_weights': {
        'A': {'w_inter': 0.0, 'w_axis': 1.0, 'w_t': 0.0},
        'B': {'w_inter': 1.0, 'w_axis': 0.0, 'w_t': 0.0},
        'C': {'w_inter': 5.0, 'w_axis': 1.0, 'w_t': 5.0}
    },
    
    # Global loss weights (for stereo_two_stage_loss)
    'global_loss_weights': {
        'w_origin': 1.0,
        'w_dir': 5.0,
        'w_xy': 1.0,
        'w_z': 5.0,
        'use_log_depth': True
    },
    
    # Camera intrinsics
    'camera': {
        'alpha': 3.0374e+03,
        'beta': 3.0335e+03,
        'ox': 1.0001e+03,
        'oy': 1.0744e+03
    },
    
    # Thresholds (already defined in THRESHOLDS, kept for reference)
    'decision_threshold': 5.5,
    'max_depth': 220.0,
    
    # Optimizer
    'weight_decay': 1e-5,

    # Decision thresholds for experiment progression
    # NOTE: Adaptive Stage C scheduling is currently DISABLED - all experiments run full 20 epochs
    'thresholds': {
        'stage_B_marginal': 0.10,    # Minimum R¬≤_z to attempt Stage C (DISABLED - kept for reference)
        'stage_B_promising': 0.40,   # R¬≤_z threshold for full Stage C (DISABLED - kept for reference)
        'stage_C_winner': 0.55,      # R¬≤_z for "winner" classification
        'stage_C_decent': 0.45,      # R¬≤_z for "decent" classification
        'angular_threshold': 4.0     # Angular error threshold (degrees)
    }
}


In [3]:
# STAGE CONFIGURATION
STAGE_CONFIG = {
    'A': {
        'freeze': ['offset_depth_head', 'spatial_attention', 'spatial_fusion', 'depth_map_head'],
        'unfreeze': ['axis_head', 'backbone', 'fusion'],
        'e2e': False
    },
    'B': {
        'freeze': ['axis_head'],
        'unfreeze': ['offset_depth_head', 'backbone', 'fusion', 'spatial_attention', 'spatial_fusion', 'depth_map_head'],
        'e2e': False
    },
    'C': {
        'freeze': [],
        'unfreeze': 'all',
        'e2e': True
    }
}

def configure_model_for_stage(model, stage):
    """Configure model parameters for training stage"""
    config = STAGE_CONFIG[stage]

    # Freeze components
    for component_name in config['freeze']:
        if hasattr(model, component_name):
            for p in getattr(model, component_name).parameters():
                p.requires_grad = False

    # Unfreeze components
    if config['unfreeze'] == 'all':
        for p in model.parameters():
            p.requires_grad = True
    else:
        for component_name in config['unfreeze']:
            if hasattr(model, component_name):
                for p in getattr(model, component_name).parameters():
                    p.requires_grad = True

    return config['e2e']

print("‚úÖ Stage configuration created")

‚úÖ Stage configuration created


## Data Loading

In [4]:
# Data transforms
transformations = [A.LongestMaxSize(max_size=224),
                   A.PadIfNeeded(min_height=224, min_width=224),
                   A.Normalize(),
                   A.ToTensorV2()]

transform = A.Compose(transformations,
                      seed=42,
                      keypoint_params=A.KeypointParams(format='xy', remove_invisible=False),
                      additional_targets={
                          'image_right': 'image',
                          'depth_map': 'mask'
                      })

def compute_pca_axis(points):
    """Fit PCA to 2D/3D points and return origin, direction."""
    points_mean = points.mean(axis=0)
    pca = PCA(n_components=2)
    pca.fit(points)
    direction = pca.components_[0]  # unit vector
    origin = points_mean[:2]  # take x,y as origin
    return origin, direction

def find_valid_depth(depth_map, x, y, max_search=60):
    for r in range(1, max_search+1):
        for dx in range(-r, r+1):
            for dy in range(-r, r+1):
                nx, ny = x+dx, y+dy
                if 0 <= nx < depth_map.shape[1] and 0 <= ny < depth_map.shape[0]:
                    val = depth_map[ny, nx]
                    if val > 0:
                        return val
    return 0.0

In [None]:
# ============================================================
# AUGMENTATION OPTIONS - Toggle ON/OFF
# ============================================================

# Set use_augmentation = True or False
use_augmentation = False  # Change this to test with/without augmentation

if use_augmentation:
    print("üîÑ Using AUGMENTATION")
    transformations = [
        # Geometric augmentations
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(
            shift_limit=0.05,
            scale_limit=0.1,
            rotate_limit=10,
            border_mode=0,
            p=0.5
        ),
        
        # Color augmentations
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1.0),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=20, p=1.0),
            A.RandomGamma(gamma_limit=(80, 120), p=1.0),
        ], p=0.5),
        
        # Noise and blur
        A.OneOf([
            A.GaussNoise(var_limit=(10.0, 50.0), p=1.0),
            A.GaussianBlur(blur_limit=(3, 5), p=1.0),
            A.MotionBlur(blur_limit=5, p=1.0),
        ], p=0.3),
        
        # Base transforms (always applied)
        A.LongestMaxSize(max_size=224),
        A.PadIfNeeded(min_height=224, min_width=224),
        A.Normalize(),
        A.ToTensorV2()
    ]
else:
    print("‚ùå NO AUGMENTATION - Basic transforms only")
    transformations = [
        A.LongestMaxSize(max_size=224),
        A.PadIfNeeded(min_height=224, min_width=224),
        A.Normalize(),
        A.ToTensorV2()
    ]

# Update transform
transform = A.Compose(
    transformations,
    seed=42,
    keypoint_params=A.KeypointParams(format='xy', remove_invisible=False),
    additional_targets={
        'image_right': 'image',
        'depth_map': 'mask'
    }
)

print(f"‚úÖ Transforms configured: {len(transformations)} steps")

In [5]:
class StereoIntersectionDataset(Dataset):
    def __init__(self, root_dir, transform=None, max_depth=220.0):
        self.left_img_paths = sorted(glob.glob(os.path.join(root_dir, "left", "images", "*.jpg")))
        self.right_img_paths = sorted(glob.glob(os.path.join(root_dir, "right", "images", "*.jpg")))
        self.probe_axis_paths = sorted(glob.glob(os.path.join(root_dir, "left", "probe_axis", "*.txt")))
        self.depth_map_paths = sorted(glob.glob(os.path.join(root_dir, "left", "depth_labels", "*.npy")))
        
        # Store filenames for optional return
        self.transform = transform
        self.max_depth = max_depth

        # Infer split from root_dir (e.g., "data/processed/train" -> "train")
        self.split = os.path.basename(root_dir.rstrip("/"))

        # Load ground truth x,y from CenterPt.txt
        gt_xy = []
        with open(os.path.join(root_dir, "left", "labels", "CenterPt.txt"), 'r') as f:
            for line in f:
                _, x_str, y_str = line.strip().split(",")
                gt_xy.append((float(x_str), float(y_str)))
        self.gt_xy = np.array(gt_xy, dtype=np.float32)

    def __len__(self):
        return len(self.left_img_paths)

    def __getitem__(self, idx):
        left_img = cv2.imread(self.left_img_paths[idx])
        right_img = cv2.imread(self.right_img_paths[idx])
        left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
        right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB)
        depth_map = np.load(self.depth_map_paths[idx]) 
        points = np.loadtxt(self.probe_axis_paths[idx])

        if self.transform:
            transformed = self.transform(
                image=left_img,
                image_right=right_img,
                keypoints=[self.gt_xy[idx]] + points.tolist(),
                depth_map=depth_map
            )
            left_img = transformed["image"]
            right_img = transformed["image_right"]
            keypoints = transformed["keypoints"]
            depth_map = transformed["depth_map"]

        probe_axis_mean, direction = compute_pca_axis(np.array(keypoints[1:]))

        _, img_h, img_w = left_img.shape
        
        probe_axis_mean = np.array(probe_axis_mean, dtype=np.float32) / np.array([img_w, img_h], dtype=np.float32)
        probe_axis = torch.tensor(probe_axis_mean, dtype=torch.float32)
        probe_dir = torch.tensor(direction, dtype=torch.float32)
        
        x, y = keypoints[0]
        x_idx = int(np.clip(round(x), 0, depth_map.shape[1] - 1))
        y_idx = int(np.clip(round(y), 0, depth_map.shape[0] - 1))
        z = depth_map[y_idx, x_idx]
        if z == 0.0:
            z = find_valid_depth(depth_map, x_idx, y_idx)
        if z <= 0:
            z = 1e-6
        intersect_norm = np.array(keypoints[0]) / np.array([img_w, img_h])
        target = torch.tensor(intersect_norm.tolist() + [z/self.max_depth], dtype=torch.float32)

        batch_dict = {
            "left": left_img,
            "right": right_img,
            "gt_origin": probe_axis,
            "gt_dir": probe_dir,
            "gt_intersection": target,
            "depth_map": depth_map
        }

        return batch_dict

In [6]:
# Load full datasets
train_dataset = StereoIntersectionDataset("data/processed/train", transform=transform)
val_dataset = StereoIntersectionDataset("data/processed/val", transform=transform)
test_dataset = StereoIntersectionDataset("data/processed/test", transform=transform)

# Create generator for reproducible shuffling
g_train = torch.Generator()
g_train.manual_seed(42)

train_loader = DataLoader(
    train_dataset, 
    batch_size=QUICK_CONFIG['batch_size'], 
    shuffle=True,
    generator=g_train,  # Reproducible shuffle
    num_workers=0  # For reproducibility
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=QUICK_CONFIG['batch_size'], 
    shuffle=False,
    num_workers=0  # For reproducibility
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=QUICK_CONFIG['batch_size'], 
    shuffle=False
)

print(f"Data loaded successfully")
print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)}")

Data loaded successfully
Train: 940 | Val: 118


## Baseline Model

In [7]:
class StereoTwoStageNet(nn.Module):
    """Baseline model - Original architecture without modifications"""
    def __init__(self, backbone_name="resnet18", pretrained=True, seed=42):
        super().__init__()
        
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained, features_only=True)
        self.feature_dims = [f["num_chs"] for f in self.backbone.feature_info]

        self.proj = nn.ModuleList([nn.Conv2d(c, 128, 1) for c in self.feature_dims])

        fused_channels = 128 * len(self.feature_dims) * 2  
        self.fusion = nn.Sequential(
            nn.Conv2d(fused_channels, 256, 1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )

        self.axis_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 4)  # (x0, y0, dx, dy)
        )

        self.offset_depth_head = nn.Sequential(
            nn.Linear(256 + 4, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 2)  # (t, z_raw)
        )

        self.softplus = nn.Softplus(beta=1.0)

    def _fused_vec(self, left_img, right_img):
        feats_l = self.backbone(left_img)
        feats_r = self.backbone(right_img)
        H, W = feats_l[-1].shape[2:]
        
        fused_scales = []
        for fl, fr, proj in zip(feats_l, feats_r, self.proj):
            fl = proj(fl)
            fr = proj(fr)
            fl = F.adaptive_avg_pool2d(fl, (H, W))
            fr = F.adaptive_avg_pool2d(fr, (H, W))
            fused_scales.append(torch.cat([fl, fr], dim=1))

        x = torch.cat(fused_scales, dim=1)
        x = self.fusion(x)
        v = x.view(x.size(0), -1)
        return v, x
    
    def forward(self, left_img, right_img):
        fused_vec, fused_4d = self._fused_vec(left_img, right_img)
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        conditioned = torch.cat([fused_vec, origin.detach(), direction.detach()], dim=1)
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection
        }

    def forward_e2e(self, left_img, right_img):
        fused_vec, fused_4d = self._fused_vec(left_img, right_img)
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        conditioned = torch.cat([fused_vec, origin, direction], dim=1)  # no detach
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection
        }

print("‚úÖ StereoTwoStageNet model defined.")

‚úÖ StereoTwoStageNet model defined.


## Training Utilities

In [8]:
def berhu_loss(pred, target, eps=1e-6, use_focal=False, focal_alpha=0.25, focal_gamma=1.5):
    """
    BerHu (Reverse Huber) loss with optional focal weighting
    
    Args:
        pred: predictions
        target: ground truth
        eps: small epsilon for numerical stability
        use_focal: if True, apply focal weighting to focus on hard samples
        focal_alpha: focal loss alpha parameter (default: 0.25) - modulation strength
        focal_gamma: focal loss gamma parameter (default: 1.5) - focus sharpness
    
    Returns:
        Mean of the (optionally focal-weighted) BerHu loss (scalar)
    
    FIXED: Corrected focal loss formula for regression
    - Old: weight = alpha * (normalized_loss)^gamma (too aggressive, ignored easy samples)
    - New: weight = 1.0 + alpha * (normalized_loss)^gamma (balanced modulation)
    """
    x = pred - target
    abs_x = torch.abs(x)
    c = 0.2 * abs_x.max().clamp(min=eps)
    l1_mask = abs_x <= c
    l1_loss = abs_x
    l2_loss = (x ** 2 + c ** 2) / (2 * c + eps)
    loss = torch.where(l1_mask, l1_loss, l2_loss)

    # Apply focal weighting if requested
    if use_focal:
        # Normalize loss to [0, 1] range for stable weighting
        max_loss = loss.detach().max() + eps
        normalized_loss = loss / max_loss
        
        # Focal modulation: gradually up-weight hard samples
        # Easy samples (loss‚âà0): weight ‚âà 1.0
        # Hard samples (loss‚âàmax): weight ‚âà 1.0 + alpha
        focal_weight = 1.0 + focal_alpha * torch.pow(normalized_loss, focal_gamma)
        loss = focal_weight * loss

    return loss.mean()

def compute_gt_t(gt_inter, gt_origin, gt_dir, eps=1e-6):
    xy_diff = gt_inter[:, :2] - gt_origin[:, :2]
    gt_dir_norm = F.normalize(gt_dir, dim=1)
    t_gt = (xy_diff * gt_dir_norm).sum(dim=1, keepdim=True) / (gt_dir_norm.norm(dim=1, keepdim=True)**2 + eps)
    return t_gt

def stereo_two_stage_loss(outputs, targets, **kwargs):
    w_origin = kwargs.get('w_origin', 1.0)
    w_dir = kwargs.get('w_dir', 5.0)
    w_xy = kwargs.get('w_xy', 1.0)
    w_z = kwargs.get('w_z', 5.0)
    w_inter = kwargs.get('w_inter', 1.0)
    w_axis = kwargs.get('w_axis', 0.2)
    w_t = kwargs.get('w_t', 1.0)
    use_log_depth = kwargs.get('use_log_depth', True)
    eps = kwargs.get('eps', 1e-6)

    # Focal loss parameters
    use_focal = kwargs.get('use_focal', False)
    focal_alpha = kwargs.get('focal_alpha', 0.25)
    focal_gamma = kwargs.get('focal_gamma', 1.5)

    pred_inter = outputs["intersection"]
    pred_z = outputs["depth_z"]
    pred_origin = outputs["origin"]
    pred_dir = outputs["direction"]
    pred_t = outputs["offset_t"]

    gt_inter = targets["gt_intersection"]
    gt_origin = targets["gt_origin"]
    gt_dir = targets["gt_dir"]

    # Intersection losses
    loss_xy = berhu_loss(pred_inter[:, :2], gt_inter[:, :2], eps=eps, 
                        use_focal=use_focal, focal_alpha=focal_alpha, focal_gamma=focal_gamma)

    # Z loss
    if use_log_depth:
        pred_log_z = torch.log(pred_z + eps)
        gt_log_z = torch.log(gt_inter[:, 2:] + eps)
        loss_z = berhu_loss(pred_log_z, gt_log_z, eps=eps, use_focal=use_focal,
                           focal_alpha=focal_alpha, focal_gamma=focal_gamma)
    else:
        loss_z = berhu_loss(pred_z.squeeze(1), gt_inter[:, 2], eps=eps, use_focal=use_focal,
                           focal_alpha=focal_alpha, focal_gamma=focal_gamma)

    loss_inter = w_xy * loss_xy + w_z * loss_z

    # Axis losses
    lo = ((pred_origin - gt_origin) ** 2).sum(dim=1).mean()
    pd = F.normalize(pred_dir, dim=1)
    gd = F.normalize(gt_dir, dim=1)
    cos = (pd * gd).sum(dim=1).clamp(-1+eps, 1-eps)
    ld = (1.0 - cos).mean()
    loss_axis = w_origin * lo + w_dir * ld

    # Offset loss
    gt_t = compute_gt_t(gt_inter, gt_origin, gt_dir, eps=eps)
    loss_t = ((pred_t - gt_t) ** 2).mean()

    total = w_inter * loss_inter + w_axis * loss_axis + w_t * loss_t
    return total

print("‚úÖ Loss functions defined.")

‚úÖ Loss functions defined.


In [9]:
def batch_to_device(batch, device):
    """
    Move batch to device. Handles dict batches and tuple/list batches.
    
    When return_filename=True, DataLoader collates as:
    - batch = (batch_dict, [filenames...], (split, split, ...)) [3-element tuple]
      where split is collated into a tuple by DataLoader
    """
    # Handle tuple/list format from return_filename=True
    if isinstance(batch, (list, tuple)):
        # 3 elements: (batch_dict, filenames, splits)
        if len(batch) == 3 and isinstance(batch[0], dict):
            batch_dict, filenames, splits = batch
            
            # Move tensors to device
            out = {}
            for k, v in batch_dict.items():
                if torch.is_tensor(v):
                    out[k] = v.to(device)
                else:
                    out[k] = v
            
            # Add filenames
            out['filename'] = list(filenames) if not isinstance(filenames, list) else filenames
            
            # Handle splits - DataLoader collates them into a tuple/list
            # Convert to list if it's a tuple
            if isinstance(splits, (tuple, list)):
                out['split'] = list(splits)
            else:
                # Single split value - replicate for all items
                out['split'] = [splits] * len(out['filename'])
            
            return out
        
        # 2 elements: (batch_dict, filenames) - legacy support
        elif len(batch) == 2 and isinstance(batch[0], dict):
            batch_dict, filenames = batch
            
            # Move tensors to device
            out = {}
            for k, v in batch_dict.items():
                if torch.is_tensor(v):
                    out[k] = v.to(device)
                else:
                    out[k] = v
            
            # Add filenames and infer split as 'train' (fallback)
            out['filename'] = list(filenames) if not isinstance(filenames, list) else filenames
            out['split'] = ['train'] * len(out['filename'])
            
            return out
    
    # Handle standard dict format
    if isinstance(batch, dict):
        out = {}
        for k, v in batch.items():
            if torch.is_tensor(v):
                out[k] = v.to(device)
            else:
                out[k] = v
        return out
    
    # Fallback error
    raise TypeError(f"Unexpected batch type: {type(batch)} with {len(batch) if isinstance(batch, (list, tuple)) else 'N/A'} elements. Expected dict or tuple/list of (dict, filenames[, split])")

def transform2Dto3D_torch(Z, uv, alpha=3.0374e+03, beta=3.0335e+03, ox=1.0001e+03, oy=1.0744e+03):
    u, v = uv[:, 0], uv[:, 1]
    X = (Z * (u - ox)) / alpha
    Y = (Z * (v - oy)) / beta
    return torch.stack([X, Y, Z], dim=1)

@torch.no_grad()
def r2_score_torch(pred, target):
    """Torch-based R¬≤ score computation"""
    ss_res = ((pred - target)**2).sum(dim=0)
    mean_t = target.mean(dim=0, keepdim=True)
    ss_tot = ((target - mean_t)**2).sum(dim=0)
    r2 = 1.0 - ss_res / (ss_tot + 1e-12)
    return r2

@torch.no_grad()
def angle_deg(pred_dir, gt_dir):
    pd = F.normalize(pred_dir, dim=1)
    gd = F.normalize(gt_dir, dim=1)
    cos = (pd * gd).sum(dim=1).clamp(-1+1e-6, 1-1e-6)
    return torch.acos(cos).mean().item() * 180.0 / math.pi

@torch.no_grad()
def errors_2d_3d(pred_inter, gt_inter, img_w=256, img_h=256, max_depth=220.0):
    pred_px = torch.empty_like(pred_inter)
    gt_px = torch.empty_like(gt_inter)

    pred_px[:, 0] = pred_inter[:, 0] * img_w
    pred_px[:, 1] = pred_inter[:, 1] * img_h  
    pred_px[:, 2] = pred_inter[:, 2] * max_depth

    gt_px[:, 0] = gt_inter[:, 0] * img_w
    gt_px[:, 1] = gt_inter[:, 1] * img_h
    gt_px[:, 2] = gt_inter[:, 2] * max_depth

    e2 = torch.norm(pred_px[:, :2] - gt_px[:, :2], dim=1).mean()
    
    pred3d = transform2Dto3D_torch(pred_px[:, 2], pred_px[:, :2])
    gt3d = transform2Dto3D_torch(gt_px[:, 2], gt_px[:, :2])
    e3 = torch.norm(pred3d - gt3d, dim=1).mean()
    
    return e2.item(), e3.item()

In [10]:
def make_optim(model, stage, lr_spatial=None):
    """
    Create optimizer with stage-specific learning rates
    
    Args:
        model: Model to optimize
        stage: Training stage ('A', 'B', or 'C')
        lr_spatial: Optional learning rate for spatial attention components
    
    Returns:
        torch.optim.Adam optimizer
    """
    lr_config = QUICK_CONFIG['lr'][stage]
    wd = QUICK_CONFIG['weight_decay']

    # Base parameter groups
    param_groups = [
        {"params": model.backbone.parameters(),
         "lr": lr_config['backbone'], "weight_decay": wd},
        {"params": model.axis_head.parameters(),
         "lr": lr_config['axis'], "weight_decay": wd}
    ]

    # Handle multi-head depth predictor
    if hasattr(model, 'head_coarse'):
        # Multi-head model: Add each head with intersection learning rate
        param_groups.extend([
            {"params": model.head_coarse.parameters(),
             "lr": lr_config['inter'], "weight_decay": wd},
            {"params": model.head_medium.parameters(),
             "lr": lr_config['inter'], "weight_decay": wd},
            {"params": model.head_fine.parameters(),
             "lr": lr_config['inter'], "weight_decay": wd}
        ])
    elif hasattr(model, 'offset_depth_head'):
        # Single head model
        param_groups.append(
            {"params": model.offset_depth_head.parameters(),
             "lr": lr_config['inter'], "weight_decay": wd}
        )

    # Add spatial attention if present
    if hasattr(model, 'spatial_attention'):
        lr_spatial_actual = lr_spatial if lr_spatial is not None else lr_config.get('spatial', 0.0)
        if lr_spatial_actual > 0:
            param_groups.extend([
                {"params": model.spatial_attention.parameters(),
                 "lr": lr_spatial_actual, "weight_decay": wd},
                {"params": model.spatial_fusion.parameters(),
                 "lr": lr_spatial_actual, "weight_decay": wd}
            ])

    # Add skip connection fusion layer if present
    if hasattr(model, 'skip_fusion'):
        param_groups.append(
            {"params": model.skip_fusion.parameters(),
             "lr": lr_config['backbone'], "weight_decay": wd}
        )
    
    # Add depth map head for auxiliary depth models
    if hasattr(model, 'depth_map_head'):
        param_groups.append(
            {"params": model.depth_map_head.parameters(),
             "lr": lr_config['inter'], "weight_decay": wd}
        )
    
    return torch.optim.Adam(param_groups)

## ‚ö° Quick Experimentation Framework

In [11]:
@torch.no_grad()
def validate(model, loader, device, e2e=False, strategy='standard', **kwargs):
    """
    Validate model on dataset with error handling and detailed metrics.
    Args:
        model: Model to validate
        loader: DataLoader for validation data
        device: Device to run on
        e2e: If True, use end-to-end forward pass
        strategy: 'standard', 'MultiHead', or 'auxiliary' for loss handling
        **kwargs: Additional args for loss functions
    Returns:
        Dictionary of validation metrics
    """
    model.eval()
    
    # Validation accumulators
    metrics = {
        'val_loss': 0.0,
        'batch_count': 0,
        'samples_seen': 0,
        'head_metrics': defaultdict(lambda: defaultdict(list)) if strategy == 'MultiHead' else None,
        'depth_map_error': [] if strategy == 'auxiliary' else None  # Initialize for auxiliary
    }
    
    # Prediction collectors
    predictions = defaultdict(list)
    targets = defaultdict(list)
    
    try:
        for batch_idx, batch in enumerate(loader):
            batch = batch_to_device(batch, device)
            L, R = batch["left"], batch["right"]
            
            # Ground truth
            go, gd = batch["gt_origin"], batch["gt_dir"]
            gi = batch["gt_intersection"]

            # Load auxiliary depth maps if needed
            if strategy == 'auxiliary' and 'depth_map' in batch:
                batch['gt_depth_map'] = torch.stack([
                    load_depth_map_for_batch(depth).to(device)
                    for depth in batch['depth_map']
                ])
            
            # Forward pass with strategy-specific handling
            try:
                out = model.forward_e2e(L, R) if e2e else model(L, R)

                # Compute loss based on strategy
                if strategy == 'auxiliary':
                    loss = auxiliary_depth_loss(out, batch, **kwargs)
                    
                    # Track depth map error if available
                    if 'depth_map' in out and 'gt_depth_map' in batch:
                        pred_depth = out['depth_map']
                        gt_depth = batch['gt_depth_map']
                        valid_mask = gt_depth > 0.1
                        if valid_mask.sum() > 10:
                            depth_error = F.l1_loss(pred_depth[valid_mask], gt_depth[valid_mask])
                            metrics['depth_map_error'].append(depth_error.item())
                            
                elif strategy == 'MultiHead':
                    loss, head_losses = multihead_loss(out, batch, **kwargs)
                    
                    # Track per-head predictions
                    if 'head_preds' in out:
                        for head_name, preds in out['head_preds'].items():
                            # Get intersection point for this head
                            head_intersection = torch.cat([
                                out["origin"] + preds[:, 0:1] * F.normalize(out["direction"], dim=1),
                                F.softplus(preds[:, 1:2])
                            ], dim=1)
                            
                            # Store predictions and losses
                            metrics['head_metrics'][head_name]['predictions'].append(head_intersection)
                            metrics['head_metrics'][head_name]['losses'].append(head_losses[head_name].item())
                else:
                    loss = stereo_two_stage_loss(out, batch, **kwargs)
                
                # Update metrics
                metrics['val_loss'] += float(loss.item())
                metrics['batch_count'] += 1
                metrics['samples_seen'] += L.size(0)
                
                # Collect predictions and targets
                predictions['origin'].append(out["origin"])
                predictions['direction'].append(out["direction"])
                predictions['intersection'].append(out["intersection"])
                targets['origin'].append(go)
                targets['direction'].append(gd)
                targets['intersection'].append(gi)
                
            except RuntimeError as e:
                print(f"\nWarning: Error in batch {batch_idx}: {str(e)}")
                continue
            
    except Exception as e:
        print(f"\nError during validation: {str(e)}")
        return {
            "val_loss": float('inf'),
            "error": str(e)
        }
    
    # Compute final metrics
    try:
        # Concatenate collected tensors
        PO = torch.cat(predictions['origin'])
        PD = torch.cat(predictions['direction'])
        PI = torch.cat(predictions['intersection'])
        GO = torch.cat(targets['origin'])
        GD = torch.cat(targets['direction'])
        GI = torch.cat(targets['intersection'])
        
        # Main metrics
        final_metrics = {
            "val_loss": metrics['val_loss'] / max(1, metrics['batch_count']),
            "ang_deg": angle_deg(PD, GD),
            "r2o": r2_score_torch(PO, GO).tolist()[:2],
            "r2d": r2_score_torch(PD, F.normalize(GD, dim=1)).tolist()[:2],
            "r2xyz": r2_score_torch(PI, GI).tolist()[:3],
            "samples_validated": metrics['samples_seen']
        }

        # 2D/3D errors
        e2, e3 = errors_2d_3d(PI, GI)
        final_metrics.update({"e2d": e2, "e3d": e3})

        # Add auxiliary depth metrics if available
        if strategy == 'auxiliary' and metrics['depth_map_error']:
            final_metrics['depth_map_error'] = sum(metrics['depth_map_error']) / len(metrics['depth_map_error'])
        
        # Strategy-specific metrics
        if strategy == 'MultiHead' and metrics['head_metrics']:
            head_metrics = {}
            for head_name, head_data in metrics['head_metrics'].items():
                if head_data['predictions']:
                    P = torch.cat(head_data['predictions'])
                    avg_loss = sum(head_data['losses']) / len(head_data['losses'])
                    r2xyz = r2_score_torch(P, GI).tolist()[:3]
                    e2, e3 = errors_2d_3d(P, GI)
                    
                    head_metrics[head_name] = {
                        'loss': avg_loss,
                        'r2_z': r2xyz[2],
                        'r2_xy': (r2xyz[0] + r2xyz[1]) / 2,
                        'e2d': e2,
                        'e3d': e3
                    }
            final_metrics['head_metrics'] = head_metrics
            
        # Memory cleanup
        del predictions, targets
        if 'head_metrics' in metrics:
            del metrics['head_metrics']
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        return final_metrics
        
    except Exception as e:
        print(f"\nError computing final metrics: {str(e)}")
        return {
            "val_loss": metrics['val_loss'] / max(1, metrics['batch_count']),
            "error": str(e),
            "partial_results": True
        }

In [None]:
class EarlyStopper:
    """Early stopping based on validation metrics"""
    def __init__(self, patience=3, min_delta=0.001, stage='A'):
        self.patience = patience
        self.min_delta = min_delta
        self.stage = stage
        self.counter = 0
        self.best_loss = float('inf')
        self.best_metric = -float('inf') if stage in ['A', 'B'] else float('inf')
        
    def should_stop(self, logs):
        """Return True if should stop training"""
        if self.stage == 'A':
            # Stage A: Stop if angular error < 8¬∞ or no improvement
            metric = logs['ang_deg']
            if metric < 8.0:  # Good enough angular error
                return True
            improved = logs['val_loss'] < (self.best_loss - self.min_delta)
        elif self.stage == 'B':
            # Stage B: Stop if R¬≤_z > 0.3 or no improvement  
            metric = logs['r2xyz'][2]
            if metric > 0.3:  # Decent R¬≤_z
                return True
            improved = logs['val_loss'] < (self.best_loss - self.min_delta)
        else:  # Stage C
            # Stage C: Stop if trend is clear
            improved = logs['val_loss'] < (self.best_loss - self.min_delta)
            
        if improved:
            self.best_loss = logs['val_loss']
            self.counter = 0
        else:
            self.counter += 1
            
        return self.counter >= self.patience

### ‚ö†Ô∏è Early Stopping Status: DISABLED

Early stopping has been temporarily disabled to allow experiments to run for the full configured number of epochs. This ensures consistent training across all experiments and makes results more comparable.

The `EarlyStopper` class is still defined below but is not actively used in the training loop.

In [12]:
def print_epoch_progress(epoch, max_epochs, train_metrics, val_metrics, strategy, epoch_time, best_epoch, stage, current_weights=None):
    """Helper function to print epoch progress - single clean line per epoch"""
    avg_loss = train_metrics['loss'] / len(train_loader)
    
    if strategy == 'MultiHead':
        # MultiHead: Show angular error in Stage A, depth metrics in later stages
        if stage == 'A' or 'r2xyz' not in val_metrics:
            ang = val_metrics.get('ang_deg', 0)
            msg = f"Epoch {epoch}/{max_epochs} | Loss: {avg_loss:.4f} | Val: {val_metrics['val_loss']:.4f} | Ang: {ang:.2f}¬∞\n"
        else:
            r2_z = val_metrics['r2xyz'][2]
            e3d = val_metrics.get('e3d', 0)
            msg = f"Epoch {epoch}/{max_epochs} | Loss: {avg_loss:.4f} | Val: {val_metrics['val_loss']:.4f} | R¬≤_z: {r2_z:.3f} | 3D: {e3d:.2f}mm | Time: {epoch_time:.1f}s\n"
        print(msg)

        # Per-head metrics (only show in stages B/C when available)
        if 'head_metrics' in val_metrics and stage != 'A' and 'r2xyz' in val_metrics:
            head_summary = " | ".join([f"{name}: {metrics['r2_z']:.3f}" 
                                       for name, metrics in val_metrics['head_metrics'].items()])
            msg = f"  ‚Ü≥ Heads: {head_summary}\n"
            print(msg)
    
    elif strategy == 'curriculum':
        # Curriculum: Show unweighted metrics and current w_z
        if stage == 'A':
            ang = val_metrics.get('ang_deg', 0)
            msg = f"Epoch {epoch}/{max_epochs} | Loss: {avg_loss:.4f} | Val: {val_metrics['val_loss']:.4f} | Ang: {ang:.2f}¬∞\n"
        else:
            r2_z = val_metrics['r2xyz'][2]
            e3d = val_metrics.get('e3d', 0)
            
            # Get current w_z if available
            w_z = current_weights.get('w_z', 5.0) if current_weights else 5.0
            
            # Show both weighted loss and unweighted metrics
            # Note: Loss is weighted and NOT comparable across epochs
            msg = f"Epoch {epoch}/{max_epochs} | Loss: {avg_loss:.4f} (w_z={w_z:.1f}) | R¬≤_z: {r2_z:.3f} | 3D: {e3d:.2f}mm\n"
        print(msg)

    else:  # Standard progress - single line
        if stage == 'A':
            ang = val_metrics.get('ang_deg', 0)
            msg = f"Epoch {epoch}/{max_epochs} | Loss: {avg_loss:.4f} | Val: {val_metrics['val_loss']:.4f} | Ang: {ang:.2f}¬∞\n"
        else:
            r2_z = val_metrics['r2xyz'][2]
            e3d = val_metrics.get('e3d', 0)
            msg = f"Epoch {epoch}/{max_epochs} | Loss: {avg_loss:.4f} | Val: {val_metrics['val_loss']:.4f} | R¬≤_z: {r2_z:.3f} | 3D: {e3d:.2f}mm\n"
        print(msg)

In [13]:
def train_stage(model, train_loader, val_loader, stage, max_epochs, device,
                use_amp=False, use_focal=False, focal_alpha=0.25, focal_gamma=1.5,
                strategy='standard', **strategy_kwargs):
    """
    Enhanced training function with better progress tracking, memory management,
    and error handling.
    
    Args:
        model: Model to train
        train_loader: Training data loader
        val_loader: Validation data loader
        stage: Training stage ('A', 'B', or 'C')
        max_epochs: Maximum number of epochs
        device: Device to train on
        use_amp: Use automatic mixed precision
        use_focal: Use focal loss for depth prediction
        focal_alpha: Focal loss alpha parameter
        focal_gamma: Focal loss gamma parameter
        strategy: Training strategy (standard/progressive/auxiliary/curriculum/spatial)
        **strategy_kwargs: Strategy-specific parameters
    
    Returns:
        dict: Best validation metrics including epoch number
    """
    # Initialize tracking
    history = {
        'train_loss': [],
        'val_loss': [],
        'r2_z': [],
        'e3d': [],
        'lr': [],
        'speed': [],
        'epoch_times': []
    }
    
    if strategy == 'MultiHead':
        for head in ['coarse', 'medium', 'fine']:
            history[f'{head}_loss'] = []
            history[f'{head}_r2z'] = []
            history[f'{head}_e3d'] = []

    # Configure training
    e2e = configure_model_for_stage(model, stage)
    optim = make_optim(model, stage)
    scaler = GradScaler(enabled=use_amp)
    # early_stopper = EarlyStopper(patience=QUICK_CONFIG['early_stop_patience'], stage=stage)  # DISABLED
    
    # Get loss weights and update with strategy-specific settings
    loss_weights = QUICK_CONFIG['stage_loss_weights'][stage].copy()
    loss_weights.update(QUICK_CONFIG['global_loss_weights'])
    
    try:
        # Strategy-specific setup
        if strategy == 'progressive':
            if stage == 'C':
                total_epochs = max_epochs
            else:
                model.detach_alpha = torch.tensor(1.0, device=device)
        elif strategy == 'auxiliary':
            w_depth_map_dict = strategy_kwargs.get('w_depth_map', {
                'A': 0.0, 'B': 0.5, 'C': 1.0
            })
            w_depth_map = w_depth_map_dict.get(stage, 0.5)
            loss_weights['w_depth_map'] = w_depth_map
        elif strategy == 'MultiHead':
            head_weights = strategy_kwargs.get('head_weights', {
                'coarse': 0.2, 'medium': 0.3, 'fine': 0.5
            })
            if abs(sum(head_weights.values()) - 1.0) > 1e-6:
                print(f"‚ö†Ô∏è Head weights sum to {sum(head_weights.values()):.3f}, not 1.0")
            loss_weights['head_weights'] = head_weights
    except Exception as e:
        print(f"\n‚ö†Ô∏è Strategy setup error: {str(e)}")
        print("Falling back to standard training")
        strategy = 'standard'

    # Training state
    best_metrics = None
    best_epoch = 0
    best_head_metrics = None
    total_batches = len(train_loader)

    # Main training loop
    for epoch in range(1, max_epochs + 1):
        epoch_start = time.time()
        
        # Strategy updates
        if strategy == 'progressive' and stage == 'C':
            alpha = 1.0 - epoch / total_epochs
            if hasattr(model, 'detach_alpha'):
                model.detach_alpha = torch.tensor(alpha, device=device)
        elif strategy == 'curriculum':
            loss_weights['w_z'] = get_curriculum_weights(stage, epoch-1, max_epochs)['w_z']

        # Training
        model.train()
        train_metrics = defaultdict(float)
        batch_times = []
        
        for batch_idx, batch in enumerate(train_loader):
            batch_start = time.time()
            
            # Process batch
            batch = batch_to_device(batch, device)
            
            # Process depth maps if using auxiliary strategy
            if strategy == 'auxiliary' and 'depth_map' in batch:
                batch['gt_depth_map'] = torch.stack([
                    load_depth_map_for_batch(depth).to(device)
                    for depth in batch['depth_map']
                ])

            # Forward pass
            optim.zero_grad()
            with autocast(device_type=device.type, enabled=use_amp):
                out = model.forward_e2e(batch["left"], batch["right"]) if e2e else model(batch["left"], batch["right"])
                
                # Compute loss based on strategy
                if strategy == 'auxiliary':
                    loss = auxiliary_depth_loss(out, batch, **loss_weights)
                elif strategy == 'MultiHead':
                    loss, head_losses = multihead_loss(out, batch, **loss_weights)
                    for name, hloss in head_losses.items():
                        train_metrics[f'{name}_loss'] += hloss.item()
                else:
                    loss = stereo_two_stage_loss(out, batch, **loss_weights)

            # Backward pass
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
            
            # Update metrics
            train_metrics['loss'] += loss.item()
            batch_time = time.time() - batch_start
            batch_times.append(batch_time)

        # Validation (no batch progress, cleaner output)
        val_metrics = validate(
            model, val_loader, device,
            e2e=e2e, strategy=strategy,
            use_focal=use_focal,
            focal_alpha=focal_alpha,
            focal_gamma=focal_gamma,
            **loss_weights
        )
        
        # Update history
        epoch_time = time.time() - epoch_start
        avg_speed = len(train_loader) / epoch_time

        # Update history in train_stage()
        try:
            # Basic metrics with safe defaults
            history['train_loss'].append(train_metrics['loss'] / len(train_loader))
            history['val_loss'].append(val_metrics.get('val_loss', float('inf')))
            history['r2_z'].append(val_metrics.get('r2xyz', [0, 0, 0])[2] if 'r2xyz' in val_metrics else 0)
            history['e3d'].append(val_metrics.get('e3d', float('inf')))
            history['lr'].append(optim.param_groups[0]['lr'])
            history['speed'].append(avg_speed)
            history['epoch_times'].append(epoch_time)
            
            # Strategy-specific metrics
            if strategy == 'MultiHead':
                for head in ['coarse', 'medium', 'fine']:
                    if head in val_metrics.get('head_metrics', {}):
                        head_metrics = val_metrics['head_metrics'][head]
                        history[f'{head}_loss'].append(train_metrics.get(f'{head}_loss', 0) / len(train_loader))
                        history[f'{head}_r2z'].append(head_metrics.get('r2_z', 0))
                        history[f'{head}_e3d'].append(head_metrics.get('e3d', float('inf')))
            elif strategy == 'auxiliary':
                # Initialize depth_map_error list if not exists
                if 'depth_map_error' not in history:
                    history['depth_map_error'] = []
                if 'w_depth_map' not in history:
                    history['w_depth_map'] = []
                    
                history['depth_map_error'].append(val_metrics.get('depth_map_error', float('inf')))
                history['w_depth_map'].append(loss_weights.get('w_depth_map', 0))
            elif strategy == 'curriculum':
                # Track curriculum weight changes
                if 'w_z_curriculum' not in history:
                    history['w_z_curriculum'] = []
                history['w_z_curriculum'].append(loss_weights.get('w_z', 5.0))

        except Exception as e:
            print(f"\nWarning: Error updating history - {str(e)}")
            print("Continuing training...")

        # Update best metrics
        if best_metrics is None or val_metrics['val_loss'] < best_metrics['val_loss']:
            best_metrics = val_metrics.copy()
            best_epoch = epoch
            if 'head_metrics' in val_metrics:
                best_head_metrics = val_metrics['head_metrics'].copy()

        # Print progress
        print_epoch_progress(epoch, max_epochs, train_metrics, val_metrics, 
                    strategy, epoch_time, best_epoch, stage, current_weights=loss_weights)

        # Early stopping check - DISABLED for now
        # if early_stopper.should_stop(val_metrics):
        #     print(f"  ‚Üí Early stopping at epoch {epoch} (best: epoch {best_epoch})")
        #     break

    # Finalize training
    try:
        # Add training statistics
        best_metrics.update({
            'epoch': best_epoch,
            'epochs_trained': epoch,
            'training_completed': True,
            'training_history': history
        })
        
        # Add head-specific metrics for multi-head models
        if strategy == 'MultiHead' and best_head_metrics is not None:
            best_metrics['head_metrics'] = best_head_metrics
            best_metrics['best_head_r2z'] = best_head_metrics['fine']['r2_z']
            best_metrics['best_head_e3d'] = best_head_metrics['fine']['e3d']
        
        return best_metrics
        
    except Exception as e:
        print(f"\n‚ö†Ô∏è Error finalizing metrics: {str(e)}")
        return {
            'val_loss': float('inf'),
            'r2xyz': [0, 0, 0],
            'e3d': float('inf'),
            'epoch': epoch,
            'epochs_trained': epoch,
            'training_completed': False,
            'error': str(e)
        }

In [14]:
def experiment(model_name, model_class, train_loader, val_loader, description="", 
               use_focal=False, focal_alpha=0.25, focal_gamma=1.5, strategy="standard",
               **strategy_kwargs):
    """
    Conduct a quick experiment with the specified model and training strategy.
    Args:
        model_name: Name of the model
        model_class: Class of the model to instantiate
        train_loader: Training data loader
        val_loader: Validation data loader
        description: Description of the experiment
        use_focal: Use focal loss for depth prediction
        focal_alpha: Focal loss alpha parameter
        focal_gamma: Focal loss gamma parameter
        strategy: Training strategy (standard/progressive/auxiliary/curriculum/spatial/MultiHead)
        **strategy_kwargs: Additional strategy-specific parameters (e.g., head_weights for MultiHead)
    Returns:
        dict: Final evaluation metrics
    """
    # Reset seed for reproducibility
    set_seed(42)

    # Setup
    device = torch.device("mps" if torch.backends.mps.is_available() else
                         "cuda" if torch.cuda.is_available() else "cpu")

    print(f"\n{'='*60}\nüöÄ EXPERIMENT: {model_name}\n{'='*60}\nDescription: {description}\nDevice: {device} | Expected time: ~5-15 minutes\n")

    start_time = time.time()

    # Initialize model
    model = model_class().to(device)

    # For multi-head strategy, configure head weights
    if strategy == "MultiHead":
        head_weights = strategy_kwargs.get('head_weights', {
            'coarse': 0.2, 'medium': 0.3, 'fine': 0.5
        })
        print(f"\nüìä Multi-Head Configuration:")
        print(f"   Head Weights: coarse={head_weights['coarse']:.1f}, "
              f"medium={head_weights['medium']:.1f}, "
              f"fine={head_weights['fine']:.1f}")
        strategy_kwargs['head_weights'] = head_weights

    # Stage A: Quick axis pretraining
    print(f"\n{'='*60}\nSTAGE A: Axis Pretraining ({QUICK_CONFIG['epochs']['A']} epochs)\n{'='*60}\n")
    logs_A = train_stage(model, train_loader, val_loader, 'A', 
                        QUICK_CONFIG['epochs']['A'], device,
                        use_focal=use_focal, focal_alpha=focal_alpha, 
                        focal_gamma=focal_gamma,
                        strategy=strategy,
                        **strategy_kwargs)

    # Stage B: Quick intersection training
    print(f"\n{'='*60}\nSTAGE B: Intersection Training ({QUICK_CONFIG['epochs']['B']} epochs)\n{'='*60}\n")
    logs_B = train_stage(model, train_loader, val_loader, 'B',
                        QUICK_CONFIG['epochs']['B'], device,
                        use_focal=use_focal, focal_alpha=focal_alpha, 
                        focal_gamma=focal_gamma,
                        strategy=strategy,
                        **strategy_kwargs)

    # For multi-head, use best fine head R¬≤_z for decision
    if strategy == "MultiHead" and 'head_metrics' in logs_B:
        r2_z_B = logs_B['head_metrics']['fine']['r2_z']
        print(f"\nüìä Stage B Head Performance:")
        for head, metrics in logs_B['head_metrics'].items():
            print(f"   ‚Ä¢ {head}: R¬≤_z={metrics['r2_z']:.3f}, 3D Error={metrics['e3d']:.2f}mm")
    else:
        r2_z_B = logs_B['r2xyz'][2]
    
    # ALWAYS run full Stage C (adaptive scheduling DISABLED)
    print(f"\nüü¢ Stage B R¬≤_z: {r2_z_B:.3f}. Running FULL Stage C ({QUICK_CONFIG['epochs']['C']} epochs).")
    stage_C_epochs = QUICK_CONFIG['epochs']['C']

    # Stage C: End-to-end fine-tuning
    if stage_C_epochs > 0:
        print(f"\n{'='*60}\nSTAGE C: End-to-End Fine-tuning ({stage_C_epochs} epochs)\n{'='*60}\n")
        logs_C = train_stage(model, train_loader, val_loader, 'C',
                           stage_C_epochs, device,
                           use_focal=use_focal, focal_alpha=focal_alpha, 
                           focal_gamma=focal_gamma,
                           strategy=strategy,
                           **strategy_kwargs)

    # Final results
    runtime = time.time() - start_time
    
    # Use best fine head metrics for multi-head models
    if strategy == "MultiHead" and 'head_metrics' in logs_C:
        best_head = logs_C['head_metrics']['fine']
        r2_z_final = best_head['r2_z']
        e3d_final = best_head['e3d']
        e2d_final = best_head.get('e2d', 0.0)
    else:
        r2_z_final = logs_C['r2xyz'][2]
        e3d_final = logs_C['e3d']
        e2d_final = logs_C['e2d']

    result = {
        'model_name': model_name,
        'description': description,
        'r2_z': r2_z_final,
        'r2_x': logs_C['r2xyz'][0],
        'r2_y': logs_C['r2xyz'][1],
        '3d_error_mm': e3d_final,
        '2d_error_px': e2d_final,
        'val_loss': logs_C['val_loss'],
        'runtime_min': runtime / 60,
        'stage_A_epochs': logs_A.get('epoch', QUICK_CONFIG['epochs']['A']),  
        'stage_B_epochs': logs_B.get('epoch', QUICK_CONFIG['epochs']['B']),
        'stage_C_epochs': logs_C.get('epoch', stage_C_epochs) if stage_C_epochs > 0 else 0,
        'timestamp': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'),
        'decision': '',
        'color': ''
    }

    # For multi-head, include per-head metrics
    if strategy == "MultiHead" and 'head_metrics' in logs_C:
        result['head_metrics'] = logs_C['head_metrics']

    # Decision
    if r2_z_final > QUICK_CONFIG['thresholds']['stage_C_winner']:
        result['decision'] = "üü¢ WINNER"
        result['color'] = "green"
    elif r2_z_final > QUICK_CONFIG['thresholds']['stage_C_decent']:
        result['decision'] = "üü° DECENT"
        result['color'] = "yellow"
    else:
        result['decision'] = "üî¥ POOR"
        result['color'] = "red"

    # Print summary with head-specific metrics for multi-head models
    runtime_min = runtime / 60
    print(f"\n{'='*60}\n‚úÖ EXPERIMENT COMPLETE: {model_name}\n{result['decision']} | R¬≤_z: {r2_z_final:.3f} | 3D Error: {e3d_final:.2f}mm | Time: {runtime_min:.1f}min\n{'='*60}\n")

    if strategy == "MultiHead" and 'head_metrics' in result:
        print("Head Performance:")
        for head, metrics in result['head_metrics'].items():
            print(f"  ‚Ä¢ {head}: R¬≤_z={metrics['r2_z']:.3f}, 3D Error={metrics['e3d']:.2f}mm")

    return result

### üéØ Training Configuration Changes

**Current Settings:**
- ‚úÖ **Early Stopping**: DISABLED - All stages run for full configured epochs
- ‚úÖ **Adaptive Stage C**: DISABLED - Stage C always runs full 20 epochs regardless of Stage B performance

**Why These Changes?**
1. **Consistency**: All experiments get the same training budget for fair comparison
2. **Late Improvements**: Some models improve significantly in later epochs
3. **Reproducibility**: Fixed epoch counts make results more reproducible

**Training Schedule:**
- Stage A: 10 epochs (axis pretraining)
- Stage B: 15 epochs (intersection training)
- Stage C: **20 epochs** (end-to-end fine-tuning) - ALWAYS runs full 20

## üìä Experiment Tracking & Comparison

In [15]:
# Global experiment results tracker
experiment_results = []

# Global baseline reference - will be set after baseline experiment runs
BASELINE_REFERENCE = {
    'r2_z': None,
    '3d_error_mm': None,
    'model_name': None
}

def set_baseline_reference(result):
    """Set the baseline reference from baseline experiment result"""
    BASELINE_REFERENCE['r2_z'] = result['r2_z']
    BASELINE_REFERENCE['3d_error_mm'] = result['3d_error_mm']
    BASELINE_REFERENCE['model_name'] = result['model_name']
    print(f"‚úÖ Baseline reference set: R¬≤_z={result['r2_z']:.3f}, 3D Error={result['3d_error_mm']:.2f}mm")

def log_experiment(result):
    """Log experiment result and save to file"""
    experiment_results.append(result)
    
    # Auto-set baseline if this is the first baseline experiment
    if result['model_name'] == 'baseline' and BASELINE_REFERENCE['r2_z'] is None:
        set_baseline_reference(result)
    
    # Save to CSV for persistence
    df = pd.DataFrame(experiment_results)
    df.to_csv('quick_experiment_results.csv', index=False)
    
    print(f"‚úÖ Experiment logged: {result['model_name']}")

def show_leaderboard():
    """Display current leaderboard of experiments"""
    if not experiment_results:
        print("No experiments run yet!")
        return
    
    df = pd.DataFrame(experiment_results)
    df_sorted = df.sort_values('r2_z', ascending=False)
    
    print("\nüèÜ EXPERIMENT LEADERBOARD")
    print("=" * 80)
    print(f"{'Rank':<4} {'Model':<20} {'R¬≤_z':<8} {'3D_err':<8} {'2D_err':<8} {'Time':<8} {'Decision':<12}")
    print("-" * 80)
    
    for i, (_, row) in enumerate(df_sorted.iterrows(), 1):
        print(f"{i:<4} {row['model_name']:<20} {row['r2_z']:<8.3f} {row['3d_error_mm']:<8.1f} {row['2d_error_px']:<8.1f} {row['runtime_min']:<8.1f} {row['decision']:<12}")
    
    print("\nüìà BEST PERFORMERS:")
    top_3 = df_sorted.head(3)
    
    # Use actual baseline reference if available
    baseline_ref = BASELINE_REFERENCE['r2_z'] if BASELINE_REFERENCE['r2_z'] is not None else 0.597
    
    for _, row in top_3.iterrows():
        improvement = ((row['r2_z'] - baseline_ref) / baseline_ref * 100)
        print(f"   {row['decision']} {row['model_name']}: R¬≤_z={row['r2_z']:.3f} ({improvement:+.1f}% vs baseline)")

def compare_with_baseline(result, baseline_r2z=None, baseline_3d=None):
    """
    Compare result with baseline and return improvement metrics
    
    Args:
        result: Experiment result dict
        baseline_r2z: Baseline R¬≤_z (if None, uses BASELINE_REFERENCE)
        baseline_3d: Baseline 3D error (if None, uses BASELINE_REFERENCE)
    
    Returns:
        dict with improvement metrics
    """
    # Use actual baseline reference if available and not overridden
    if baseline_r2z is None:
        if BASELINE_REFERENCE['r2_z'] is not None:
            baseline_r2z = BASELINE_REFERENCE['r2_z']
            print(f"üìç Using actual baseline reference: R¬≤_z={baseline_r2z:.3f}")
        else:
            baseline_r2z = 0.597  # Fallback to documentation value
            print(f"‚ö†Ô∏è  No baseline reference set, using documentation value: R¬≤_z={baseline_r2z:.3f}")
    
    if baseline_3d is None:
        if BASELINE_REFERENCE['3d_error_mm'] is not None:
            baseline_3d = BASELINE_REFERENCE['3d_error_mm']
        else:
            baseline_3d = 6.65  # Fallback to documentation value
    
    r2z_improvement = result['r2_z'] - baseline_r2z
    r2z_improvement_pct = (r2z_improvement / baseline_r2z) * 100
    
    error_improvement = baseline_3d - result['3d_error_mm']
    error_improvement_pct = (error_improvement / baseline_3d) * 100
    
    return {
        'r2z_improvement': r2z_improvement,
        'r2z_improvement_pct': r2z_improvement_pct,
        'error_improvement_mm': error_improvement,
        'error_improvement_pct': error_improvement_pct,
        'is_better': r2z_improvement > QUICK_CONFIG['decision_threshold'],
        'baseline_r2z_used': baseline_r2z,
        'baseline_3d_used': baseline_3d
    }

print("üìä Experiment tracking system ready!")

üìä Experiment tracking system ready!


In [17]:
# Load existing experiment results from CSV (if they exist)
import os
if os.path.exists('quick_experiment_results.csv'):
    df_existing = pd.read_csv('quick_experiment_results.csv')
    # Filter out NaN entries
    df_existing = df_existing[df_existing['r2_z'].notna()]
    experiment_results = df_existing.to_dict('records')
    print(f"‚úÖ Loaded {len(experiment_results)} existing experiments from CSV")
else:
    experiment_results = []
    print("üìä Starting fresh - no existing results found")

# Display current leaderboard
show_leaderboard()

‚úÖ Loaded 33 existing experiments from CSV

üèÜ EXPERIMENT LEADERBOARD
Rank Model                R¬≤_z     3D_err   2D_err   Time     Decision    
--------------------------------------------------------------------------------
1    skip_connections + focal_loss + augmentation 0.570    7.6      26.3     24.0     üü¢ WINNER    
2    skip_connections + focal_loss (full training) 0.557    7.1      15.2     49.9     üü¢ WINNER    
3    skip_connections + spatial_attention 0.554    7.1      23.8     34.2     üü¢ WINNER    
4    depth_curriculum + skip_connections 0.537    7.6      24.8     26.0     üü° DECENT    
5    skip_connections + auxiliary_depth 0.536    7.6      24.3     31.9     üü° DECENT    
6    skip_connections + spatial_attention (full training) 0.534    7.0      9.6      60.0     üü° DECENT    
7    skip_connections + multihead_depth 0.528    7.6      23.5     27.0     üü° DECENT    
8    depth_curriculum     0.517    7.8      19.5     23.6     üü° DECENT    
9    s

## üéØ Test Framework with Baseline

In [None]:
# # Test the framework with baseline model first
# print("üî• Testing Quick Experiment Framework")
# print("Running baseline experiment to validate framework...")

# baseline_result = experiment(
#     model_name="baseline",
#     model_class=StereoTwoStageNet,
#     train_loader=train_loader,
#     val_loader=val_loader,
#     description="Original model, no modifications - framework validation"
# )

# log_experiment(baseline_result)
# show_leaderboard()

# print(f"\n‚úÖ Framework validation complete!")
# print(f"Baseline quick result: R¬≤_z={baseline_result['r2_z']:.3f} in {baseline_result['runtime_min']:.1f} minutes")
# print(f"Expected full baseline: R¬≤_z‚âà0.597 (from documentation)")
# print(f"Quick vs Full difference: ~{abs(baseline_result['r2_z'] - 0.597):.3f} R¬≤_z units")
# print(f"\nüéØ Framework is ready for testing the 15 improvement approaches!")

## üß™ Experiment 1: Focal Loss

In [None]:
class FocalLossNet(StereoTwoStageNet):
    """
    Wrapper around StereoTwoStageNet for focal loss experiments.
    
    Architecture: Identical to baseline StereoTwoStageNet
    Difference: Uses focal loss weighting during training (applied in loss function)
    """
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super().__init__(backbone_name, pretrained)
        
print("‚úÖ FocalLossNet model defined!")

In [None]:
# Run focal loss experiment
focal_result = experiment(
    model_name="focal_loss w/ augmentation",
    model_class=FocalLossNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Focal Loss for depth prediction to handle hard samples",
    use_focal=True,
    focal_alpha=0.25,
    focal_gamma=1.5
)

log_experiment(focal_result)
show_leaderboard()

# Compare with baseline
comparison = compare_with_baseline(focal_result)  # Use quick baseline
print(f"\nüìä FOCAL LOSS ANALYSIS:")
print(f"R¬≤_z change: {comparison['r2z_improvement']:+.3f} ({comparison['r2z_improvement_pct']:+.1f}%)")
print(f"3D error change: {comparison['error_improvement_mm']:+.2f}mm ({comparison['error_improvement_pct']:+.1f}%)")
print(f"Better than baseline? {'‚úÖ YES' if comparison['is_better'] else '‚ùå NO'}")

# Decision on next approach
if comparison['r2z_improvement'] > 0.05:  # >5% improvement
    print("üü¢ EXCELLENT! Focal loss significantly improved depth prediction.")
    print("   ‚Üí Next: Try Spatial Attention to further enhance feature representation")
elif comparison['r2z_improvement'] > 0.02:  # 2-5% improvement
    print("üü° GOOD! Focal loss shows promise.")
    print("   ‚Üí Next: Combine with Spatial Attention or try Multi-scale Depth Heads")
else:
    print("üî¥ MINIMAL IMPROVEMENT. Loss function change not sufficient.")
    print("   ‚Üí Next: Try architectural changes like Spatial Attention or Stronger Backbone")

## üß™ Experiment 2: Skip Connections

In [None]:
class SkipConnectionNet(StereoTwoStageNet):
    """
    Enhanced fusion with residual pathway for better gradient flow.
    
    Architecture:
    - Main path: Concatenated features ‚Üí Conv 1280‚Üí512‚Üí256 ‚Üí Pool ‚Üí (B, 256)
    - Skip path: Concatenated features ‚Üí Pool ‚Üí Linear 1280‚Üí256 ‚Üí (B, 256)
    - Output: main + skip (residual addition)
    """
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super().__init__(backbone_name, pretrained)
        
        # Enhanced fusion with increased capacity
        fused_channels = 128 * len(self.feature_dims) * 2
        self.fusion = nn.Sequential(
            nn.Conv2d(fused_channels, 512, 1),  # Expanded capacity
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        
        # Skip connection path - processes raw concatenated features
        # This provides alternative gradient path and preserves feature information
        self.skip_fusion = nn.Linear(fused_channels, 256)
        
    def _fused_vec(self, left_img, right_img):
        """Extract fused features with residual connection"""
        feats_l = self.backbone(left_img)
        feats_r = self.backbone(right_img)
        H, W = feats_l[-1].shape[2:]
        
        # Project and fuse features at multiple scales (using inherited self.proj)
        fused_scales = []
        for fl, fr, proj in zip(feats_l, feats_r, self.proj):
            # Project to consistent channels
            fl_proj = proj(fl)
            fr_proj = proj(fr)
            # Resize to common spatial size
            fl_proj = F.adaptive_avg_pool2d(fl_proj, (H, W))
            fr_proj = F.adaptive_avg_pool2d(fr_proj, (H, W))
            # Concatenate left-right stereo features
            fused_scales.append(torch.cat([fl_proj, fr_proj], dim=1))

        # Concatenate all scales: (B, 1280, H, W) for ResNet18
        x = torch.cat(fused_scales, dim=1)
        
        # Main fusion path: Convolutional transformation
        fused_4d = self.fusion(x)
        fused_vec = fused_4d.view(fused_4d.size(0), -1)  # (B, 256)
        
        # Skip connection path: Direct linear transformation
        # Pool the raw concatenated features and project to same dimension
        skip_pooled = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)  # (B, 1280)
        skip_vec = self.skip_fusion(skip_pooled)  # (B, 256)
        
        # Residual addition: Combine both pathways
        # This helps gradient flow and provides ensemble-like effect
        combined_vec = fused_vec + skip_vec
        
        return combined_vec, fused_4d

print("‚úÖ SkipConnectionNet model defined!")

In [None]:
# Run skip connections experiment
skip_result = experiment(
    model_name="skip_connections, focal_loss, augmentation",
    model_class=SkipConnectionNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Skip connections around fusion for better gradient flow",
    use_focal=True,
    focal_alpha=0.25,
    focal_gamma=1.5
)

log_experiment(skip_result)
show_leaderboard()

# Compare with baseline
comparison = compare_with_baseline(skip_result)
print(f"\nüìä SKIP CONNECTIONS ANALYSIS:")
print(f"R¬≤_z improvement: {comparison['r2z_improvement']:+.3f} ({comparison['r2z_improvement_pct']:+.1f}%)")
print(f"3D error change: {comparison['error_improvement_mm']:+.2f}mm ({comparison['error_improvement_pct']:+.1f}%)")
print(f"Is better than baseline? {'‚úÖ YES' if comparison['is_better'] else '‚ùå NO'}")

if comparison['is_better']:
    print("üü¢ PROMISING! Consider full training run.")
else:
    print("üî¥ SKIP - Moving to next approach.")

## üß™ Experiment 3: Spatial Attention for Depth

In [18]:
class SpatialAttentionModule(nn.Module):
    """Spatial attention to weight important regions for depth prediction"""
    def __init__(self, in_channels):
        super().__init__()
        # Channel reduction for attention map
        self.attention_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 4, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // 4, 1, 1),
            nn.Sigmoid()  # Attention weights [0, 1]
        )
        
    def forward(self, x):
        # x: (B, C, H, W)
        attn_map = self.attention_conv(x)  # (B, 1, H, W)
        attended = x * attn_map  # Element-wise multiplication
        return attended, attn_map


class SpatialAttentionDepthNet(StereoTwoStageNet):
    """
    Enhanced model with spatial attention for depth prediction
    
    Key improvements:
    1. Applies learned attention to weight depth-relevant spatial regions
    2. Dual-path processing: attended global context + attended spatial features
    3. Returns attention maps for visualization and analysis
    """
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super().__init__(backbone_name, pretrained)
        
        # Calculate fused channels
        fused_channels = 128 * len(self.feature_dims) * 2
        
        # Spatial attention module
        self.spatial_attention = SpatialAttentionModule(fused_channels)
        
        # Enhanced fusion with spatial pathway
        self.spatial_fusion = nn.Sequential(
            nn.Conv2d(fused_channels, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(inplace=True),
        )
        
        # Keep original global fusion for axis prediction
        self.fusion = nn.Sequential(
            nn.Conv2d(fused_channels, 256, 1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        
        # Enhanced depth head with spatial features
        # Combines: global features (256) + spatial features (256) + axis (4) = 516
        self.offset_depth_head = nn.Sequential(
            nn.Linear(256 + 256 + 4, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),  # Regularization for larger head
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 2)  # (offset_t, depth_z_raw)
        )
        
    def _fused_vec(self, left_img, right_img):
        """Extract fused features with spatial attention
        
        Returns:
            fused_vec: Concatenated global + spatial features (B, 512)
            global_4d: Global features for axis head (B, 256, 1, 1)
            attn_map: Attention map for visualization (B, 1, H, W)
        """
        feats_l = self.backbone(left_img)
        feats_r = self.backbone(right_img)
        H, W = feats_l[-1].shape[2:]
        
        # Multi-scale fusion (same as baseline)
        fused_scales = []
        for fl, fr, proj in zip(feats_l, feats_r, self.proj):
            fl = proj(fl)
            fr = proj(fr)
            fl = F.adaptive_avg_pool2d(fl, (H, W))
            fr = F.adaptive_avg_pool2d(fr, (H, W))
            fused_scales.append(torch.cat([fl, fr], dim=1))

        x = torch.cat(fused_scales, dim=1)  # (B, 1280, H, W)
        
        # Apply spatial attention - FIX #2: Keep attention map for visualization
        attended_x, attn_map = self.spatial_attention(x)
        
        # Spatial pathway: preserve spatial features for depth
        spatial_features = self.spatial_fusion(attended_x)  # (B, 256, H, W)
        spatial_vec = F.adaptive_avg_pool2d(spatial_features, 1).view(spatial_features.size(0), -1)  # (B, 256)
        
        # Global pathway: for axis prediction - FIX #1: Use attended features
        global_4d = self.fusion(attended_x)  # FIXED: was self.fusion(x)
        global_vec = global_4d.view(global_4d.size(0), -1)
        
        fused_vec = torch.cat([global_vec, spatial_vec], dim=1)
        return fused_vec, global_4d, attn_map
    
    def forward(self, left_img, right_img):
        fused_vec, fused_4d, attn_map = self._fused_vec(left_img, right_img)
        
        # Stage 1: Axis prediction (uses attended global features)
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        # Stage 2: Depth prediction (uses global + spatial + axis)
        conditioned = torch.cat([fused_vec, origin.detach(), direction.detach()], dim=1)
        
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "attn_map": attn_map  # Return attention map for analysis
        }

    def forward_e2e(self, left_img, right_img):
        """End-to-end forward pass (no gradient detachment)"""
        fused_vec, fused_4d, attn_map = self._fused_vec(left_img, right_img)
        
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        conditioned = torch.cat([fused_vec, origin, direction], dim=1)
        
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "attn_map": attn_map  # Return attention map for analysis
        }

print("‚úÖ SpatialAttentionDepthNet model defined!")

‚úÖ SpatialAttentionDepthNet model defined!


In [None]:
# Run spatial attention experiment
spatial_result = experiment(
    model_name="spatial_attention, focal_loss, augmentation",
    model_class=SpatialAttentionDepthNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Spatial attention to preserve depth-relevant spatial features",
    strategy="spatial"
)

log_experiment(spatial_result)
show_leaderboard()

# Compare with baseline
comparison = compare_with_baseline(spatial_result)
print(f"\nüìä SPATIAL ATTENTION ANALYSIS:")
print(f"R¬≤_z improvement: {comparison['r2z_improvement']:+.3f} ({comparison['r2z_improvement_pct']:+.1f}%)")
print(f"3D error change: {comparison['error_improvement_mm']:+.2f}mm ({comparison['error_improvement_pct']:+.1f}%)")
print(f"Is better than baseline? {'‚úÖ YES' if comparison['is_better'] else '‚ùå NO'}")

if comparison['r2z_improvement'] > 0.05:
    print("\nüü¢ EXCELLENT! Spatial attention significantly improved depth prediction.")
    print("   ‚Üí This validates that preserving spatial features is critical for depth!")
    print("   ‚Üí Next: Try combining with multi-head depth predictor or remove gradient detachment")
elif comparison['r2z_improvement'] > 0.02:
    print("\nüü° GOOD! Spatial attention shows promise.")
    print("   ‚Üí Consider full training run to see true potential")
else:
    print("\nüî¥ NO IMPROVEMENT. Try alternative approaches.")
    print("   ‚Üí Next: Multi-head depth predictor or stronger backbone")

## üß™ Experiment 4: Progressive Gradient Unblocking

In [None]:
class ProgressiveGradientNet(StereoTwoStageNet):
    """
    Progressive Gradient Unblocking
    
    Gradually removes .detach() during Stage C to allow end-to-end learning
    while maintaining training stability.
    
    Key concept: Instead of abruptly switching from detached (Stage B) to
    end-to-end (Stage C), smoothly transition using alpha schedule:
    - Epoch 1: alpha=0.8 ‚Üí 80% detached, 20% gradients
    - Epoch 5: alpha=0.0 ‚Üí 0% detached, 100% gradients (full e2e)
    """
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super().__init__(backbone_name, pretrained)
        # Register buffer for detach_alpha (not a trainable parameter)
        self.register_buffer('detach_alpha', torch.tensor(1.0))
        
    def set_detach_alpha(self, alpha):
        """Set the detachment strength (1.0 = full detach, 0.0 = no detach)"""
        self.detach_alpha = torch.tensor(alpha)
        
    def forward(self, left_img, right_img):
        fused_vec, fused_4d = self._fused_vec(left_img, right_img)
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        # Progressive detachment: alpha controls gradient flow
        # alpha=1.0: fully detached (Stages A/B)
        # alpha=0.0: no detachment (end of Stage C)
        alpha = self.detach_alpha.item()
        origin_input = alpha * origin.detach() + (1 - alpha) * origin
        direction_input = alpha * direction.detach() + (1 - alpha) * direction
        
        conditioned = torch.cat([fused_vec, origin_input, direction_input], dim=1)
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection
        }

    def forward_e2e(self, left_img, right_img):
        """Same as forward() - progressive logic handles both cases"""
        return self.forward(left_img, right_img)

print("‚úÖ ProgressiveGradientNet model defined")

In [None]:
# Run progressive gradient unblocking experiment
progressive_result = experiment(
    model_name="progressive_gradient, focal_loss, augmentation",
    model_class=ProgressiveGradientNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Progressive gradient unblocking: alpha 1.0‚Üí0.0 during Stage C",
    strategy="progressive",
    use_focal=True,
    focal_alpha=0.25,
    focal_gamma=1.5
)

log_experiment(progressive_result)
show_leaderboard()

# Compare with baseline
comparison = compare_with_baseline(progressive_result)
print(f"\nüìä PROGRESSIVE GRADIENT ANALYSIS:")
print(f"R¬≤_z improvement: {comparison['r2z_improvement']:+.3f} ({comparison['r2z_improvement_pct']:+.1f}%)")
print(f"3D error change: {comparison['error_improvement_mm']:+.2f}mm ({comparison['error_improvement_pct']:+.1f}%)")
print(f"Is better than baseline? {'‚úÖ YES' if comparison['is_better'] else '‚ùå NO'}")

if comparison['r2z_improvement'] > 0.05:
    print("\nüü¢ EXCELLENT! Progressive gradient unblocking significantly improved depth.")
    print("   ‚Üí End-to-end learning is working!")
    print("   ‚Üí Next: Try combining with stronger backbone or depth curriculum")
elif comparison['r2z_improvement'] > 0.02:
    print("\nüü° GOOD! Progressive unblocking shows promise.")
    print("   ‚Üí Consider full training run to see true potential")
else:
    print("\nüî¥ MINIMAL IMPROVEMENT.")
    print("   ‚Üí Gradient flow may not be the main bottleneck")
    print("   ‚Üí Next: Try depth loss curriculum or auxiliary depth supervision")


## üß™ Experiment 5: Auxiliary Depth Supervision

In [19]:
class AuxiliaryDepthNet(StereoTwoStageNet):
    """
    Auxiliary Depth Supervision

    Adds a depth map prediction head for direct depth supervision
    using the existing depth_labels/*.npy files.

    Expected impact: +10-15% R¬≤_z improvement
    Risk: Low
    """
    def __init__(self, backbone_name="resnet18", pretrained=True, seed=42):
        super().__init__(backbone_name, pretrained, seed)

        # Store fused features before pooling for depth map prediction
        # The fused features have shape (B, 256, H, W) before pooling

        # Depth map head: upsamples from pooled features to produce depth map
        # We'll use the 4D fused features before pooling
        self.depth_map_head = nn.Sequential(
            # Start from 256 channels at downsampled resolution
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 2x upsample
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 4x upsample
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 8x upsample
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, kernel_size=3, padding=1),                        # Final depth map
            nn.Softplus(beta=1.0)  # Ensure positive depth values
        )

    def _fused_vec_with_4d(self, left_img, right_img):
        """Modified version that returns both vector and 4D features before pooling"""
        feats_l = self.backbone(left_img)
        feats_r = self.backbone(right_img)
        H, W = feats_l[-1].shape[2:]

        fused_scales = []
        for fl, fr, proj in zip(feats_l, feats_r, self.proj):
            fl = proj(fl)
            fr = proj(fr)
            fl = F.adaptive_avg_pool2d(fl, (H, W))
            fr = F.adaptive_avg_pool2d(fr, (H, W))
            fused_scales.append(torch.cat([fl, fr], dim=1))

        fused_4d = torch.cat(fused_scales, dim=1)  # (B, fused_channels, H, W)

        # Apply 1x1 conv before pooling
        fused_4d_reduced = self.fusion[0](fused_4d)  # (B, 256, H, W)
        fused_4d_reduced = self.fusion[1](fused_4d_reduced)  # ReLU

        # Pool to get vector
        pooled = self.fusion[2](fused_4d_reduced)  # AdaptiveAvgPool2d
        v = pooled.view(pooled.size(0), -1)

        return v, pooled, fused_4d_reduced

    def forward(self, left_img, right_img):
        # Get both vector and 4D features
        fused_vec, fused_pooled, fused_4d = self._fused_vec_with_4d(left_img, right_img)

        # Predict axis
        axis_params = self.axis_head(fused_pooled)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)

        # Predict intersection (offset + depth)
        conditioned = torch.cat([fused_vec, origin.detach(), direction.detach()], dim=1)
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])

        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)

        # Predict depth map from 4D features
        depth_map = self.depth_map_head(fused_4d)  # (B, 1, H_out, W_out)

        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "depth_map": depth_map  # Add depth map to outputs
        }

    def forward_e2e(self, left_img, right_img):
        # Same as forward but without detach
        fused_vec, fused_pooled, fused_4d = self._fused_vec_with_4d(left_img, right_img)

        axis_params = self.axis_head(fused_pooled)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)

        conditioned = torch.cat([fused_vec, origin, direction], dim=1)  # no detach
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])

        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)

        depth_map = self.depth_map_head(fused_4d)

        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "depth_map": depth_map
        }

print("‚úÖ AuxiliaryDepthNet model defined!")

‚úÖ AuxiliaryDepthNet model defined!


In [20]:
def load_depth_map_for_batch(depth_map, target_size=(56, 56)):
    """Process and resize depth map from batch to target size

    Args:
        depth_map (np.ndarray): The depth map from dataset batch
        target_size (tuple): Target size for resizing, default (56, 56)

    Returns:
        torch.Tensor: Processed depth map of shape (1, H, W)
    """
    # Convert to tensor if needed
    if not isinstance(depth_map, torch.Tensor):
        depth_map = torch.from_numpy(depth_map).float()
    
    # Add batch and channel dimensions if needed
    if depth_map.dim() == 2:  # (H, W)
        depth_map = depth_map.unsqueeze(0).unsqueeze(0)  # (1, 1, H, W)
    elif depth_map.dim() == 3:  # (1, H, W)
        depth_map = depth_map.unsqueeze(0)  # (1, 1, H, W)
    
    # Resize to match model output size (56x56 for 224x224 input with 8x upsampling)
    depth_resized = F.interpolate(depth_map, size=target_size, mode='bilinear', align_corners=False)
    
    return depth_resized.squeeze(0)  # (1, H, W)

In [21]:
def auxiliary_depth_loss(outputs, batch, **kwargs):
    """
    Compute loss with auxiliary depth map supervision
    """
    # Extract auxiliary-specific parameters
    w_depth_map = kwargs.get('w_depth_map', 0.5)
    use_log_depth = kwargs.get('use_log_depth', True)
    eps = kwargs.get('eps', 1e-6)

    # Get main intersection loss
    main_loss = stereo_two_stage_loss(outputs, batch, **kwargs)
    
    # Skip depth map loss if weight is 0
    if w_depth_map == 0:
        return main_loss
        
    # Compute depth map loss if we have the predictions and targets
    depth_map_loss = torch.tensor(0.0, device=main_loss.device)
    if 'depth_map' in outputs and 'gt_depth_map' in batch:
        pred_depth_map = outputs['depth_map']
        gt_depth_map = batch['gt_depth_map']
        
        # Only compute loss on valid depths (> 0.1)
        valid_mask = gt_depth_map > 0.1
        if valid_mask.sum() > 10:
            if use_log_depth:
                pred_log = torch.log(pred_depth_map[valid_mask] + eps)
                gt_log = torch.log(gt_depth_map[valid_mask] + eps)
                depth_map_loss = F.l1_loss(pred_log, gt_log)
            else:
                depth_map_loss = F.l1_loss(pred_depth_map[valid_mask], gt_depth_map[valid_mask])
    
    # Combine losses
    total_loss = main_loss + w_depth_map * depth_map_loss
    
    return total_loss

In [None]:
# Run auxiliary depth experiment
result = experiment(
    model_name="auxiliary_depth, augmentation",
    model_class=AuxiliaryDepthNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Auxiliary depth map supervision using depth_labels/*.npy",
    strategy="auxiliary",
    w_depth_map={  # Strategy-specific weight scheduling
        'A': 0.0,    # No depth supervision in axis stage
        'B': 0.5,    # Moderate depth supervision in intersection stage
        'C': 1.0     # Full depth supervision in end-to-end stage
    }
)

log_experiment(result)
show_leaderboard()

## üß™ Experiment 7: Depth Loss Curriculum

**‚ö†Ô∏è Important Note on Loss Interpretation:**
- When using curriculum learning, the **loss value changes meaning across epochs** because w_z increases
- **DO NOT compare loss values** between epochs - they are NOT comparable!
- **Instead, focus on R¬≤_z and 3D error** which remain directly comparable
- Loss may increase even as performance improves due to changing weight schedule

In [None]:
# Depth Loss Curriculum
# Gradually increase depth weight during training for better convergence

def get_curriculum_weights(stage, epoch, total_epochs):
    """
    Progressive depth weight scheduling
    
    FIXES APPLIED:
    - Bug #1 FIXED: Now reaches target values by using (total_epochs - 1) in denominator
    
    Args:
        stage: Training stage ('A', 'B', or 'C')
        epoch: Current epoch (0-indexed)
        total_epochs: Total number of epochs
    
    Returns:
        dict with 'w_z' key containing the curriculum weight
    """
    if stage == 'A':
        # Stage A: Start with baseline, gradually increase (5.0 ‚Üí 7.0)
        progress = epoch / (total_epochs - 1) if total_epochs > 1 else 1.0
        w_z = 5.0 + progress * 2.0  # 5.0 ‚Üí 7.0
        return {'w_z': w_z}
    elif stage == 'B':
        # Stage B: Gradually increase depth importance (7.0 ‚Üí 10.0)
        progress = epoch / (total_epochs - 1) if total_epochs > 1 else 1.0
        w_z = 7.0 + progress * 3.0  # 7.0 ‚Üí 10.0
        return {'w_z': w_z}
    else:
        # Stage C: High depth emphasis, further increase (10.0 ‚Üí 15.0)
        progress = epoch / (total_epochs - 1) if total_epochs > 1 else 1.0
        w_z = 10.0 + progress * 2.0  # 10.0 ‚Üí 15.0
        return {'w_z': w_z}

print("‚úÖ Depth loss curriculum functions defined")

In [None]:
# Run Depth Loss Curriculum experiment
# The ONLY untested method from original Top 5!

result = experiment(
    model_name="depth_curriculum + skip_connections + focal_loss",
    model_class=SkipConnectionNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Depth loss curriculum: w_z gradually increases (1‚Üí10‚Üí15)",
    strategy="curriculum",
    use_focal=True,
    focal_alpha=0.25,
    focal_gamma=1.5
)

log_experiment(result)
show_leaderboard()

### üìä Curriculum Learning Metrics Explanation

**Understanding the Output:**

For curriculum experiments, you'll see output like:
```
Epoch 1/15 | Loss: 0.3173 (w_z=1.0) | R¬≤_z: 0.048 | 3D: 12.88mm
Epoch 4/15 | Loss: 0.3523 (w_z=2.9) | R¬≤_z: 0.249 | 3D: 11.24mm
```

**Key Insights:**

1. **Loss Increases ‚â† Performance Degrades**
   - Loss = weighted sum of errors
   - As w_z grows from 1.0 ‚Üí 10.0, same depth error contributes more to total loss
   - Example: depth_error = 0.1 ‚Üí Loss contribution = w_z * 0.1
     - Epoch 1: 1.0 * 0.1 = 0.10
     - Epoch 4: 2.9 * 0.1 = 0.29 (nearly 3x larger!)

2. **Track These Instead:**
   - ‚úÖ **R¬≤_z**: Directly comparable across epochs (correlation-based, not loss-based)
   - ‚úÖ **3D Error (mm)**: Absolute metric, always comparable
   - ‚ùå **Loss**: Changes meaning as w_z changes, NOT comparable

3. **Why This Design Works:**
   - Early epochs: Low w_z ‚Üí model focuses on learning basic depth patterns
   - Later epochs: High w_z ‚Üí depth precision becomes more important
   - Progressive emphasis guides learning from coarse to fine

## üß™ Experiment 8: Multi-Head Depth Predictor

In [None]:
class MultiHeadDepthNet(StereoTwoStageNet):
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super().__init__(backbone_name, pretrained)
        in_dim = 256 + 4  # [fused_vec, origin, direction]
        
        # Each head predicts (offset_t, depth_z)
        # Coarse head: Direct prediction with no hidden layers
        self.head_coarse = nn.Linear(in_dim, 2)
        
        self.head_medium = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(64, 2)
        )
        
        self.head_fine = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 2)
        )
        
        # Remove original head since we're using multi-heads
        delattr(self, 'offset_depth_head')

    def forward(self, left_img, right_img):
        fused_vec, fused_4d = self._fused_vec(left_img, right_img)
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        # Prepare input for heads
        conditioned = torch.cat([fused_vec, origin.detach(), direction.detach()], dim=1)
        
        # Get predictions from each head
        pred_coarse = self.head_coarse(conditioned)  # (B, 2)
        pred_medium = self.head_medium(conditioned)  # (B, 2)
        pred_fine = self.head_fine(conditioned)    # (B, 2)
        
        # Weighted average of predictions (fine head gets more weight)
        weights = torch.tensor([0.2, 0.3, 0.5], device=conditioned.device)
        offset_t = (weights[0] * pred_coarse[:, 0:1] + 
                   weights[1] * pred_medium[:, 0:1] + 
                   weights[2] * pred_fine[:, 0:1])
        
        depth_z = self.softplus(weights[0] * pred_coarse[:, 1:2] + 
                              weights[1] * pred_medium[:, 1:2] + 
                              weights[2] * pred_fine[:, 1:2])
        
        # Compute intersection
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            # Add individual predictions for analysis
            "head_preds": {
                "coarse": pred_coarse,
                "medium": pred_medium,
                "fine": pred_fine
            }
        }

    def forward_e2e(self, left_img, right_img):
        # Same logic but without detach()
        fused_vec, fused_4d = self._fused_vec(left_img, right_img)
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        conditioned = torch.cat([fused_vec, origin, direction], dim=1)
        
        pred_coarse = self.head_coarse(conditioned)
        pred_medium = self.head_medium(conditioned)
        pred_fine = self.head_fine(conditioned)
        
        weights = torch.tensor([0.2, 0.3, 0.5], device=conditioned.device)
        offset_t = (weights[0] * pred_coarse[:, 0:1] + 
                   weights[1] * pred_medium[:, 0:1] + 
                   weights[2] * pred_fine[:, 0:1])
        
        depth_z = self.softplus(weights[0] * pred_coarse[:, 1:2] + 
                              weights[1] * pred_medium[:, 1:2] + 
                              weights[2] * pred_fine[:, 1:2])
        
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "head_preds": {
                "coarse": pred_coarse,
                "medium": pred_medium,
                "fine": pred_fine
            }
        }

In [None]:
def multihead_loss(outputs, targets, head_weights=None, **kwargs):
    """Loss function for multi-head models with per-head supervision"""
    if head_weights is None:
        head_weights = {'coarse': 0.2, 'medium': 0.3, 'fine': 0.5}
    
    # Individual head losses
    head_losses = {}
    for head_name, preds in outputs['head_preds'].items():
        # Compute intersection from raw predictions
        # Note: depth_z needs softplus since preds are raw outputs
        offset_t = preds[:, 0:1]
        depth_z = F.softplus(preds[:, 1:2])
        
        head_out = {
            "origin": outputs["origin"],
            "direction": outputs["direction"],
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": torch.cat([
                outputs["origin"] + offset_t * F.normalize(outputs["direction"], dim=1),
                depth_z
            ], dim=1)
        }
        head_losses[head_name] = stereo_two_stage_loss(head_out, targets, **kwargs)
    
    # Use ONLY individual head losses (weighted combination)
    # The ensemble prediction is already a weighted combination of the heads,
    # so we don't need to add a separate ensemble loss
    total_loss = sum(w * head_losses[name] for name, w in head_weights.items())
    
    return total_loss, head_losses

In [None]:
# Run multi-head depth predictor experiment
multihead_result = experiment(
    model_name="multihead_depth, augmentation",
    model_class=MultiHeadDepthNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Multi-head depth predictor: ensemble of coarse, medium, fine heads"
)

log_experiment(multihead_result)
show_leaderboard()

In [None]:
class SkipConnectionSpatialAttentionNet(StereoTwoStageNet):
    """
    Combined architecture: Skip connections + Spatial attention for depth prediction.
    
    Triple-path fusion:
    1. Main convolutional path (spatial attention applied)
    2. Skip connection path (direct linear transformation)
    3. Spatial attention maps for depth-relevant region weighting
    
    This combines the gradient flow benefits of skip connections with
    the region-focused learning of spatial attention.
    """
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super().__init__(backbone_name, pretrained)
        
        # Calculate fused channels
        fused_channels = 128 * len(self.feature_dims) * 2
        
        # Spatial attention module
        self.spatial_attention = SpatialAttentionModule(fused_channels)
        
        # Enhanced fusion with increased capacity (for attention-weighted features)
        self.spatial_fusion = nn.Sequential(
            nn.Conv2d(fused_channels, 512, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        
        # Skip connection path - processes raw concatenated features
        self.skip_fusion = nn.Linear(fused_channels, 256)
        
    def _fused_vec(self, left_img, right_img):
        """Extract fused features with skip connection and spatial attention"""
        feats_l = self.backbone(left_img)
        feats_r = self.backbone(right_img)
        H, W = feats_l[-1].shape[2:]
        
        # Project and fuse features at multiple scales
        fused_scales = []
        for fl, fr, proj in zip(feats_l, feats_r, self.proj):
            fl_proj = proj(fl)
            fr_proj = proj(fr)
            fl_proj = F.adaptive_avg_pool2d(fl_proj, (H, W))
            fr_proj = F.adaptive_avg_pool2d(fr_proj, (H, W))
            fused_scales.append(torch.cat([fl_proj, fr_proj], dim=1))

        # Concatenate all scales: (B, 1280, H, W)
        x = torch.cat(fused_scales, dim=1)
        
        # Apply spatial attention to weight depth-relevant regions
        attended_x, attention_map = self.spatial_attention(x)
        
        # Main fusion path: Convolutional transformation on attended features
        fused_4d = self.spatial_fusion(attended_x)
        fused_vec = fused_4d.view(fused_4d.size(0), -1)  # (B, 256)
        
        # Skip connection path: Direct linear transformation of raw features
        skip_pooled = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)  # (B, 1280)
        skip_vec = self.skip_fusion(skip_pooled)  # (B, 256)
        
        # Triple fusion: Combine main attended path + skip path
        combined_vec = fused_vec + skip_vec
        
        return combined_vec, fused_4d

print("‚úÖ SkipConnectionSpatialAttentionNet model defined!")

In [None]:
# Run skip_connections + spatial_attention experiment
skip_spatial_result = experiment(
    model_name="skip_connections + spatial_attention + focal_loss",
    model_class=SkipConnectionSpatialAttentionNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Combined skip connections + spatial attention for enhanced depth prediction"
)

log_experiment(skip_spatial_result)
show_leaderboard()

In [None]:
# Run depth_curriculum + skip_connections experiment
depth_curriculum_skip_result = experiment(
    model_name="depth_curriculum + skip_connections",
    model_class=SkipConnectionMultiHeadNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Depth loss curriculum: w_z gradually increases + skip connections for better gradient flow",
    strategy="curriculum"
)

log_experiment(depth_curriculum_skip_result)
show_leaderboard()

In [None]:
class SkipConnectionMultiHeadNet(StereoTwoStageNet):
    """
    Combined architecture: Skip connections + Multi-head depth prediction.

    Features:
    1. Skip connections for better gradient flow in fusion
    2. Three depth heads (coarse, medium, fine) for ensemble prediction
    3. Each head gets the skip-enhanced features

    This combines the gradient flow benefits of skip connections with
    the ensemble robustness of multi-head prediction.
    """
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super().__init__(backbone_name, pretrained)

        # Calculate fused channels
        fused_channels = 128 * len(self.feature_dims) * 2

        # Enhanced fusion with skip connections (same as SkipConnectionNet)
        self.fusion = nn.Sequential(
            nn.Conv2d(fused_channels, 512, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )

        # Skip connection path
        self.skip_fusion = nn.Linear(fused_channels, 256)

        # Multi-head depth predictors (coarse, medium, fine)
        # Each head takes the skip-enhanced features and predicts depth
        self.depth_head_coarse = nn.Sequential(
            nn.Linear(256 + 4, 128),  # +4 for axis parameters
            nn.ReLU(inplace=True),
            nn.Linear(128, 2)  # (t, z_raw)
        )

        self.depth_head_medium = nn.Sequential(
            nn.Linear(256 + 4, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 2)
        )

        self.depth_head_fine = nn.Sequential(
            nn.Linear(256 + 4, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 2)
        )

    def _fused_vec(self, left_img, right_img):
        """Extract fused features with skip connection enhancement"""
        feats_l = self.backbone(left_img)
        feats_r = self.backbone(right_img)
        H, W = feats_l[-1].shape[2:]

        # Project and fuse features at multiple scales
        fused_scales = []
        for fl, fr, proj in zip(feats_l, feats_r, self.proj):
            fl_proj = proj(fl)
            fr_proj = proj(fr)
            fl_proj = F.adaptive_avg_pool2d(fl_proj, (H, W))
            fr_proj = F.adaptive_avg_pool2d(fr_proj, (H, W))
            fused_scales.append(torch.cat([fl_proj, fr_proj], dim=1))

        # Concatenate all scales: (B, 1280, H, W)
        x = torch.cat(fused_scales, dim=1)

        # Main fusion path: Convolutional transformation
        fused_4d = self.fusion(x)
        fused_vec = fused_4d.view(fused_4d.size(0), -1)  # (B, 256)

        # Skip connection path: Direct linear transformation
        skip_pooled = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)  # (B, 1280)
        skip_vec = self.skip_fusion(skip_pooled)  # (B, 256)

        # Combine both pathways for enhanced features
        combined_vec = fused_vec + skip_vec

        return combined_vec, fused_4d

    def forward(self, left_img, right_img):
        """Forward pass with multi-head depth prediction"""
        fused_vec, fused_4d = self._fused_vec(left_img, right_img)

        # Axis prediction (shared across heads)
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)

        # Condition each depth head on the axis and skip-enhanced features
        conditioned = torch.cat([fused_vec, origin.detach(), direction.detach()], dim=1)

        # Multi-head depth predictions (raw outputs, before softplus)
        depth_coarse_raw = self.depth_head_coarse(conditioned)
        depth_medium_raw = self.depth_head_medium(conditioned)
        depth_fine_raw = self.depth_head_fine(conditioned)

        # Compute ensemble depth prediction (weighted average of heads)
        # Each head outputs (B, 2): [offset_t_raw, depth_z_raw]
        depth_ensemble_raw = (0.2 * depth_coarse_raw + 0.3 * depth_medium_raw + 0.5 * depth_fine_raw)

        # Extract offset and depth from ensemble
        offset_t = depth_ensemble_raw[:, 0:1]
        depth_z_raw = depth_ensemble_raw[:, 1:2]
        depth_z = F.softplus(depth_z_raw)

        # Compute intersection point
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)

        # Return structure expected by multihead_loss and validate
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "head_preds": {
                "coarse": depth_coarse_raw,
                "medium": depth_medium_raw,
                "fine": depth_fine_raw
            }
        }

    def forward_e2e(self, left_img, right_img):
        """End-to-end forward pass with multi-head depth prediction (no gradient detachment)"""
        fused_vec, fused_4d = self._fused_vec(left_img, right_img)

        # Axis prediction (shared across heads)
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)

        # Condition each depth head on the axis and skip-enhanced features (no detach for e2e)
        conditioned = torch.cat([fused_vec, origin, direction], dim=1)

        # Multi-head depth predictions (raw outputs, before softplus)
        depth_coarse_raw = self.depth_head_coarse(conditioned)
        depth_medium_raw = self.depth_head_medium(conditioned)
        depth_fine_raw = self.depth_head_fine(conditioned)

        # Compute ensemble depth prediction (weighted average of heads)
        # Each head outputs (B, 2): [offset_t_raw, depth_z_raw]
        depth_ensemble_raw = (0.2 * depth_coarse_raw + 0.3 * depth_medium_raw + 0.5 * depth_fine_raw)

        # Extract offset and depth from ensemble
        offset_t = depth_ensemble_raw[:, 0:1]
        depth_z_raw = depth_ensemble_raw[:, 1:2]
        depth_z = F.softplus(depth_z_raw)

        # Compute intersection point
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)

        # Return structure expected by multihead_loss and validate
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "head_preds": {
                "coarse": depth_coarse_raw,
                "medium": depth_medium_raw,
                "fine": depth_fine_raw
            }
        }

print("‚úÖ SkipConnectionMultiHeadNet model defined!")

In [None]:
# Run skip_connections + multihead_depth experiment
skip_multihead_result = experiment(
    model_name="skip_connections + multihead_depth",
    model_class=SkipConnectionMultiHeadNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Skip connections + multi-head depth predictor: ensemble of coarse/medium/fine heads with gradient flow enhancement",
    strategy="MultiHead",
    head_weights={'coarse': 0.2, 'medium': 0.3, 'fine': 0.5},
    use_focal=True,
    focal_alpha=0.25,
    focal_gamma=1.5
)

log_experiment(skip_multihead_result)
show_leaderboard()

In [None]:
# Run skip_connections + multihead_depth experiment with curriculum learning
def experiment_multihead_curriculum(model_name, model_class, train_loader, val_loader, description="",
                                 head_weights={"coarse": 0.2, "medium": 0.3, "fine": 0.5},
                                 use_focal=False, focal_alpha=0.25, focal_gamma=1.5):
    """Combined MultiHead + Curriculum experiment"""
    # Reset seed for reproducibility
    set_seed(42)
    
    device = torch.device("mps" if torch.backends.mps.is_available() else
                         "cuda" if torch.cuda.is_available() else "cpu")

    print(f"{'='*60}\n üöÄ EXPERIMENT: {model_name}\n{'='*60}\nDescription: {description}\n        Device: {device}\n")

    start_time = time.time()
    model = model_class().to(device)
    
    # MultiHead + Curriculum setup
    print(f"üìä Multi-Head + Curriculum Configuration:")
    print(f"   Head Weights: coarse={head_weights['coarse']:.1f}, medium={head_weights['medium']:.1f}, fine={head_weights['fine']:.1f}")
    print(f"   Curriculum: Progressive depth weight increase (5.0‚Üí7.0‚Üí10.0‚Üí15.0)")
    
    # Stage A: Axis pretraining
    print(f"{'='*60}\nSTAGE A: Axis Pretraining ({QUICK_CONFIG['epochs']['A']} epochs)\n{'='*60}")
    logs_A = train_stage_curriculum_multihead(model, train_loader, val_loader, 'A', 
                                           QUICK_CONFIG['epochs']['A'], device,
                                           head_weights=head_weights,
                                           use_focal=use_focal, focal_alpha=focal_alpha, focal_gamma=focal_gamma)
    
    # Stage B: Depth training with curriculum
    print(f"{'='*60}\nSTAGE B: MultiHead + Curriculum ({QUICK_CONFIG['epochs']['B']} epochs)\n{'='*60}")
    logs_B = train_stage_curriculum_multihead(model, train_loader, val_loader, 'B', 
                                           QUICK_CONFIG['epochs']['B'], device,
                                           head_weights=head_weights,
                                           use_focal=use_focal, focal_alpha=focal_alpha, focal_gamma=focal_gamma)
    
    # Stage C: End-to-end fine-tuning
    print(f"{'='*60}\nSTAGE C: End-to-End Fine-tuning ({QUICK_CONFIG['epochs']['C']} epochs)\n{'='*60}")
    logs_C = train_stage_curriculum_multihead(model, train_loader, val_loader, 'C', 
                                           QUICK_CONFIG['epochs']['C'], device,
                                           head_weights=head_weights,
                                           use_focal=use_focal, focal_alpha=focal_alpha, focal_gamma=focal_gamma)
    
    # Final evaluation
    final_metrics = validate(model, val_loader, device)
    runtime = time.time() - start_time
    
    result = {
        "model_name": model_name,
        "r2_z": final_metrics["r2xyz"][2],
        "3d_error_mm": final_metrics.get("e3d", 0),
        "2d_error_px": final_metrics.get("e2d", 0),
        "runtime_min": runtime / 60,
        "decision": "COMBINED",
        "description": description
    }

    print(f"‚úÖ EXPERIMENT COMPLETE: {model_name}")
    print(f"   R¬≤_z: {result['r2_z']:.3f}, 3D Error: {result['3d_error_mm']:.1f}mm")
    print(f"   Runtime: {result['runtime_min']:.1f} minutes")
    
    return result

# Custom training stage for MultiHead + Curriculum
def train_stage_curriculum_multihead(model, train_loader, val_loader, stage, max_epochs, device,
                                   head_weights, use_amp=False, use_focal=False, focal_alpha=0.25, focal_gamma=1.5):
    """Train one stage with MultiHead loss and curriculum depth weighting"""
    e2e = configure_model_for_stage(model, stage)
    optim = make_optim(model, stage)
    scaler = GradScaler(enabled=use_amp)
    
    # Loss weights setup
    loss_weights = QUICK_CONFIG['stage_loss_weights'][stage].copy()
    loss_weights.update(QUICK_CONFIG['global_loss_weights'])
    loss_weights['head_weights'] = head_weights
    
    history = {'train_loss': [], 'val_loss': [], 'r2_z': []}
    
    for epoch in range(1, max_epochs + 1):
        # Apply curriculum weight for this epoch
        loss_weights['w_z'] = get_curriculum_weights(stage, epoch-1, max_epochs)['w_z']
        
        # Training loop
        model.train()
        train_metrics = defaultdict(float)
        
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass
            if e2e:
                out = model.forward_e2e(batch['left'], batch['right'])
            else:
                out = model(batch['left'], batch['right'])
            
            # MultiHead loss with curriculum
            loss, head_losses = multihead_loss(out, batch, **loss_weights)
            
            # Track losses
            train_metrics['loss'] += loss.item()
            for name, hloss in head_losses.items():
                train_metrics[f"{name}_loss"] += hloss.item()
            
            # Backward pass
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
            optim.zero_grad()
        
        # Validation
        val_metrics = validate(model, val_loader, device)
        
        # Logging
        avg_loss = train_metrics['loss'] / len(train_loader)
        r2_z = val_metrics['r2xyz'][2]
        w_z = loss_weights['w_z']
        
        print(f"Epoch {epoch}/{max_epochs} | Loss: {avg_loss:.4f} | Val R¬≤_z: {r2_z:.3f} | w_z: {w_z:.1f}")
        
        history['train_loss'].append(avg_loss)
        history['val_loss'].append(val_metrics['val_loss'])
        history['r2_z'].append(r2_z)
    
    return history

# Run the combined experiment
skip_multihead_curriculum_result = experiment_multihead_curriculum(
    model_name="skip_connections + multihead_depth + curriculum",
    model_class=SkipConnectionMultiHeadNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Skip connections + multi-head depth + curriculum learning: progressive depth weight increase",
    head_weights={'coarse': 0.2, 'medium': 0.3, 'fine': 0.5},
    use_focal=True,
    focal_alpha=0.25,
    focal_gamma=1.5
)

log_experiment(skip_multihead_curriculum_result)
show_leaderboard()


In [None]:
# Run skip_connections + multihead_depth experiment
skip_multihead_result = experiment(
    model_name="skip_connections + multihead_depth",
    model_class=SkipConnectionMultiHeadNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Skip connections + multi-head depth predictor: ensemble of coarse/medium/fine heads with gradient flow enhancement",
    strategy="MultiHead",
    head_weights={'coarse': 0.2, 'medium': 0.3, 'fine': 0.5}
)

log_experiment(skip_multihead_result)
show_leaderboard()

## Experiment: Skip Connections + Auxiliary Depth

**Hypothesis**: Combining skip connections (better gradient flow) with auxiliary depth supervision (richer spatial features) should improve performance.

**Architecture**:
- Skip connections: Dual-path fusion (main conv path + skip linear path)
- Auxiliary depth: Separate depth head with transposed convolutions
- Expected R¬≤_z: 0.55+ (combining strengths of both approaches)

In [22]:
class SkipConnectionAuxiliaryDepthNet(StereoTwoStageNet):
    """Skip connections + Auxiliary depth supervision"""
    
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super().__init__(backbone_name, pretrained)
        
        # Skip connection fusion - preserve spatial dimensions
        fused_channels = 128 * len(self.feature_dims) * 2
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(fused_channels, 512, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 1),
            nn.ReLU(inplace=True),
        )
        self.skip_fusion = nn.Linear(fused_channels, 256)
        
        # Auxiliary depth head - starts from 7x7 feature map
        self.depth_map_head = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 7x7 -> 14x14
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 14x14 -> 28x28
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 28x28 -> 56x56
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, kernel_size=3, padding=1),                        # 56x56 -> 56x56
            nn.Softplus(beta=1.0)  # Ensure positive depth values
        )
    
    def _fused_vec_with_4d(self, left_img, right_img):
        feats_l = self.backbone(left_img)
        feats_r = self.backbone(right_img)
        H, W = feats_l[-1].shape[2:]
        
        fused_scales = []
        for fl, fr, proj in zip(feats_l, feats_r, self.proj):
            fl_proj = proj(fl)
            fr_proj = proj(fr)
            fl_proj = F.adaptive_avg_pool2d(fl_proj, (H, W))
            fr_proj = F.adaptive_avg_pool2d(fr_proj, (H, W))
            fused_scales.append(torch.cat([fl_proj, fr_proj], dim=1))
        
        x = torch.cat(fused_scales, dim=1)  # (B, 1280, H, W) - H,W typically 7x7
        fused_4d = self.fusion_conv(x)      # (B, 256, H, W)
        
        # Skip connection
        skip_pooled = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)
        skip_vec = self.skip_fusion(skip_pooled)
        
        # Pooled vector for intersection head
        fused_vec = F.adaptive_avg_pool2d(fused_4d, 1).view(fused_4d.size(0), -1)
        combined_vec = fused_vec + skip_vec  # Residual addition
        
        return combined_vec, fused_4d
    
    def forward(self, left_img, right_img):
        fused_vec, fused_4d = self._fused_vec_with_4d(left_img, right_img)
        
        pooled = F.adaptive_avg_pool2d(fused_4d, 1)
        axis_params = self.axis_head(pooled)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        conditioned = torch.cat([fused_vec, origin.detach(), direction.detach()], dim=1)
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        depth_map = self.depth_map_head(fused_4d)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "depth_map": depth_map
        }
    
    def forward_e2e(self, left_img, right_img):
        fused_vec, fused_4d = self._fused_vec_with_4d(left_img, right_img)
        
        pooled = F.adaptive_avg_pool2d(fused_4d, 1)
        axis_params = self.axis_head(pooled)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        conditioned = torch.cat([fused_vec, origin, direction], dim=1)
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        depth_map = self.depth_map_head(fused_4d)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "depth_map": depth_map
        }

print("‚úÖ SkipConnectionAuxiliaryDepthNet model defined (FIXED)!")

‚úÖ SkipConnectionAuxiliaryDepthNet model defined (FIXED)!


In [23]:
# Run skip_connections + auxiliary_depth experiment
result = experiment(
    model_name="skip_connections + auxiliary_depth",
    model_class=SkipConnectionAuxiliaryDepthNet,
    train_loader=train_loader,
    val_loader=val_loader,
    description="Skip connections for gradient flow + auxiliary depth supervision",
    strategy="auxiliary",
    w_depth_map={'A': 0.0, 'B': 0.5, 'C': 1.0}
)

log_experiment(result)
show_leaderboard()
comparison = compare_with_baseline(result)


üöÄ EXPERIMENT: skip_connections + auxiliary_depth
Description: Skip connections for gradient flow + auxiliary depth supervision
Device: mps | Expected time: ~5-15 minutes


STAGE A: Axis Pretraining (20 epochs)

Epoch 1/20 | Loss: 3.1604 | Val: 1.6880 | Ang: 27.82¬∞

Epoch 2/20 | Loss: 1.3264 | Val: 1.4528 | Ang: 28.10¬∞

Epoch 3/20 | Loss: 1.1800 | Val: 1.3516 | Ang: 26.89¬∞

Epoch 4/20 | Loss: 0.9812 | Val: 0.6296 | Ang: 16.46¬∞

Epoch 5/20 | Loss: 0.2929 | Val: 0.2112 | Ang: 8.66¬∞

Epoch 6/20 | Loss: 0.1895 | Val: 0.1189 | Ang: 3.55¬∞

Epoch 7/20 | Loss: 0.1384 | Val: 0.1305 | Ang: 6.07¬∞

Epoch 8/20 | Loss: 0.1214 | Val: 0.1022 | Ang: 3.69¬∞

Epoch 9/20 | Loss: 0.1042 | Val: 0.0883 | Ang: 2.59¬∞

Epoch 10/20 | Loss: 0.0959 | Val: 0.0818 | Ang: 2.86¬∞

Epoch 11/20 | Loss: 0.0866 | Val: 0.0820 | Ang: 3.05¬∞

Epoch 12/20 | Loss: 0.0742 | Val: 0.0613 | Ang: 2.30¬∞

Epoch 13/20 | Loss: 0.0679 | Val: 0.0625 | Ang: 3.87¬∞

Epoch 14/20 | Loss: 0.0621 | Val: 0.0526 | Ang: 3.22¬∞

Epoch 