# MCResNet Training on NYU Depth V2 - Google Colab

**Complete end-to-end training pipeline for Google Colab with A100 GPU**

---

## 📋 Checklist Before Running:

- [ ] **Enable A100 GPU:** Runtime → Change runtime type → Hardware accelerator: GPU → GPU type: A100
- [ ] **Mount Google Drive:** Your code and dataset will be stored on Drive
- [ ] **Upload dataset to Drive:** `MyDrive/datasets/nyu_depth_v2_labeled.mat` (or we'll download it)
- [ ] **Expected Runtime:** ~2-3 hours for 90 epochs

---

## 🎯 What This Notebook Does:

1. ✅ Verify A100 GPU is available
2. ✅ Mount Google Drive
3. ✅ Clone your repository to local disk (fast I/O)
4. ✅ Download/copy NYU Depth V2 to local disk (10-20x faster than Drive)
5. ✅ Install dependencies
6. ✅ Train MCResNet with all optimizations
7. ✅ Save checkpoints to Drive (persistent storage)
8. ✅ Generate training curves and analysis

---

**Let's get started!** 🚀

## 1. Environment Setup & GPU Verification

In [None]:
# Check GPU availability and specs
import torch
import subprocess

print("=" * 60)
print("GPU VERIFICATION")
print("=" * 60)

# Check PyTorch and CUDA
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    
    # Check if it's A100
    gpu_name = torch.cuda.get_device_name(0)
    if 'A100' in gpu_name:
        print("\n✅ A100 GPU detected - PERFECT for training!")
    elif 'V100' in gpu_name:
        print("\n✅ V100 GPU detected - Good for training (slower than A100)")
    elif 'T4' in gpu_name:
        print("\n⚠️  T4 GPU detected - Will be slower, consider upgrading to A100")
    else:
        print(f"\n⚠️  GPU: {gpu_name} - Consider using A100 for best performance")
else:
    print("\n❌ NO GPU DETECTED!")
    print("Please enable GPU: Runtime → Change runtime type → Hardware accelerator: GPU")
    raise RuntimeError("GPU is required for training")

print("\n" + "=" * 60)

In [None]:
# Detailed GPU info
!nvidia-smi

## 2. Mount Google Drive

In [None]:
from google.colab import drive
import os
from pathlib import Path

# Mount Google Drive
drive.mount('/content/drive')

print("\n✅ Google Drive mounted successfully!")
print(f"\nDrive contents:")
!ls -la /content/drive/MyDrive/ | head -20

## 3. Clone Repository to Local Disk (Fast I/O)

**Important:** We clone to `/content/` (local SSD) instead of Drive for 10-20x faster I/O

In [None]:
import os
from pathlib import Path

# Configuration
PROJECT_NAME = "Multi-Stream-Neural-Networks"
DRIVE_REPO_PATH = f"/content/drive/MyDrive/{PROJECT_NAME}"  # Your repo on Drive
LOCAL_REPO_PATH = f"/content/{PROJECT_NAME}"  # Local copy for fast I/O

print("=" * 60)
print("REPOSITORY SETUP")
print("=" * 60)

# Option 1: Copy from Drive (if repo is on Drive)
if Path(DRIVE_REPO_PATH).exists():
    print(f"📁 Found repo on Drive: {DRIVE_REPO_PATH}")
    print(f"📥 Copying to local disk for fast I/O...")
    
    # Remove old local copy if exists
    if Path(LOCAL_REPO_PATH).exists():
        !rm -rf {LOCAL_REPO_PATH}
    
    # Copy from Drive to local disk
    !cp -r {DRIVE_REPO_PATH} {LOCAL_REPO_PATH}
    print("✅ Repo copied to local disk")

# Option 2: Clone from GitHub (if you prefer fresh clone)
else:
    print("📁 Repo not found on Drive")
    print("🔄 Cloning from GitHub instead...")
    
    # UPDATE THIS with your GitHub repo URL
    GITHUB_REPO = "https://github.com/YOUR_USERNAME/Multi-Stream-Neural-Networks.git"
    
    # Remove old local copy if exists
    if Path(LOCAL_REPO_PATH).exists():
        !rm -rf {LOCAL_REPO_PATH}
    
    # Clone from GitHub
    !git clone {GITHUB_REPO} {LOCAL_REPO_PATH}
    print("✅ Repo cloned from GitHub")

# Verify repo structure
print(f"\n📂 Repository structure:")
!ls -la {LOCAL_REPO_PATH}

# Change to repo directory
os.chdir(LOCAL_REPO_PATH)
print(f"\n✅ Working directory: {os.getcwd()}")

## 4. Install Dependencies

In [None]:
# Install required packages
print("Installing dependencies...")

!pip install -q h5py tqdm matplotlib seaborn

# Verify installations
import h5py
import tqdm
import matplotlib
import seaborn

print("✅ All dependencies installed!")
print(f"   h5py: {h5py.__version__}")
print(f"   matplotlib: {matplotlib.__version__}")

## 5. Download NYU Depth V2 Dataset to Local Disk

**Performance Note:** Local disk I/O is ~10-20x faster than Drive!

**Options:**
- **Option A:** Download directly to local disk (~2.8 GB, takes 2-3 min)
- **Option B:** Copy from Drive if you already have it there

In [None]:
import urllib.request
from pathlib import Path
from tqdm import tqdm

# Paths
DRIVE_DATASET_PATH = "/content/drive/MyDrive/datasets/nyu_depth_v2_labeled.mat"
LOCAL_DATASET_PATH = "/content/nyu_depth_v2_labeled.mat"  # Local disk (FAST)
DATASET_URL = "http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat"

print("=" * 60)
print("NYU DEPTH V2 DATASET SETUP")
print("=" * 60)

# Check if already on local disk
if Path(LOCAL_DATASET_PATH).exists():
    print(f"✅ Dataset already on local disk: {LOCAL_DATASET_PATH}")
    print(f"   Size: {Path(LOCAL_DATASET_PATH).stat().st_size / 1024**3:.2f} GB")

# Option A: Copy from Drive (if available)
elif Path(DRIVE_DATASET_PATH).exists():
    print(f"📁 Found dataset on Drive: {DRIVE_DATASET_PATH}")
    print(f"📥 Copying to local disk for 10-20x faster training...")
    print(f"   This takes ~2-3 minutes but saves 60+ minutes during training!")
    
    !cp {DRIVE_DATASET_PATH} {LOCAL_DATASET_PATH}
    
    print(f"✅ Dataset copied to local disk")
    print(f"   Size: {Path(LOCAL_DATASET_PATH).stat().st_size / 1024**3:.2f} GB")

# Option B: Download from internet
else:
    print(f"📥 Downloading NYU Depth V2 dataset (~2.8 GB)...")
    print(f"   URL: {DATASET_URL}")
    print(f"   Destination: {LOCAL_DATASET_PATH}")
    print(f"   This will take ~2-3 minutes")
    
    # Download with progress bar
    class DownloadProgressBar(tqdm):
        def update_to(self, b=1, bsize=1, tsize=None):
            if tsize is not None:
                self.total = tsize
            self.update(b * bsize - self.n)
    
    with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc='NYU Depth V2') as t:
        urllib.request.urlretrieve(DATASET_URL, LOCAL_DATASET_PATH, reporthook=t.update_to)
    
    print(f"\n✅ Download complete!")
    print(f"   Size: {Path(LOCAL_DATASET_PATH).stat().st_size / 1024**3:.2f} GB")
    
    # Optionally save to Drive for future use
    save_to_drive = input("\nSave dataset to Drive for future sessions? (y/N): ")
    if save_to_drive.lower() == 'y':
        print("Saving to Drive...")
        !mkdir -p /content/drive/MyDrive/datasets/
        !cp {LOCAL_DATASET_PATH} {DRIVE_DATASET_PATH}
        print("✅ Saved to Drive")

print("\n" + "=" * 60)
print(f"Dataset ready at: {LOCAL_DATASET_PATH}")
print("=" * 60)

## 6. Setup Python Path & Import MCResNet

In [None]:
import sys
import os

# Add project to Python path
project_root = '/content/Multi-Stream-Neural-Networks'
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Verify project structure
print("Project structure:")
!ls -la {project_root}/src/models/

# Import MCResNet
print("\nImporting MCResNet...")
from src.models.multi_channel.mc_resnet import mc_resnet18, mc_resnet50
from src.data_utils.nyu_depth_dataset import create_nyu_dataloaders

print("✅ MCResNet imported successfully!")

## 7. Load NYU Depth V2 Dataset

In [None]:
print("=" * 60)
print("LOADING NYU DEPTH V2 DATASET")
print("=" * 60)

# Dataset configuration
DATASET_CONFIG = {
    'dataset_path': '/content/nyu_depth_v2_labeled.mat',
    'batch_size': 128,  # A100 can handle this with AMP
    'num_workers': 4,   # Colab has limited CPU cores
    'target_size': (224, 224),
    'num_classes': 13   # Scene classification
}

print(f"Configuration:")
for key, value in DATASET_CONFIG.items():
    print(f"  {key}: {value}")

print(f"\nLoading dataset from: {DATASET_CONFIG['dataset_path']}")

# Create dataloaders
train_loader, val_loader = create_nyu_dataloaders(
    h5_file_path=DATASET_CONFIG['dataset_path'],
    batch_size=DATASET_CONFIG['batch_size'],
    num_workers=DATASET_CONFIG['num_workers'],
    target_size=DATASET_CONFIG['target_size'],
    num_classes=DATASET_CONFIG['num_classes']
)

print(f"\n✅ Dataset loaded successfully!")
print(f"\nDataset Statistics:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Train samples: {len(train_loader.dataset)}")
print(f"  Val samples: {len(val_loader.dataset)}")
print(f"  Batch size: {DATASET_CONFIG['batch_size']}")

# Test loading a batch
print(f"\nTesting batch loading...")
rgb_batch, depth_batch, label_batch = next(iter(train_loader))
print(f"  RGB shape: {rgb_batch.shape}")
print(f"  Depth shape: {depth_batch.shape}")
print(f"  Labels shape: {label_batch.shape}")
print(f"  Labels range: {label_batch.min().item()} - {label_batch.max().item()}")

print("\n" + "=" * 60)

## 8. Visualize Sample Data

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Visualize some samples
fig, axes = plt.subplots(2, 4, figsize=(14, 7))

for i in range(4):
    rgb = rgb_batch[i].cpu()
    depth = depth_batch[i].cpu()
    label = label_batch[i].item()
    
    # Denormalize for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    rgb_vis = rgb * std + mean
    rgb_vis = torch.clamp(rgb_vis, 0, 1)
    
    # Plot RGB
    axes[0, i].imshow(rgb_vis.permute(1, 2, 0))
    axes[0, i].set_title(f"RGB - Class {label}", fontsize=10)
    axes[0, i].axis('off')
    
    # Plot Depth
    axes[1, i].imshow(depth.squeeze(), cmap='viridis')
    axes[1, i].set_title(f"Depth - Class {label}", fontsize=10)
    axes[1, i].axis('off')

plt.suptitle('NYU Depth V2 Sample Data (RGB + Depth)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("✅ Sample visualization complete!")

## 9. Create & Compile MCResNet Model

In [None]:
print("=" * 60)
print("MODEL CREATION")
print("=" * 60)

# Model configuration
MODEL_CONFIG = {
    'architecture': 'resnet18',  # or 'resnet50' for better accuracy
    'num_classes': 13,
    'stream1_channels': 3,  # RGB
    'stream2_channels': 1,  # Depth
    'device': 'cuda',
    'use_amp': True  # Automatic Mixed Precision (2x faster on A100)
}

print(f"Configuration:")
for key, value in MODEL_CONFIG.items():
    print(f"  {key}: {value}")

# Create model
print(f"\nCreating MCResNet-{MODEL_CONFIG['architecture'].upper()}...")

if MODEL_CONFIG['architecture'] == 'resnet18':
    model = mc_resnet18(
        num_classes=MODEL_CONFIG['num_classes'],
        stream1_channels=MODEL_CONFIG['stream1_channels'],
        stream2_channels=MODEL_CONFIG['stream2_channels'],
        device=MODEL_CONFIG['device'],
        use_amp=MODEL_CONFIG['use_amp']
    )
elif MODEL_CONFIG['architecture'] == 'resnet50':
    model = mc_resnet50(
        num_classes=MODEL_CONFIG['num_classes'],
        stream1_channels=MODEL_CONFIG['stream1_channels'],
        stream2_channels=MODEL_CONFIG['stream2_channels'],
        device=MODEL_CONFIG['device'],
        use_amp=MODEL_CONFIG['use_amp']
    )

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n✅ Model created successfully!")
print(f"\nModel Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: {total_params * 4 / 1024**2:.2f} MB (FP32)")
print(f"  Device: {MODEL_CONFIG['device']}")
print(f"  AMP enabled: {MODEL_CONFIG['use_amp']}")

print("\n" + "=" * 60)

In [None]:
# Training configuration
TRAINING_CONFIG = {
    'optimizer': 'sgd',
    'learning_rate': 0.1,  # Standard for ImageNet-style training
    'weight_decay': 1e-4,
    'momentum': 0.9,
    'loss': 'cross_entropy',
    'scheduler': 'cosine'
}

print("Compiling model...")
print(f"\nTraining configuration:")
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")

# Compile
model.compile(
    optimizer=TRAINING_CONFIG['optimizer'],
    learning_rate=TRAINING_CONFIG['learning_rate'],
    weight_decay=TRAINING_CONFIG['weight_decay'],
    momentum=TRAINING_CONFIG['momentum'],
    loss=TRAINING_CONFIG['loss'],
    scheduler=TRAINING_CONFIG['scheduler']
)

print("\n✅ Model compiled successfully!")

## 10. Test Forward Pass

In [None]:
# Test forward pass
print("Testing forward pass...")

model.eval()
with torch.no_grad():
    rgb_test, depth_test, labels_test = next(iter(train_loader))
    outputs = model(rgb_test.to('cuda'), depth_test.to('cuda'))
    
    print(f"\nInput shapes:")
    print(f"  RGB: {rgb_test.shape}")
    print(f"  Depth: {depth_test.shape}")
    print(f"  Labels: {labels_test.shape}")
    
    print(f"\nOutput shape: {outputs.shape}")
    print(f"Expected: ({DATASET_CONFIG['batch_size']}, {DATASET_CONFIG['num_classes']})")
    
    _, predictions = torch.max(outputs, 1)
    print(f"\nSample predictions: {predictions.cpu().numpy()[:10]}")
    print(f"Ground truth: {labels_test.numpy()[:10]}")

print("\n✅ Forward pass successful!")

## 11. Setup Checkpoint Directory

In [None]:
from pathlib import Path
from datetime import datetime

# Create checkpoint directory on Drive (persistent storage)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = f"/content/drive/MyDrive/MCResNet_checkpoints/nyu_depth_v2_{timestamp}"
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)

print(f"Checkpoint directory: {checkpoint_dir}")
print(f"\nCheckpoints will be saved to Google Drive for persistence")
print("✅ Directory created")

## 12. Train the Model 🚀

**Expected time:** ~2-3 hours for 90 epochs on A100

**All optimizations enabled:**
- ✅ Automatic Mixed Precision (2x faster)
- ✅ Gradient Clipping (stability)
- ✅ Cosine Annealing LR
- ✅ Early Stopping
- ✅ Best Model Checkpointing
- ✅ Local disk I/O (10-20x faster than Drive)

In [None]:
print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)

# Training configuration
TRAIN_CONFIG = {
    'epochs': 90,  # Standard for ImageNet-style training
    'grad_clip_norm': 5.0,  # Gradient clipping for stability
    'early_stopping': True,
    'patience': 15,
    'min_delta': 0.001,
    'monitor': 'val_accuracy',
    'restore_best_weights': True,
    'save_path': f"{checkpoint_dir}/best_model.pt"
}

print(f"Configuration:")
for key, value in TRAIN_CONFIG.items():
    print(f"  {key}: {value}")

print("\n" + "=" * 60)
print(f"Training will take approximately 2-3 hours on A100")
print(f"Progress will be shown below...")
print("=" * 60 + "\n")

# Train!
history = model.fit(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=TRAIN_CONFIG['epochs'],
    verbose=True,
    save_path=TRAIN_CONFIG['save_path'],
    early_stopping=TRAIN_CONFIG['early_stopping'],
    patience=TRAIN_CONFIG['patience'],
    min_delta=TRAIN_CONFIG['min_delta'],
    monitor=TRAIN_CONFIG['monitor'],
    restore_best_weights=TRAIN_CONFIG['restore_best_weights'],
    grad_clip_norm=TRAIN_CONFIG['grad_clip_norm']
)

print("\n" + "=" * 60)
print("🎉 TRAINING COMPLETE!")
print("=" * 60)

## 13. Evaluate Final Model

In [None]:
print("=" * 60)
print("FINAL EVALUATION")
print("=" * 60)

# Evaluate on validation set
results = model.evaluate(data_loader=val_loader)

print(f"\nFinal Validation Results:")
print(f"  Loss: {results['loss']:.4f}")
print(f"  Accuracy: {results['accuracy']*100:.2f}%")

print(f"\nTraining Summary:")
print(f"  Initial train loss: {history['train_loss'][0]:.4f}")
print(f"  Final train loss: {history['train_loss'][-1]:.4f}")
print(f"  Best val loss: {min(history['val_loss']):.4f}")
print(f"  Initial train acc: {history['train_accuracy'][0]*100:.2f}%")
print(f"  Final train acc: {history['train_accuracy'][-1]*100:.2f}%")
print(f"  Best val acc: {max(history['val_accuracy'])*100:.2f}%")
print(f"  Total epochs: {len(history['train_loss'])}")

if 'early_stopping' in history:
    print(f"\nEarly Stopping Info:")
    print(f"  Stopped early: {history['early_stopping']['stopped_early']}")
    print(f"  Best epoch: {history['early_stopping']['best_epoch']}")
    print(f"  Best {history['early_stopping']['monitor']}: {history['early_stopping']['best_metric']:.4f}")

print("\n" + "=" * 60)

## 14. Plot Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss curve
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Accuracy curve
axes[1].plot([acc*100 for acc in history['train_accuracy']], label='Train Acc', linewidth=2)
axes[1].plot([acc*100 for acc in history['val_accuracy']], label='Val Acc', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

# Learning rate curve
if len(history['learning_rates']) > 0:
    # Sample learning rates (they're recorded per step, not per epoch)
    sampled_lrs = history['learning_rates'][::max(1, len(history['learning_rates'])//100)]
    axes[2].plot(sampled_lrs, linewidth=2, color='green')
    axes[2].set_xlabel('Training Step (sampled)', fontsize=12)
    axes[2].set_ylabel('Learning Rate', fontsize=12)
    axes[2].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    axes[2].grid(True, alpha=0.3)
    axes[2].set_yscale('log')

plt.tight_layout()
plt.savefig(f"{checkpoint_dir}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✅ Training curves saved to: {checkpoint_dir}/training_curves.png")

## 15. Pathway Analysis (RGB vs Depth Contributions)

In [None]:
print("=" * 60)
print("PATHWAY ANALYSIS")
print("=" * 60)
print("\nAnalyzing RGB and Depth pathway contributions...")
print("This may take a few minutes...\n")

# Analyze pathways
pathway_analysis = model.analyze_pathways(
    data_loader=val_loader,
    num_samples=len(val_loader.dataset)  # Use all validation samples
)

print("\nAccuracy Metrics:")
print(f"  Full model (RGB+Depth): {pathway_analysis['accuracy']['full_model']*100:.2f}%")
print(f"  RGB only: {pathway_analysis['accuracy']['color_only']*100:.2f}%")
print(f"  Depth only: {pathway_analysis['accuracy']['brightness_only']*100:.2f}%")
print(f"\n  RGB contribution: {pathway_analysis['accuracy']['color_contribution']*100:.2f}%")
print(f"  Depth contribution: {pathway_analysis['accuracy']['brightness_contribution']*100:.2f}%")

print("\nLoss Metrics:")
print(f"  Full model: {pathway_analysis['loss']['full_model']:.4f}")
print(f"  RGB only: {pathway_analysis['loss']['color_only']:.4f}")
print(f"  Depth only: {pathway_analysis['loss']['brightness_only']:.4f}")

print("\nFeature Norm Statistics:")
print(f"  RGB mean: {pathway_analysis['feature_norms']['color_mean']:.4f}")
print(f"  RGB std: {pathway_analysis['feature_norms']['color_std']:.4f}")
print(f"  Depth mean: {pathway_analysis['feature_norms']['brightness_mean']:.4f}")
print(f"  Depth std: {pathway_analysis['feature_norms']['brightness_std']:.4f}")
print(f"  RGB/Depth ratio: {pathway_analysis['feature_norms']['color_to_brightness_ratio']:.4f}")

print("\n" + "=" * 60)
print("✅ Pathway analysis complete!")
print("=" * 60)

# Visualize pathway contributions
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Accuracy comparison
pathways = ['Full Model\n(RGB+Depth)', 'RGB Only', 'Depth Only']
accuracies = [
    pathway_analysis['accuracy']['full_model'] * 100,
    pathway_analysis['accuracy']['color_only'] * 100,
    pathway_analysis['accuracy']['brightness_only'] * 100
]
colors = ['green', 'blue', 'orange']

axes[0].bar(pathways, accuracies, color=colors, alpha=0.7)
axes[0].set_ylabel('Accuracy (%)', fontsize=12)
axes[0].set_title('Pathway Accuracy Comparison', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(accuracies):
    axes[0].text(i, v + 1, f'{v:.1f}%', ha='center', fontweight='bold')

# Feature norm comparison
norms = ['RGB Features', 'Depth Features']
norm_values = [
    pathway_analysis['feature_norms']['color_mean'],
    pathway_analysis['feature_norms']['brightness_mean']
]
axes[1].bar(norms, norm_values, color=['blue', 'orange'], alpha=0.7)
axes[1].set_ylabel('Feature Norm (Mean)', fontsize=12)
axes[1].set_title('Feature Magnitude Comparison', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(norm_values):
    axes[1].text(i, v + 0.1, f'{v:.2f}', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig(f"{checkpoint_dir}/pathway_analysis.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✅ Pathway analysis plot saved to: {checkpoint_dir}/pathway_analysis.png")

## 16. Save Results & Training History

In [None]:
import json
import torch

print("=" * 60)
print("SAVING RESULTS")
print("=" * 60)

# Save training history as JSON
history_path = f"{checkpoint_dir}/training_history.json"
with open(history_path, 'w') as f:
    json_history = {
        'train_loss': [float(x) for x in history['train_loss']],
        'val_loss': [float(x) for x in history['val_loss']],
        'train_accuracy': [float(x) for x in history['train_accuracy']],
        'val_accuracy': [float(x) for x in history['val_accuracy']],
        'learning_rates': [float(x) for x in history['learning_rates']],
        'model_config': MODEL_CONFIG,
        'dataset_config': DATASET_CONFIG,
        'training_config': TRAINING_CONFIG,
        'final_results': {
            'val_loss': float(results['loss']),
            'val_accuracy': float(results['accuracy'])
        },
        'pathway_analysis': {
            'full_model_accuracy': float(pathway_analysis['accuracy']['full_model']),
            'rgb_only_accuracy': float(pathway_analysis['accuracy']['color_only']),
            'depth_only_accuracy': float(pathway_analysis['accuracy']['brightness_only'])
        }
    }
    if 'early_stopping' in history:
        json_history['early_stopping'] = history['early_stopping']
    
    json.dump(json_history, f, indent=2)

print(f"✅ Training history saved: {history_path}")

# Save final model (in addition to best model)
final_model_path = f"{checkpoint_dir}/final_model.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': model.optimizer.state_dict(),
    'config': MODEL_CONFIG,
    'history': history,
    'val_accuracy': results['accuracy']
}, final_model_path)

print(f"✅ Final model saved: {final_model_path}")

# Save summary report
summary_path = f"{checkpoint_dir}/summary.txt"
with open(summary_path, 'w') as f:
    f.write("=" * 60 + "\n")
    f.write("MCResNet Training Summary - NYU Depth V2\n")
    f.write("=" * 60 + "\n\n")
    f.write(f"Model: MCResNet-{MODEL_CONFIG['architecture'].upper()}\n")
    f.write(f"Dataset: NYU Depth V2 (Scene Classification)\n")
    f.write(f"Training Samples: {len(train_loader.dataset)}\n")
    f.write(f"Validation Samples: {len(val_loader.dataset)}\n")
    f.write(f"Total Parameters: {total_params:,}\n")
    f.write(f"\nTraining Configuration:\n")
    f.write(f"  Epochs: {len(history['train_loss'])}\n")
    f.write(f"  Batch Size: {DATASET_CONFIG['batch_size']}\n")
    f.write(f"  Learning Rate: {TRAINING_CONFIG['learning_rate']}\n")
    f.write(f"  Optimizer: {TRAINING_CONFIG['optimizer']}\n")
    f.write(f"  Scheduler: {TRAINING_CONFIG['scheduler']}\n")
    f.write(f"  AMP: {MODEL_CONFIG['use_amp']}\n")
    f.write(f"  Gradient Clipping: {TRAIN_CONFIG['grad_clip_norm']}\n")
    f.write(f"\nFinal Results:\n")
    f.write(f"  Val Loss: {results['loss']:.4f}\n")
    f.write(f"  Val Accuracy: {results['accuracy']*100:.2f}%\n")
    f.write(f"  Best Val Accuracy: {max(history['val_accuracy'])*100:.2f}%\n")
    f.write(f"\nPathway Analysis:\n")
    f.write(f"  Full Model: {pathway_analysis['accuracy']['full_model']*100:.2f}%\n")
    f.write(f"  RGB Only: {pathway_analysis['accuracy']['color_only']*100:.2f}%\n")
    f.write(f"  Depth Only: {pathway_analysis['accuracy']['brightness_only']*100:.2f}%\n")

print(f"✅ Summary report saved: {summary_path}")

print("\n" + "=" * 60)
print(f"All results saved to: {checkpoint_dir}")
print("=" * 60)

# List saved files
print("\nSaved files:")
!ls -lh {checkpoint_dir}

## 17. Summary & Next Steps

### 🎉 Training Complete!

**What we accomplished:**
- ✅ Trained MCResNet on NYU Depth V2 dataset
- ✅ Used A100 GPU with AMP (2x speedup)
- ✅ Saved all checkpoints to Google Drive
- ✅ Analyzed RGB and Depth pathway contributions
- ✅ Generated training curves and visualizations

**Results are saved to:** Check the output above for the checkpoint directory path

### 📊 Expected Performance:

For **NYU Depth V2 Scene Classification (13 classes)**:
- **Good:** 65-75% validation accuracy
- **Very Good:** 75-80% validation accuracy  
- **Excellent:** 80-85% validation accuracy

### 🔍 Next Steps:

1. **Review Results:**
   - Check training curves above
   - Review pathway analysis
   - Compare RGB vs Depth contributions

2. **Download Results:**
   - All files are saved to your Google Drive
   - Download checkpoints for local inference

3. **Experiment:**
   - Try ResNet50 for better accuracy (change `architecture` in Model Config)
   - Adjust hyperparameters (learning rate, batch size)
   - Train longer if early stopping triggered

4. **Deploy:**
   - Use the best model for inference
   - Test on new RGB-D images
   - Integrate into your application

---

**Questions or issues?** Check the training summary and pathway analysis above!

In [None]:
# Print final summary
print("=" * 60)
print("FINAL SUMMARY")
print("=" * 60)
print(f"\n✅ Training Complete!")
print(f"\nFinal Validation Accuracy: {results['accuracy']*100:.2f}%")
print(f"Best Validation Accuracy: {max(history['val_accuracy'])*100:.2f}%")
print(f"\nRGB Pathway: {pathway_analysis['accuracy']['color_only']*100:.2f}%")
print(f"Depth Pathway: {pathway_analysis['accuracy']['brightness_only']*100:.2f}%")
print(f"Combined (RGB+Depth): {pathway_analysis['accuracy']['full_model']*100:.2f}%")
print(f"\nTotal Training Epochs: {len(history['train_loss'])}")
print(f"Total Parameters: {total_params:,}")
print(f"\nCheckpoints saved to: {checkpoint_dir}")
print("\n" + "=" * 60)
print("🎉 All done! Check Google Drive for saved models and results.")
print("=" * 60)