# MCResNet Training on SUN RGB-D - Google Colab

**Complete end-to-end training pipeline for Google Colab with GPU**

---

## 📋 Checklist Before Running:

- [ ] **Enable GPU:** Runtime → Change runtime type → Hardware accelerator: GPU
- [ ] **Mount Google Drive:** Your code and preprocessed dataset will be stored on Drive
- [ ] **Upload SUN RGB-D dataset to Drive:** `MyDrive/datasets/sunrgbd_15/` (preprocessed locally)
- [ ] **Expected Runtime:** ~3-4 hours for 30 epochs

---

## 🎯 What This Notebook Does:

1. ✅ Verify GPU is available
2. ✅ Mount Google Drive
3. ✅ Clone your repository to local disk (fast I/O)
4. ✅ Copy SUN RGB-D dataset to local disk (10-20x faster than Drive)
5. ✅ Install dependencies
6. ✅ Train MCResNet on 15 scene categories
7. ✅ Save checkpoints to Drive (persistent storage)
8. ✅ Generate training analysis

---

**Dataset:** SUN RGB-D 15 categories (10,059 samples, 8.5x class balance)

**Let's get started!** 🚀

## 1. Environment Setup & GPU Verification

In [None]:
# Check GPU availability and specs
import torch

print("=" * 60)
print("GPU VERIFICATION")
print("=" * 60)

# Check PyTorch and CUDA
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    
    # Check GPU type
    gpu_name = torch.cuda.get_device_name(0)
    if 'A100' in gpu_name:
        print("\n✅ A100 GPU detected - PERFECT for training!")
    elif 'V100' in gpu_name:
        print("\n✅ V100 GPU detected - Good for training!")
    elif 'T4' in gpu_name:
        print("\n✅ T4 GPU detected - Will work fine!")
    else:
        print(f"\n✅ GPU: {gpu_name}")
else:
    print("\n⚠️  NO GPU DETECTED - Training will be slow!")
    print("Please enable GPU: Runtime → Change runtime type → Hardware accelerator: GPU")

print("\n" + "=" * 60)

In [None]:
# Detailed GPU info
!nvidia-smi

## 2. Mount Google Drive

In [None]:
from google.colab import drive
from pathlib import Path

# Mount Google Drive
drive.mount('/content/drive')

print("\n✅ Google Drive mounted successfully!")
print(f"\nChecking for dataset...")

# Check if dataset exists on Drive
dataset_path = Path('/content/drive/MyDrive/datasets/sunrgbd_15')
if dataset_path.exists():
    print(f"✅ Dataset found on Drive!")
    print(f"   Path: {dataset_path}")
else:
    print(f"❌ Dataset NOT found at: {dataset_path}")
    print(f"\nPlease upload the preprocessed dataset to:")
    print(f"   Google Drive → My Drive → datasets → sunrgbd_15/")
    print(f"\nExpected structure:")
    print(f"   sunrgbd_15/")
    print(f"     train/rgb/       (8,041 images)")
    print(f"     train/depth/     (8,041 images)")
    print(f"     train/labels.txt")
    print(f"     val/rgb/         (2,018 images)")
    print(f"     val/depth/       (2,018 images)")
    print(f"     val/labels.txt")
    print(f"     class_names.txt")
    print(f"     dataset_info.txt")

## 3. Clone Repository to Local Disk (Fast I/O)

In [None]:
import os
from pathlib import Path

# Configuration
PROJECT_NAME = "Multi-Stream-Neural-Networks"
GITHUB_REPO = "https://github.com/clingergab/Multi-Stream-Neural-Networks.git"  # UPDATE THIS
LOCAL_REPO_PATH = f"/content/{PROJECT_NAME}"

print("=" * 60)
print("REPOSITORY SETUP")
print("=" * 60)

os.chdir('/content')

# Check if repo already exists
if Path(LOCAL_REPO_PATH).exists() and Path(f"{LOCAL_REPO_PATH}/.git").exists():
    print(f"\n📁 Repo already exists: {LOCAL_REPO_PATH}")
    print(f"🔄 Pulling latest changes...")
    os.chdir(LOCAL_REPO_PATH)
    !git pull
    print("✅ Repo updated")
else:
    # Remove old incomplete copy if exists
    if Path(LOCAL_REPO_PATH).exists():
        !rm -rf {LOCAL_REPO_PATH}
    
    print(f"\n🔄 Cloning from GitHub...")
    !git clone {GITHUB_REPO} {LOCAL_REPO_PATH}
    os.chdir(LOCAL_REPO_PATH)
    print("✅ Repo cloned successfully")

print(f"\n✅ Working directory: {os.getcwd()}")

## 4. Copy Dataset to Local Disk (CRITICAL for Speed!)

**Performance:** Local disk I/O is ~10-20x faster than Drive!

In [None]:
from pathlib import Path
import shutil

# Paths
DRIVE_DATASET_PATH = "/content/drive/MyDrive/datasets/sunrgbd_15"
LOCAL_DATASET_PATH = "/content/data/sunrgbd_15"  # Local disk (FAST)

print("=" * 80)
print("SUN RGB-D DATASET SETUP")
print("=" * 80)

# Check if already on local disk
if Path(LOCAL_DATASET_PATH).exists():
    print(f"\n✅ Dataset already on local disk: {LOCAL_DATASET_PATH}")
    
    # Verify structure
    train_rgb = len(list(Path(LOCAL_DATASET_PATH).glob("train/rgb/*.png")))
    val_rgb = len(list(Path(LOCAL_DATASET_PATH).glob("val/rgb/*.png")))
    
    print(f"   Train samples: {train_rgb}")
    print(f"   Val samples: {val_rgb}")
    
    if train_rgb == 8041 and val_rgb == 2018:
        print(f"   ✅ Dataset complete!")
    else:
        print(f"   ⚠ Dataset incomplete, will re-copy from Drive")
        shutil.rmtree(LOCAL_DATASET_PATH)

# Copy from Drive to local disk
if not Path(LOCAL_DATASET_PATH).exists():
    if Path(DRIVE_DATASET_PATH).exists():
        print(f"\n📁 Found dataset on Drive: {DRIVE_DATASET_PATH}")
        print(f"📥 Copying to local disk for 10-20x faster training...")
        print(f"   This takes ~2-3 minutes but saves 60+ minutes during training!")
        print(f"   Dataset size: ~4.3 GB")
        
        # Create parent directory
        Path(LOCAL_DATASET_PATH).parent.mkdir(parents=True, exist_ok=True)
        
        # Copy (faster with shell command)
        !cp -r {DRIVE_DATASET_PATH} /content/data/
        
        print(f"\n✅ Dataset copied to local disk")
        
        # Verify
        train_rgb = len(list(Path(LOCAL_DATASET_PATH).glob("train/rgb/*.png")))
        val_rgb = len(list(Path(LOCAL_DATASET_PATH).glob("val/rgb/*.png")))
        
        print(f"   Train samples: {train_rgb}")
        print(f"   Val samples: {val_rgb}")
        
        if train_rgb == 8041 and val_rgb == 2018:
            print(f"   ✅ All samples verified!")
    else:
        raise FileNotFoundError(f"Dataset not found on Drive: {DRIVE_DATASET_PATH}")

print("\n" + "=" * 80)
print(f"✅ Dataset ready at: {LOCAL_DATASET_PATH}")
print("=" * 80)

## 5. Install Dependencies

In [None]:
# Install required packages
print("Installing dependencies...")

!pip install -q tqdm matplotlib seaborn

print("✅ All dependencies installed!")

## 6. Setup Python Path & Import Modules

In [None]:
import sys
import os

# Add project to Python path
project_root = '/content/Multi-Stream-Neural-Networks'
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Import modules
print("Importing modules...")
from src.models.multi_channel.mc_resnet import mc_resnet18, mc_resnet50
from src.data_utils.sunrgbd_dataset import get_sunrgbd_dataloaders

print("✅ Modules imported successfully!")

## 7. Load SUN RGB-D Dataset

In [None]:
print("=" * 60)
print("LOADING SUN RGB-D DATASET")
print("=" * 60)

# Dataset configuration
DATASET_CONFIG = {
    'data_root': '/content/data/sunrgbd_15',
    'batch_size': 64,  # Adjust based on GPU memory
    'num_workers': 2,
    'target_size': (224, 224),
    'num_classes': 15  # SUN RGB-D 15 merged categories
}

print(f"\nConfiguration:")
for key, value in DATASET_CONFIG.items():
    print(f"  {key}: {value}")

# Create dataloaders
train_loader, val_loader = get_sunrgbd_dataloaders(
    data_root=DATASET_CONFIG['data_root'],
    batch_size=DATASET_CONFIG['batch_size'],
    num_workers=DATASET_CONFIG['num_workers'],
    target_size=DATASET_CONFIG['target_size'],
)

print(f"\n✅ Dataset loaded successfully!")
print(f"\nDataset Statistics:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Train samples: {len(train_loader.dataset)}")
print(f"  Val samples: {len(val_loader.dataset)}")

# Test loading a batch
rgb_batch, depth_batch, label_batch = next(iter(train_loader))
print(f"\nBatch shapes:")
print(f"  RGB: {rgb_batch.shape}")
print(f"  Depth: {depth_batch.shape}")
print(f"  Labels: {label_batch.shape}")
print(f"  Label range: [{label_batch.min()}, {label_batch.max()}]")

print("\n" + "=" * 60)

## 8. Visualize Sample Data

In [None]:
import matplotlib.pyplot as plt
import torch

# Get class names
CLASS_NAMES = train_loader.dataset.CLASS_NAMES

# Visualize samples
fig, axes = plt.subplots(2, 4, figsize=(14, 7))

for i in range(4):
    rgb = rgb_batch[i].cpu()
    depth = depth_batch[i].cpu()
    label = label_batch[i].item()
    
    # Denormalize RGB
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    rgb_vis = rgb * std + mean
    rgb_vis = torch.clamp(rgb_vis, 0, 1)
    
    # Plot RGB
    axes[0, i].imshow(rgb_vis.permute(1, 2, 0))
    axes[0, i].set_title(f"RGB - {CLASS_NAMES[label]}", fontsize=10)
    axes[0, i].axis('off')
    
    # Plot Depth
    axes[1, i].imshow(depth.squeeze(), cmap='viridis')
    axes[1, i].set_title(f"Depth - {CLASS_NAMES[label]}", fontsize=10)
    axes[1, i].axis('off')

plt.suptitle('SUN RGB-D Sample Data (RGB + Depth)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("✅ Sample visualization complete!")

## 9. Create MCResNet Model

In [None]:
print("=" * 60)
print("MODEL CREATION")
print("=" * 60)

# Model configuration
MODEL_CONFIG = {
    'architecture': 'resnet18',  # or 'resnet50' for better accuracy
    'num_classes': 15,
    'pretrained': False,  # Set True to use ImageNet pretrained weights
}

print(f"\nConfiguration:")
for key, value in MODEL_CONFIG.items():
    print(f"  {key}: {value}")

# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if MODEL_CONFIG['architecture'] == 'resnet18':
    model = mc_resnet18(
        num_classes=MODEL_CONFIG['num_classes'],
        pretrained=MODEL_CONFIG['pretrained'],
    )
elif MODEL_CONFIG['architecture'] == 'resnet50':
    model = mc_resnet50(
        num_classes=MODEL_CONFIG['num_classes'],
        pretrained=MODEL_CONFIG['pretrained'],
    )

model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n✅ Model created successfully!")
print(f"\nModel Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Device: {device}")

print("\n" + "=" * 60)

## 10. Setup Training Configuration

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from datetime import datetime

print("=" * 60)
print("TRAINING CONFIGURATION")
print("=" * 60)

# Training configuration
TRAIN_CONFIG = {
    'epochs': 30,
    'base_lr': 0.001,
    'weight_decay': 1e-4,
    'scheduler': 'cosine',
    
    # Stream-specific optimization (optional)
    'use_stream_specific': False,  # Set True if one stream is overfitting
    'stream1_lr_mult': 1.0,   # RGB stream LR multiplier
    'stream2_lr_mult': 1.5,   # Depth stream LR multiplier (boost if needed)
    'stream1_wd_mult': 1.0,   # RGB weight decay multiplier
    'stream2_wd_mult': 0.5,   # Depth weight decay multiplier
}

print(f"\nConfiguration:")
for key, value in TRAIN_CONFIG.items():
    print(f"  {key}: {value}")

# Create optimizer
if TRAIN_CONFIG['use_stream_specific']:
    print(f"\n✓ Using stream-specific optimization")
    
    # Separate parameters
    stream1_params = []
    stream2_params = []
    other_params = []
    
    for name, param in model.named_parameters():
        if 'stream1' in name:
            stream1_params.append(param)
        elif 'stream2' in name:
            stream2_params.append(param)
        else:
            other_params.append(param)
    
    optimizer = optim.Adam([
        {
            'params': stream1_params,
            'lr': TRAIN_CONFIG['base_lr'] * TRAIN_CONFIG['stream1_lr_mult'],
            'weight_decay': TRAIN_CONFIG['weight_decay'] * TRAIN_CONFIG['stream1_wd_mult']
        },
        {
            'params': stream2_params,
            'lr': TRAIN_CONFIG['base_lr'] * TRAIN_CONFIG['stream2_lr_mult'],
            'weight_decay': TRAIN_CONFIG['weight_decay'] * TRAIN_CONFIG['stream2_wd_mult']
        },
        {
            'params': other_params,
            'lr': TRAIN_CONFIG['base_lr'],
            'weight_decay': TRAIN_CONFIG['weight_decay']
        }
    ])
    
    print(f"  Stream1 (RGB) LR: {TRAIN_CONFIG['base_lr'] * TRAIN_CONFIG['stream1_lr_mult']:.6f}")
    print(f"  Stream2 (Depth) LR: {TRAIN_CONFIG['base_lr'] * TRAIN_CONFIG['stream2_lr_mult']:.6f}")
else:
    print(f"\n✓ Using standard optimization")
    optimizer = optim.Adam(
        model.parameters(),
        lr=TRAIN_CONFIG['base_lr'],
        weight_decay=TRAIN_CONFIG['weight_decay']
    )

# Loss and scheduler
criterion = nn.CrossEntropyLoss()
scheduler = CosineAnnealingLR(optimizer, T_max=TRAIN_CONFIG['epochs'])

# Checkpoint directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = f"/content/drive/MyDrive/mcresnet_checkpoints/sunrgbd_{timestamp}"
!mkdir -p {checkpoint_dir}

print(f"\n✅ Training setup complete!")
print(f"\nCheckpoint directory: {checkpoint_dir}")
print("\n" + "=" * 60)

## 11. Training Loop

In [None]:
from tqdm import tqdm
import json

print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)

best_val_acc = 0.0
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
}

for epoch in range(TRAIN_CONFIG['epochs']):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch + 1}/{TRAIN_CONFIG['epochs']}")
    print(f"{'='*60}")
    
    # Training
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for rgb, depth, labels in tqdm(train_loader, desc="Training"):
        rgb, depth, labels = rgb.to(device), depth.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(rgb, depth)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        train_total += labels.size(0)
        train_correct += predicted.eq(labels).sum().item()
    
    train_loss = train_loss / len(train_loader)
    train_acc = train_correct / train_total
    
    # Validation
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for rgb, depth, labels in tqdm(val_loader, desc="Validation"):
            rgb, depth, labels = rgb.to(device), depth.to(device), labels.to(device)
            outputs = model(rgb, depth)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()
    
    val_loss = val_loss / len(val_loader)
    val_acc = val_correct / val_total
    
    # Update scheduler
    scheduler.step()
    
    # Record history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print results
    print(f"\nEpoch {epoch + 1} Results:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val Loss: {val_loss:.4f}   | Val Acc: {val_acc*100:.2f}%")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        print(f"\n💾 New best validation accuracy: {val_acc*100:.2f}% - Saving...")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'config': MODEL_CONFIG
        }, f"{checkpoint_dir}/best_model.pth")

print("\n" + "=" * 60)
print("🎉 TRAINING COMPLETE!")
print(f"Best validation accuracy: {best_val_acc*100:.2f}%")
print("=" * 60)

# Save history
with open(f"{checkpoint_dir}/training_history.json", 'w') as f:
    json.dump(history, f, indent=2)

print(f"\n✅ Results saved to: {checkpoint_dir}")

## 12. Plot Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curve
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Accuracy curve
axes[1].plot([acc*100 for acc in history['train_acc']], label='Train Acc', linewidth=2)
axes[1].plot([acc*100 for acc in history['val_acc']], label='Val Acc', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{checkpoint_dir}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✅ Training curves saved to: {checkpoint_dir}/training_curves.png")

## 13. Final Summary

In [None]:
print("=" * 60)
print("FINAL SUMMARY")
print("=" * 60)

print(f"\n✅ Training Complete!")
print(f"\nDataset: SUN RGB-D (15 categories)")
print(f"  Train samples: {len(train_loader.dataset)}")
print(f"  Val samples: {len(val_loader.dataset)}")

print(f"\nModel: MCResNet-{MODEL_CONFIG['architecture'].upper()}")
print(f"  Parameters: {total_params:,}")

print(f"\nTraining Results:")
print(f"  Epochs trained: {len(history['train_loss'])}")
print(f"  Best validation accuracy: {best_val_acc*100:.2f}%")
print(f"  Final train accuracy: {history['train_acc'][-1]*100:.2f}%")
print(f"  Final val accuracy: {history['val_acc'][-1]*100:.2f}%")

print(f"\nCheckpoints saved to:")
print(f"  {checkpoint_dir}")

print("\n" + "=" * 60)
print("🎉 All done! Check Google Drive for saved models.")
print("=" * 60)