# MCResNet Training on NYU Depth V2 - Google Colab

**Complete end-to-end training pipeline for Google Colab with A100 GPU**

---

## 📋 Checklist Before Running:

- [ ] **Enable A100 GPU:** Runtime → Change runtime type → Hardware accelerator: GPU → GPU type: A100
- [ ] **Mount Google Drive:** Your code and dataset will be stored on Drive
- [ ] **Upload dataset to Drive:** `MyDrive/datasets/nyu_depth_v2_labeled.mat` (or we'll download it)
- [ ] **Expected Runtime:** ~2-3 hours for 90 epochs

---

## 🎯 What This Notebook Does:

1. ✅ Verify A100 GPU is available
2. ✅ Mount Google Drive
3. ✅ Clone your repository to local disk (fast I/O)
4. ✅ Download/copy NYU Depth V2 to local disk (10-20x faster than Drive)
5. ✅ Install dependencies
6. ✅ Train MCResNet with all optimizations
7. ✅ Save checkpoints to Drive (persistent storage)
8. ✅ Generate training curves and analysis

---

**Let's get started!** 🚀

## 1. Environment Setup & GPU Verification

In [None]:
# Check GPU availability and specs
import torch
import subprocess

print("=" * 60)
print("GPU VERIFICATION")
print("=" * 60)

# Check PyTorch and CUDA
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    
    # Check if it's A100
    gpu_name = torch.cuda.get_device_name(0)
    if 'A100' in gpu_name:
        print("\n✅ A100 GPU detected - PERFECT for training!")
    elif 'V100' in gpu_name:
        print("\n✅ V100 GPU detected - Good for training (slower than A100)")
    elif 'T4' in gpu_name:
        print("\n⚠️  T4 GPU detected - Will be slower, consider upgrading to A100")
    else:
        print(f"\n⚠️  GPU: {gpu_name} - Consider using A100 for best performance")
else:
    print("\n❌ NO GPU DETECTED!")
    print("Please enable GPU: Runtime → Change runtime type → Hardware accelerator: GPU")
    raise RuntimeError("GPU is required for training")

print("\n" + "=" * 60)

In [None]:
# Detailed GPU info
!nvidia-smi

## 2. Mount Google Drive

In [None]:
from google.colab import drive
import os
from pathlib import Path

# Mount Google Drive
drive.mount('/content/drive')

print("\n✅ Google Drive mounted successfully!")
print(f"\nDrive contents:")
!ls -la /content/drive/MyDrive/ | head -20

## 3. Clone Repository to Local Disk (Fast I/O)

**Important:** We clone to `/content/` (local SSD) instead of Drive for 10-20x faster I/O

**Default:** Clone from GitHub (recommended - always gets latest code)

In [None]:
import os
from pathlib import Path

# Configuration
PROJECT_NAME = "Multi-Stream-Neural-Networks"
GITHUB_REPO = "https://github.com/clingergab/Multi-Stream-Neural-Networks.git"  # UPDATE THIS
LOCAL_REPO_PATH = f"/content/{PROJECT_NAME}"  # Local copy for fast I/O

print("=" * 60)
print("REPOSITORY SETUP")
print("=" * 60)

# Ensure we're in a valid directory
os.chdir('/content')
print(f"Starting in: {os.getcwd()}")

# Check if repo already exists (same session, rerunning cell)
if Path(LOCAL_REPO_PATH).exists() and Path(f"{LOCAL_REPO_PATH}/.git").exists():
    print(f"\n📁 Repo already exists: {LOCAL_REPO_PATH}")
    print(f"🔄 Pulling latest changes...")
    
    os.chdir(LOCAL_REPO_PATH)
    !git pull
    print("✅ Repo updated")

# Clone from GitHub (first run)
else:
    # Remove old incomplete copy if exists
    if Path(LOCAL_REPO_PATH).exists():
        print(f"\n🗑️  Removing incomplete repo copy...")
        !rm -rf {LOCAL_REPO_PATH}
    
    print(f"\n🔄 Cloning from GitHub...")
    print(f"   Repo: {GITHUB_REPO}")
    print(f"   Destination: {LOCAL_REPO_PATH}")
    
    !git clone {GITHUB_REPO} {LOCAL_REPO_PATH}
    
    # Verify clone succeeded
    if not Path(LOCAL_REPO_PATH).exists():
        raise RuntimeError(f"Failed to clone repository to {LOCAL_REPO_PATH}")
    
    print("✅ Repo cloned successfully")
    os.chdir(LOCAL_REPO_PATH)

# Verify repo structure
print(f"\n📂 Repository structure:")
!ls -la {LOCAL_REPO_PATH}

print(f"\n✅ Working directory: {os.getcwd()}")

## 4. Install Dependencies

In [None]:
# Install required packages
print("Installing dependencies...")

!pip install -q h5py tqdm matplotlib seaborn

# Verify installations
import h5py
import tqdm
import matplotlib
import seaborn

print("✅ All dependencies installed!")
print(f"   h5py: {h5py.__version__}")
print(f"   matplotlib: {matplotlib.__version__}")

## 5. Download NYU Depth V2 Dataset to Local Disk

**Performance Note:** Local disk I/O is ~10-20x faster than Drive!

**Options:**
- **Option A:** Download directly to local disk (~2.8 GB, takes 2-3 min)
- **Option B:** Copy from Drive if you already have it there

In [None]:
import urllib.request
from pathlib import Path
from tqdm import tqdm

# Paths
DRIVE_DATASET_PATH = "/content/drive/MyDrive/datasets/nyu_depth_v2_labeled.mat"
LOCAL_DATASET_PATH = "/content/nyu_depth_v2_labeled.mat"  # Local disk (FAST)
DATASET_URL = "http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat"

print("=" * 60)
print("NYU DEPTH V2 DATASET SETUP")
print("=" * 60)

# Check if already on local disk
if Path(LOCAL_DATASET_PATH).exists():
    print(f"✅ Dataset already on local disk: {LOCAL_DATASET_PATH}")
    print(f"   Size: {Path(LOCAL_DATASET_PATH).stat().st_size / 1024**3:.2f} GB")

# Option A: Copy from Drive (if available)
elif Path(DRIVE_DATASET_PATH).exists():
    print(f"📁 Found dataset on Drive: {DRIVE_DATASET_PATH}")
    print(f"📥 Copying to local disk for 10-20x faster training...")
    print(f"   This takes ~2-3 minutes but saves 60+ minutes during training!")
    
    !cp {DRIVE_DATASET_PATH} {LOCAL_DATASET_PATH}
    
    print(f"✅ Dataset copied to local disk")
    print(f"   Size: {Path(LOCAL_DATASET_PATH).stat().st_size / 1024**3:.2f} GB")

# Option B: Download from internet
else:
    print(f"📥 Downloading NYU Depth V2 dataset (~2.8 GB)...")
    print(f"   URL: {DATASET_URL}")
    print(f"   Destination: {LOCAL_DATASET_PATH}")
    print(f"   This will take ~2-3 minutes")
    
    # Download with progress bar
    class DownloadProgressBar(tqdm):
        def update_to(self, b=1, bsize=1, tsize=None):
            if tsize is not None:
                self.total = tsize
            self.update(b * bsize - self.n)
    
    with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc='NYU Depth V2') as t:
        urllib.request.urlretrieve(DATASET_URL, LOCAL_DATASET_PATH, reporthook=t.update_to)
    
    print(f"\n✅ Download complete!")
    print(f"   Size: {Path(LOCAL_DATASET_PATH).stat().st_size / 1024**3:.2f} GB")
    
    # Optionally save to Drive for future use
    save_to_drive = input("\nSave dataset to Drive for future sessions? (y/N): ")
    if save_to_drive.lower() == 'y':
        print("Saving to Drive...")
        !mkdir -p /content/drive/MyDrive/datasets/
        !cp {LOCAL_DATASET_PATH} {DRIVE_DATASET_PATH}
        print("✅ Saved to Drive")

print("\n" + "=" * 60)
print(f"Dataset ready at: {LOCAL_DATASET_PATH}")
print("=" * 60)

## 6. Setup Python Path & Import MCResNet

In [None]:
import sys
import os

# Remove cached modules
modules_to_reload = [k for k in sys.modules.keys() if k.startswith('src.')]
for module in modules_to_reload:
    del sys.modules[module]
    
# Add project to Python path
project_root = '/content/Multi-Stream-Neural-Networks'
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Verify project structure
print("Project structure:")
!ls -la {project_root}/src/models/

# Import MCResNet
print("\nImporting MCResNet...")
from src.models.multi_channel.mc_resnet import mc_resnet18, mc_resnet50
from src.data_utils.nyu_depth_dataset import create_nyu_dataloaders

print("✅ MCResNet imported successfully!")

## 7. Load NYU Depth V2 Dataset

In [None]:
# DEBUG: Check actual HDF5 structure
import h5py

print("=" * 60)
print("DEBUG: NYU DEPTH V2 DATASET STRUCTURE")
print("=" * 60)

with h5py.File('/content/nyu_depth_v2_labeled.mat', 'r') as f:
    print("\nAvailable keys:")
    for key in f.keys():
        print(f"  {key}")
    
    print("\nDataset shapes:")
    print(f"  images: {f['images'].shape}")
    print(f"  depths: {f['depths'].shape}")
    print(f"  labels: {f['labels'].shape}")
    
    # Check if scenes exists
    if 'scenes' in f:
        print(f"  scenes: {f['scenes'].shape}")
    
    # Sample first image to check format
    print("\nSample data inspection:")
    print(f"  images[0:3, 0:10, 0:10, 0] shape: {f['images'][0:3, 0:10, 0:10, 0].shape}")
    print(f"  images dtype: {f['images'].dtype}")
    print(f"  depths dtype: {f['depths'].dtype}")

print("\n" + "=" * 60)

In [None]:
print("=" * 60)
print("LOADING NYU DEPTH V2 DATASET")
print("=" * 60)

# Dataset configuration
DATASET_CONFIG = {
    'dataset_path': '/content/nyu_depth_v2_labeled.mat',
    'batch_size': 128,  # A100 can handle this with AMP
    'num_workers': 2,
    'target_size': (224, 224),
    'num_classes': 27   # NYU Depth V2 has 27 scene types (labels 0-26)
}

print(f"Configuration:")
for key, value in DATASET_CONFIG.items():
    print(f"  {key}: {value}")

print(f"\nLoading dataset from: {DATASET_CONFIG['dataset_path']}")

# Create dataloaders
train_loader, val_loader = create_nyu_dataloaders(
    h5_file_path=DATASET_CONFIG['dataset_path'],
    batch_size=DATASET_CONFIG['batch_size'],
    num_workers=DATASET_CONFIG['num_workers'],
    target_size=DATASET_CONFIG['target_size'],
    num_classes=DATASET_CONFIG['num_classes']
)

print(f"\n✅ Dataset loaded successfully!")
print(f"\nDataset Statistics:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Train samples: {len(train_loader.dataset)}")
print(f"  Val samples: {len(val_loader.dataset)}")
print(f"  Batch size: {DATASET_CONFIG['batch_size']}")

# Test loading a batch
print(f"\nTesting batch loading...")
rgb_batch, depth_batch, label_batch = next(iter(train_loader))
print(f"  RGB shape: {rgb_batch.shape}")
print(f"  Depth shape: {depth_batch.shape}")
print(f"  Labels shape: {label_batch.shape}")
print(f"  Labels min: {label_batch.min().item()}, max: {label_batch.max().item()}")

print("\n" + "=" * 60)

## 8. Visualize Sample Data

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Visualize some samples
fig, axes = plt.subplots(2, 4, figsize=(14, 7))

for i in range(4):
    rgb = rgb_batch[i].cpu()
    depth = depth_batch[i].cpu()
    label = label_batch[i].item()
    
    # Denormalize for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    rgb_vis = rgb * std + mean
    rgb_vis = torch.clamp(rgb_vis, 0, 1)
    
    # Plot RGB
    axes[0, i].imshow(rgb_vis.permute(1, 2, 0))
    axes[0, i].set_title(f"RGB - Class {label}", fontsize=10)
    axes[0, i].axis('off')
    
    # Plot Depth
    axes[1, i].imshow(depth.squeeze(), cmap='viridis')
    axes[1, i].set_title(f"Depth - Class {label}", fontsize=10)
    axes[1, i].axis('off')

plt.suptitle('NYU Depth V2 Sample Data (RGB + Depth)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("✅ Sample visualization complete!")

## 9. Create & Compile MCResNet Model

In [None]:
print("=" * 60)
print("MODEL CREATION")
print("=" * 60)

# Model configuration with stream-specific optimization
MODEL_CONFIG = {
    'architecture': 'resnet18',  # or 'resnet50' for better accuracy
    'num_classes': 27,  # NYU Depth V2 has 27 scene types (labels 0-26)
    'stream1_channels': 3,  # RGB
    'stream2_channels': 1,  # Depth
    'fusion_type': 'weighted',  # 'concat', 'weighted', or 'gated'
    'dropout_p': 0.3,  # Dropout for regularization
    'device': 'cuda',
    'use_amp': True  # Automatic Mixed Precision (2x faster on A100)
}

print(f"Configuration:")
for key, value in MODEL_CONFIG.items():
    print(f"  {key}: {value}")

# Create model
print(f"\nCreating MCResNet-{MODEL_CONFIG['architecture'].upper()}...")

if MODEL_CONFIG['architecture'] == 'resnet18':
    model = mc_resnet18(
        num_classes=MODEL_CONFIG['num_classes'],
        stream1_input_channels=MODEL_CONFIG['stream1_channels'],
        stream2_input_channels=MODEL_CONFIG['stream2_channels'],
        fusion_type=MODEL_CONFIG['fusion_type'],
        dropout_p=MODEL_CONFIG['dropout_p'],
        device=MODEL_CONFIG['device'],
        use_amp=MODEL_CONFIG['use_amp']
    )
elif MODEL_CONFIG['architecture'] == 'resnet50':
    model = mc_resnet50(
        num_classes=MODEL_CONFIG['num_classes'],
        stream1_input_channels=MODEL_CONFIG['stream1_channels'],
        stream2_input_channels=MODEL_CONFIG['stream2_channels'],
        fusion_type=MODEL_CONFIG['fusion_type'],
        dropout_p=MODEL_CONFIG['dropout_p'],
        device=MODEL_CONFIG['device'],
        use_amp=MODEL_CONFIG['use_amp']
    )

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
fusion_params = sum(p.numel() for p in model.fusion.parameters())

print(f"\n✅ Model created successfully!")
print(f"\nModel Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Fusion parameters: {fusion_params:,}")
print(f"  Model size: {total_params * 4 / 1024**2:.2f} MB (FP32)")
print(f"  Fusion strategy: {model.fusion_strategy}")
print(f"  Device: {MODEL_CONFIG['device']}")
print(f"  AMP enabled: {MODEL_CONFIG['use_amp']}")

print("\n" + "=" * 60)

In [None]:
# Training configuration
TRAINING_CONFIG = {
    'optimizer': 'sgd',
    'learning_rate': 0.1,  # Standard for ImageNet-style training
    'weight_decay': 1e-4,
    'momentum': 0.9,
    'loss': 'cross_entropy',
    'scheduler': 'cosine'
}

print("Compiling model...")
print(f"\nTraining configuration:")
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")

# Compile
model.compile(
    optimizer=TRAINING_CONFIG['optimizer'],
    learning_rate=TRAINING_CONFIG['learning_rate'],
    weight_decay=TRAINING_CONFIG['weight_decay'],
    momentum=TRAINING_CONFIG['momentum'],
    loss=TRAINING_CONFIG['loss'],
    scheduler=TRAINING_CONFIG['scheduler']
)

print("\n✅ Model compiled successfully!")

In [None]:
# Compile model with stream-specific optimization (optional but recommended)
print("=" * 60)
print("MODEL COMPILATION")
print("=" * 60)

# Option 1: Standard optimization (same LR/WD for all parameters)
# Use this if you don't know which stream needs different treatment yet

STANDARD_CONFIG = {
    'optimizer': 'adamw',
    'learning_rate': 1e-4,
    'weight_decay': 2e-2,
    'loss': 'cross_entropy',
    'scheduler': 'cosine'
}

# Option 2: Stream-specific optimization (RECOMMENDED for addressing pathway imbalance)
# Use this if monitoring shows one stream is overfitting or not learning

STREAM_SPECIFIC_CONFIG = {
    'optimizer': 'adamw',
    'learning_rate': 1e-4,      # Base LR for shared params
    'weight_decay': 2e-2,        # Base weight decay
    # Stream-specific settings (uncomment and adjust based on monitoring)
    # 'stream1_lr': 5e-4,         # 5x higher LR if RGB not learning
    # 'stream1_weight_decay': 1e-3,  # Lighter regularization if RGB underfitting
    # 'stream2_lr': 5e-5,         # Lower LR if Depth overfitting
    # 'stream2_weight_decay': 5e-2,  # Heavier regularization if Depth overfitting
    'loss': 'cross_entropy',
    'scheduler': 'cosine'
}

# Choose which config to use
USE_STREAM_SPECIFIC = False  # Set to True after first training run if needed
config = STREAM_SPECIFIC_CONFIG if USE_STREAM_SPECIFIC else STANDARD_CONFIG

print(f"Optimization strategy: {'Stream-Specific' if USE_STREAM_SPECIFIC else 'Standard'}")
print(f"\nConfiguration:")
for key, value in config.items():
    if value is not None and not key.startswith('_'):
        print(f"  {key}: {value}")

# Compile
model.compile(**config)

print("\n✅ Model compiled successfully!")

# Show parameter groups if using stream-specific optimization
if USE_STREAM_SPECIFIC and hasattr(model.optimizer, 'param_groups'):
    print(f"\nParameter groups created: {len(model.optimizer.param_groups)}")
    for i, group in enumerate(model.optimizer.param_groups):
        print(f"  Group {i}: LR={group['lr']:.2e}, WD={group['weight_decay']:.2e}")

print("\n" + "=" * 60)

In [None]:
# Import stream monitoring utilities
from src.models.utils import StreamMonitor

print("=" * 60)
print("STREAM MONITORING SETUP")
print("=" * 60)

# Create stream monitor
monitor = StreamMonitor(model)

print(f"\n✅ Stream monitor created!")
print(f"\nMonitoring capabilities:")
print(f"  • Gradient tracking per stream")
print(f"  • Overfitting detection per stream")
print(f"  • Weight evolution tracking")
print(f"  • Automatic hyperparameter recommendations")

# Count stream parameters by examining model structure
stream1_count = 0
stream2_count = 0
shared_count = 0

for name, param in model.named_parameters():
    if 'stream1' in name:
        stream1_count += 1
    elif 'stream2' in name:
        stream2_count += 1
    else:
        shared_count += 1

print(f"\nStream parameter counts:")
print(f"  Stream1 (RGB): {stream1_count} parameter tensors")
print(f"  Stream2 (Depth): {stream2_count} parameter tensors")
print(f"  Shared: {shared_count} parameter tensors")

print("\n" + "=" * 60)

## 10a. Setup Stream Monitoring 🔍

**Stream monitoring helps you:**
- Track which stream (RGB vs Depth) is learning better
- Detect which stream is overfitting more
- Get automatic recommendations for hyperparameter adjustments
- Monitor gradient flow and weight evolution per stream

## 10. Test Forward Pass

In [None]:
# Test forward pass with detailed debugging
print("Testing forward pass with CUDA_LAUNCH_BLOCKING for better error messages...")

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # Synchronous CUDA for better error messages

model.eval()
with torch.no_grad():
    rgb_test, depth_test, labels_test = next(iter(train_loader))
    
    print(f"\nInput validation:")
    print(f"  RGB shape: {rgb_test.shape}, dtype: {rgb_test.dtype}")
    print(f"  RGB min: {rgb_test.min():.4f}, max: {rgb_test.max():.4f}")
    print(f"  RGB has NaN: {torch.isnan(rgb_test).any()}")
    print(f"  RGB has Inf: {torch.isinf(rgb_test).any()}")
    
    print(f"\n  Depth shape: {depth_test.shape}, dtype: {depth_test.dtype}")
    print(f"  Depth min: {depth_test.min():.4f}, max: {depth_test.max():.4f}")
    print(f"  Depth has NaN: {torch.isnan(depth_test).any()}")
    print(f"  Depth has Inf: {torch.isinf(depth_test).any()}")
    
    print(f"\n  Labels shape: {labels_test.shape}, dtype: {labels_test.dtype}")
    print(f"  Labels min: {labels_test.min()}, max: {labels_test.max()}")
    print(f"  Labels unique: {torch.unique(labels_test).tolist()}")
    
    print("\nRunning forward pass...")
    rgb_cuda = rgb_test.to('cuda')
    depth_cuda = depth_test.to('cuda')
    
    try:
        outputs = model(rgb_cuda, depth_cuda)
        print(f"  ✅ Forward pass successful!")
        print(f"  Output shape: {outputs.shape}")
        print(f"  Output min: {outputs.min():.4f}, max: {outputs.max():.4f}")
        
        _, predictions = torch.max(outputs, 1)
        print(f"\nSample predictions: {predictions.cpu().numpy()[:10]}")
        print(f"Ground truth: {labels_test.numpy()[:10]}")
        
    except Exception as e:
        print(f"\n❌ Forward pass failed!")
        print(f"Error: {e}")
        print(f"\nThis is likely a model architecture issue, not a data issue.")
        print(f"Possible causes:")
        print(f"  1. BatchNorm running stats issue")
        print(f"  2. Invalid tensor operations in model")
        print(f"  3. Memory corruption")
        raise

print("\n✅ Forward pass test complete!")

## 11. Setup Checkpoint Directory

In [None]:
import torch
from tqdm import tqdm
import numpy as np

print("=" * 60)
print("TRAINING WITH COMPREHENSIVE STREAM MONITORING")
print("=" * 60)

# Training configuration
TRAIN_CONFIG = {
    'epochs': 100,
    'grad_clip_norm': 5.0,
    'early_stopping': True,
    'patience': 15,
    'min_delta': 0.001,
    'monitor': 'val_accuracy',
    'restore_best_weights': True,
    'save_path': f"{checkpoint_dir}/best_model.pt",
    # Monitoring settings
    'monitor_gradients_every': 1,  # Check gradients every N epochs
    'monitor_overfitting_every': 1,  # Check overfitting every N epochs
    'display_recommendations': True  # Show automatic recommendations
}

print(f"Configuration:")
for key, value in TRAIN_CONFIG.items():
    print(f"  {key}: {value}")

# Initialize monitoring storage
monitoring_history = {
    'stream1_grad_norms': [],
    'stream2_grad_norms': [],
    'stream1_overfitting_scores': [],
    'stream2_overfitting_scores': [],
    'stream1_weight_norms': [],
    'stream2_weight_norms': [],
    'recommendations': []
}

print("\n" + "=" * 60)
print(f"Training will take approximately 2-3 hours on A100")
print(f"Stream monitoring active - you'll see detailed metrics each epoch")
print("=" * 60 + "\n")

# Custom training loop with integrated monitoring
best_val_acc = 0.0
patience_counter = 0

for epoch in range(TRAIN_CONFIG['epochs']):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch + 1}/{TRAIN_CONFIG['epochs']}")
    print(f"{'='*60}")
    
    # ===== TRAINING PHASE =====
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    # Track gradients on first batch
    first_batch_gradients = None
    
    for batch_idx, (rgb, depth, labels) in enumerate(tqdm(train_loader, desc="Training")):
        rgb, depth, labels = rgb.cuda(), depth.cuda(), labels.cuda()
        
        # Forward
        model.optimizer.zero_grad()
        outputs = model(rgb, depth)
        loss = model.criterion(outputs, labels)
        
        # Backward
        loss.backward()
        
        # Monitor gradients on first batch
        if batch_idx == 0 and (epoch % TRAIN_CONFIG['monitor_gradients_every'] == 0):
            first_batch_gradients = monitor.compute_stream_gradients()
        
        # Gradient clipping
        if TRAIN_CONFIG['grad_clip_norm'] is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CONFIG['grad_clip_norm'])
        
        model.optimizer.step()
        
        # Track metrics
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        train_total += labels.size(0)
        train_correct += predicted.eq(labels).sum().item()
    
    # Compute training metrics
    train_loss = train_loss / len(train_loader)
    train_acc = train_correct / train_total
    
    # ===== VALIDATION PHASE =====
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for rgb, depth, labels in tqdm(val_loader, desc="Validation"):
            rgb, depth, labels = rgb.cuda(), depth.cuda(), labels.cuda()
            outputs = model(rgb, depth)
            loss = model.criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()
    
    val_loss = val_loss / len(val_loader)
    val_acc = val_correct / val_total
    
    # ===== STREAM MONITORING =====
    print(f"\n📊 Epoch {epoch + 1} Results:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val Loss: {val_loss:.4f}   | Val Acc: {val_acc*100:.2f}%")
    
    # Monitor gradients
    if first_batch_gradients is not None:
        monitoring_history['stream1_grad_norms'].append(first_batch_gradients['stream1_grad_norm'])
        monitoring_history['stream2_grad_norms'].append(first_batch_gradients['stream2_grad_norm'])
        
        print(f"\n🔍 Gradient Monitoring:")
        print(f"  Stream1 (RGB) grad norm: {first_batch_gradients['stream1_grad_norm']:.6f}")
        print(f"  Stream2 (Depth) grad norm: {first_batch_gradients['stream2_grad_norm']:.6f}")
        print(f"  Stream1/Stream2 ratio: {first_batch_gradients['stream1_to_stream2_ratio']:.4f}")
        
        # Check for gradient issues
        if first_batch_gradients['stream1_grad_norm'] < 1e-6:
            print(f"  ⚠️  Stream1 gradients very small - may not be learning!")
        if first_batch_gradients['stream2_grad_norm'] < 1e-6:
            print(f"  ⚠️  Stream2 gradients very small - may not be learning!")
    
    # Monitor overfitting every N epochs
    if epoch % TRAIN_CONFIG['monitor_overfitting_every'] == 0:
        print(f"\n🔍 Overfitting Detection:")
        overfitting_stats = monitor.compute_stream_overfitting_indicators(
            train_loss, val_loss, train_acc, val_acc,
            train_loader, val_loader
        )
        
        monitoring_history['stream1_overfitting_scores'].append(
            overfitting_stats['stream1_overfitting_score']
        )
        monitoring_history['stream2_overfitting_scores'].append(
            overfitting_stats['stream2_overfitting_score']
        )
        
        print(f"  Stream1 (RGB):")
        print(f"    Train acc: {overfitting_stats['stream1_train_acc']*100:.2f}% | Val acc: {overfitting_stats['stream1_val_acc']*100:.2f}%")
        print(f"    Overfitting score: {overfitting_stats['stream1_overfitting_score']:.4f}")
        
        print(f"  Stream2 (Depth):")
        print(f"    Train acc: {overfitting_stats['stream2_train_acc']*100:.2f}% | Val acc: {overfitting_stats['stream2_val_acc']*100:.2f}%")
        print(f"    Overfitting score: {overfitting_stats['stream2_overfitting_score']:.4f}")
        
        # Determine which stream is overfitting more
        if overfitting_stats['stream1_overfitting_score'] > overfitting_stats['stream2_overfitting_score'] * 1.5:
            print(f"  ⚠️  Stream1 (RGB) is overfitting MORE than Stream2")
        elif overfitting_stats['stream2_overfitting_score'] > overfitting_stats['stream1_overfitting_score'] * 1.5:
            print(f"  ⚠️  Stream2 (Depth) is overfitting MORE than Stream1")
        else:
            print(f"  ✅ Streams are relatively balanced")
        
        # Get automatic recommendations
        if TRAIN_CONFIG['display_recommendations']:
            monitor.log_metrics(epoch, {
                **first_batch_gradients if first_batch_gradients else {},
                **overfitting_stats
            })
            recommendations = monitor.get_recommendations()
            
            if recommendations:
                print(f"\n💡 Recommendations:")
                for rec in recommendations:
                    print(f"  • {rec}")
                monitoring_history['recommendations'].append({
                    'epoch': epoch + 1,
                    'recommendations': recommendations
                })
    
    # Monitor weight norms
    weight_stats = monitor.compute_stream_weights()
    monitoring_history['stream1_weight_norms'].append(weight_stats['stream1_weight_norm'])
    monitoring_history['stream2_weight_norms'].append(weight_stats['stream2_weight_norm'])
    
    # ===== EARLY STOPPING & CHECKPOINTING =====
    if val_acc > best_val_acc + TRAIN_CONFIG['min_delta']:
        best_val_acc = val_acc
        patience_counter = 0
        print(f"\n💾 New best validation accuracy: {val_acc*100:.2f}% - Saving checkpoint...")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': model.optimizer.state_dict(),
            'val_acc': val_acc,
            'monitoring_history': monitoring_history
        }, TRAIN_CONFIG['save_path'])
    else:
        patience_counter += 1
        if patience_counter >= TRAIN_CONFIG['patience']:
            print(f"\n⏹️  Early stopping triggered at epoch {epoch + 1}")
            print(f"   Best val accuracy: {best_val_acc*100:.2f}%")
            break
    
    # Update learning rate scheduler
    if hasattr(model, 'scheduler') and model.scheduler is not None:
        model.scheduler.step()

print("\n" + "=" * 60)
print("🎉 TRAINING COMPLETE!")
print("=" * 60)

# Load best weights
print(f"\n📥 Loading best model weights...")
checkpoint = torch.load(TRAIN_CONFIG['save_path'])
model.load_state_dict(checkpoint['model_state_dict'])
print(f"✅ Best model loaded (Epoch {checkpoint['epoch']}, Val Acc: {checkpoint['val_acc']*100:.2f}%)")

# Save monitoring history
import json
monitoring_path = f"{checkpoint_dir}/monitoring_history.json"
with open(monitoring_path, 'w') as f:
    json.dump({
        'stream1_grad_norms': [float(x) for x in monitoring_history['stream1_grad_norms']],
        'stream2_grad_norms': [float(x) for x in monitoring_history['stream2_grad_norms']],
        'stream1_overfitting_scores': [float(x) for x in monitoring_history['stream1_overfitting_scores']],
        'stream2_overfitting_scores': [float(x) for x in monitoring_history['stream2_overfitting_scores']],
        'stream1_weight_norms': [float(x) for x in monitoring_history['stream1_weight_norms']],
        'stream2_weight_norms': [float(x) for x in monitoring_history['stream2_weight_norms']],
        'recommendations': monitoring_history['recommendations']
    }, f, indent=2)
print(f"✅ Monitoring history saved to: {monitoring_path}")

## 12. Pre-Training Diagnostics (Crash Prevention)

**IMPORTANT:** Run this cell BEFORE training to diagnose potential kernel crash issues!

In [None]:
import os
import torch
from tqdm import tqdm

print("=" * 60)
print("PRE-TRAINING DIAGNOSTICS")
print("=" * 60)

# FIX 1: Disable tqdm notebook widgets (prevents kernel crash on Colab)
print("\n1. Disabling tqdm notebook widgets...")
os.environ['TQDM_DISABLE'] = '0'  # Keep tqdm enabled, but force text mode
print("   ✅ tqdm will use text mode (not notebook widgets)")

# TEST 2: CUDA initialization
print("\n2. Testing CUDA initialization...")
try:
    test_cuda = torch.randn(100, 100).cuda()
    print(f"   ✅ CUDA works: {test_cuda.device}")
    print(f"   ✅ CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    del test_cuda
    torch.cuda.empty_cache()
except Exception as e:
    print(f"   ❌ CUDA initialization failed: {e}")
    raise

# TEST 3: DataLoader batch loading and label range
print("\n3. Testing DataLoader and checking label range...")
try:
    rgb, depth, labels = next(iter(train_loader))
    print(f"   ✅ DataLoader works: {rgb.shape}")
    print(f"   ✅ Labels shape: {labels.shape}")
    print(f"   ✅ Labels min: {labels.min().item()}, max: {labels.max().item()}")
    
    # CRITICAL CHECK: Labels must be in [0, num_classes-1]
    if labels.min() < 0 or labels.max() >= 13:
        raise ValueError(f"Labels out of range! Expected [0, 12], got [{labels.min()}, {labels.max()}]")
    print(f"   ✅ Labels are in valid range [0, 12]")
    
except Exception as e:
    print(f"   ❌ DataLoader or label check failed: {e}")
    raise

# TEST 4: Model forward pass
print("\n4. Testing model forward pass...")
try:
    model.eval()
    with torch.no_grad():
        rgb, depth, labels = next(iter(train_loader))
        out = model(rgb.cuda(), depth.cuda())
        print(f"   ✅ Forward pass works: {out.shape}")
        del rgb, depth, labels, out
        torch.cuda.empty_cache()
except Exception as e:
    print(f"   ❌ Forward pass failed: {e}")
    raise

# TEST 5: Backward pass and optimizer (CRITICAL - this was failing before)
print("\n5. Testing backward pass and optimizer...")
try:
    model.train()
    rgb, depth, labels = next(iter(train_loader))
    rgb, depth, labels = rgb.cuda(), depth.cuda(), labels.cuda()
    
    # Forward
    outputs = model(rgb, depth)
    loss = model.criterion(outputs, labels)
    
    # Backward (this will fail if labels are out of range)
    model.optimizer.zero_grad()
    loss.backward()
    
    # Optimizer step
    model.optimizer.step()
    
    print(f"   ✅ Backward pass works: loss={loss.item():.4f}")
    print(f"   ✅ Optimizer step works")
    
    del rgb, depth, labels, outputs, loss
    torch.cuda.empty_cache()
except RuntimeError as e:
    if "device-side assert triggered" in str(e):
        print(f"   ❌ CUDA assertion failed - likely label indexing issue!")
        print(f"   ❌ This means labels are out of valid range [0, num_classes-1]")
        print(f"   ❌ Make sure you've pulled the latest code with the label fix!")
    raise
except Exception as e:
    print(f"   ❌ Backward/optimizer failed: {e}")
    raise

# TEST 6: tqdm progress bar
print("\n6. Testing tqdm progress bar...")
try:
    for i in tqdm(range(10), desc="Test"):
        pass
    print(f"   ✅ tqdm works")
except Exception as e:
    print(f"   ❌ tqdm failed: {e}")
    raise

# TEST 7: DataLoader iteration
print("\n7. Testing DataLoader multi-batch iteration...")
try:
    batch_count = 0
    for rgb, depth, labels in train_loader:
        batch_count += 1
        if batch_count >= 3:  # Test first 3 batches
            break
    print(f"   ✅ DataLoader iteration works ({batch_count} batches tested)")
except Exception as e:
    print(f"   ❌ DataLoader iteration failed: {e}")
    raise

print("\n" + "=" * 60)
print("✅ ALL DIAGNOSTICS PASSED - Ready to train!")
print("=" * 60)
print("\nIf training still crashes after this, try:")
print("  1. Run with verbose=False in model.fit()")
print("  2. Restart runtime and rerun all cells")
print("  3. Check the crash log for new error messages")
print("=" * 60 + "\n")

In [None]:
print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)

# Training configuration
TRAIN_CONFIG = {
    'epochs': 90,  # Standard for ImageNet-style training
    'grad_clip_norm': 5.0,  # Gradient clipping for stability
    'early_stopping': True,
    'patience': 15,
    'min_delta': 0.001,
    'monitor': 'val_accuracy',
    'restore_best_weights': True,
    'save_path': f"{checkpoint_dir}/best_model.pt"
}

print(f"Configuration:")
for key, value in TRAIN_CONFIG.items():
    print(f"  {key}: {value}")

print("\n" + "=" * 60)
print(f"Training will take approximately 2-3 hours on A100")
print(f"Progress will be shown below...")
print("=" * 60 + "\n")

# Train!
history = model.fit(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=TRAIN_CONFIG['epochs'],
    verbose=True,
    save_path=TRAIN_CONFIG['save_path'],
    early_stopping=TRAIN_CONFIG['early_stopping'],
    patience=TRAIN_CONFIG['patience'],
    min_delta=TRAIN_CONFIG['min_delta'],
    monitor=TRAIN_CONFIG['monitor'],
    restore_best_weights=TRAIN_CONFIG['restore_best_weights'],
    grad_clip_norm=TRAIN_CONFIG['grad_clip_norm']
)

print("\n" + "=" * 60)
print("🎉 TRAINING COMPLETE!")
print("=" * 60)

## 13. Train the Model 🚀

**Expected time:** ~2-3 hours for 90 epochs on A100

**All optimizations enabled:**
- ✅ Automatic Mixed Precision (2x faster)
- ✅ Gradient Clipping (stability)
- ✅ Cosine Annealing LR
- ✅ Early Stopping
- ✅ Best Model Checkpointing
- ✅ Local disk I/O (10-20x faster than Drive)

In [None]:
print("=" * 60)
print("MONITORING INTERPRETATION GUIDE")
print("=" * 60)

# Analyze monitoring results and provide actionable recommendations
print("\n📊 Analysis of Stream Behavior:\n")

# 1. Gradient Analysis
if len(monitoring_history['stream1_grad_norms']) > 0:
    s1_grads = monitoring_history['stream1_grad_norms']
    s2_grads = monitoring_history['stream2_grad_norms']
    
    avg_s1_grad = np.mean(s1_grads)
    avg_s2_grad = np.mean(s2_grads)
    grad_ratio = avg_s1_grad / max(avg_s2_grad, 1e-10)
    
    print("1️⃣ GRADIENT FLOW:")
    print(f"   RGB avg gradient: {avg_s1_grad:.6f}")
    print(f"   Depth avg gradient: {avg_s2_grad:.6f}")
    print(f"   RGB/Depth ratio: {grad_ratio:.2f}")
    
    if grad_ratio < 0.3:
        print(f"\n   ⚠️  RGB gradients much smaller - RGB stream learning slowly")
        print(f"   💡 ACTION: Increase stream1_lr (try 5x base LR)")
    elif grad_ratio > 3.0:
        print(f"\n   ⚠️  Depth gradients much smaller - Depth stream learning slowly")
        print(f"   💡 ACTION: Increase stream2_lr (try 5x base LR)")
    else:
        print(f"\n   ✅ Gradients relatively balanced - good!")

# 2. Overfitting Analysis
if len(monitoring_history['stream1_overfitting_scores']) > 0:
    s1_overfit = monitoring_history['stream1_overfitting_scores'][-1]
    s2_overfit = monitoring_history['stream2_overfitting_scores'][-1]
    
    print(f"\n2️⃣ OVERFITTING DETECTION:")
    print(f"   RGB overfitting score: {s1_overfit:.3f}")
    print(f"   Depth overfitting score: {s2_overfit:.3f}")
    
    if s1_overfit > s2_overfit * 1.5:
        print(f"\n   ⚠️  RGB stream overfitting MORE than Depth")
        print(f"   💡 ACTIONS:")
        print(f"      • Increase stream1_weight_decay (try 5e-2)")
        print(f"      • Or decrease stream1_lr (try 0.5x base LR)")
        print(f"      • Or increase dropout_p in model config")
    elif s2_overfit > s1_overfit * 1.5:
        print(f"\n   ⚠️  Depth stream overfitting MORE than RGB")
        print(f"   💡 ACTIONS:")
        print(f"      • Increase stream2_weight_decay (try 5e-2)")
        print(f"      • Or decrease stream2_lr (try 0.5x base LR)")
        print(f"      • Or increase dropout_p in model config")
    else:
        print(f"\n   ✅ Overfitting relatively balanced")

# 3. Weight Magnitude Analysis
if len(monitoring_history['stream1_weight_norms']) > 0:
    s1_weights = monitoring_history['stream1_weight_norms']
    s2_weights = monitoring_history['stream2_weight_norms']
    
    final_s1_weight = s1_weights[-1]
    final_s2_weight = s2_weights[-1]
    weight_ratio = final_s1_weight / max(final_s2_weight, 1e-10)
    
    print(f"\n3️⃣ WEIGHT MAGNITUDES:")
    print(f"   RGB final weight norm: {final_s1_weight:.4f}")
    print(f"   Depth final weight norm: {final_s2_weight:.4f}")
    print(f"   RGB/Depth ratio: {weight_ratio:.2f}")
    
    if weight_ratio < 0.5:
        print(f"\n   ⚠️  RGB weights much smaller - may indicate underfitting")
        print(f"   💡 This often correlates with low gradients")
    elif weight_ratio > 2.0:
        print(f"\n   ⚠️  Depth weights much smaller - may indicate underfitting")
        print(f"   💡 This often correlates with low gradients")
    else:
        print(f"\n   ✅ Weight magnitudes relatively balanced")

# 4. Generate specific config recommendations
print(f"\n{'='*60}")
print("💡 RECOMMENDED NEXT STEPS:")
print(f"{'='*60}\n")

needs_adjustment = False

if len(monitoring_history['stream1_overfitting_scores']) > 0:
    s1_overfit = monitoring_history['stream1_overfitting_scores'][-1]
    s2_overfit = monitoring_history['stream2_overfitting_scores'][-1]
    
    if s1_overfit > s2_overfit * 1.5 or s2_overfit > s1_overfit * 1.5:
        needs_adjustment = True
        print("🔧 RERUN WITH STREAM-SPECIFIC OPTIMIZATION:\n")
        print("Go back to cell 'MODEL COMPILATION' and:")
        print("  1. Set USE_STREAM_SPECIFIC = True")
        print("  2. Uncomment and adjust these lines in STREAM_SPECIFIC_CONFIG:\n")
        
        if s1_overfit > s2_overfit * 1.5:
            print("     # RGB is overfitting - regularize it more:")
            print("     'stream1_lr': 5e-5,           # Lower LR for RGB")
            print("     'stream1_weight_decay': 5e-2, # Higher WD for RGB")
            print("     'stream2_lr': 2e-4,           # Keep Depth learning")
            print("     'stream2_weight_decay': 1e-3, # Lighter WD for Depth")
        else:
            print("     # Depth is overfitting - regularize it more:")
            print("     'stream1_lr': 2e-4,           # Keep RGB learning")
            print("     'stream1_weight_decay': 1e-3, # Lighter WD for RGB")
            print("     'stream2_lr': 5e-5,           # Lower LR for Depth")
            print("     'stream2_weight_decay': 5e-2, # Higher WD for Depth")
        
        print("\n  3. Rerun training from that cell onwards")

if len(monitoring_history['stream1_grad_norms']) > 0:
    grad_ratio = np.mean(s1_grads) / max(np.mean(s2_grads), 1e-10)
    
    if grad_ratio < 0.3 or grad_ratio > 3.0:
        if not needs_adjustment:
            needs_adjustment = True
            print("🔧 RERUN WITH STREAM-SPECIFIC OPTIMIZATION:\n")
            print("Go back to cell 'MODEL COMPILATION' and:")
            print("  1. Set USE_STREAM_SPECIFIC = True")
            print("  2. Uncomment and adjust these lines in STREAM_SPECIFIC_CONFIG:\n")
        
        if grad_ratio < 0.3:
            print("     # RGB gradients too small - boost RGB learning:")
            print("     'stream1_lr': 5e-4,           # 5x higher LR for RGB")
            print("     'stream1_weight_decay': 1e-3, # Lighter WD for RGB")
        else:
            print("     # Depth gradients too small - boost Depth learning:")
            print("     'stream2_lr': 5e-4,           # 5x higher LR for Depth")
            print("     'stream2_weight_decay': 1e-3, # Lighter WD for Depth")

if not needs_adjustment:
    print("✅ Training appears balanced - no major adjustments needed!\n")
    print("If validation accuracy is still low, consider:")
    print("  • Training for more epochs")
    print("  • Using ResNet50 instead of ResNet18 (more capacity)")
    print("  • Adjusting fusion_type (try 'weighted' or 'gated')")
    print("  • Increasing dropout_p for more regularization")

print("\n" + "=" * 60)

## 14b. Monitoring Interpretation & Decision Guide 💡

**How to use monitoring results to improve your model**

In [None]:
import matplotlib.pyplot as plt
import numpy as np

print("=" * 60)
print("STREAM MONITORING VISUALIZATION")
print("=" * 60)

# Create comprehensive monitoring plots
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# 1. Gradient Norms Over Time
if len(monitoring_history['stream1_grad_norms']) > 0:
    epochs = range(1, len(monitoring_history['stream1_grad_norms']) + 1)
    axes[0, 0].plot(epochs, monitoring_history['stream1_grad_norms'], 
                    label='Stream1 (RGB)', linewidth=2, color='blue', marker='o', markersize=3)
    axes[0, 0].plot(epochs, monitoring_history['stream2_grad_norms'], 
                    label='Stream2 (Depth)', linewidth=2, color='orange', marker='s', markersize=3)
    axes[0, 0].set_xlabel('Epoch', fontsize=11)
    axes[0, 0].set_ylabel('Gradient Norm', fontsize=11)
    axes[0, 0].set_title('Stream Gradient Magnitudes', fontsize=13, fontweight='bold')
    axes[0, 0].legend(fontsize=10)
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].set_yscale('log')

# 2. Gradient Ratio Over Time
if len(monitoring_history['stream1_grad_norms']) > 0:
    grad_ratios = [s1/max(s2, 1e-10) for s1, s2 in zip(
        monitoring_history['stream1_grad_norms'], 
        monitoring_history['stream2_grad_norms']
    )]
    axes[0, 1].plot(epochs, grad_ratios, linewidth=2, color='green', marker='d', markersize=3)
    axes[0, 1].axhline(y=1.0, color='red', linestyle='--', linewidth=1.5, label='Perfect Balance')
    axes[0, 1].set_xlabel('Epoch', fontsize=11)
    axes[0, 1].set_ylabel('Gradient Ratio (RGB/Depth)', fontsize=11)
    axes[0, 1].set_title('Gradient Balance Between Streams', fontsize=13, fontweight='bold')
    axes[0, 1].legend(fontsize=10)
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].set_yscale('log')

# 3. Overfitting Scores Comparison
if len(monitoring_history['stream1_overfitting_scores']) > 0:
    overfit_epochs = range(1, len(monitoring_history['stream1_overfitting_scores']) + 1)
    axes[0, 2].plot(overfit_epochs, monitoring_history['stream1_overfitting_scores'], 
                    label='Stream1 (RGB)', linewidth=2.5, color='blue', marker='o', markersize=4)
    axes[0, 2].plot(overfit_epochs, monitoring_history['stream2_overfitting_scores'], 
                    label='Stream2 (Depth)', linewidth=2.5, color='orange', marker='s', markersize=4)
    axes[0, 2].axhline(y=0, color='gray', linestyle='-', linewidth=1, alpha=0.5)
    axes[0, 2].set_xlabel('Epoch', fontsize=11)
    axes[0, 2].set_ylabel('Overfitting Score', fontsize=11)
    axes[0, 2].set_title('Stream-Specific Overfitting Detection', fontsize=13, fontweight='bold')
    axes[0, 2].legend(fontsize=10)
    axes[0, 2].grid(True, alpha=0.3)
    
    # Annotate which stream is overfitting more
    final_s1 = monitoring_history['stream1_overfitting_scores'][-1]
    final_s2 = monitoring_history['stream2_overfitting_scores'][-1]
    if final_s1 > final_s2 * 1.5:
        axes[0, 2].text(0.5, 0.95, '⚠️ RGB overfitting more', 
                       transform=axes[0, 2].transAxes, ha='center', va='top',
                       bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7), fontsize=10)
    elif final_s2 > final_s1 * 1.5:
        axes[0, 2].text(0.5, 0.95, '⚠️ Depth overfitting more', 
                       transform=axes[0, 2].transAxes, ha='center', va='top',
                       bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7), fontsize=10)

# 4. Weight Norms Over Time
if len(monitoring_history['stream1_weight_norms']) > 0:
    weight_epochs = range(1, len(monitoring_history['stream1_weight_norms']) + 1)
    axes[1, 0].plot(weight_epochs, monitoring_history['stream1_weight_norms'], 
                    label='Stream1 (RGB)', linewidth=2, color='blue', marker='o', markersize=3)
    axes[1, 0].plot(weight_epochs, monitoring_history['stream2_weight_norms'], 
                    label='Stream2 (Depth)', linewidth=2, color='orange', marker='s', markersize=3)
    axes[1, 0].set_xlabel('Epoch', fontsize=11)
    axes[1, 0].set_ylabel('Weight Norm', fontsize=11)
    axes[1, 0].set_title('Stream Weight Magnitudes', fontsize=13, fontweight='bold')
    axes[1, 0].legend(fontsize=10)
    axes[1, 0].grid(True, alpha=0.3)

# 5. Weight Ratio Over Time
if len(monitoring_history['stream1_weight_norms']) > 0:
    weight_ratios = [s1/max(s2, 1e-10) for s1, s2 in zip(
        monitoring_history['stream1_weight_norms'], 
        monitoring_history['stream2_weight_norms']
    )]
    axes[1, 1].plot(weight_epochs, weight_ratios, linewidth=2, color='purple', marker='d', markersize=3)
    axes[1, 1].axhline(y=1.0, color='red', linestyle='--', linewidth=1.5, label='Perfect Balance')
    axes[1, 1].set_xlabel('Epoch', fontsize=11)
    axes[1, 1].set_ylabel('Weight Ratio (RGB/Depth)', fontsize=11)
    axes[1, 1].set_title('Weight Balance Between Streams', fontsize=13, fontweight='bold')
    axes[1, 1].legend(fontsize=10)
    axes[1, 1].grid(True, alpha=0.3)

# 6. Summary Statistics (Text)
axes[1, 2].axis('off')
summary_text = "📊 MONITORING SUMMARY\\n\\n"

if len(monitoring_history['stream1_grad_norms']) > 0:
    avg_grad_ratio = np.mean([s1/max(s2, 1e-10) for s1, s2 in zip(
        monitoring_history['stream1_grad_norms'], 
        monitoring_history['stream2_grad_norms']
    )])
    summary_text += f"Gradient Balance:\\n"
    summary_text += f"  RGB/Depth ratio: {avg_grad_ratio:.2f}\\n"
    summary_text += f"  {'✅ Balanced' if 0.5 <= avg_grad_ratio <= 2.0 else '⚠️ Imbalanced'}\\n\\n"

if len(monitoring_history['stream1_overfitting_scores']) > 0:
    final_s1_overfit = monitoring_history['stream1_overfitting_scores'][-1]
    final_s2_overfit = monitoring_history['stream2_overfitting_scores'][-1]
    summary_text += f"Final Overfitting:\\n"
    summary_text += f"  RGB score: {final_s1_overfit:.3f}\\n"
    summary_text += f"  Depth score: {final_s2_overfit:.3f}\\n"
    if final_s1_overfit > final_s2_overfit * 1.5:
        summary_text += f"  ⚠️ RGB overfitting more\\n\\n"
    elif final_s2_overfit > final_s1_overfit * 1.5:
        summary_text += f"  ⚠️ Depth overfitting more\\n\\n"
    else:
        summary_text += f"  ✅ Relatively balanced\\n\\n"

if len(monitoring_history['recommendations']) > 0:
    summary_text += f"Recommendations Given:\\n"
    summary_text += f"  {len(monitoring_history['recommendations'])} times\\n\\n"
    
    # Show most recent recommendations
    if monitoring_history['recommendations']:
        recent = monitoring_history['recommendations'][-1]
        summary_text += f"Latest (Epoch {recent['epoch']}):\\n"
        for rec in recent['recommendations'][:3]:  # Show first 3
            summary_text += f"  • {rec[:40]}...\\n"

axes[1, 2].text(0.1, 0.9, summary_text, transform=axes[1, 2].transAxes, 
               fontsize=10, verticalalignment='top', family='monospace',
               bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.3))

plt.suptitle('Stream Monitoring Analysis - MCResNet Training', 
             fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig(f"{checkpoint_dir}/stream_monitoring_analysis.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✅ Stream monitoring visualization saved to:")
print(f"   {checkpoint_dir}/stream_monitoring_analysis.png")
print("\n" + "=" * 60)

## 14a. Visualize Stream Monitoring Results 🔍

**Comprehensive analysis of stream-specific behavior during training**

## 13. Evaluate Final Model

In [None]:
print("=" * 60)
print("FINAL EVALUATION")
print("=" * 60)

# Evaluate on validation set
results = model.evaluate(data_loader=val_loader)

print(f"\nFinal Validation Results:")
print(f"  Loss: {results['loss']:.4f}")
print(f"  Accuracy: {results['accuracy']*100:.2f}%")

print(f"\nTraining Summary:")
print(f"  Initial train loss: {history['train_loss'][0]:.4f}")
print(f"  Final train loss: {history['train_loss'][-1]:.4f}")
print(f"  Best val loss: {min(history['val_loss']):.4f}")
print(f"  Initial train acc: {history['train_accuracy'][0]*100:.2f}%")
print(f"  Final train acc: {history['train_accuracy'][-1]*100:.2f}%")
print(f"  Best val acc: {max(history['val_accuracy'])*100:.2f}%")
print(f"  Total epochs: {len(history['train_loss'])}")

if 'early_stopping' in history:
    print(f"\nEarly Stopping Info:")
    print(f"  Stopped early: {history['early_stopping']['stopped_early']}")
    print(f"  Best epoch: {history['early_stopping']['best_epoch']}")
    print(f"  Best {history['early_stopping']['monitor']}: {history['early_stopping']['best_metric']:.4f}")

print("\n" + "=" * 60)

## 14. Plot Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss curve
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Accuracy curve
axes[1].plot([acc*100 for acc in history['train_accuracy']], label='Train Acc', linewidth=2)
axes[1].plot([acc*100 for acc in history['val_accuracy']], label='Val Acc', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

# Learning rate curve
if len(history['learning_rates']) > 0:
    # Sample learning rates (they're recorded per step, not per epoch)
    sampled_lrs = history['learning_rates'][::max(1, len(history['learning_rates'])//100)]
    axes[2].plot(sampled_lrs, linewidth=2, color='green')
    axes[2].set_xlabel('Training Step (sampled)', fontsize=12)
    axes[2].set_ylabel('Learning Rate', fontsize=12)
    axes[2].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    axes[2].grid(True, alpha=0.3)
    axes[2].set_yscale('log')

plt.tight_layout()
plt.savefig(f"{checkpoint_dir}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✅ Training curves saved to: {checkpoint_dir}/training_curves.png")

## 15. Pathway Analysis (RGB vs Depth Contributions)

In [None]:
print("=" * 60)
print("PATHWAY ANALYSIS")
print("=" * 60)
print("\nAnalyzing RGB and Depth pathway contributions...")
print("This may take a few minutes...\n")

# Analyze pathways
pathway_analysis = model.analyze_pathways(
    data_loader=val_loader,
    num_samples=len(val_loader.dataset)  # Use all validation samples
)

print("\nAccuracy Metrics:")
print(f"  Full model (RGB+Depth): {pathway_analysis['accuracy']['full_model']*100:.2f}%")
print(f"  RGB only: {pathway_analysis['accuracy']['color_only']*100:.2f}%")
print(f"  Depth only: {pathway_analysis['accuracy']['brightness_only']*100:.2f}%")
print(f"\n  RGB contribution: {pathway_analysis['accuracy']['color_contribution']*100:.2f}%")
print(f"  Depth contribution: {pathway_analysis['accuracy']['brightness_contribution']*100:.2f}%")

print("\nLoss Metrics:")
print(f"  Full model: {pathway_analysis['loss']['full_model']:.4f}")
print(f"  RGB only: {pathway_analysis['loss']['color_only']:.4f}")
print(f"  Depth only: {pathway_analysis['loss']['brightness_only']:.4f}")

print("\nFeature Norm Statistics:")
print(f"  RGB mean: {pathway_analysis['feature_norms']['color_mean']:.4f}")
print(f"  RGB std: {pathway_analysis['feature_norms']['color_std']:.4f}")
print(f"  Depth mean: {pathway_analysis['feature_norms']['brightness_mean']:.4f}")
print(f"  Depth std: {pathway_analysis['feature_norms']['brightness_std']:.4f}")
print(f"  RGB/Depth ratio: {pathway_analysis['feature_norms']['color_to_brightness_ratio']:.4f}")

print("\n" + "=" * 60)
print("✅ Pathway analysis complete!")
print("=" * 60)

# Visualize pathway contributions
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Accuracy comparison
pathways = ['Full Model\n(RGB+Depth)', 'RGB Only', 'Depth Only']
accuracies = [
    pathway_analysis['accuracy']['full_model'] * 100,
    pathway_analysis['accuracy']['color_only'] * 100,
    pathway_analysis['accuracy']['brightness_only'] * 100
]
colors = ['green', 'blue', 'orange']

axes[0].bar(pathways, accuracies, color=colors, alpha=0.7)
axes[0].set_ylabel('Accuracy (%)', fontsize=12)
axes[0].set_title('Pathway Accuracy Comparison', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(accuracies):
    axes[0].text(i, v + 1, f'{v:.1f}%', ha='center', fontweight='bold')

# Feature norm comparison
norms = ['RGB Features', 'Depth Features']
norm_values = [
    pathway_analysis['feature_norms']['color_mean'],
    pathway_analysis['feature_norms']['brightness_mean']
]
axes[1].bar(norms, norm_values, color=['blue', 'orange'], alpha=0.7)
axes[1].set_ylabel('Feature Norm (Mean)', fontsize=12)
axes[1].set_title('Feature Magnitude Comparison', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(norm_values):
    axes[1].text(i, v + 0.1, f'{v:.2f}', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig(f"{checkpoint_dir}/pathway_analysis.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✅ Pathway analysis plot saved to: {checkpoint_dir}/pathway_analysis.png")

## 16. Save Results & Training History

In [None]:
import json
import torch

print("=" * 60)
print("SAVING RESULTS")
print("=" * 60)

# Save training history as JSON
history_path = f"{checkpoint_dir}/training_history.json"
with open(history_path, 'w') as f:
    json_history = {
        'train_loss': [float(x) for x in history['train_loss']],
        'val_loss': [float(x) for x in history['val_loss']],
        'train_accuracy': [float(x) for x in history['train_accuracy']],
        'val_accuracy': [float(x) for x in history['val_accuracy']],
        'learning_rates': [float(x) for x in history['learning_rates']],
        'model_config': MODEL_CONFIG,
        'dataset_config': DATASET_CONFIG,
        'training_config': TRAINING_CONFIG,
        'final_results': {
            'val_loss': float(results['loss']),
            'val_accuracy': float(results['accuracy'])
        },
        'pathway_analysis': {
            'full_model_accuracy': float(pathway_analysis['accuracy']['full_model']),
            'rgb_only_accuracy': float(pathway_analysis['accuracy']['color_only']),
            'depth_only_accuracy': float(pathway_analysis['accuracy']['brightness_only'])
        }
    }
    if 'early_stopping' in history:
        json_history['early_stopping'] = history['early_stopping']
    
    json.dump(json_history, f, indent=2)

print(f"✅ Training history saved: {history_path}")

# Save final model (in addition to best model)
final_model_path = f"{checkpoint_dir}/final_model.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': model.optimizer.state_dict(),
    'config': MODEL_CONFIG,
    'history': history,
    'val_accuracy': results['accuracy']
}, final_model_path)

print(f"✅ Final model saved: {final_model_path}")

# Save summary report
summary_path = f"{checkpoint_dir}/summary.txt"
with open(summary_path, 'w') as f:
    f.write("=" * 60 + "\n")
    f.write("MCResNet Training Summary - NYU Depth V2\n")
    f.write("=" * 60 + "\n\n")
    f.write(f"Model: MCResNet-{MODEL_CONFIG['architecture'].upper()}\n")
    f.write(f"Dataset: NYU Depth V2 (Scene Classification)\n")
    f.write(f"Training Samples: {len(train_loader.dataset)}\n")
    f.write(f"Validation Samples: {len(val_loader.dataset)}\n")
    f.write(f"Total Parameters: {total_params:,}\n")
    f.write(f"\nTraining Configuration:\n")
    f.write(f"  Epochs: {len(history['train_loss'])}\n")
    f.write(f"  Batch Size: {DATASET_CONFIG['batch_size']}\n")
    f.write(f"  Learning Rate: {TRAINING_CONFIG['learning_rate']}\n")
    f.write(f"  Optimizer: {TRAINING_CONFIG['optimizer']}\n")
    f.write(f"  Scheduler: {TRAINING_CONFIG['scheduler']}\n")
    f.write(f"  AMP: {MODEL_CONFIG['use_amp']}\n")
    f.write(f"  Gradient Clipping: {TRAIN_CONFIG['grad_clip_norm']}\n")
    f.write(f"\nFinal Results:\n")
    f.write(f"  Val Loss: {results['loss']:.4f}\n")
    f.write(f"  Val Accuracy: {results['accuracy']*100:.2f}%\n")
    f.write(f"  Best Val Accuracy: {max(history['val_accuracy'])*100:.2f}%\n")
    f.write(f"\nPathway Analysis:\n")
    f.write(f"  Full Model: {pathway_analysis['accuracy']['full_model']*100:.2f}%\n")
    f.write(f"  RGB Only: {pathway_analysis['accuracy']['color_only']*100:.2f}%\n")
    f.write(f"  Depth Only: {pathway_analysis['accuracy']['brightness_only']*100:.2f}%\n")

print(f"✅ Summary report saved: {summary_path}")

print("\n" + "=" * 60)
print(f"All results saved to: {checkpoint_dir}")
print("=" * 60)

# List saved files
print("\nSaved files:")
!ls -lh {checkpoint_dir}

## 17. Summary & Next Steps

### 🎉 Training Complete!

**What we accomplished:**
- ✅ Trained MCResNet on NYU Depth V2 dataset
- ✅ Used A100 GPU with AMP (2x speedup)
- ✅ Saved all checkpoints to Google Drive
- ✅ Analyzed RGB and Depth pathway contributions
- ✅ Generated training curves and visualizations
- ✅ Comprehensive stream monitoring with overfitting detection

**Results are saved to:** Check the output above for the checkpoint directory path

### 📊 Expected Performance:

For **NYU Depth V2 Scene Classification (27 classes)**:
- **Good:** 60-70% validation accuracy
- **Very Good:** 70-75% validation accuracy  
- **Excellent:** 75-80% validation accuracy

### 🔍 Next Steps:

1. **Review Results:**
   - Check training curves above
   - Review pathway analysis
   - Compare RGB vs Depth contributions
   - Analyze stream monitoring plots

2. **Download Results:**
   - All files are saved to your Google Drive
   - Download checkpoints for local inference

3. **Experiment:**
   - Try ResNet50 for better accuracy (change `architecture` in Model Config)
   - Use stream-specific optimization if monitoring shows imbalance
   - Adjust fusion_type (try 'weighted' or 'gated')
   - Train longer if early stopping triggered

4. **Deploy:**
   - Use the best model for inference
   - Test on new RGB-D images
   - Integrate into your application

---

**Questions or issues?** Check the training summary and pathway analysis above!

In [None]:
# Print final summary
print("=" * 60)
print("FINAL SUMMARY")
print("=" * 60)
print(f"\n✅ Training Complete!")
print(f"\nFinal Validation Accuracy: {results['accuracy']*100:.2f}%")
print(f"Best Validation Accuracy: {max(history['val_accuracy'])*100:.2f}%")
print(f"\nRGB Pathway: {pathway_analysis['accuracy']['color_only']*100:.2f}%")
print(f"Depth Pathway: {pathway_analysis['accuracy']['brightness_only']*100:.2f}%")
print(f"Combined (RGB+Depth): {pathway_analysis['accuracy']['full_model']*100:.2f}%")
print(f"\nTotal Training Epochs: {len(history['train_loss'])}")
print(f"Total Parameters: {total_params:,}")
print(f"\nCheckpoints saved to: {checkpoint_dir}")
print("\n" + "=" * 60)
print("🎉 All done! Check Google Drive for saved models and results.")
print("=" * 60)