# MCResNet Training on SUN RGB-D - Google Colab

**Complete end-to-end training pipeline for Google Colab with A100 GPU**

---

## 📋 Checklist Before Running:

- [ ] **Enable A100 GPU:** Runtime → Change runtime type → Hardware accelerator: GPU → GPU type: A100
- [ ] **Mount Google Drive:** Your code and dataset will be stored on Drive
- [ ] **Upload dataset to Drive:** `MyDrive/datasets/sunrgbd_15/` (preprocessed 15-category dataset)
- [ ] **Expected Runtime:** ~2-3 hours for training

---

## 🎯 What This Notebook Does:

1. ✅ Verify A100 GPU is available
2. ✅ Mount Google Drive
3. ✅ Clone your repository to local disk (fast I/O)
4. ✅ Copy SUN RGB-D dataset to local disk (10-20x faster than Drive)
5. ✅ Install dependencies
6. ✅ Train MCResNet with all optimizations
7. ✅ Save checkpoints to Drive (persistent storage)
8. ✅ Generate training curves and analysis

---

**Let's get started!** 🚀

## 1. Environment Setup & GPU Verification

In [None]:
# Check GPU availability and specs
import torch
import subprocess

print("=" * 60)
print("GPU VERIFICATION")
print("=" * 60)

# Check PyTorch and CUDA
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    
    # Check if it's A100
    gpu_name = torch.cuda.get_device_name(0)
    if 'A100' in gpu_name:
        print("\n✅ A100 GPU detected - PERFECT for training!")
    elif 'V100' in gpu_name:
        print("\n✅ V100 GPU detected - Good for training (slower than A100)")
    elif 'T4' in gpu_name:
        print("\n⚠️  T4 GPU detected - Will be slower, consider upgrading to A100")
    else:
        print(f"\n⚠️  GPU: {gpu_name} - Consider using A100 for best performance")
else:
    print("\n❌ NO GPU DETECTED!")
    print("Please enable GPU: Runtime → Change runtime type → Hardware accelerator: GPU")
    raise RuntimeError("GPU is required for training")

print("\n" + "=" * 60)

In [None]:
# Detailed GPU info
!nvidia-smi

## 2. Mount Google Drive

In [None]:
from google.colab import drive
import os
from pathlib import Path

# Mount Google Drive
drive.mount('/content/drive')

print("\n✅ Google Drive mounted successfully!")
print(f"\nDrive contents:")
!ls -la /content/drive/MyDrive/ | head -20

## 3. Clone Repository to Local Disk (Fast I/O)

**Important:** We clone to `/content/` (local SSD) instead of Drive for 10-20x faster I/O

**Default:** Clone from GitHub (recommended - always gets latest code)

In [None]:
import os
from pathlib import Path

# Configuration
PROJECT_NAME = "Multi-Stream-Neural-Networks"
GITHUB_REPO = "https://github.com/clingergab/Multi-Stream-Neural-Networks.git"  # UPDATE THIS
LOCAL_REPO_PATH = f"/content/{PROJECT_NAME}"  # Local copy for fast I/O

print("=" * 60)
print("REPOSITORY SETUP")
print("=" * 60)

# Ensure we're in a valid directory
os.chdir('/content')
print(f"Starting in: {os.getcwd()}")

# Check if repo already exists (same session, rerunning cell)
if Path(LOCAL_REPO_PATH).exists() and Path(f"{LOCAL_REPO_PATH}/.git").exists():
    print(f"\n📁 Repo already exists: {LOCAL_REPO_PATH}")
    print(f"🔄 Pulling latest changes...")
    
    os.chdir(LOCAL_REPO_PATH)
    !git pull
    print("✅ Repo updated")

# Clone from GitHub (first run)
else:
    # Remove old incomplete copy if exists
    if Path(LOCAL_REPO_PATH).exists():
        print(f"\n🗑️  Removing incomplete repo copy...")
        !rm -rf {LOCAL_REPO_PATH}
    
    print(f"\n🔄 Cloning from GitHub...")
    print(f"   Repo: {GITHUB_REPO}")
    print(f"   Destination: {LOCAL_REPO_PATH}")
    
    !git clone {GITHUB_REPO} {LOCAL_REPO_PATH}
    
    # Verify clone succeeded
    if not Path(LOCAL_REPO_PATH).exists():
        raise RuntimeError(f"Failed to clone repository to {LOCAL_REPO_PATH}")
    
    print("✅ Repo cloned successfully")
    os.chdir(LOCAL_REPO_PATH)

# Verify repo structure
print(f"\n📂 Repository structure:")
!ls -la {LOCAL_REPO_PATH}

print(f"\n✅ Working directory: {os.getcwd()}")

## 4. Install Dependencies

In [None]:
# Install required packages
print("Installing dependencies...")

!pip install -q h5py tqdm matplotlib seaborn

# Verify installations
import h5py
import tqdm
import matplotlib
import seaborn

print("✅ All dependencies installed!")
print(f"   h5py: {h5py.__version__}")
print(f"   matplotlib: {matplotlib.__version__}")

## 5. Copy SUN RGB-D Dataset to Local Disk

**Performance Note:** Local disk I/O is ~10-20x faster than Drive!

**Dataset:** SUN RGB-D 15-category preprocessed dataset (~3.5 GB)

In [None]:
from pathlib import Path
import os

# Paths
DRIVE_DATASET_TAR = "/content/drive/MyDrive/datasets/sunrgbd_15.tar.gz"  # Compressed file
LOCAL_DATASET_PATH = "/content/data/sunrgbd_15"  # Extracted location

print("=" * 60)
print("SUN RGB-D 15-CATEGORY DATASET SETUP")
print("=" * 60)

# Check if already on local disk
if Path(LOCAL_DATASET_PATH).exists():
    print(f"✅ Dataset already on local disk: {LOCAL_DATASET_PATH}")
    
    # Verify structure
    train_rgb_count = len(list(Path(f"{LOCAL_DATASET_PATH}/train/rgb").glob("*.png")))
    val_rgb_count = len(list(Path(f"{LOCAL_DATASET_PATH}/val/rgb").glob("*.png")))
    print(f"   Train samples: {train_rgb_count}")
    print(f"   Val samples: {val_rgb_count}")

# Copy and extract from Drive
elif Path(DRIVE_DATASET_TAR).exists():
    print(f"📁 Found compressed dataset on Drive: {DRIVE_DATASET_TAR}")
    print(f"📥 Copying 4.2GB compressed file to local disk...")
    print(f"   ⏱️  This takes ~3-5 minutes (much faster than 20k individual files!)")
    
    # Create parent directory
    !mkdir -p /content/data
    
    # Copy compressed file with progress
    print(f"\nCopying compressed archive...")
    !rsync -ah --info=progress2 {DRIVE_DATASET_TAR} /content/data/sunrgbd_15.tar.gz
    
    # Extract to local disk (suppress macOS metadata warnings)
    print(f"\n📦 Extracting dataset to local disk...")
    !tar -xzf /content/data/sunrgbd_15.tar.gz -C /content/data/ 2>&1 | grep -v "Ignoring unknown extended header"
    
    # Remove tar file to save space
    !rm /content/data/sunrgbd_15.tar.gz
    
    print(f"\n✅ Dataset extracted to local disk")
    
    # Verify extraction
    train_rgb_count = len(list(Path(f"{LOCAL_DATASET_PATH}/train/rgb").glob("*.png")))
    val_rgb_count = len(list(Path(f"{LOCAL_DATASET_PATH}/val/rgb").glob("*.png")))
    print(f"   Train samples: {train_rgb_count}")
    print(f"   Val samples: {val_rgb_count}")

else:
    print(f"❌ Compressed dataset not found on Drive!")
    print(f"   Expected location: {DRIVE_DATASET_TAR}")
    print(f"\n📋 To fix this:")
    print(f"   1. Run: COPYFILE_DISABLE=1 tar -czf sunrgbd_15.tar.gz sunrgbd_15/")
    print(f"   2. Upload sunrgbd_15.tar.gz to Google Drive")
    print(f"   3. Place it at: {DRIVE_DATASET_TAR}")
    raise FileNotFoundError(f"Compressed dataset not found at {DRIVE_DATASET_TAR}")

print("\n" + "=" * 60)
print(f"Dataset ready at: {LOCAL_DATASET_PATH}")
print("=" * 60)

## 6. Setup Python Path & Import MCResNet

In [None]:
import sys
import os

# Remove cached modules
modules_to_reload = [k for k in sys.modules.keys() if k.startswith('src.')]
for module in modules_to_reload:
    del sys.modules[module]
    
# Add project to Python path
project_root = '/content/Multi-Stream-Neural-Networks'
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Verify project structure
print("Project structure:")
!ls -la {project_root}/src/models/

# Import MCResNet and SUN RGB-D dataloader
print("\nImporting MCResNet and dataloaders...")
from src.models.multi_channel.mc_resnet import mc_resnet18, mc_resnet50
from src.data_utils.sunrgbd_dataset import get_sunrgbd_dataloaders

print("✅ MCResNet and dataloaders imported successfully!")

## 7. Load SUN RGB-D Dataset

In [None]:
# Verify dataset structure
from pathlib import Path

print("=" * 60)
print("DATASET STRUCTURE VERIFICATION")
print("=" * 60)

dataset_root = Path(LOCAL_DATASET_PATH)

print("\nDirectory structure:")
print(f"  {dataset_root}/")
print(f"    train/")
print(f"      rgb/ - {len(list((dataset_root / 'train' / 'rgb').glob('*.png')))} images")
print(f"      depth/ - {len(list((dataset_root / 'train' / 'depth').glob('*.png')))} images")
print(f"      labels.txt")
print(f"    val/")
print(f"      rgb/ - {len(list((dataset_root / 'val' / 'rgb').glob('*.png')))} images")
print(f"      depth/ - {len(list((dataset_root / 'val' / 'depth').glob('*.png')))} images")
print(f"      labels.txt")
print(f"    class_names.txt")
print(f"    dataset_info.txt")

# Read class names
with open(dataset_root / 'class_names.txt', 'r') as f:
    class_names = [line.strip() for line in f]

print(f"\nClasses ({len(class_names)}):")
for i, name in enumerate(class_names):
    print(f"  {i}: {name}")

print("\n" + "=" * 60)

In [None]:
print("=" * 60)
print("LOADING SUN RGB-D 15-CATEGORY DATASET")
print("=" * 60)

# Dataset configuration
DATASET_CONFIG = {
    'data_root': LOCAL_DATASET_PATH,
    'batch_size': 96,  # Good balance for A100
    'num_workers': 4,
    'target_size': (416, 544),  
    'num_classes': 15   # SUN RGB-D merged to 15 categories (labels 0-14)
}

print(f"Configuration:")
for key, value in DATASET_CONFIG.items():
    print(f"  {key}: {value}")

print(f"\nLoading dataset from: {DATASET_CONFIG['data_root']}")

# Create dataloaders
train_loader, val_loader = get_sunrgbd_dataloaders(
    data_root=DATASET_CONFIG['data_root'],
    batch_size=DATASET_CONFIG['batch_size'],
    num_workers=DATASET_CONFIG['num_workers'],
    target_size=DATASET_CONFIG['target_size']
)

print(f"\n✅ Dataset loaded successfully!")
print(f"\nDataset Statistics:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Train samples: {len(train_loader.dataset)}")
print(f"  Val samples: {len(val_loader.dataset)}")
print(f"  Batch size: {DATASET_CONFIG['batch_size']}")

# Test loading a batch
print(f"\nTesting batch loading...")
rgb_batch, depth_batch, label_batch = next(iter(train_loader))
print(f"  RGB shape: {rgb_batch.shape}")
print(f"  Depth shape: {depth_batch.shape}")
print(f"  Labels shape: {label_batch.shape}")
print(f"  Labels min: {label_batch.min().item()}, max: {label_batch.max().item()}")

print("\n" + "=" * 60)

## 8. Visualize Sample Data

Shows RGB images, depth maps, and scene labels from the dataset

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Visualize some samples
fig, axes = plt.subplots(2, 4, figsize=(14, 7))

for i in range(4):
    rgb = rgb_batch[i].cpu()
    depth = depth_batch[i].cpu()
    label = label_batch[i].item()
    
    # Denormalize for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    rgb_vis = rgb * std + mean
    rgb_vis = torch.clamp(rgb_vis, 0, 1)
    
    # Plot RGB
    axes[0, i].imshow(rgb_vis.permute(1, 2, 0))
    axes[0, i].set_title(f"RGB - Class {label}", fontsize=10)
    axes[0, i].axis('off')
    
    # Plot Depth
    axes[1, i].imshow(depth.squeeze(), cmap='viridis')
    axes[1, i].set_title(f"Depth - Class {label}", fontsize=10)
    axes[1, i].axis('off')

plt.suptitle('SUN RGB-D Sample Data (RGB + Depth)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("✅ Sample visualization complete!")

## 9. Create & Compile MCResNet Model

In [None]:
print("=" * 60)
print("MODEL CREATION")
print("=" * 60)

# Model configuration
MODEL_CONFIG = {
    'architecture': 'resnet18',  # or 'resnet50' for better accuracy
    'num_classes': 15,  # SUN RGB-D has 15 merged categories (labels 0-14)
    'stream1_channels': 3,  # RGB
    'stream2_channels': 1,  # Depth
    'fusion_type': 'concat',  # 'concat', 'weighted', or 'gated'
    'dropout_p': 0.5,  # Dropout for regularization
    'device': 'cuda',
    'use_amp': True  # Automatic Mixed Precision (2x faster on A100)
}

print(f"Configuration:")
for key, value in MODEL_CONFIG.items():
    print(f"  {key}: {value}")

# Create model
print(f"\nCreating MCResNet-{MODEL_CONFIG['architecture'].upper()}...")

if MODEL_CONFIG['architecture'] == 'resnet18':
    model = mc_resnet18(
        num_classes=MODEL_CONFIG['num_classes'],
        stream1_input_channels=MODEL_CONFIG['stream1_channels'],
        stream2_input_channels=MODEL_CONFIG['stream2_channels'],
        fusion_type=MODEL_CONFIG['fusion_type'],
        dropout_p=MODEL_CONFIG['dropout_p'],
        device=MODEL_CONFIG['device'],
        use_amp=MODEL_CONFIG['use_amp']
    )
elif MODEL_CONFIG['architecture'] == 'resnet50':
    model = mc_resnet50(
        num_classes=MODEL_CONFIG['num_classes'],
        stream1_input_channels=MODEL_CONFIG['stream1_channels'],
        stream2_input_channels=MODEL_CONFIG['stream2_channels'],
        fusion_type=MODEL_CONFIG['fusion_type'],
        dropout_p=MODEL_CONFIG['dropout_p'],
        device=MODEL_CONFIG['device'],
        use_amp=MODEL_CONFIG['use_amp']
    )

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
fusion_params = sum(p.numel() for p in model.fusion.parameters())

print(f"\n✅ Model created successfully!")
print(f"\nModel Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Fusion parameters: {fusion_params:,}")
print(f"  Model size: {total_params * 4 / 1024**2:.2f} MB (FP32)")
print(f"  Fusion strategy: {model.fusion_strategy}")
print(f"  Device: {MODEL_CONFIG['device']}")
print(f"  AMP enabled: {MODEL_CONFIG['use_amp']}")

print("\n" + "=" * 60)

## 9b. Model Compilation Options

**Choose your optimization strategy below (cell-22)**

In [None]:
# Compile model with stream-specific optimization
print("=" * 60)
print("MODEL COMPILATION")
print("=" * 60)

# Stream-specific configuration for optimal RGB/Depth balance
STREAM_SPECIFIC_CONFIG = {
    'optimizer': 'adamw',
    'learning_rate': 7e-5,           # Base LR for shared params (fusion, classifier)
    'weight_decay': 2e-4,             # Base weight decay

    # Stream-specific settings (adjusted based on research):
    'stream1_lr': 3e-5,               # RGB stream: lower LR (more regularization)
    'stream1_weight_decay': 5e-4,     # RGB stream: higher WD (prevent overfitting)
    'stream2_lr': 1e-4,               # Depth stream: higher LR (needs more learning)
    'stream2_weight_decay': 1e-4,     # Depth stream: lighter WD (less regularization)

    'loss': 'cross_entropy',
    'scheduler': 'cosine'
}

print(f"Configuration:")
for key, value in STREAM_SPECIFIC_CONFIG.items():
    if value is not None:
        print(f"  {key}: {value}")

# Compile
model.compile(**STREAM_SPECIFIC_CONFIG)

print("\n✅ Model compiled successfully!")

# Show parameter groups
if hasattr(model.optimizer, 'param_groups'):
    print(f"\nParameter groups created: {len(model.optimizer.param_groups)}")
    for i, group in enumerate(model.optimizer.param_groups):
        num_params = sum(p.numel() for p in group['params'])
        print(f"  Group {i}: LR={group['lr']:.2e}, WD={group['weight_decay']:.2e}, Params={num_params:,}")

print("\n" + "=" * 60)

## 10. Test Forward Pass

In [None]:
# Test forward pass with detailed debugging
print("Testing forward pass with CUDA_LAUNCH_BLOCKING for better error messages...")

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # Synchronous CUDA for better error messages

model.eval()
with torch.no_grad():
    rgb_test, depth_test, labels_test = next(iter(train_loader))
    
    print(f"\nInput validation:")
    print(f"  RGB shape: {rgb_test.shape}, dtype: {rgb_test.dtype}")
    print(f"  RGB min: {rgb_test.min():.4f}, max: {rgb_test.max():.4f}")
    print(f"  RGB has NaN: {torch.isnan(rgb_test).any()}")
    print(f"  RGB has Inf: {torch.isinf(rgb_test).any()}")
    
    print(f"\n  Depth shape: {depth_test.shape}, dtype: {depth_test.dtype}")
    print(f"  Depth min: {depth_test.min():.4f}, max: {depth_test.max():.4f}")
    print(f"  Depth has NaN: {torch.isnan(depth_test).any()}")
    print(f"  Depth has Inf: {torch.isinf(depth_test).any()}")
    
    print(f"\n  Labels shape: {labels_test.shape}, dtype: {labels_test.dtype}")
    print(f"  Labels min: {labels_test.min()}, max: {labels_test.max()}")
    print(f"  Labels unique: {torch.unique(labels_test).tolist()}")
    
    print("\nRunning forward pass...")
    rgb_cuda = rgb_test.to('cuda')
    depth_cuda = depth_test.to('cuda')
    
    try:
        outputs = model(rgb_cuda, depth_cuda)
        print(f"  ✅ Forward pass successful!")
        print(f"  Output shape: {outputs.shape}")
        print(f"  Output min: {outputs.min():.4f}, max: {outputs.max():.4f}")
        
        _, predictions = torch.max(outputs, 1)
        print(f"\nSample predictions: {predictions.cpu().numpy()[:10]}")
        print(f"Ground truth: {labels_test.numpy()[:10]}")
        
    except Exception as e:
        print(f"\n❌ Forward pass failed!")
        print(f"Error: {e}")
        print(f"\nThis is likely a model architecture issue, not a data issue.")
        print(f"Possible causes:")
        print(f"  1. BatchNorm running stats issue")
        print(f"  2. Invalid tensor operations in model")
        print(f"  3. Memory corruption")
        raise

print("\n✅ Forward pass test complete!")

## 11. Setup Checkpoint Directory

In [None]:
import os
from datetime import datetime
from pathlib import Path

print("=" * 60)
print("CHECKPOINT DIRECTORY SETUP")
print("=" * 60)

# Create checkpoint directory on Google Drive (persistent storage)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = f"/content/drive/MyDrive/mcresnet_checkpoints/run_{timestamp}"

# Create directory
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)

print(f"\n✅ Checkpoint directory created:")
print(f"   {checkpoint_dir}")
print(f"\nAll training artifacts will be saved here:")
print(f"  • Best model weights")
print(f"  • Training history")
print(f"  • Monitoring metrics")
print(f"  • Visualizations")

print("\n" + "=" * 60)

## 13. Train the Model 🚀

**Expected time:** ~2-3 hours for 90 epochs on A100

**All optimizations enabled:**
- ✅ Automatic Mixed Precision (2x faster)
- ✅ Gradient Clipping (stability)
- ✅ Cosine Annealing LR
- ✅ Early Stopping
- ✅ Best Model Checkpointing
- ✅ Local disk I/O (10-20x faster than Drive)

In [None]:
print("=" * 60)
print("TRAINING WITH STREAM MONITORING")
print("=" * 60)

# Training configuration
TRAIN_CONFIG = {
    'epochs': 90,
    'grad_clip_norm': 5.0,
    'early_stopping': True,
    'patience': 15,
    'min_delta': 0.001,
    'monitor': 'val_accuracy',
    'restore_best_weights': True,
    'save_path': f"{checkpoint_dir}/best_model.pt",
    'stream_monitoring': True  # Built-in stream monitoring!
}

print(f"Configuration:")
for key, value in TRAIN_CONFIG.items():
    print(f"  {key}: {value}")

print("\n" + "=" * 60)
print(f"Training will take approximately 2-3 hours on A100")
print(f"Stream monitoring active - detailed per-stream metrics shown each epoch")
print("=" * 60 + "\n")

# Train using built-in fit() method with stream monitoring
history = model.fit(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=TRAIN_CONFIG['epochs'],
    verbose=True,
    save_path=TRAIN_CONFIG['save_path'],
    early_stopping=TRAIN_CONFIG['early_stopping'],
    patience=TRAIN_CONFIG['patience'],
    min_delta=TRAIN_CONFIG['min_delta'],
    monitor=TRAIN_CONFIG['monitor'],
    restore_best_weights=TRAIN_CONFIG['restore_best_weights'],
    grad_clip_norm=TRAIN_CONFIG['grad_clip_norm'],
    stream_monitoring=TRAIN_CONFIG['stream_monitoring']  # Enable built-in monitoring
)

print("\n" + "=" * 60)
print("🎉 TRAINING COMPLETE!")
print("=" * 60)

## 13. Evaluate Final Model

In [None]:
print("=" * 60)
print("FINAL EVALUATION")
print("=" * 60)

# Evaluate on validation set
results = model.evaluate(data_loader=val_loader)

print(f"\nFinal Validation Results:")
print(f"  Loss: {results['loss']:.4f}")
print(f"  Accuracy: {results['accuracy']*100:.2f}%")

print(f"\nTraining Summary:")
print(f"  Initial train loss: {history['train_loss'][0]:.4f}")
print(f"  Final train loss: {history['train_loss'][-1]:.4f}")
print(f"  Best val loss: {min(history['val_loss']):.4f}")
print(f"  Initial train acc: {history['train_accuracy'][0]*100:.2f}%")
print(f"  Final train acc: {history['train_accuracy'][-1]*100:.2f}%")
print(f"  Best val acc: {max(history['val_accuracy'])*100:.2f}%")
print(f"  Total epochs: {len(history['train_loss'])}")

if 'early_stopping' in history:
    print(f"\nEarly Stopping Info:")
    print(f"  Stopped early: {history['early_stopping']['stopped_early']}")
    print(f"  Best epoch: {history['early_stopping']['best_epoch']}")
    print(f"  Best {history['early_stopping']['monitor']}: {history['early_stopping']['best_metric']:.4f}")

print("\n" + "=" * 60)

## 14. Plot Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss curve
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Accuracy curve
axes[1].plot([acc*100 for acc in history['train_accuracy']], label='Train Acc', linewidth=2)
axes[1].plot([acc*100 for acc in history['val_accuracy']], label='Val Acc', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

# Learning rate curve
if len(history['learning_rates']) > 0:
    # Sample learning rates (they're recorded per step, not per epoch)
    sampled_lrs = history['learning_rates'][::max(1, len(history['learning_rates'])//100)]
    axes[2].plot(sampled_lrs, linewidth=2, color='green')
    axes[2].set_xlabel('Training Step (sampled)', fontsize=12)
    axes[2].set_ylabel('Learning Rate', fontsize=12)
    axes[2].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    axes[2].grid(True, alpha=0.3)
    axes[2].set_yscale('log')

plt.tight_layout()
plt.savefig(f"{checkpoint_dir}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✅ Training curves saved to: {checkpoint_dir}/training_curves.png")

## 15. Pathway Analysis (RGB vs Depth Contributions)

In [None]:
print("=" * 60)
print("PATHWAY ANALYSIS")
print("=" * 60)
print("\nAnalyzing RGB and Depth pathway contributions...")
print("This may take a few minutes...\n")

# Analyze pathways
pathway_analysis = model.analyze_pathways(
    data_loader=val_loader,
    num_samples=len(val_loader.dataset)  # Use all validation samples
)

print("\nAccuracy Metrics:")
print(f"  Full model (RGB+Depth): {pathway_analysis['accuracy']['full_model']*100:.2f}%")
print(f"  RGB only: {pathway_analysis['accuracy']['color_only']*100:.2f}%")
print(f"  Depth only: {pathway_analysis['accuracy']['brightness_only']*100:.2f}%")
print(f"\n  RGB contribution: {pathway_analysis['accuracy']['color_contribution']*100:.2f}%")
print(f"  Depth contribution: {pathway_analysis['accuracy']['brightness_contribution']*100:.2f}%")

print("\nLoss Metrics:")
print(f"  Full model: {pathway_analysis['loss']['full_model']:.4f}")
print(f"  RGB only: {pathway_analysis['loss']['color_only']:.4f}")
print(f"  Depth only: {pathway_analysis['loss']['brightness_only']:.4f}")

print("\nFeature Norm Statistics:")
print(f"  RGB mean: {pathway_analysis['feature_norms']['color_mean']:.4f}")
print(f"  RGB std: {pathway_analysis['feature_norms']['color_std']:.4f}")
print(f"  Depth mean: {pathway_analysis['feature_norms']['brightness_mean']:.4f}")
print(f"  Depth std: {pathway_analysis['feature_norms']['brightness_std']:.4f}")
print(f"  RGB/Depth ratio: {pathway_analysis['feature_norms']['color_to_brightness_ratio']:.4f}")

print("\n" + "=" * 60)
print("STREAM CONTRIBUTION ANALYSIS")
print("=" * 60)
print("\nCalculating how much the fusion relies on each stream...")
print("(Measures performance drop when each stream is removed)\n")

# Calculate stream contributions - shows how much fusion relies on each stream
stream_contributions = model.calculate_stream_contributions(
    data_loader=val_loader,
    batch_size=96
)

print("Stream Contribution to Final Predictions:")
print(f"  RGB importance: {stream_contributions['color_importance']*100:.1f}%")
print(f"  Depth importance: {stream_contributions['brightness_importance']*100:.1f}%")

print("\nPerformance Drop Analysis:")
print(f"  Without RGB: {stream_contributions['performance_drops']['without_color']*100:.2f}% accuracy drop")
print(f"  Without Depth: {stream_contributions['performance_drops']['without_brightness']*100:.2f}% accuracy drop")

print("\nInterpretation:")
if stream_contributions['color_importance'] > 0.6:
    print("  → Fusion relies heavily on RGB stream")
elif stream_contributions['brightness_importance'] > 0.6:
    print("  → Fusion relies heavily on Depth stream")
else:
    print("  → Fusion uses both streams fairly equally")

print("\n" + "=" * 60)
print("✅ Pathway analysis complete!")
print("=" * 60)

# Visualize pathway contributions
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Accuracy comparison
pathways = ['Full Model\n(RGB+Depth)', 'RGB Only', 'Depth Only']
accuracies = [
    pathway_analysis['accuracy']['full_model'] * 100,
    pathway_analysis['accuracy']['color_only'] * 100,
    pathway_analysis['accuracy']['brightness_only'] * 100
]
colors = ['green', 'blue', 'orange']

axes[0].bar(pathways, accuracies, color=colors, alpha=0.7)
axes[0].set_ylabel('Accuracy (%)', fontsize=12)
axes[0].set_title('Pathway Accuracy Comparison', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(accuracies):
    axes[0].text(i, v + 1, f'{v:.1f}%', ha='center', fontweight='bold')

# Feature norm comparison
norms = ['RGB Features', 'Depth Features']
norm_values = [
    pathway_analysis['feature_norms']['color_mean'],
    pathway_analysis['feature_norms']['brightness_mean']
]
axes[1].bar(norms, norm_values, color=['blue', 'orange'], alpha=0.7)
axes[1].set_ylabel('Feature Norm (Mean)', fontsize=12)
axes[1].set_title('Feature Magnitude Comparison', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(norm_values):
    axes[1].text(i, v + 0.1, f'{v:.2f}', ha='center', fontweight='bold')

# Stream importance (contribution to predictions)
streams = ['RGB Stream', 'Depth Stream']
importance_values = [
    stream_contributions['color_importance'] * 100,
    stream_contributions['brightness_importance'] * 100
]
axes[2].bar(streams, importance_values, color=['blue', 'orange'], alpha=0.7)
axes[2].set_ylabel('Importance (%)', fontsize=12)
axes[2].set_title('Stream Contribution to Predictions', fontsize=14, fontweight='bold')
axes[2].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(importance_values):
    axes[2].text(i, v + 1, f'{v:.1f}%', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig(f"{checkpoint_dir}/pathway_analysis.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✅ Pathway analysis plot saved to: {checkpoint_dir}/pathway_analysis.png")

## 16. Save Results & Training History

In [None]:
import json
import torch

print("=" * 60)
print("SAVING RESULTS")
print("=" * 60)

# Save training history as JSON
history_path = f"{checkpoint_dir}/training_history.json"
with open(history_path, 'w') as f:
    json_history = {
        'train_loss': [float(x) for x in history['train_loss']],
        'val_loss': [float(x) for x in history['val_loss']],
        'train_accuracy': [float(x) for x in history['train_accuracy']],
        'val_accuracy': [float(x) for x in history['val_accuracy']],
        'learning_rates': [float(x) for x in history['learning_rates']],
        'model_config': MODEL_CONFIG,
        'dataset_config': DATASET_CONFIG,
        'stream_specific_config': STREAM_SPECIFIC_CONFIG,
        'training_config': TRAIN_CONFIG,
        'scheduler_kwargs': history.get('scheduler_kwargs', {}),  # Scheduler-specific parameters
        'final_results': {
            'val_loss': float(results['loss']),
            'val_accuracy': float(results['accuracy'])
        },
        'pathway_analysis': {
            'full_model_accuracy': float(pathway_analysis['accuracy']['full_model']),
            'rgb_only_accuracy': float(pathway_analysis['accuracy']['color_only']),
            'depth_only_accuracy': float(pathway_analysis['accuracy']['brightness_only'])
        }
    }
    if 'early_stopping' in history:
        json_history['early_stopping'] = history['early_stopping']
    
    json.dump(json_history, f, indent=2)

print(f"✅ Training history saved: {history_path}")

# Save final model (in addition to best model)
final_model_path = f"{checkpoint_dir}/final_model.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': model.optimizer.state_dict(),
    'config': MODEL_CONFIG,
    'stream_specific_config': STREAM_SPECIFIC_CONFIG,
    'scheduler_kwargs': history.get('scheduler_kwargs', {}),  # Scheduler-specific parameters
    'history': history,
    'val_accuracy': results['accuracy']
}, final_model_path)

print(f"✅ Final model saved: {final_model_path}")

# Save summary report
summary_path = f"{checkpoint_dir}/summary.txt"
with open(summary_path, 'w') as f:
    f.write("=" * 60 + "\n")
    f.write("MCResNet Training Summary - SUN RGB-D\n")
    f.write("=" * 60 + "\n\n")
    
    # Model Configuration
    f.write("Model Configuration:\n")
    f.write(f"  Architecture: MCResNet-{MODEL_CONFIG['architecture'].upper()}\n")
    f.write(f"  Num Classes: {MODEL_CONFIG['num_classes']}\n")
    f.write(f"  Stream1 Channels: {MODEL_CONFIG['stream1_channels']} (RGB)\n")
    f.write(f"  Stream2 Channels: {MODEL_CONFIG['stream2_channels']} (Depth)\n")
    f.write(f"  Fusion Type: {MODEL_CONFIG['fusion_type']}\n")
    f.write(f"  Dropout: {MODEL_CONFIG['dropout_p']}\n")
    f.write(f"  Device: {MODEL_CONFIG['device']}\n")
    f.write(f"  AMP Enabled: {MODEL_CONFIG['use_amp']}\n")
    f.write(f"  Total Parameters: {total_params:,}\n")
    f.write(f"  Trainable Parameters: {trainable_params:,}\n")
    f.write(f"  Fusion Parameters: {fusion_params:,}\n")
    
    # Dataset Configuration
    f.write(f"\nDataset Configuration:\n")
    f.write(f"  Dataset: SUN RGB-D 15-category (Scene Classification)\n")
    f.write(f"  Training Samples: {len(train_loader.dataset)}\n")
    f.write(f"  Validation Samples: {len(val_loader.dataset)}\n")
    f.write(f"  Batch Size: {DATASET_CONFIG['batch_size']}\n")
    f.write(f"  Num Workers: {DATASET_CONFIG['num_workers']}\n")
    f.write(f"  Input Size: {DATASET_CONFIG['target_size']}\n")
    
    # Optimization Configuration
    f.write(f"\nOptimization Configuration:\n")
    f.write(f"  Optimizer: {STREAM_SPECIFIC_CONFIG['optimizer']}\n")
    f.write(f"  Loss Function: {STREAM_SPECIFIC_CONFIG['loss']}\n")
    
    # Label smoothing if present
    if 'label_smoothing' in STREAM_SPECIFIC_CONFIG and STREAM_SPECIFIC_CONFIG['label_smoothing'] > 0:
        f.write(f"  Label Smoothing: {STREAM_SPECIFIC_CONFIG['label_smoothing']}\n")
    
    f.write(f"  Scheduler: {STREAM_SPECIFIC_CONFIG['scheduler']}\n")
    
    # Scheduler-specific parameters
    scheduler_kwargs = history.get('scheduler_kwargs', {})
    if scheduler_kwargs:
        f.write(f"  Scheduler Parameters:\n")
        for key, value in scheduler_kwargs.items():
            f.write(f"    {key}: {value}\n")
    
    f.write(f"  Base LR: {STREAM_SPECIFIC_CONFIG['learning_rate']}\n")
    f.write(f"  Base Weight Decay: {STREAM_SPECIFIC_CONFIG['weight_decay']}\n")
    f.write(f"  Gradient Clipping: {TRAIN_CONFIG['grad_clip_norm']}\n")
    
    # Stream-Specific Settings
    f.write(f"\nStream-Specific Settings:\n")
    f.write(f"  Stream1 (RGB):\n")
    f.write(f"    Learning Rate: {STREAM_SPECIFIC_CONFIG['stream1_lr']}\n")
    f.write(f"    Weight Decay: {STREAM_SPECIFIC_CONFIG['stream1_weight_decay']}\n")
    f.write(f"  Stream2 (Depth):\n")
    f.write(f"    Learning Rate: {STREAM_SPECIFIC_CONFIG['stream2_lr']}\n")
    f.write(f"    Weight Decay: {STREAM_SPECIFIC_CONFIG['stream2_weight_decay']}\n")
    
    # Training Configuration
    f.write(f"\nTraining Configuration:\n")
    f.write(f"  Total Epochs: {len(history['train_loss'])}\n")
    f.write(f"  Stream Monitoring: {TRAIN_CONFIG['stream_monitoring']}\n")
    f.write(f"  Early Stopping: {TRAIN_CONFIG['early_stopping']}\n")
    if TRAIN_CONFIG['early_stopping']:
        f.write(f"    Monitor: {TRAIN_CONFIG['monitor']}\n")
        f.write(f"    Patience: {TRAIN_CONFIG['patience']}\n")
        f.write(f"    Min Delta: {TRAIN_CONFIG['min_delta']}\n")
        f.write(f"    Restore Best Weights: {TRAIN_CONFIG['restore_best_weights']}\n")
    
    # Results
    f.write(f"\nFinal Results:\n")
    f.write(f"  Val Loss: {results['loss']:.4f}\n")
    f.write(f"  Val Accuracy: {results['accuracy']*100:.2f}%\n")
    f.write(f"  Best Val Accuracy: {max(history['val_accuracy'])*100:.2f}%\n")
    f.write(f"  Initial Train Loss: {history['train_loss'][0]:.4f}\n")
    f.write(f"  Final Train Loss: {history['train_loss'][-1]:.4f}\n")
    f.write(f"  Best Val Loss: {min(history['val_loss']):.4f}\n")
    
    # Pathway Analysis
    f.write(f"\nPathway Analysis:\n")
    f.write(f"  Full Model (RGB+Depth): {pathway_analysis['accuracy']['full_model']*100:.2f}%\n")
    f.write(f"  RGB Only: {pathway_analysis['accuracy']['color_only']*100:.2f}%\n")
    f.write(f"  Depth Only: {pathway_analysis['accuracy']['brightness_only']*100:.2f}%\n")
    f.write(f"  RGB Contribution: {pathway_analysis['accuracy']['color_contribution']*100:.2f}%\n")
    f.write(f"  Depth Contribution: {pathway_analysis['accuracy']['brightness_contribution']*100:.2f}%\n")

print(f"✅ Summary report saved: {summary_path}")

print("\n" + "=" * 60)
print(f"All results saved to: {checkpoint_dir}")
print("=" * 60)

# List saved files
print("\nSaved files:")
!ls -lh {checkpoint_dir}

## 17. Summary & Next Steps

### 🎉 Training Complete!

**What we accomplished:**
- ✅ Trained MCResNet on SUN RGB-D dataset (15 categories)
- ✅ Used A100 GPU with AMP (2x speedup)
- ✅ Saved all checkpoints to Google Drive
- ✅ Analyzed RGB and Depth pathway contributions
- ✅ Generated training curves and visualizations
- ✅ Comprehensive stream monitoring with overfitting detection

**Results are saved to:** Check the output above for the checkpoint directory path

### 📊 Expected Performance:

For **SUN RGB-D Scene Classification (15 categories, 10,335 images)**:
- **Good:** 65-75% validation accuracy
- **Very Good:** 75-80% validation accuracy
- **Excellent:** 80-85% validation accuracy

**Much better than NYU Depth V2 due to:**
- 6.9x more training samples (8,041 vs 1,159)
- 22.6x better class balance (8.5x vs 192x)
- Higher quality, more diverse dataset

### 🔍 Next Steps:

1. **Review Results:**
   - Check training curves above
   - Review pathway analysis
   - Compare RGB vs Depth contributions
   - Analyze stream monitoring plots

2. **Download Results:**
   - All files are saved to your Google Drive
   - Download checkpoints for local inference

3. **Experiment:**
   - Try ResNet50 for better accuracy (change `architecture` in Model Config)
   - Use stream-specific optimization if monitoring shows imbalance
   - Adjust fusion_type (try 'weighted' or 'gated')
   - Train longer if early stopping triggered

4. **Deploy:**
   - Use the best model for inference
   - Test on new RGB-D images
   - Integrate into your application

---

**Questions or issues?** Check the training summary and pathway analysis above!

In [None]:
# Print final summary
print("=" * 60)
print("FINAL SUMMARY")
print("=" * 60)
print(f"\n✅ Training Complete!")
print(f"\nFinal Validation Accuracy: {results['accuracy']*100:.2f}%")
print(f"Best Validation Accuracy: {max(history['val_accuracy'])*100:.2f}%")
print(f"\nRGB Pathway: {pathway_analysis['accuracy']['color_only']*100:.2f}%")
print(f"Depth Pathway: {pathway_analysis['accuracy']['brightness_only']*100:.2f}%")
print(f"Combined (RGB+Depth): {pathway_analysis['accuracy']['full_model']*100:.2f}%")
print(f"\nTotal Training Epochs: {len(history['train_loss'])}")
print(f"Total Parameters: {total_params:,}")
print(f"\nCheckpoints saved to: {checkpoint_dir}")
print("\n" + "=" * 60)
print("🎉 All done! Check Google Drive for saved models and results.")
print("=" * 60)