# Full Training Runs - Best Models

**Created**: 2025-11-06 17:45

## Purpose
Full training runs of the best models identified in quick_experiments.ipynb

## Top Models Selected:
1. **Skip Connections + Focal Loss + Augmentation** (R¬≤_z: 0.5703)
2. **Skip Connections + Spatial Attention** (R¬≤_z: 0.5540)
3. **Skip Connections + Auxiliary Depth** (R¬≤_z: 0.5362)

## Training Configuration:
- **Stage A**: 20 epochs (vs 10 in quick experiments)
- **Stage B**: 30 epochs (vs 15 in quick experiments)
- **Stage C**: 40 epochs (vs 20 in quick experiments)
- **Early Stopping**: Enabled with patience=5
- **Checkpointing**: Save best model per stage
- **Target**: R¬≤_z > 0.60

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.models as models
import timm

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
import csv
from datetime import datetime
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

# Create directories for outputs
CHECKPOINT_DIR = Path('checkpoints')
CHECKPOINT_DIR.mkdir(exist_ok=True)
print(f'Checkpoint directory: {CHECKPOINT_DIR}')

## Load Dataset

Dataset class and data loaders

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import glob
import os

# Training transformations - MATCH quick_experiments EXACTLY
train_transformations = [
    A.LongestMaxSize(max_size=224),
    A.PadIfNeeded(min_height=224, min_width=224),
    A.Normalize(),  # Uses default ImageNet normalization
    ToTensorV2()
]

train_transform = A.Compose(
    train_transformations,
    keypoint_params=A.KeypointParams(format='xy', remove_invisible=False),
    additional_targets={'image_right': 'image', 'depth_map': 'mask'}
)

# Validation transform (no augmentation) - MATCH quick_experiments
val_transformations = [
    A.LongestMaxSize(max_size=224),
    A.PadIfNeeded(min_height=224, min_width=224),
    A.Normalize(),
    ToTensorV2()
]

val_transform = A.Compose(
    val_transformations,
    keypoint_params=A.KeypointParams(format='xy', remove_invisible=False),
    additional_targets={'image_right': 'image', 'depth_map': 'mask'}
)

print('‚úÖ Augmentation transforms defined (MATCHING quick_experiments)')
print('   - Using LongestMaxSize + PadIfNeeded (preserves aspect ratio)')
print('   - Using default Normalize() parameters')

In [None]:
def compute_pca_axis(points_2d):
    """
    Compute the principal axis (origin, direction) from 2D probe points using PCA.
    
    Args:
        points_2d: (N, 2) array of 2D points
    
    Returns:
        origin: (2,) center point
        direction: (2,) unit direction vector
    """
    from sklearn.decomposition import PCA
    mean = points_2d.mean(axis=0)
    pca = PCA(n_components=1)
    pca.fit(points_2d)
    direction = pca.components_[0]
    return mean, direction


def find_valid_depth(depth_map, x_idx, y_idx, max_radius=10):
    """Find nearest valid depth if center point has invalid depth"""
    for r in range(1, max_radius):
        for dy in range(-r, r+1):
            for dx in range(-r, r+1):
                ny = np.clip(y_idx + dy, 0, depth_map.shape[0] - 1)
                nx = np.clip(x_idx + dx, 0, depth_map.shape[1] - 1)
                d = depth_map[ny, nx]
                if d > 0:
                    return d
    return 1e-6


class StereoIntersectionDataset(torch.utils.data.Dataset):
    """Dataset for stereo intersection detection with probe axis - EXACT match to quick_experiments"""
    
    def __init__(self, root_dir, transform=None, max_depth=220.0):
        self.left_img_paths = sorted(glob.glob(os.path.join(root_dir, "left", "images", "*.jpg")))
        self.right_img_paths = sorted(glob.glob(os.path.join(root_dir, "right", "images", "*.jpg")))
        self.probe_axis_paths = sorted(glob.glob(os.path.join(root_dir, "left", "probe_axis", "*.txt")))
        self.depth_map_paths = sorted(glob.glob(os.path.join(root_dir, "left", "depth_labels", "*.npy")))
        
        self.transform = transform
        self.max_depth = max_depth
        self.split = os.path.basename(root_dir.rstrip("/"))

        # Load ground truth x,y from CenterPt.txt
        gt_xy = []
        with open(os.path.join(root_dir, "left", "labels", "CenterPt.txt"), 'r') as f:
            for line in f:
                _, x_str, y_str = line.strip().split(",")
                gt_xy.append((float(x_str), float(y_str)))
        self.gt_xy = np.array(gt_xy, dtype=np.float32)

    def __len__(self):
        return len(self.left_img_paths)

    def __getitem__(self, idx):
        left_img = cv2.imread(self.left_img_paths[idx])
        right_img = cv2.imread(self.right_img_paths[idx])
        left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB)
        right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB)
        depth_map = np.load(self.depth_map_paths[idx])
        points = np.loadtxt(self.probe_axis_paths[idx])

        if self.transform:
            transformed = self.transform(
                image=left_img,
                image_right=right_img,
                keypoints=[self.gt_xy[idx]] + points.tolist(),
                depth_map=depth_map
            )
            left_img = transformed["image"]
            right_img = transformed["image_right"]
            keypoints = transformed["keypoints"]
            depth_map = transformed["depth_map"]
        else:
            keypoints = [self.gt_xy[idx]] + points.tolist()
            # Convert to tensors manually if no transform
            left_img = torch.from_numpy(left_img).permute(2, 0, 1).float() / 255.0
            right_img = torch.from_numpy(right_img).permute(2, 0, 1).float() / 255.0
            depth_map = torch.from_numpy(depth_map).float()

        probe_axis_mean, direction = compute_pca_axis(np.array(keypoints[1:]))

        _, img_h, img_w = left_img.shape
        
        # Normalize coordinates by image dimensions
        probe_axis_mean = np.array(probe_axis_mean, dtype=np.float32) / np.array([img_w, img_h], dtype=np.float32)
        probe_axis = torch.tensor(probe_axis_mean, dtype=torch.float32)
        probe_dir = torch.tensor(direction, dtype=torch.float32)
        
        # Get depth at center point (use y, x order for numpy arrays!)
        x, y = keypoints[0]
        x_idx = int(np.clip(round(x), 0, depth_map.shape[1] - 1 if depth_map.dim() == 2 else depth_map.shape[2] - 1))
        y_idx = int(np.clip(round(y), 0, depth_map.shape[0] if depth_map.dim() == 2 else depth_map.shape[1] - 1))
        
        if depth_map.dim() == 3:
            z = depth_map[0, y_idx, x_idx].item()
        else:
            z = depth_map[y_idx, x_idx]
        
        # Handle invalid depth
        if z == 0.0:
            z = find_valid_depth(depth_map.numpy() if isinstance(depth_map, torch.Tensor) else depth_map, x_idx, y_idx)
        if z <= 0:
            z = 1e-6
        
        # Normalize intersection coordinates and depth
        intersect_norm = np.array(keypoints[0]) / np.array([img_w, img_h])
        target = torch.tensor(intersect_norm.tolist() + [z/self.max_depth], dtype=torch.float32)

        # Match quick_experiments return format
        return {
            "left_img": left_img,
            "right_img": right_img,
            "origin": probe_axis,
            "direction": probe_dir,
            "intersection": target,
            "depth_label": depth_map / self.max_depth  # Normalize depth map for auxiliary training
        }


print('‚úÖ Dataset class defined (EXACT match to quick_experiments)')

In [None]:
# Dataset configuration
DATA_DIR = Path('data/processed')
BATCH_SIZE = 16

print('Loading datasets...')

# Training dataset with augmentation
train_dataset = StereoIntersectionDataset(
    root_dir=str(DATA_DIR / 'train'),
    transform=train_transform,
    max_depth=220.0
)

# Validation dataset without augmentation
val_dataset = StereoIntersectionDataset(
    root_dir=str(DATA_DIR / 'val'),
    transform=val_transform,
    max_depth=220.0
)

# Data loaders
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

print(f'Train samples: {len(train_dataset)}')
print(f'Val samples: {len(val_dataset)}')
print(f'Batch size: {BATCH_SIZE}')
print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')

## Model Architectures

Base class and top-performing model variants

In [None]:
class StereoTwoStageNet(nn.Module):
    """Baseline model - Original architecture without modifications"""
    def __init__(self, backbone_name="resnet18", pretrained=True, seed=42):
        super().__init__()
        
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained, features_only=True)
        self.feature_dims = [f["num_chs"] for f in self.backbone.feature_info]

        self.proj = nn.ModuleList([nn.Conv2d(c, 128, 1) for c in self.feature_dims])

        fused_channels = 128 * len(self.feature_dims) * 2  
        self.fusion = nn.Sequential(
            nn.Conv2d(fused_channels, 256, 1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )

        self.axis_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 4)  # (x0, y0, dx, dy)
        )

        self.offset_depth_head = nn.Sequential(
            nn.Linear(256 + 4, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 2)  # (t, z_raw)
        )

        self.softplus = nn.Softplus(beta=1.0)

    def _fused_vec(self, left_img, right_img):
        feats_l = self.backbone(left_img)
        feats_r = self.backbone(right_img)
        H, W = feats_l[-1].shape[2:]
        
        fused_scales = []
        for fl, fr, proj in zip(feats_l, feats_r, self.proj):
            fl = proj(fl)
            fr = proj(fr)
            fl = F.adaptive_avg_pool2d(fl, (H, W))
            fr = F.adaptive_avg_pool2d(fr, (H, W))
            fused_scales.append(torch.cat([fl, fr], dim=1))

        x = torch.cat(fused_scales, dim=1)
        x = self.fusion(x)
        v = x.view(x.size(0), -1)
        return v, x
    
    def forward(self, left_img, right_img):
        fused_vec, fused_4d = self._fused_vec(left_img, right_img)
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        conditioned = torch.cat([fused_vec, origin.detach(), direction.detach()], dim=1)
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection
        }

    def forward_e2e(self, left_img, right_img):
        fused_vec, fused_4d = self._fused_vec(left_img, right_img)
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        conditioned = torch.cat([fused_vec, origin, direction], dim=1)  # no detach
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection
        }

print("‚úÖ StereoTwoStageNet model defined.")

### Model 1: Skip Connections + Focal Loss + Augmentation

**Quick Experiment Result**: R¬≤_z = 0.5703 (WINNER)

**Architecture**: Dual-path fusion with skip connections for better gradient flow

In [None]:
class SkipConnectionNet(StereoTwoStageNet):
    """
    Enhanced fusion with residual pathway for better gradient flow.
    
    Architecture:
    - Main path: Concatenated features ‚Üí Conv 1280‚Üí512‚Üí256 ‚Üí Pool ‚Üí (B, 256)
    - Skip path: Concatenated features ‚Üí Pool ‚Üí Linear 1280‚Üí256 ‚Üí (B, 256)
    - Output: main + skip (residual addition)
    """
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super().__init__(backbone_name, pretrained)
        
        # Enhanced fusion with increased capacity
        fused_channels = 128 * len(self.feature_dims) * 2
        self.fusion = nn.Sequential(
            nn.Conv2d(fused_channels, 512, 1),  # Expanded capacity
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        
        # Skip connection path - processes raw concatenated features
        # This provides alternative gradient path and preserves feature information
        self.skip_fusion = nn.Linear(fused_channels, 256)
        
    def _fused_vec(self, left_img, right_img):
        """Extract fused features with residual connection"""
        feats_l = self.backbone(left_img)
        feats_r = self.backbone(right_img)
        H, W = feats_l[-1].shape[2:]
        
        # Project and fuse features at multiple scales (using inherited self.proj)
        fused_scales = []
        for fl, fr, proj in zip(feats_l, feats_r, self.proj):
            # Project to consistent channels
            fl_proj = proj(fl)
            fr_proj = proj(fr)
            # Resize to common spatial size
            fl_proj = F.adaptive_avg_pool2d(fl_proj, (H, W))
            fr_proj = F.adaptive_avg_pool2d(fr_proj, (H, W))
            # Concatenate left-right stereo features
            fused_scales.append(torch.cat([fl_proj, fr_proj], dim=1))

        # Concatenate all scales: (B, 1280, H, W) for ResNet18
        x = torch.cat(fused_scales, dim=1)
        
        # Main fusion path: Convolutional transformation
        fused_4d = self.fusion(x)
        fused_vec = fused_4d.view(fused_4d.size(0), -1)  # (B, 256)
        
        # Skip connection path: Direct linear transformation
        # Pool the raw concatenated features and project to same dimension
        skip_pooled = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)  # (B, 1280)
        skip_vec = self.skip_fusion(skip_pooled)  # (B, 256)
        
        # Residual addition: Combine both pathways
        # This helps gradient flow and provides ensemble-like effect
        combined_vec = fused_vec + skip_vec
        
        return combined_vec, fused_4d

print("‚úÖ SkipConnectionNet model defined!")

### Model 2: Skip Connections + Spatial Attention

**Quick Experiment Result**: R¬≤_z = 0.5540 (WINNER)

**Architecture**: Combines skip connections with spatial attention mechanism

In [None]:
class SpatialAttentionModule(nn.Module):
    """Spatial attention to weight important regions for depth prediction"""
    def __init__(self, in_channels):
        super().__init__()
        # Channel reduction for attention map
        self.attention_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 4, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // 4, 1, 1),
            nn.Sigmoid()  # Attention weights [0, 1]
        )
        
    def forward(self, x):
        # x: (B, C, H, W)
        attn_map = self.attention_conv(x)  # (B, 1, H, W)
        attended = x * attn_map  # Element-wise multiplication
        return attended, attn_map


class SpatialAttentionDepthNet(StereoTwoStageNet):
    """
    Enhanced model with spatial attention for depth prediction
    
    Key improvements:
    1. Applies learned attention to weight depth-relevant spatial regions
    2. Dual-path processing: attended global context + attended spatial features
    3. Returns attention maps for visualization and analysis
    """
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super().__init__(backbone_name, pretrained)
        
        # Calculate fused channels
        fused_channels = 128 * len(self.feature_dims) * 2
        
        # Spatial attention module
        self.spatial_attention = SpatialAttentionModule(fused_channels)
        
        # Enhanced fusion with spatial pathway
        self.spatial_fusion = nn.Sequential(
            nn.Conv2d(fused_channels, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(inplace=True),
        )
        
        # Keep original global fusion for axis prediction
        self.fusion = nn.Sequential(
            nn.Conv2d(fused_channels, 256, 1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        
        # Enhanced depth head with spatial features
        # Combines: global features (256) + spatial features (256) + axis (4) = 516
        self.offset_depth_head = nn.Sequential(
            nn.Linear(256 + 256 + 4, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),  # Regularization for larger head
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 2)  # (offset_t, depth_z_raw)
        )
        
    def _fused_vec(self, left_img, right_img):
        """Extract fused features with spatial attention
        
        Returns:
            fused_vec: Concatenated global + spatial features (B, 512)
            global_4d: Global features for axis head (B, 256, 1, 1)
            attn_map: Attention map for visualization (B, 1, H, W)
        """
        feats_l = self.backbone(left_img)
        feats_r = self.backbone(right_img)
        H, W = feats_l[-1].shape[2:]
        
        # Multi-scale fusion (same as baseline)
        fused_scales = []
        for fl, fr, proj in zip(feats_l, feats_r, self.proj):
            fl = proj(fl)
            fr = proj(fr)
            fl = F.adaptive_avg_pool2d(fl, (H, W))
            fr = F.adaptive_avg_pool2d(fr, (H, W))
            fused_scales.append(torch.cat([fl, fr], dim=1))

        x = torch.cat(fused_scales, dim=1)  # (B, 1280, H, W)
        
        # Apply spatial attention - FIX #2: Keep attention map for visualization
        attended_x, attn_map = self.spatial_attention(x)
        
        # Spatial pathway: preserve spatial features for depth
        spatial_features = self.spatial_fusion(attended_x)  # (B, 256, H, W)
        spatial_vec = F.adaptive_avg_pool2d(spatial_features, 1).view(spatial_features.size(0), -1)  # (B, 256)
        
        # Global pathway: for axis prediction - FIX #1: Use attended features
        global_4d = self.fusion(attended_x)  # FIXED: was self.fusion(x)
        global_vec = global_4d.view(global_4d.size(0), -1)
        
        fused_vec = torch.cat([global_vec, spatial_vec], dim=1)
        return fused_vec, global_4d, attn_map
    
    def forward(self, left_img, right_img):
        fused_vec, fused_4d, attn_map = self._fused_vec(left_img, right_img)
        
        # Stage 1: Axis prediction (uses attended global features)
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        # Stage 2: Depth prediction (uses global + spatial + axis)
        conditioned = torch.cat([fused_vec, origin.detach(), direction.detach()], dim=1)
        
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "attn_map": attn_map  # Return attention map for analysis
        }

    def forward_e2e(self, left_img, right_img):
        """End-to-end forward pass (no gradient detachment)"""
        fused_vec, fused_4d, attn_map = self._fused_vec(left_img, right_img)
        
        axis_params = self.axis_head(fused_4d)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        conditioned = torch.cat([fused_vec, origin, direction], dim=1)
        
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "attn_map": attn_map  # Return attention map for analysis
        }

print("‚úÖ SpatialAttentionDepthNet model defined!")

### Model 3: Skip Connections + Auxiliary Depth

**Quick Experiment Result**: R¬≤_z = 0.5362 (DECENT)

**Architecture**: Skip connections with auxiliary depth map supervision

In [None]:
class SkipConnectionAuxiliaryDepthNet(StereoTwoStageNet):
    """Skip connections + Auxiliary depth supervision"""
    
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super().__init__(backbone_name, pretrained)
        
        # Skip connection fusion - preserve spatial dimensions
        fused_channels = 128 * len(self.feature_dims) * 2
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(fused_channels, 512, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 1),
            nn.ReLU(inplace=True),
        )
        self.skip_fusion = nn.Linear(fused_channels, 256)
        
        # Auxiliary depth head - starts from 7x7 feature map
        self.depth_map_head = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 7x7 -> 14x14
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 14x14 -> 28x28
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 28x28 -> 56x56
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, kernel_size=3, padding=1),                        # 56x56 -> 56x56
            nn.Softplus(beta=1.0)  # Ensure positive depth values
        )
    
    def _fused_vec_with_4d(self, left_img, right_img):
        feats_l = self.backbone(left_img)
        feats_r = self.backbone(right_img)
        H, W = feats_l[-1].shape[2:]
        
        fused_scales = []
        for fl, fr, proj in zip(feats_l, feats_r, self.proj):
            fl_proj = proj(fl)
            fr_proj = proj(fr)
            fl_proj = F.adaptive_avg_pool2d(fl_proj, (H, W))
            fr_proj = F.adaptive_avg_pool2d(fr_proj, (H, W))
            fused_scales.append(torch.cat([fl_proj, fr_proj], dim=1))
        
        x = torch.cat(fused_scales, dim=1)  # (B, 1280, H, W) - H,W typically 7x7
        fused_4d = self.fusion_conv(x)      # (B, 256, H, W)
        
        # Skip connection
        skip_pooled = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)
        skip_vec = self.skip_fusion(skip_pooled)
        
        # Pooled vector for intersection head
        fused_vec = F.adaptive_avg_pool2d(fused_4d, 1).view(fused_4d.size(0), -1)
        combined_vec = fused_vec + skip_vec  # Residual addition
        
        return combined_vec, fused_4d
    
    def forward(self, left_img, right_img):
        fused_vec, fused_4d = self._fused_vec_with_4d(left_img, right_img)
        
        pooled = F.adaptive_avg_pool2d(fused_4d, 1)
        axis_params = self.axis_head(pooled)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        conditioned = torch.cat([fused_vec, origin.detach(), direction.detach()], dim=1)
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        depth_map = self.depth_map_head(fused_4d)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "depth_map": depth_map
        }
    
    def forward_e2e(self, left_img, right_img):
        fused_vec, fused_4d = self._fused_vec_with_4d(left_img, right_img)
        
        pooled = F.adaptive_avg_pool2d(fused_4d, 1)
        axis_params = self.axis_head(pooled)
        origin = axis_params[:, :2]
        direction = F.normalize(axis_params[:, 2:], dim=1)
        
        conditioned = torch.cat([fused_vec, origin, direction], dim=1)
        od = self.offset_depth_head(conditioned)
        offset_t = od[:, 0:1]
        depth_z = self.softplus(od[:, 1:2])
        
        xy_inter = origin + offset_t * direction
        intersection = torch.cat([xy_inter, depth_z], dim=1)
        depth_map = self.depth_map_head(fused_4d)
        
        return {
            "origin": origin,
            "direction": direction,
            "offset_t": offset_t,
            "depth_z": depth_z,
            "intersection": intersection,
            "depth_map": depth_map
        }

print("‚úÖ SkipConnectionAuxiliaryDepthNet model defined (FIXED)!")

## Training Infrastructure

Utilities for full training with checkpointing and early stopping

In [None]:
class FocalLoss(nn.Module):
    """Focal Loss for handling hard samples"""
    def __init__(self, alpha=0.25, gamma=1.5):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, pred, target):
        mse = (pred - target) ** 2
        pt = torch.exp(-mse)
        focal_weight = self.alpha * (1 - pt) ** self.gamma
        loss = focal_weight * mse
        return loss.mean()


def auxiliary_depth_loss(pred_depth_map, target_depth_map, valid_mask=None):
    """
    Auxiliary depth map loss (log-space L1).
    
    Args:
        pred_depth_map: (B, 1, H, W) predicted depth
        target_depth_map: (B, 1, H, W) target depth
        valid_mask: (B, 1, H, W) mask for valid depth values
    """
    if valid_mask is None:
        valid_mask = target_depth_map > 0.1  # Filter invalid depths
    
    if valid_mask.sum() == 0:
        return torch.tensor(0.0, device=pred_depth_map.device)
    
    # Log-space L1 for numerical stability
    pred_log = torch.log(pred_depth_map[valid_mask] + 1e-6)
    target_log = torch.log(target_depth_map[valid_mask] + 1e-6)
    loss = F.l1_loss(pred_log, target_log)
    
    return loss


def compute_metrics(pred, target):
    """
    Compute evaluation metrics.
    
    Args:
        pred: dict with keys ['origin', 'direction', 'intersection']
        target: dict with keys ['origin', 'direction', 'intersection']
    
    Returns:
        dict with metrics
    """
    with torch.no_grad():
        # R¬≤ for depth (z coordinate)
        z_pred = pred['intersection'][:, 2]
        z_true = target['intersection'][:, 2]
        z_var = torch.var(z_true)
        z_mse = F.mse_loss(z_pred, z_true)
        r2_z = 1 - (z_mse / (z_var + 1e-8))
        
        # 3D Euclidean error (mm)
        e3d = torch.norm(pred['intersection'] - target['intersection'], dim=1).mean()
        
        # 2D pixel error
        e2d_origin = torch.norm(pred['origin'] - target['origin'], dim=1).mean()
        e2d_inter = torch.norm(pred['intersection'][:, :2] - target['intersection'][:, :2], dim=1).mean()
        e2d = (e2d_origin + e2d_inter) / 2
        
        # Angular error (degrees)
        # Normalize both directions first
        pred_dir = F.normalize(pred['direction'], dim=1)
        target_dir = F.normalize(target['direction'], dim=1)
        cos_sim = (pred_dir * target_dir).sum(dim=1)
        cos_sim = torch.clamp(cos_sim, -1.0 + 1e-7, 1.0 - 1e-7)
        ang_error = torch.acos(cos_sim).mean() * 180 / np.pi
    
    return {
        'r2_z': r2_z.item(),
        'e3d': e3d.item(),
        'e2d': e2d.item(),
        'ang_deg': ang_error.item()
    }

print('‚úÖ Loss functions and metrics defined')

In [None]:
class CheckpointManager:
    """Manage model checkpoints"""
    def __init__(self, save_dir, model_name):
        self.save_dir = Path(save_dir) / model_name
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.best_scores = {}
    
    def save(self, model, optimizer, epoch, stage, metrics, is_best=False):
        """Save checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'stage': stage,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metrics': metrics
        }
        
        # Save latest
        latest_path = self.save_dir / f'stage_{stage}_latest.pt'
        torch.save(checkpoint, latest_path)
        
        # Save best
        if is_best:
            best_path = self.save_dir / f'stage_{stage}_best.pt'
            torch.save(checkpoint, best_path)
            self.best_scores[stage] = metrics['r2_z']
            print(f'   üíæ Saved best model for stage {stage} (R¬≤_z: {metrics["r2_z"]:.4f})')
    
    def load_best(self, model, optimizer, stage):
        """Load best checkpoint for a stage"""
        best_path = self.save_dir / f'stage_{stage}_best.pt'
        if best_path.exists():
            checkpoint = torch.load(best_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(f'   üìÇ Loaded best model from stage {stage}')
            return checkpoint['metrics']
        return None


class EarlyStopping:
    """Early stopping to prevent overfitting"""
    def __init__(self, patience=5, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.should_stop = False
    
    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_score = score
            self.counter = 0
        
        return self.should_stop

print('‚úÖ Checkpoint and early stopping managers defined')

In [None]:
def train_full_model(
    model,
    train_loader,
    val_loader,
    model_name,
    stage_A_epochs=20,
    stage_B_epochs=30,
    stage_C_epochs=40,
    use_focal=False,
    use_auxiliary_depth=False,
    w_depth_map=None,
    early_stopping_patience=5
):
    """
    Full three-stage training with checkpointing and early stopping.
    
    Args:
        model: The model to train
        train_loader: Training data loader
        val_loader: Validation data loader
        model_name: Name for saving checkpoints
        stage_A_epochs: Epochs for Stage A (axis pretraining)
        stage_B_epochs: Epochs for Stage B (intersection training)
        stage_C_epochs: Epochs for Stage C (end-to-end fine-tuning)
        use_focal: Whether to use Focal Loss
        use_auxiliary_depth: Whether model has auxiliary depth head
        w_depth_map: Dict with depth map weights per stage (for auxiliary)
        early_stopping_patience: Patience for early stopping
    
    Returns:
        dict: Final metrics and training history
    """
    
    model = model.to(device)
    checkpoint_mgr = CheckpointManager(CHECKPOINT_DIR, model_name)
    
    # Loss functions
    if use_focal:
        depth_loss_fn = FocalLoss(alpha=0.25, gamma=1.5)
    else:
        depth_loss_fn = nn.MSELoss()
    
    # Default depth map weights
    if w_depth_map is None:
        w_depth_map = {'A': 0.0, 'B': 0.5, 'C': 1.0}
    
    history = {'A': [], 'B': [], 'C': []}
    
    # ==================== STAGE A: Axis Pretraining ====================
    print("\n" + "="*80)
    print(f"STAGE A: Axis Pretraining ({stage_A_epochs} epochs)")
    print("="*80)
    print("Training: axis_head only")
    print("Frozen: offset_depth_head")
    
    # Freeze intersection head
    for param in model.offset_depth_head.parameters():
        param.requires_grad = False
    
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=stage_A_epochs)
    early_stop = EarlyStopping(patience=early_stopping_patience)
    
    best_val_r2z = -float('inf')
    
    for epoch in range(stage_A_epochs):
        # Training
        model.train()
        train_loss = 0.0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{stage_A_epochs}"):
            left_img = batch['left_img'].to(device)
            right_img = batch['right_img'].to(device)
            origin = batch['origin'].to(device)
            direction = batch['direction'].to(device)
            
            optimizer.zero_grad()
            pred = model(left_img, right_img)
            
            # Stage A loss: axis only
            loss_origin = F.mse_loss(pred['origin'], origin)
            loss_dir = F.mse_loss(pred['direction'], direction)
            loss = loss_origin + loss_dir
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        val_metrics = {'r2_z': 0, 'e3d': 0, 'e2d': 0, 'ang_deg': 0}
        
        with torch.no_grad():
            for batch in val_loader:
                left_img = batch['left_img'].to(device)
                right_img = batch['right_img'].to(device)
                
                target = {
                    'origin': batch['origin'].to(device),
                    'direction': batch['direction'].to(device),
                    'intersection': batch['intersection'].to(device)
                }
                
                pred = model(left_img, right_img)
                metrics = compute_metrics(pred, target)
                
                for k in val_metrics:
                    val_metrics[k] += metrics[k]
        
        for k in val_metrics:
            val_metrics[k] /= len(val_loader)
        
        history['A'].append({'epoch': epoch+1, 'train_loss': train_loss, **val_metrics})
        
        print(f"Epoch {epoch+1}/{stage_A_epochs} | Loss: {train_loss:.4f} | R¬≤_z: {val_metrics['r2_z']:.4f} | Ang: {val_metrics['ang_deg']:.2f}¬∞")
        
        # Save checkpoint
        is_best = val_metrics['r2_z'] > best_val_r2z
        if is_best:
            best_val_r2z = val_metrics['r2_z']
        checkpoint_mgr.save(model, optimizer, epoch, 'A', val_metrics, is_best)
        
        scheduler.step()
        
        # Early stopping
        if early_stop(val_metrics['r2_z']):
            print(f"   ‚è∏Ô∏è Early stopping triggered at epoch {epoch+1}")
            break
    
    # Load best Stage A model
    checkpoint_mgr.load_best(model, optimizer, 'A')
    
    print(f"\n‚úÖ Stage A complete. Best R¬≤_z: {best_val_r2z:.4f}")
    
    # ==================== STAGE B: Intersection Training ====================
    print("\n" + "="*80)
    print(f"STAGE B: Intersection Training ({stage_B_epochs} epochs)")
    print("="*80)
    print("Training: offset_depth_head only")
    print("Frozen: backbone, axis_head")
    
    # Freeze backbone and axis head
    for param in model.backbone.parameters():
        param.requires_grad = False
    for param in model.axis_head.parameters():
        param.requires_grad = False
    
    # Unfreeze intersection head
    for param in model.offset_depth_head.parameters():
        param.requires_grad = True
    
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=stage_B_epochs)
    early_stop = EarlyStopping(patience=early_stopping_patience)
    
    best_val_r2z = -float('inf')
    
    for epoch in range(stage_B_epochs):
        # Training
        model.train()
        train_loss = 0.0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{stage_B_epochs}"):
            left_img = batch['left_img'].to(device)
            right_img = batch['right_img'].to(device)
            intersection = batch['intersection'].to(device)
            
            optimizer.zero_grad()
            pred = model(left_img, right_img)
            
            # Stage B loss: intersection with weighted depth
            loss_xy = F.mse_loss(pred['intersection'][:, :2], intersection[:, :2])
            loss_z = depth_loss_fn(pred['intersection'][:, 2:3], intersection[:, 2:3])
            loss = loss_xy + 5.0 * loss_z  # Match quick_experiments default
            
            # Auxiliary depth loss
            if use_auxiliary_depth and 'depth_map' in pred and 'depth_label' in batch:
                depth_label = batch['depth_label'].to(device)
                loss_depth_map = auxiliary_depth_loss(pred['depth_map'], depth_label)
                loss = loss + w_depth_map['B'] * loss_depth_map
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validation (same as Stage A)
        model.eval()
        val_metrics = {'r2_z': 0, 'e3d': 0, 'e2d': 0, 'ang_deg': 0}
        
        with torch.no_grad():
            for batch in val_loader:
                left_img = batch['left_img'].to(device)
                right_img = batch['right_img'].to(device)
                
                target = {
                    'origin': batch['origin'].to(device),
                    'direction': batch['direction'].to(device),
                    'intersection': batch['intersection'].to(device)
                }
                
                pred = model(left_img, right_img)
                metrics = compute_metrics(pred, target)
                
                for k in val_metrics:
                    val_metrics[k] += metrics[k]
        
        for k in val_metrics:
            val_metrics[k] /= len(val_loader)
        
        history['B'].append({'epoch': epoch+1, 'train_loss': train_loss, **val_metrics})
        
        print(f"Epoch {epoch+1}/{stage_B_epochs} | Loss: {train_loss:.4f} | R¬≤_z: {val_metrics['r2_z']:.4f} | Ang: {val_metrics['ang_deg']:.2f}¬∞")
        
        # Save checkpoint
        is_best = val_metrics['r2_z'] > best_val_r2z
        if is_best:
            best_val_r2z = val_metrics['r2_z']
        checkpoint_mgr.save(model, optimizer, epoch, 'B', val_metrics, is_best)
        
        scheduler.step()
        
        # Early stopping
        if early_stop(val_metrics['r2_z']):
            print(f"   ‚è∏Ô∏è Early stopping triggered at epoch {epoch+1}")
            break
    
    # Load best Stage B model
    checkpoint_mgr.load_best(model, optimizer, 'B')
    
    print(f"\n‚úÖ Stage B complete. Best R¬≤_z: {best_val_r2z:.4f}")
    
    # ==================== STAGE C: End-to-End Fine-tuning ====================
    print("\n" + "="*80)
    print(f"STAGE C: End-to-End Fine-tuning ({stage_C_epochs} epochs)")
    print("="*80)
    print("Training: All parameters")
    
    # Unfreeze all parameters
    for param in model.parameters():
        param.requires_grad = True
    
    optimizer = optim.AdamW(model.parameters(), lr=1e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=stage_C_epochs)
    early_stop = EarlyStopping(patience=early_stopping_patience)
    
    best_val_r2z = -float('inf')
    
    for epoch in range(stage_C_epochs):
        # Training
        model.train()
        train_loss = 0.0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{stage_C_epochs}"):
            left_img = batch['left_img'].to(device)
            right_img = batch['right_img'].to(device)
            origin = batch['origin'].to(device)
            direction = batch['direction'].to(device)
            intersection = batch['intersection'].to(device)
            
            optimizer.zero_grad()
            pred = model.forward_e2e(left_img, right_img)  # End-to-end forward
            
            # Full loss
            loss_origin = F.mse_loss(pred['origin'], origin)
            loss_dir = F.mse_loss(pred['direction'], direction)
            loss_xy = F.mse_loss(pred['intersection'][:, :2], intersection[:, :2])
            loss_z = depth_loss_fn(pred['intersection'][:, 2:3], intersection[:, 2:3])
            
            loss = loss_origin + loss_dir + loss_xy + 5.0 * loss_z  # Match quick_experiments default
            
            # Auxiliary depth loss
            if use_auxiliary_depth and 'depth_map' in pred and 'depth_label' in batch:
                depth_label = batch['depth_label'].to(device)
                loss_depth_map = auxiliary_depth_loss(pred['depth_map'], depth_label)
                loss = loss + w_depth_map['C'] * loss_depth_map
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        val_metrics = {'r2_z': 0, 'e3d': 0, 'e2d': 0, 'ang_deg': 0}
        
        with torch.no_grad():
            for batch in val_loader:
                left_img = batch['left_img'].to(device)
                right_img = batch['right_img'].to(device)
                
                target = {
                    'origin': batch['origin'].to(device),
                    'direction': batch['direction'].to(device),
                    'intersection': batch['intersection'].to(device)
                }
                
                pred = model.forward_e2e(left_img, right_img)
                metrics = compute_metrics(pred, target)
                
                for k in val_metrics:
                    val_metrics[k] += metrics[k]
        
        for k in val_metrics:
            val_metrics[k] /= len(val_loader)
        
        history['C'].append({'epoch': epoch+1, 'train_loss': train_loss, **val_metrics})
        
        print(f"Epoch {epoch+1}/{stage_C_epochs} | Loss: {train_loss:.4f} | R¬≤_z: {val_metrics['r2_z']:.4f} | E3D: {val_metrics['e3d']:.2f}mm")
        
        # Save checkpoint
        is_best = val_metrics['r2_z'] > best_val_r2z
        if is_best:
            best_val_r2z = val_metrics['r2_z']
        checkpoint_mgr.save(model, optimizer, epoch, 'C', val_metrics, is_best)
        
        scheduler.step()
        
        # Early stopping
        if early_stop(val_metrics['r2_z']):
            print(f"   ‚è∏Ô∏è Early stopping triggered at epoch {epoch+1}")
            break
    
    # Load best Stage C model
    final_metrics = checkpoint_mgr.load_best(model, optimizer, 'C')
    
    print(f"\n‚úÖ Stage C complete. Best R¬≤_z: {best_val_r2z:.4f}")
    print("\n" + "="*80)
    print("TRAINING COMPLETE!")
    print("="*80)
    print(f"Final Results:")
    print(f"  R¬≤_z: {final_metrics['r2_z']:.4f}")
    print(f"  3D Error: {final_metrics['e3d']:.2f} mm")
    print(f"  2D Error: {final_metrics['e2d']:.2f} px")
    print(f"  Angular Error: {final_metrics['ang_deg']:.2f}¬∞")
    
    return {
        'model_name': model_name,
        'final_metrics': final_metrics,
        'history': history
    }

print('‚úÖ Full training function defined')

### Resume/Extend Training

Functions to continue training if you need more epochs

In [None]:
def continue_training(
    model,
    train_loader,
    val_loader,
    model_name,
    stage,
    additional_epochs=10,
    use_focal=False,
    use_auxiliary_depth=False,
    w_depth_map=None,
    early_stopping_patience=5
):
    """
    Continue training from a saved checkpoint.
    
    Args:
        model: The model (must match saved checkpoint)
        train_loader: Training data loader
        val_loader: Validation data loader
        model_name: Name used when saving checkpoints
        stage: Which stage to continue ('A', 'B', or 'C')
        additional_epochs: How many more epochs to train
        use_focal: Whether to use Focal Loss
        use_auxiliary_depth: Whether model has auxiliary depth head
        w_depth_map: Dict with depth map weights per stage
        early_stopping_patience: Patience for early stopping
    
    Returns:
        dict: Updated metrics and training history
    """
    
    model = model.to(device)
    checkpoint_mgr = CheckpointManager(CHECKPOINT_DIR, model_name)
    
    # Load best checkpoint from the stage
    checkpoint_path = checkpoint_mgr.save_dir / f'stage_{stage}_best.pt'
    
    if not checkpoint_path.exists():
        print(f"‚ùå No checkpoint found for stage {stage}")
        print(f"   Looking for: {checkpoint_path}")
        return None
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint['epoch']
    previous_best = checkpoint['metrics']['r2_z']
    
    print(f"\nüìÇ Loaded checkpoint from stage {stage}")
    print(f"   Starting from epoch: {start_epoch}")
    print(f"   Previous best R¬≤_z: {previous_best:.4f}")
    print(f"   Training for {additional_epochs} more epochs...\n")
    
    # Setup loss functions
    if use_focal:
        depth_loss_fn = FocalLoss(alpha=0.25, gamma=1.5)
    else:
        depth_loss_fn = nn.MSELoss()
    
    if w_depth_map is None:
        w_depth_map = {'A': 0.0, 'B': 0.5, 'C': 1.0}
    
    # Configure training based on stage
    if stage == 'A':
        # Freeze intersection head
        for param in model.offset_depth_head.parameters():
            param.requires_grad = False
        lr = 1e-4
        print("Stage A: Training axis_head only")
        
    elif stage == 'B':
        # Freeze backbone and axis head
        for param in model.backbone.parameters():
            param.requires_grad = False
        for param in model.axis_head.parameters():
            param.requires_grad = False
        for param in model.offset_depth_head.parameters():
            param.requires_grad = True
        lr = 5e-5
        print("Stage B: Training offset_depth_head only")
        
    elif stage == 'C':
        # Unfreeze all
        for param in model.parameters():
            param.requires_grad = True
        lr = 1e-5
        print("Stage C: Training all parameters (end-to-end)")
    else:
        raise ValueError(f"Invalid stage: {stage}. Must be 'A', 'B', or 'C'")
    
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    
    # Load optimizer state if available
    if 'optimizer_state_dict' in checkpoint:
        try:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print("‚úÖ Loaded optimizer state")
        except:
            print("‚ö†Ô∏è  Could not load optimizer state, using fresh optimizer")
    
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=additional_epochs)
    early_stop = EarlyStopping(patience=early_stopping_patience)
    
    best_val_r2z = previous_best
    history = []
    
    print("\n" + "="*80)
    print(f"CONTINUING STAGE {stage} TRAINING")
    print("="*80)
    
    for epoch in range(additional_epochs):
        actual_epoch = start_epoch + epoch + 1
        
        # Training
        model.train()
        train_loss = 0.0
        
        for batch in tqdm(train_loader, desc=f"Epoch {actual_epoch} ({epoch+1}/{additional_epochs})"):
            left_img = batch['left_img'].to(device)
            right_img = batch['right_img'].to(device)
            
            optimizer.zero_grad()
            
            # Forward pass based on stage
            if stage == 'C':
                pred = model.forward_e2e(left_img, right_img)
            else:
                pred = model(left_img, right_img)
            
            # Compute loss based on stage
            if stage == 'A':
                origin = batch['origin'].to(device)
                direction = batch['direction'].to(device)
                loss_origin = F.mse_loss(pred['origin'], origin)
                loss_dir = F.mse_loss(pred['direction'], direction)
                loss = loss_origin + loss_dir
                
            elif stage == 'B':
                intersection = batch['intersection'].to(device)
                loss_xy = F.mse_loss(pred['intersection'][:, :2], intersection[:, :2])
                loss_z = depth_loss_fn(pred['intersection'][:, 2:3], intersection[:, 2:3])
                loss = loss_xy + 5.0 * loss_z  # Match quick_experiments
                
                if use_auxiliary_depth and 'depth_map' in pred and 'depth_label' in batch:
                    depth_label = batch['depth_label'].to(device)
                    loss_depth_map = auxiliary_depth_loss(pred['depth_map'], depth_label)
                    loss = loss + w_depth_map['B'] * loss_depth_map
                    
            else:  # stage == 'C'
                origin = batch['origin'].to(device)
                direction = batch['direction'].to(device)
                intersection = batch['intersection'].to(device)
                
                loss_origin = F.mse_loss(pred['origin'], origin)
                loss_dir = F.mse_loss(pred['direction'], direction)
                loss_xy = F.mse_loss(pred['intersection'][:, :2], intersection[:, :2])
                loss_z = depth_loss_fn(pred['intersection'][:, 2:3], intersection[:, 2:3])
                loss = loss_origin + loss_dir + loss_xy + 5.0 * loss_z  # Match quick_experiments
                
                if use_auxiliary_depth and 'depth_map' in pred and 'depth_label' in batch:
                    depth_label = batch['depth_label'].to(device)
                    loss_depth_map = auxiliary_depth_loss(pred['depth_map'], depth_label)
                    loss = loss + w_depth_map['C'] * loss_depth_map
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        val_metrics = {'r2_z': 0, 'e3d': 0, 'e2d': 0, 'ang_deg': 0}
        
        with torch.no_grad():
            for batch in val_loader:
                left_img = batch['left_img'].to(device)
                right_img = batch['right_img'].to(device)
                
                target = {
                    'origin': batch['origin'].to(device),
                    'direction': batch['direction'].to(device),
                    'intersection': batch['intersection'].to(device)
                }
                
                if stage == 'C':
                    pred = model.forward_e2e(left_img, right_img)
                else:
                    pred = model(left_img, right_img)
                    
                metrics = compute_metrics(pred, target)
                
                for k in val_metrics:
                    val_metrics[k] += metrics[k]
        
        for k in val_metrics:
            val_metrics[k] /= len(val_loader)
        
        history.append({'epoch': actual_epoch, 'train_loss': train_loss, **val_metrics})
        
        print(f"Epoch {actual_epoch} | Loss: {train_loss:.4f} | R¬≤_z: {val_metrics['r2_z']:.4f} | E3D: {val_metrics['e3d']:.2f}mm")
        
        # Save checkpoint
        is_best = val_metrics['r2_z'] > best_val_r2z
        if is_best:
            best_val_r2z = val_metrics['r2_z']
            print(f"   üéØ New best R¬≤_z: {best_val_r2z:.4f} (improvement: +{best_val_r2z - previous_best:.4f})")
        checkpoint_mgr.save(model, optimizer, actual_epoch, stage, val_metrics, is_best)
        
        scheduler.step()
        
        # Early stopping
        if early_stop(val_metrics['r2_z']):
            print(f"   ‚è∏Ô∏è Early stopping triggered at epoch {actual_epoch}")
            break
    
    # Load best model from extended training
    final_metrics = checkpoint_mgr.load_best(model, optimizer, stage)
    
    print("\n" + "="*80)
    print(f"EXTENDED TRAINING COMPLETE FOR STAGE {stage}")
    print("="*80)
    print(f"Previous best R¬≤_z: {previous_best:.4f}")
    print(f"New best R¬≤_z: {best_val_r2z:.4f}")
    print(f"Improvement: {best_val_r2z - previous_best:+.4f}")
    print("="*80)
    
    return {
        'model_name': model_name,
        'stage': stage,
        'previous_best': previous_best,
        'new_best': best_val_r2z,
        'improvement': best_val_r2z - previous_best,
        'final_metrics': final_metrics,
        'history': history
    }

print('‚úÖ Continue training function defined')

## Run Full Training

Execute full training runs for each best model

### Train Model 1: Skip Connections + Focal Loss + Augmentation

In [None]:
# Initialize model
model1 = SkipConnectionNet(backbone_name="resnet18", pretrained=True)

# Run full training
result1 = train_full_model(
    model=model1,
    train_loader=train_loader,
    val_loader=val_loader,
    model_name="skip_connections_focal_augmentation_FULL",
    stage_A_epochs=20,
    stage_B_epochs=30,
    stage_C_epochs=40,
    use_focal=True,
    use_auxiliary_depth=False,
    early_stopping_patience=5
)

print("\nüéØ Model 1 Training Complete!")
print(f"Quick Experiment R¬≤_z: 0.5703")
print(f"Full Training R¬≤_z: {result1['final_metrics']['r2_z']:.4f}")
print(f"Improvement: {result1['final_metrics']['r2_z'] - 0.5703:.4f}")

### Train Model 2: Skip Connections + Spatial Attention

In [None]:
# Initialize model
model2 = SkipConnectionSpatialAttentionNet(backbone_name="resnet18", pretrained=True)

# Run full training
result2 = train_full_model(
    model=model2,
    train_loader=train_loader,
    val_loader=val_loader,
    model_name="skip_connections_spatial_attention_FULL",
    stage_A_epochs=20,
    stage_B_epochs=30,
    stage_C_epochs=40,
    use_focal=False,
    use_auxiliary_depth=False,
    early_stopping_patience=5
)

print("\nüéØ Model 2 Training Complete!")
print(f"Quick Experiment R¬≤_z: 0.5540")
print(f"Full Training R¬≤_z: {result2['final_metrics']['r2_z']:.4f}")
print(f"Improvement: {result2['final_metrics']['r2_z'] - 0.5540:.4f}")

### Train Model 3: Skip Connections + Auxiliary Depth

In [None]:
# Initialize model
model3 = SkipConnectionAuxiliaryDepthNet(backbone_name="resnet18", pretrained=True)

# Run full training
result3 = train_full_model(
    model=model3,
    train_loader=train_loader,
    val_loader=val_loader,
    model_name="skip_connections_auxiliary_depth_FULL",
    stage_A_epochs=20,
    stage_B_epochs=30,
    stage_C_epochs=40,
    use_focal=False,
    use_auxiliary_depth=True,
    w_depth_map={'A': 0.0, 'B': 0.5, 'C': 1.0},
    early_stopping_patience=5
)

print("\nüéØ Model 3 Training Complete!")
print(f"Quick Experiment R¬≤_z: 0.5362")
print(f"Full Training R¬≤_z: {result3['final_metrics']['r2_z']:.4f}")
print(f"Improvement: {result3['final_metrics']['r2_z'] - 0.5362:.4f}")

## Results Comparison and Visualization

In [None]:
# Collect all results
all_results = [
    {'name': 'Model 1: Skip+Focal+Aug', 'quick': 0.5703, 'full': result1['final_metrics']['r2_z']},
    {'name': 'Model 2: Skip+Spatial', 'quick': 0.5540, 'full': result2['final_metrics']['r2_z']},
    {'name': 'Model 3: Skip+AuxDepth', 'quick': 0.5362, 'full': result3['final_metrics']['r2_z']}
]

# Print comparison table
print("\n" + "="*80)
print("FULL TRAINING RESULTS COMPARISON")
print("="*80)
print(f"{'Model':<30} {'Quick Exp':<12} {'Full Train':<12} {'Improvement':<12} {'Status'}")
print("-"*80)

for r in all_results:
    improvement = r['full'] - r['quick']
    status = 'üü¢ WINNER' if r['full'] > 0.60 else 'üü° DECENT' if r['full'] > 0.55 else 'üî¥ POOR'
    print(f"{r['name']:<30} {r['quick']:.4f}       {r['full']:.4f}       {improvement:+.4f}      {status}")

# Find best model
best = max(all_results, key=lambda x: x['full'])
print("\n" + "="*80)
print(f"üèÜ BEST MODEL: {best['name']}")
print(f"   R¬≤_z: {best['full']:.4f}")
print(f"   Improvement over quick experiment: {best['full'] - best['quick']:.4f}")
print("="*80)

In [None]:
# Plot training curves
fig, axes = plt.subplots(3, 3, figsize=(18, 12))
fig.suptitle('Full Training Results: All Models', fontsize=16, fontweight='bold')

results = [result1, result2, result3]
titles = ['Model 1: Skip+Focal+Aug', 'Model 2: Skip+Spatial', 'Model 3: Skip+AuxDepth']

for row, (result, title) in enumerate(zip(results, titles)):
    # Combine all stages
    all_epochs = []
    all_r2z = []
    all_e3d = []
    all_loss = []
    
    offset = 0
    for stage in ['A', 'B', 'C']:
        for entry in result['history'][stage]:
            all_epochs.append(entry['epoch'] + offset)
            all_r2z.append(entry['r2_z'])
            all_e3d.append(entry['e3d'])
            all_loss.append(entry['train_loss'])
        offset += len(result['history'][stage])
    
    # Plot R¬≤_z
    axes[row, 0].plot(all_epochs, all_r2z, 'b-', linewidth=2)
    axes[row, 0].set_xlabel('Epoch')
    axes[row, 0].set_ylabel('R¬≤_z')
    axes[row, 0].set_title(f'{title} - R¬≤_z')
    axes[row, 0].grid(True, alpha=0.3)
    axes[row, 0].axhline(y=0.60, color='g', linestyle='--', alpha=0.5, label='Target (0.60)')
    axes[row, 0].legend()
    
    # Plot 3D Error
    axes[row, 1].plot(all_epochs, all_e3d, 'r-', linewidth=2)
    axes[row, 1].set_xlabel('Epoch')
    axes[row, 1].set_ylabel('3D Error (mm)')
    axes[row, 1].set_title(f'{title} - 3D Error')
    axes[row, 1].grid(True, alpha=0.3)
    
    # Plot Loss
    axes[row, 2].plot(all_epochs, all_loss, 'purple', linewidth=2)
    axes[row, 2].set_xlabel('Epoch')
    axes[row, 2].set_ylabel('Training Loss')
    axes[row, 2].set_title(f'{title} - Loss')
    axes[row, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('full_training_curves.png', dpi=150, bbox_inches='tight')
print('‚úÖ Saved training curves to full_training_curves.png')
plt.show()

In [None]:
# Save results to CSV
import csv
from datetime import datetime

results_file = 'full_training_results.csv'

# Check if file exists
file_exists = Path(results_file).exists()

with open(results_file, 'a', newline='') as f:
    fieldnames = ['model_name', 'r2_z', 'e3d_mm', 'e2d_px', 'ang_deg', 'total_epochs', 'timestamp']
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    
    if not file_exists:
        writer.writeheader()
    
    # Write all results
    for result in [result1, result2, result3]:
        total_epochs = sum(len(result['history'][stage]) for stage in ['A', 'B', 'C'])
        writer.writerow({
            'model_name': result['model_name'],
            'r2_z': f"{result['final_metrics']['r2_z']:.6f}",
            'e3d_mm': f"{result['final_metrics']['e3d']:.6f}",
            'e2d_px': f"{result['final_metrics']['e2d']:.6f}",
            'ang_deg': f"{result['final_metrics']['ang_deg']:.6f}",
            'total_epochs': total_epochs,
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        })

print(f'‚úÖ Results saved to {results_file}')

## Extend Training (Add More Epochs)

Use these cells if you need to add more epochs to any stage

### Example 1: Add 20 more epochs to Stage C

Most common use case - extend the final fine-tuning stage

In [None]:
# Example: Continue training Model 1 Stage C for 20 more epochs
# (Only run this if you've already trained model1)

extended_result = continue_training(
    model=model1,  # The model you want to continue training
    train_loader=train_loader,
    val_loader=val_loader,
    model_name="skip_connections_focal_augmentation_FULL",  # Must match original name
    stage='C',  # Which stage to continue
    additional_epochs=20,  # How many more epochs
    use_focal=True,  # Same settings as original training
    use_auxiliary_depth=False,
    early_stopping_patience=5
)

if extended_result:
    print(f"\n‚úÖ Extended training complete!")
    print(f"   Improvement: {extended_result['improvement']:+.4f}")
    print(f"   New R¬≤_z: {extended_result['new_best']:.4f}")

### Example 2: Add epochs to any stage

You can extend Stage A, B, or C independently

In [None]:
# Example: Extend Stage B (intersection training) for Model 2

# extended_result_B = continue_training(
#     model=model2,
#     train_loader=train_loader,
#     val_loader=val_loader,
#     model_name="skip_connections_spatial_attention_FULL",
#     stage='B',  # Continue Stage B
#     additional_epochs=15,
#     use_focal=False,
#     use_auxiliary_depth=False,
#     early_stopping_patience=5
# )

# For Model 3 (with auxiliary depth):
# extended_result_C = continue_training(
#     model=model3,
#     train_loader=train_loader,
#     val_loader=val_loader,
#     model_name="skip_connections_auxiliary_depth_FULL",
#     stage='C',
#     additional_epochs=20,
#     use_focal=False,
#     use_auxiliary_depth=True,
#     w_depth_map={'A': 0.0, 'B': 0.5, 'C': 1.0},
#     early_stopping_patience=5
# )

print("Uncomment the code above to extend training for any model")

### Example 3: Train new model with custom epochs

Start fresh with different epoch counts

In [None]:
# Example: Train a model with more epochs from the start

# model_custom = SkipConnectionNet(backbone_name="resnet18", pretrained=True)

# result_custom = train_full_model(
#     model=model_custom,
#     train_loader=train_loader,
#     val_loader=val_loader,
#     model_name="skip_connections_EXTENDED",
#     stage_A_epochs=30,  # More epochs!
#     stage_B_epochs=50,
#     stage_C_epochs=60,
#     use_focal=True,
#     use_auxiliary_depth=False,
#     early_stopping_patience=7  # Higher patience for longer training
# )

print("Uncomment the code above to train with custom epoch counts")

### Tips for Adding More Epochs

**When to add more epochs:**
- ‚úÖ Training curves still improving (not plateaued)
- ‚úÖ Validation R¬≤_z increasing steadily
- ‚úÖ No signs of overfitting (train/val gap small)

**When NOT to add more epochs:**
- ‚ùå Validation R¬≤_z plateaued for 5+ epochs (early stopping will handle this)
- ‚ùå Large train/val gap (overfitting)
- ‚ùå Training loss still decreasing but val loss increasing

**How many epochs to add:**
- **Small improvement needed (+0.01)**: Add 10-15 epochs
- **Moderate improvement (+0.02-0.03)**: Add 20-30 epochs
- **Still learning**: Add 40-50 epochs

**Important Notes:**
1. The `continue_training()` function automatically:
   - Loads the best checkpoint from the specified stage
   - Uses the same training configuration
   - Saves new checkpoints if performance improves

2. You can continue training multiple times:
   ```python
   # First extension: +20 epochs
   continue_training(model, ..., additional_epochs=20)
   
   # Second extension: +10 more epochs
   continue_training(model, ..., additional_epochs=10)
   ```

3. Model checkpoints are preserved:
   - Old best: Kept as backup
   - New best: Saved if performance improves
   - You can always load previous checkpoints manually


## Next Steps

Based on your full training results:

### If best model achieves R¬≤_z > 0.60:
1. **Deploy the model** - You have a production-ready solution
2. **Test on held-out test set** - Validate generalization
3. **Create inference pipeline** - Package for production use

### If best model is 0.55 < R¬≤_z < 0.60:
1. **Build ensemble** - Combine top 2-3 models for 0.02-0.04 boost
2. **Try ResNet34/50** - Stronger backbone may help
3. **Hyperparameter tuning** - Fine-tune learning rates, weights

### If best model < 0.55:
1. **Investigate training curves** - Look for overfitting/underfitting
2. **Check data quality** - Verify labels and augmentation
3. **Try unexplored combinations** - E.g., skip+spatial+focal+aug

### Model Checkpoints
All trained models are saved in: `checkpoints/`
- `stage_A_best.pt` - Best axis predictor
- `stage_B_best.pt` - Best intersection predictor  
- `stage_C_best.pt` - Best end-to-end model (use this for deployment)
