# LINet Training on SUN RGB-D - Google Colab

**Complete end-to-end training pipeline for Linear Integration ResNet (LINet) on Google Colab with A100 GPU**

---

## 📋 Checklist Before Running:

- [ ] **Enable A100 GPU:** Runtime → Change runtime type → Hardware accelerator: GPU → GPU type: A100
- [ ] **Mount Google Drive:** Your code and dataset will be stored on Drive
- [ ] **Upload dataset to Drive:** `MyDrive/datasets/sunrgbd_15/` (preprocessed 15-category dataset)
- [ ] **Expected Runtime:** ~2-3 hours for training

---

## 🎯 What This Notebook Does:

1. ✅ Verify A100 GPU is available
2. ✅ Mount Google Drive
3. ✅ Clone your repository to local disk (fast I/O)
4. ✅ Copy SUN RGB-D dataset to local disk (10-20x faster than Drive)
5. ✅ Install dependencies
6. ✅ Train LINet (3-stream Linear Integration ResNet) with all optimizations
7. ✅ Save checkpoints to Drive (persistent storage)
8. ✅ Generate training curves and analysis

---

## 🧠 About LINet:

**LINet** (Linear Integration Network) is a 3-stream neural network architecture where:
- **Stream1** processes RGB images
- **Stream2** processes Depth maps
- **Integrated Stream** combines both streams using learned linear integration weights

Unlike traditional fusion methods, LINet performs integration **at the neuron level** through 5 weight matrices per convolution:
- `stream1_weight` (full kernel for RGB)
- `stream2_weight` (full kernel for Depth)
- `integrated_weight` (1×1 channel-wise for integrated features)
- `integration_from_stream1` (1×1 integration from RGB)
- `integration_from_stream2` (1×1 integration from Depth)

This allows the network to learn optimal integration strategies at every layer!

---

**Let's get started!** 🚀

## 1. Environment Setup & GPU Verification

In [None]:
# Check GPU availability and specs
import torch
import subprocess

print("=" * 60)
print("GPU VERIFICATION")
print("=" * 60)

# Check PyTorch and CUDA
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    
    # Check if it's A100
    gpu_name = torch.cuda.get_device_name(0)
    if 'A100' in gpu_name:
        print("\n✅ A100 GPU detected - PERFECT for training!")
    elif 'V100' in gpu_name:
        print("\n✅ V100 GPU detected - Good for training (slower than A100)")
    elif 'T4' in gpu_name:
        print("\n⚠️  T4 GPU detected - Will be slower, consider upgrading to A100")
    else:
        print(f"\n⚠️  GPU: {gpu_name} - Consider using A100 for best performance")
else:
    print("\n❌ NO GPU DETECTED!")
    print("Please enable GPU: Runtime → Change runtime type → Hardware accelerator: GPU")
    raise RuntimeError("GPU is required for training")

print("\n" + "=" * 60)

In [None]:
# Detailed GPU info
!nvidia-smi

## 2. Mount Google Drive

In [None]:
from google.colab import drive
import os
from pathlib import Path

# Mount Google Drive
drive.mount('/content/drive')

print("\n✅ Google Drive mounted successfully!")
print(f"\nDrive contents:")
!ls -la /content/drive/MyDrive/ | head -20

## 3. Clone Repository to Local Disk (Fast I/O)

**Important:** We clone to `/content/` (local SSD) instead of Drive for 10-20x faster I/O

**Default:** Clone from GitHub (recommended - always gets latest code)

In [None]:
import os
from pathlib import Path

# Configuration
PROJECT_NAME = "Multi-Stream-Neural-Networks"
GITHUB_REPO = "https://github.com/clingergab/Multi-Stream-Neural-Networks.git"  # UPDATE THIS
LOCAL_REPO_PATH = f"/content/{PROJECT_NAME}"  # Local copy for fast I/O

print("=" * 60)
print("REPOSITORY SETUP")
print("=" * 60)

# Ensure we're in a valid directory
os.chdir('/content')
print(f"Starting in: {os.getcwd()}")

# Check if repo already exists (same session, rerunning cell)
if Path(LOCAL_REPO_PATH).exists() and Path(f"{LOCAL_REPO_PATH}/.git").exists():
    print(f"\n📁 Repo already exists: {LOCAL_REPO_PATH}")
    print(f"🔄 Pulling latest changes...")
    
    os.chdir(LOCAL_REPO_PATH)
    !git pull
    print("✅ Repo updated")

# Clone from GitHub (first run)
else:
    # Remove old incomplete copy if exists
    if Path(LOCAL_REPO_PATH).exists():
        print(f"\n🗑️  Removing incomplete repo copy...")
        !rm -rf {LOCAL_REPO_PATH}
    
    print(f"\n🔄 Cloning from GitHub...")
    print(f"   Repo: {GITHUB_REPO}")
    print(f"   Destination: {LOCAL_REPO_PATH}")
    
    !git clone {GITHUB_REPO} {LOCAL_REPO_PATH}
    
    # Verify clone succeeded
    if not Path(LOCAL_REPO_PATH).exists():
        raise RuntimeError(f"Failed to clone repository to {LOCAL_REPO_PATH}")
    
    print("✅ Repo cloned successfully")
    os.chdir(LOCAL_REPO_PATH)

# Verify repo structure
print(f"\n📂 Repository structure:")
!ls -la {LOCAL_REPO_PATH}

print(f"\n✅ Working directory: {os.getcwd()}")

## 4. Install Dependencies

In [None]:
# Install required packages
print("Installing dependencies...")

!pip install -q h5py tqdm matplotlib seaborn

# Verify installations
import h5py
import tqdm
import matplotlib
import seaborn

print("✅ All dependencies installed!")
print(f"   h5py: {h5py.__version__}")
print(f"   matplotlib: {matplotlib.__version__}")

## 5. Copy SUN RGB-D Dataset to Local Disk

**Performance Note:** Local disk I/O is ~10-20x faster than Drive!

**Dataset:** SUN RGB-D 15-category preprocessed dataset (~3.5 GB)

In [None]:
from pathlib import Path
import os

# Paths
DRIVE_DATASET_TAR = "/content/drive/MyDrive/datasets/sunrgbd_15.tar.gz"  # Compressed file
LOCAL_DATASET_PATH = "/content/data/sunrgbd_15"  # Extracted location

print("=" * 60)
print("SUN RGB-D 15-CATEGORY DATASET SETUP")
print("=" * 60)

# Check if already on local disk
if Path(LOCAL_DATASET_PATH).exists():
    print(f"✅ Dataset already on local disk: {LOCAL_DATASET_PATH}")
    
    # Verify structure
    train_rgb_count = len(list(Path(f"{LOCAL_DATASET_PATH}/train/rgb").glob("*.png")))
    val_rgb_count = len(list(Path(f"{LOCAL_DATASET_PATH}/val/rgb").glob("*.png")))
    print(f"   Train samples: {train_rgb_count}")
    print(f"   Val samples: {val_rgb_count}")

# Copy and extract from Drive
elif Path(DRIVE_DATASET_TAR).exists():
    print(f"📁 Found compressed dataset on Drive: {DRIVE_DATASET_TAR}")
    print(f"📥 Copying 4.2GB compressed file to local disk...")
    print(f"   ⏱️  This takes ~3-5 minutes (much faster than 20k individual files!)")
    
    # Create parent directory
    !mkdir -p /content/data
    
    # Copy compressed file with progress
    print(f"\nCopying compressed archive...")
    !rsync -ah --info=progress2 {DRIVE_DATASET_TAR} /content/data/sunrgbd_15.tar.gz
    
    # Extract to local disk (suppress macOS metadata warnings)
    print(f"\n📦 Extracting dataset to local disk...")
    !tar -xzf /content/data/sunrgbd_15.tar.gz -C /content/data/ 2>&1 | grep -v "Ignoring unknown extended header"
    
    # Remove tar file to save space
    !rm /content/data/sunrgbd_15.tar.gz
    
    print(f"\n✅ Dataset extracted to local disk")
    
    # Verify extraction
    train_rgb_count = len(list(Path(f"{LOCAL_DATASET_PATH}/train/rgb").glob("*.png")))
    val_rgb_count = len(list(Path(f"{LOCAL_DATASET_PATH}/val/rgb").glob("*.png")))
    print(f"   Train samples: {train_rgb_count}")
    print(f"   Val samples: {val_rgb_count}")

else:
    print(f"❌ Compressed dataset not found on Drive!")
    print(f"   Expected location: {DRIVE_DATASET_TAR}")
    print(f"\n📋 To fix this:")
    print(f"   1. Run: COPYFILE_DISABLE=1 tar -czf sunrgbd_15.tar.gz sunrgbd_15/")
    print(f"   2. Upload sunrgbd_15.tar.gz to Google Drive")
    print(f"   3. Place it at: {DRIVE_DATASET_TAR}")
    raise FileNotFoundError(f"Compressed dataset not found at {DRIVE_DATASET_TAR}")

print("\n" + "=" * 60)
print(f"Dataset ready at: {LOCAL_DATASET_PATH}")
print("=" * 60)

## 6. Setup Python Path & Import LINet

In [None]:
import sys
import os

# Remove cached modules
modules_to_reload = [k for k in sys.modules.keys() if k.startswith('src.')]
for module in modules_to_reload:
    del sys.modules[module]
    
# Add project to Python path
project_root = '/content/Multi-Stream-Neural-Networks'
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Verify project structure
print("Project structure:")
!ls -la {project_root}/src/models/

# Import LINet and SUN RGB-D dataloader
print("\nImporting LINet and dataloaders...")
from src.models.linear_integration.li_net import li_resnet18, li_resnet50
from src.data_utils.sunrgbd_dataset import get_sunrgbd_dataloaders

print("✅ LINet and dataloaders imported successfully!")

## 7. Load SUN RGB-D Dataset

In [None]:
# Verify dataset structure
from pathlib import Path

print("=" * 60)
print("DATASET STRUCTURE VERIFICATION")
print("=" * 60)

dataset_root = Path(LOCAL_DATASET_PATH)

print("\nDirectory structure:")
print(f"  {dataset_root}/")
print(f"    train/")
print(f"      rgb/ - {len(list((dataset_root / 'train' / 'rgb').glob('*.png')))} images")
print(f"      depth/ - {len(list((dataset_root / 'train' / 'depth').glob('*.png')))} images")
print(f"      labels.txt")
print(f"    val/")
print(f"      rgb/ - {len(list((dataset_root / 'val' / 'rgb').glob('*.png')))} images")
print(f"      depth/ - {len(list((dataset_root / 'val' / 'depth').glob('*.png')))} images")
print(f"      labels.txt")
print(f"    class_names.txt")
print(f"    dataset_info.txt")

# Read class names
with open(dataset_root / 'class_names.txt', 'r') as f:
    class_names = [line.strip() for line in f]

print(f"\nClasses ({len(class_names)}):")
for i, name in enumerate(class_names):
    print(f"  {i}: {name}")

print("\n" + "=" * 60)

In [None]:
print("=" * 60)
print("LOADING SUN RGB-D 15-CATEGORY DATASET")
print("=" * 60)

# Dataset configuration
DATASET_CONFIG = {
    'data_root': LOCAL_DATASET_PATH,
    'batch_size': 128,  # Good balance for A100
    'num_workers': 4,
    'target_size': (416, 544),  
    'num_classes': 15   # SUN RGB-D merged to 15 categories (labels 0-14)
}

print(f"Configuration:")
for key, value in DATASET_CONFIG.items():
    print(f"  {key}: {value}")

print(f"\nLoading dataset from: {DATASET_CONFIG['data_root']}")

# Create dataloaders
train_loader, val_loader = get_sunrgbd_dataloaders(
    data_root=DATASET_CONFIG['data_root'],
    batch_size=DATASET_CONFIG['batch_size'],
    num_workers=DATASET_CONFIG['num_workers'],
    target_size=DATASET_CONFIG['target_size']
)

print(f"\n✅ Dataset loaded successfully!")
print(f"\nDataset Statistics:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Train samples: {len(train_loader.dataset)}")
print(f"  Val samples: {len(val_loader.dataset)}")
print(f"  Batch size: {DATASET_CONFIG['batch_size']}")

# Test loading a batch
print(f"\nTesting batch loading...")
rgb_batch, depth_batch, label_batch = next(iter(train_loader))
print(f"  RGB shape: {rgb_batch.shape}")
print(f"  Depth shape: {depth_batch.shape}")
print(f"  Labels shape: {label_batch.shape}")
print(f"  Labels min: {label_batch.min().item()}, max: {label_batch.max().item()}")

print("\n" + "=" * 60)

## 8. Visualize Sample Data

Shows RGB images, depth maps, and scene labels from the dataset

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Visualize some samples
fig, axes = plt.subplots(2, 4, figsize=(14, 7))

for i in range(4):
    rgb = rgb_batch[i].cpu()
    depth = depth_batch[i].cpu()
    label = label_batch[i].item()
    
    # Denormalize for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    rgb_vis = rgb * std + mean
    rgb_vis = torch.clamp(rgb_vis, 0, 1)
    
    # Plot RGB
    axes[0, i].imshow(rgb_vis.permute(1, 2, 0))
    axes[0, i].set_title(f"RGB - Class {label}", fontsize=10)
    axes[0, i].axis('off')
    
    # Plot Depth
    axes[1, i].imshow(depth.squeeze(), cmap='viridis')
    axes[1, i].set_title(f"Depth - Class {label}", fontsize=10)
    axes[1, i].axis('off')

plt.suptitle('SUN RGB-D Sample Data (RGB + Depth)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("✅ Sample visualization complete!")

## 9. Create & Compile LINet Model

In [None]:
print("=" * 60)
print("MODEL CREATION")
print("=" * 60)

# Model configuration
MODEL_CONFIG = {
    'architecture': 'resnet18',  # or 'resnet50' for better accuracy
    'num_classes': 15,  # SUN RGB-D has 15 merged categories (labels 0-14)
    'stream1_channels': 3,  # RGB
    'stream2_channels': 1,  # Depth
    'dropout_p': 0.5,  # Dropout for regularization
    'device': 'cuda',
    'use_amp': True  # Automatic Mixed Precision (2x faster on A100)
}

print(f"Configuration:")
for key, value in MODEL_CONFIG.items():
    print(f"  {key}: {value}")

# Create model
print(f"\nCreating LINet-{MODEL_CONFIG['architecture'].upper()} (3-stream Linear Integration ResNet)...")

if MODEL_CONFIG['architecture'] == 'resnet18':
    model = li_resnet18(
        num_classes=MODEL_CONFIG['num_classes'],
        stream1_input_channels=MODEL_CONFIG['stream1_channels'],
        stream2_input_channels=MODEL_CONFIG['stream2_channels'],
        dropout_p=MODEL_CONFIG['dropout_p'],
        device=MODEL_CONFIG['device'],
        use_amp=MODEL_CONFIG['use_amp']
    )
elif MODEL_CONFIG['architecture'] == 'resnet50':
    model = li_resnet50(
        num_classes=MODEL_CONFIG['num_classes'],
        stream1_input_channels=MODEL_CONFIG['stream1_channels'],
        stream2_input_channels=MODEL_CONFIG['stream2_channels'],
        dropout_p=MODEL_CONFIG['dropout_p'],
        device=MODEL_CONFIG['device'],
        use_amp=MODEL_CONFIG['use_amp']
    )

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

# Count integration-specific parameters
integration_params = 0
for name, param in model.named_parameters():
    if 'integration' in name or 'integrated_weight' in name:
        integration_params += param.numel()

print(f"\n✅ Model created successfully!")
print(f"\nModel Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Integration parameters: {integration_params:,}")
print(f"  Model size: {total_params * 4 / 1024**2:.2f} MB (FP32)")
print(f"  Architecture: 3-stream Linear Integration (LINet)")
print(f"  Device: {MODEL_CONFIG['device']}")
print(f"  AMP enabled: {MODEL_CONFIG['use_amp']}")

print("\n" + "=" * 60)

## 9b. Model Compilation (Keras-Style API with Warmup)

**Create optimizer and scheduler as objects, then pass to compile()**

**NEW:** Learning rate warmup support! The scheduler will linearly increase the learning rate from a lower starting point to the target LR over the first few epochs, helping stabilize early training.

In [None]:
# Compile model with stream-specific optimization
print("=" * 60)
print("MODEL COMPILATION")
print("=" * 60)

# Import optimizer and scheduler utilities
from src.training.optimizers import create_stream_optimizer
from src.training.schedulers import setup_scheduler

# Stream-specific configuration for optimal RGB/Depth balance
STREAM_SPECIFIC_CONFIG = {
    # Stream-specific learning rates (adjusted based on research):
    'stream1_lr': 3e-5,               # RGB stream: lower LR (more regularization)
    'stream2_lr': 1e-4,               # Depth stream: higher LR (needs more learning)
    'shared_lr': 7e-5,                # Shared params: base LR

    # Stream-specific weight decay:
    'stream1_weight_decay': 5e-4,     # RGB stream: higher WD (prevent overfitting)
    'stream2_weight_decay': 1e-4,     # Depth stream: lighter WD (less regularization)
    'shared_weight_decay': 2e-4,      # Shared params: base WD
}

# Scheduler configuration (with warmup support!)
SCHEDULER_CONFIG = {
    'scheduler_type': 'cosine',
    't_max': 80,  # Will be updated to match epochs in training config
    'eta_min': 1e-6,
    'warmup_epochs': 5,  # Warmup: linearly increase LR for first 5 epochs
    'warmup_start_factor': 0.1  # Start at 10% of target LR during warmup
}

print(f"Stream-Specific Configuration:")
for key, value in STREAM_SPECIFIC_CONFIG.items():
    print(f"  {key}: {value}")

print(f"\nScheduler Configuration (with warmup):")
for key, value in SCHEDULER_CONFIG.items():
    print(f"  {key}: {value}")

# Step 1: Create optimizer with stream-specific learning rates
print("\n[Step 1] Creating stream-specific optimizer...")
optimizer = create_stream_optimizer(
    model,
    optimizer_type='adamw',
    stream1_lr=STREAM_SPECIFIC_CONFIG['stream1_lr'],
    stream2_lr=STREAM_SPECIFIC_CONFIG['stream2_lr'],
    shared_lr=STREAM_SPECIFIC_CONFIG['shared_lr'],
    stream1_weight_decay=STREAM_SPECIFIC_CONFIG['stream1_weight_decay'],
    stream2_weight_decay=STREAM_SPECIFIC_CONFIG['stream2_weight_decay'],
    shared_weight_decay=STREAM_SPECIFIC_CONFIG['shared_weight_decay']
)

print(f"✅ Optimizer created: {optimizer.__class__.__name__}")
print(f"   Parameter groups: {len(optimizer.param_groups)}")
for i, group in enumerate(optimizer.param_groups):
    num_params = sum(p.numel() for p in group['params'])
    print(f"   Group {i+1}: lr={group['lr']:.2e}, wd={group['weight_decay']:.2e}, params={num_params:,}")

# Step 2: Create scheduler with warmup support
print("\n[Step 2] Creating learning rate scheduler with warmup...")
scheduler = setup_scheduler(
    optimizer,
    scheduler_type=SCHEDULER_CONFIG['scheduler_type'],
    epochs=80,  # Placeholder - will match TRAIN_CONFIG['epochs']
    train_loader_len=len(train_loader),
    t_max=SCHEDULER_CONFIG['t_max'],
    eta_min=SCHEDULER_CONFIG['eta_min'],
    warmup_epochs=SCHEDULER_CONFIG['warmup_epochs'],
    warmup_start_factor=SCHEDULER_CONFIG['warmup_start_factor']
)

print(f"✅ Scheduler created: {scheduler.__class__.__name__}")
print(f"   Warmup: {SCHEDULER_CONFIG['warmup_epochs']} epochs (LR: {SCHEDULER_CONFIG['warmup_start_factor']*100:.0f}% → 100%)")
print(f"   Main scheduler: {SCHEDULER_CONFIG['scheduler_type']} annealing")

# Step 3: Compile model with optimizer and scheduler objects (Keras-style!)
print("\n[Step 3] Compiling model with optimizer and scheduler objects...")
model.compile(
    optimizer=optimizer,
    scheduler=scheduler,
    loss='cross_entropy',
    label_smoothing=0.1
)

print("\n✅ Model compiled successfully (Keras-style)!")
print("\n💡 Learning rate warmup enabled!")
print(f"   First {SCHEDULER_CONFIG['warmup_epochs']} epochs: LR increases from {SCHEDULER_CONFIG['warmup_start_factor']*100:.0f}% to 100%")
print(f"   Remaining epochs: {SCHEDULER_CONFIG['scheduler_type']} annealing from 100% to {SCHEDULER_CONFIG['eta_min']:.0e}")
print("\n" + "=" * 60)

## 10. Test Forward Pass

In [None]:
# Test forward pass with detailed debugging
print("Testing forward pass with CUDA_LAUNCH_BLOCKING for better error messages...")

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # Synchronous CUDA for better error messages

model.eval()
with torch.no_grad():
    rgb_test, depth_test, labels_test = next(iter(train_loader))
    
    print(f"\nInput validation:")
    print(f"  RGB shape: {rgb_test.shape}, dtype: {rgb_test.dtype}")
    print(f"  RGB min: {rgb_test.min():.4f}, max: {rgb_test.max():.4f}")
    print(f"  RGB has NaN: {torch.isnan(rgb_test).any()}")
    print(f"  RGB has Inf: {torch.isinf(rgb_test).any()}")
    
    print(f"\n  Depth shape: {depth_test.shape}, dtype: {depth_test.dtype}")
    print(f"  Depth min: {depth_test.min():.4f}, max: {depth_test.max():.4f}")
    print(f"  Depth has NaN: {torch.isnan(depth_test).any()}")
    print(f"  Depth has Inf: {torch.isinf(depth_test).any()}")
    
    print(f"\n  Labels shape: {labels_test.shape}, dtype: {labels_test.dtype}")
    print(f"  Labels min: {labels_test.min()}, max: {labels_test.max()}")
    print(f"  Labels unique: {torch.unique(labels_test).tolist()}")
    
    print("\nRunning forward pass...")
    rgb_cuda = rgb_test.to('cuda')
    depth_cuda = depth_test.to('cuda')
    
    try:
        outputs = model(rgb_cuda, depth_cuda)
        print(f"  ✅ Forward pass successful!")
        print(f"  Output shape: {outputs.shape}")
        print(f"  Output min: {outputs.min():.4f}, max: {outputs.max():.4f}")
        
        _, predictions = torch.max(outputs, 1)
        print(f"\nSample predictions: {predictions.cpu().numpy()[:10]}")
        print(f"Ground truth: {labels_test.numpy()[:10]}")
        
    except Exception as e:
        print(f"\n❌ Forward pass failed!")
        print(f"Error: {e}")
        print(f"\nThis is likely a model architecture issue, not a data issue.")
        print(f"Possible causes:")
        print(f"  1. BatchNorm running stats issue")
        print(f"  2. Invalid tensor operations in model")
        print(f"  3. Memory corruption")
        raise

print("\n✅ Forward pass test complete!")

## 11. Setup Checkpoint Directory

In [None]:
import os
from datetime import datetime
from pathlib import Path

print("=" * 60)
print("CHECKPOINT DIRECTORY SETUP")
print("=" * 60)

# Create checkpoint directory on Google Drive (persistent storage)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = f"/content/drive/MyDrive/linet_checkpoints/run_{timestamp}"

# Create directory
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)

print(f"\n✅ Checkpoint directory created:")
print(f"   {checkpoint_dir}")
print(f"\nAll training artifacts will be saved here:")
print(f"  • Best model weights")
print(f"  • Training history")
print(f"  • Monitoring metrics")
print(f"  • Visualizations")

print("\n" + "=" * 60)

## 13. Train the Model 🚀

**Expected time:** ~2-3 hours for 90 epochs on A100

**All optimizations enabled:**
- ✅ Automatic Mixed Precision (2x faster)
- ✅ Gradient Clipping (stability)
- ✅ Cosine Annealing LR
- ✅ Early Stopping
- ✅ Best Model Checkpointing
- ✅ Local disk I/O (10-20x faster than Drive)

In [None]:
import warnings

# Suppress PyTorch SequentialLR deprecation warning (internal PyTorch issue, not our code)
warnings.filterwarnings(
    'ignore',
    message='The epoch parameter in `scheduler.step\\(\\)` was not necessary',
    category=UserWarning
)

print("=" * 60)
print("TRAINING WITH STREAM MONITORING")
print("=" * 60)

# Training configuration
TRAIN_CONFIG = {
    'epochs': 80,
    'grad_clip_norm': 5.0,
    'early_stopping': True,
    'patience': 12,
    'min_delta': 0.001,
    'monitor': 'val_accuracy',
    'restore_best_weights': True,
    'save_path': f"{checkpoint_dir}/best_model.pt",
    'stream_monitoring': True,  # Built-in stream monitoring!
}

print(f"Configuration:")
for key, value in TRAIN_CONFIG.items():
    print(f"  {key}: {value}")

print("\n" + "=" * 60)
print(f"Training will take approximately 2-3 hours on A100")
print(f"Stream monitoring active - detailed per-stream metrics shown each epoch")
print("==" * 60 + "\n")

# Train using built-in fit() method with stream monitoring
# NOTE: No scheduler_kwargs needed! Scheduler was already created and passed to compile()
history = model.fit(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=TRAIN_CONFIG['epochs'],
    verbose=True,
    save_path=TRAIN_CONFIG['save_path'],
    early_stopping=TRAIN_CONFIG['early_stopping'],
    patience=TRAIN_CONFIG['patience'],
    min_delta=TRAIN_CONFIG['min_delta'],
    monitor=TRAIN_CONFIG['monitor'],
    restore_best_weights=TRAIN_CONFIG['restore_best_weights'],
    grad_clip_norm=TRAIN_CONFIG['grad_clip_norm'],
    stream_monitoring=TRAIN_CONFIG['stream_monitoring']  # Enable built-in monitoring
)

print("\n" + "=" * 60)
print("🎉 TRAINING COMPLETE!")
print("=" * 60)

## 13. Evaluate Final Model

In [None]:
print("=" * 60)
print("FINAL EVALUATION")
print("=" * 60)

# Evaluate on validation set
results = model.evaluate(data_loader=val_loader)

print(f"\nFinal Validation Results:")
print(f"  Loss: {results['loss']:.4f}")
print(f"  Overall Accuracy: {results['accuracy']*100:.2f}%")

# Show stream-specific accuracies if available
if 'stream1_accuracy' in results:
    print(f"\nStream-Specific Performance:")
    print(f"  Stream1 (RGB) Accuracy: {results['stream1_accuracy']*100:.2f}%")
    print(f"  Stream2 (Depth) Accuracy: {results['stream2_accuracy']*100:.2f}%")

print(f"\nTraining Summary:")
print(f"  Initial train loss: {history['train_loss'][0]:.4f}")
print(f"  Final train loss: {history['train_loss'][-1]:.4f}")
print(f"  Best val loss: {min(history['val_loss']):.4f}")
print(f"  Initial train acc: {history['train_accuracy'][0]*100:.2f}%")
print(f"  Final train acc: {history['train_accuracy'][-1]*100:.2f}%")
print(f"  Best val acc: {max(history['val_accuracy'])*100:.2f}%")
print(f"  Total epochs: {len(history['train_loss'])}")

if 'early_stopping' in history:
    print(f"\nEarly Stopping Info:")
    print(f"  Stopped early: {history['early_stopping']['stopped_early']}")
    print(f"  Best epoch: {history['early_stopping']['best_epoch']}")
    print(f"  Best {history['early_stopping']['monitor']}: {history['early_stopping']['best_metric']:.4f}")

print("\n" + "=" * 60)

## 14. Plot Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss curve
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Accuracy curve with stream-specific curves
# Note: train_accuracy and val_accuracy are the full model (integrated stream) accuracies
axes[1].plot([acc*100 for acc in history['train_accuracy']], label='Full Model Train', linewidth=2, color='green')
axes[1].plot([acc*100 for acc in history['val_accuracy']], label='Full Model Val', linewidth=2, color='darkgreen')

# Add stream-specific curves if available (stream1 and stream2 only)
if len(history.get('stream1_train_acc', [])) > 0:
    axes[1].plot([acc*100 for acc in history['stream1_train_acc']], 
                label='Stream1 (RGB) Train', linewidth=1, alpha=0.6, linestyle='--', color='skyblue')
    axes[1].plot([acc*100 for acc in history['stream1_val_acc']], 
                label='Stream1 (RGB) Val', linewidth=1, alpha=0.6, linestyle='--', color='blue')
    axes[1].plot([acc*100 for acc in history['stream2_train_acc']], 
                label='Stream2 (Depth) Train', linewidth=1, alpha=0.6, linestyle='--', color='lightcoral')
    axes[1].plot([acc*100 for acc in history['stream2_val_acc']], 
                label='Stream2 (Depth) Val', linewidth=1, alpha=0.6, linestyle='--', color='red')

axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Training and Validation Accuracy\n(Full Model = Integrated Stream)', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=9, loc='lower right')
axes[1].grid(True, alpha=0.3)

# Learning rate curve with stream-specific LRs
if len(history['learning_rates']) > 0:
    # Sample learning rates (they're recorded per step, not per epoch)
    sampled_lrs = history['learning_rates'][::max(1, len(history['learning_rates'])//100)]
    axes[2].plot(sampled_lrs, linewidth=2, color='green', label='Base LR')
    
    # Add stream-specific LRs if available (recorded per epoch)
    if len(history.get('stream1_lr', [])) > 0:
        axes[2].plot(history['stream1_lr'], linewidth=1, alpha=0.7, linestyle='--', 
                    color='blue', label='Stream1 (RGB) LR')
        axes[2].plot(history['stream2_lr'], linewidth=1, alpha=0.7, linestyle='--', 
                    color='red', label='Stream2 (Depth) LR')
    
    axes[2].set_xlabel('Epoch' if len(history.get('stream1_lr', [])) > 0 else 'Training Step (sampled)', fontsize=12)
    axes[2].set_ylabel('Learning Rate', fontsize=12)
    axes[2].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    axes[2].legend(fontsize=9, loc='upper right')
    axes[2].grid(True, alpha=0.3)
    # axes[2].set_yscale('log')  # Removed: Linear scale shows scheduler shape better

plt.tight_layout()
plt.savefig(f"{checkpoint_dir}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✅ Training curves saved to: {checkpoint_dir}/training_curves.png")

## 15. Pathway Analysis (Stream1/Stream2/Integrated Contributions)

In [None]:
print("=" * 60)
print("PATHWAY ANALYSIS (3-STREAM LINET)")
print("=" * 60)
print("\nAnalyzing Stream1, Stream2, and Integrated pathway contributions...")
print("This may take a few minutes...\n")

# Analyze pathways (LINet has 3 streams: stream1, stream2, integrated)
pathway_analysis = model.analyze_pathways(
    data_loader=val_loader
)

print("\nAccuracy Metrics:")
print(f"  Full model (Integrated): {pathway_analysis['accuracy']['full_model']*100:.2f}%")
print(f"  Stream1 only (RGB): {pathway_analysis['accuracy']['stream1_only']*100:.2f}%")
print(f"  Stream2 only (Depth): {pathway_analysis['accuracy']['stream2_only']*100:.2f}%")
print(f"  Integrated only: {pathway_analysis['accuracy']['integrated_only']*100:.2f}%")
print(f"\n  Stream1 contribution: {pathway_analysis['accuracy']['stream1_contribution']*100:.2f}%")
print(f"  Stream2 contribution: {pathway_analysis['accuracy']['stream2_contribution']*100:.2f}%")
print(f"  Integrated contribution: {pathway_analysis['accuracy']['integrated_contribution']*100:.2f}%")

print("\nLoss Metrics:")
print(f"  Full model: {pathway_analysis['loss']['full_model']:.4f}")
print(f"  Stream1 only: {pathway_analysis['loss']['stream1_only']:.4f}")
print(f"  Stream2 only: {pathway_analysis['loss']['stream2_only']:.4f}")
print(f"  Integrated only: {pathway_analysis['loss']['integrated_only']:.4f}")

print("\nFeature Norm Statistics:")
print(f"  Stream1 mean: {pathway_analysis['feature_norms']['stream1_mean']:.4f}")
print(f"  Stream1 std: {pathway_analysis['feature_norms']['stream1_std']:.4f}")
print(f"  Stream2 mean: {pathway_analysis['feature_norms']['stream2_mean']:.4f}")
print(f"  Stream2 std: {pathway_analysis['feature_norms']['stream2_std']:.4f}")
print(f"  Integrated mean: {pathway_analysis['feature_norms']['integrated_mean']:.4f}")
print(f"  Integrated std: {pathway_analysis['feature_norms']['integrated_std']:.4f}")
print(f"  Stream1/Stream2 ratio: {pathway_analysis['feature_norms']['stream1_to_stream2_ratio']:.4f}")

print("\n" + "=" * 60)
print("INTEGRATION WEIGHT ANALYSIS")
print("=" * 60)
print("\nAnalyzing integration weight magnitudes...")
print("(Measures how much the architecture favors each stream)\n")

# Calculate stream contributions to integration - NEW SIMPLIFIED METHOD
integration_contributions = model.calculate_stream_contributions_to_integration()

print("Integration Weight Contributions:")
print(f"  Stream1 (RGB): {integration_contributions['interpretation']['stream1_percentage']}")
print(f"  Stream2 (Depth): {integration_contributions['interpretation']['stream2_percentage']}")

print("\nRaw Integration Weight Norms:")
print(f"  Stream1 integration weights: {integration_contributions['raw_norms']['stream1_integration_weights']:.4f}")
print(f"  Stream2 integration weights: {integration_contributions['raw_norms']['stream2_integration_weights']:.4f}")
print(f"  Total: {integration_contributions['raw_norms']['total']:.4f}")

print("\nInterpretation:")
s1_contrib = integration_contributions['stream1_contribution']
s2_contrib = integration_contributions['stream2_contribution']
if s1_contrib > 0.55:
    print("  → Architecture favors Stream1 (RGB) - larger integration weights")
elif s2_contrib > 0.55:
    print("  → Architecture favors Stream2 (Depth) - larger integration weights")
else:
    print("  → Architecture uses both streams fairly equally")

print(f"\nNote: {integration_contributions['note']}")

print("\n" + "=" * 60)
print("✅ Pathway analysis complete!")
print("=" * 60)

# Visualize pathway contributions
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Accuracy comparison (3 pathways for LINet)
pathways = ['Full Model\n(Integrated)', 'Stream1\nOnly', 'Stream2\nOnly']
accuracies = [
    pathway_analysis['accuracy']['full_model'] * 100,
    pathway_analysis['accuracy']['stream1_only'] * 100,
    pathway_analysis['accuracy']['stream2_only'] * 100
]
colors = ['green', 'blue', 'orange']

axes[0].bar(pathways, accuracies, color=colors, alpha=0.7)
axes[0].set_ylabel('Accuracy (%)', fontsize=12)
axes[0].set_title('Pathway Accuracy Comparison', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(accuracies):
    axes[0].text(i, v + 1, f'{v:.1f}%', ha='center', fontweight='bold')

# Feature norm comparison (3 streams)
norms = ['Stream1\nFeatures', 'Stream2\nFeatures', 'Integrated\nFeatures']
norm_values = [
    pathway_analysis['feature_norms']['stream1_mean'],
    pathway_analysis['feature_norms']['stream2_mean'],
    pathway_analysis['feature_norms']['integrated_mean']
]
axes[1].bar(norms, norm_values, color=['blue', 'orange', 'purple'], alpha=0.7)
axes[1].set_ylabel('Feature Norm (Mean)', fontsize=12)
axes[1].set_title('Runtime Feature Magnitude', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(norm_values):
    axes[1].text(i, v + 0.1, f'{v:.2f}', ha='center', fontweight='bold')

# Integration weight contributions (2 input streams only - NOT integrated!)
streams = ['Stream1\n(RGB)', 'Stream2\n(Depth)']
integration_values = [
    integration_contributions['stream1_contribution'] * 100,
    integration_contributions['stream2_contribution'] * 100
]
axes[2].bar(streams, integration_values, color=['blue', 'orange'], alpha=0.7)
axes[2].set_ylabel('Contribution (%)', fontsize=12)
axes[2].set_title('Integration Weight Contribution', fontsize=14, fontweight='bold')
axes[2].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(integration_values):
    axes[2].text(i, v + 1, f'{v:.1f}%', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig(f"{checkpoint_dir}/pathway_analysis.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✅ Pathway analysis plot saved to: {checkpoint_dir}/pathway_analysis.png")

## 16. Save Results & Training History

In [None]:
import json
import torch

print("=" * 60)
print("SAVING RESULTS")
print("=" * 60)

# Save training history as JSON
history_path = f"{checkpoint_dir}/training_history.json"
with open(history_path, 'w') as f:
    json_history = {
        'train_loss': [float(x) for x in history['train_loss']],
        'val_loss': [float(x) for x in history['val_loss']],
        'train_accuracy': [float(x) for x in history['train_accuracy']],
        'val_accuracy': [float(x) for x in history['val_accuracy']],
        'learning_rates': [float(x) for x in history['learning_rates']],
        'model_config': MODEL_CONFIG,
        'dataset_config': DATASET_CONFIG,
        'stream_specific_config': STREAM_SPECIFIC_CONFIG,
        'scheduler_config': SCHEDULER_CONFIG,
        'training_config': TRAIN_CONFIG,
        'final_results': {
            'val_loss': float(results['loss']),
            'val_accuracy': float(results['accuracy'])
        },
        'pathway_analysis': {
            'full_model_accuracy': float(pathway_analysis['accuracy']['full_model']),
            'stream1_only_accuracy': float(pathway_analysis['accuracy']['stream1_only']),
            'stream2_only_accuracy': float(pathway_analysis['accuracy']['stream2_only']),
            'integrated_only_accuracy': float(pathway_analysis['accuracy']['integrated_only'])
        }
    }
    if 'early_stopping' in history:
        json_history['early_stopping'] = history['early_stopping']
    
    json.dump(json_history, f, indent=2)

print(f"✅ Training history saved: {history_path}")

# Save final model (in addition to best model)
final_model_path = f"{checkpoint_dir}/final_model.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': model.optimizer.state_dict(),
    'scheduler_state_dict': model.scheduler.state_dict() if model.scheduler else None,
    'config': MODEL_CONFIG,
    'stream_specific_config': STREAM_SPECIFIC_CONFIG,
    'scheduler_config': SCHEDULER_CONFIG,
    'history': history,
    'val_accuracy': results['accuracy']
}, final_model_path)

print(f"✅ Final model saved: {final_model_path}")

# Save summary report
summary_path = f"{checkpoint_dir}/summary.txt"
with open(summary_path, 'w') as f:
    f.write("=" * 60 + "\n")
    f.write("LINet Training Summary - SUN RGB-D\n")
    f.write("=" * 60 + "\n\n")
    
    # Model Configuration
    f.write("Model Configuration:\n")
    f.write(f"  Architecture: LINet-{MODEL_CONFIG['architecture'].upper()} (3-stream Linear Integration)\n")
    f.write(f"  Num Classes: {MODEL_CONFIG['num_classes']}\n")
    f.write(f"  Stream1 Channels: {MODEL_CONFIG['stream1_channels']} (RGB)\n")
    f.write(f"  Stream2 Channels: {MODEL_CONFIG['stream2_channels']} (Depth)\n")
    f.write(f"  Dropout: {MODEL_CONFIG['dropout_p']}\n")
    f.write(f"  Device: {MODEL_CONFIG['device']}\n")
    f.write(f"  AMP Enabled: {MODEL_CONFIG['use_amp']}\n")
    f.write(f"  Total Parameters: {total_params:,}\n")
    f.write(f"  Trainable Parameters: {trainable_params:,}\n")
    f.write(f"  Integration Parameters: {integration_params:,}\n")
    
    # Dataset Configuration
    f.write(f"\nDataset Configuration:\n")
    f.write(f"  Dataset: SUN RGB-D 15-category (Scene Classification)\n")
    f.write(f"  Training Samples: {len(train_loader.dataset)}\n")
    f.write(f"  Validation Samples: {len(val_loader.dataset)}\n")
    f.write(f"  Batch Size: {DATASET_CONFIG['batch_size']}\n")
    f.write(f"  Num Workers: {DATASET_CONFIG['num_workers']}\n")
    f.write(f"  Input Size: {DATASET_CONFIG['target_size']}\n")
    
    # Optimization Configuration
    f.write(f"\nOptimization Configuration (Keras-Style API):\n")
    f.write(f"  Optimizer: AdamW (stream-specific)\n")
    f.write(f"  Loss Function: cross_entropy\n")
    f.write(f"  Label Smoothing: 0.1\n")
    f.write(f"  Scheduler: {SCHEDULER_CONFIG['scheduler_type']}\n")
    f.write(f"  Scheduler t_max: {SCHEDULER_CONFIG['t_max']}\n")
    f.write(f"  Scheduler eta_min: {SCHEDULER_CONFIG['eta_min']}\n")
    f.write(f"  Gradient Clipping: {TRAIN_CONFIG['grad_clip_norm']}\n")
    
    # Stream-Specific Settings
    f.write(f"\nStream-Specific Settings:\n")
    f.write(f"  Stream1 (RGB):\n")
    f.write(f"    Learning Rate: {STREAM_SPECIFIC_CONFIG['stream1_lr']}\n")
    f.write(f"    Weight Decay: {STREAM_SPECIFIC_CONFIG['stream1_weight_decay']}\n")
    f.write(f"  Stream2 (Depth):\n")
    f.write(f"    Learning Rate: {STREAM_SPECIFIC_CONFIG['stream2_lr']}\n")
    f.write(f"    Weight Decay: {STREAM_SPECIFIC_CONFIG['stream2_weight_decay']}\n")
    f.write(f"  Shared (Fusion/Classifier):\n")
    f.write(f"    Learning Rate: {STREAM_SPECIFIC_CONFIG['shared_lr']}\n")
    f.write(f"    Weight Decay: {STREAM_SPECIFIC_CONFIG['shared_weight_decay']}\n")
    
    # Training Configuration
    f.write(f"\nTraining Configuration:\n")
    f.write(f"  Total Epochs: {len(history['train_loss'])}\n")
    f.write(f"  Stream Monitoring: {TRAIN_CONFIG['stream_monitoring']}\n")
    f.write(f"  Early Stopping: {TRAIN_CONFIG['early_stopping']}\n")
    if TRAIN_CONFIG['early_stopping']:
        f.write(f"    Monitor: {TRAIN_CONFIG['monitor']}\n")
        f.write(f"    Patience: {TRAIN_CONFIG['patience']}\n")
        f.write(f"    Min Delta: {TRAIN_CONFIG['min_delta']}\n")
        f.write(f"    Restore Best Weights: {TRAIN_CONFIG['restore_best_weights']}\n")
    
    # Results
    f.write(f"\nFinal Results:\n")
    f.write(f"  Val Loss: {results['loss']:.4f}\n")
    f.write(f"  Val Accuracy: {results['accuracy']*100:.2f}%\n")
    f.write(f"  Best Val Accuracy: {max(history['val_accuracy'])*100:.2f}%\n")
    f.write(f"  Initial Train Loss: {history['train_loss'][0]:.4f}\n")
    f.write(f"  Final Train Loss: {history['train_loss'][-1]:.4f}\n")
    f.write(f"  Best Val Loss: {min(history['val_loss']):.4f}\n")
    
    # Pathway Analysis (3 streams for LINet)
    f.write(f"\nPathway Analysis (3-stream LINet):\n")
    f.write(f"  Full Model (Integrated): {pathway_analysis['accuracy']['full_model']*100:.2f}%\n")
    f.write(f"  Stream1 Only (RGB): {pathway_analysis['accuracy']['stream1_only']*100:.2f}%\n")
    f.write(f"  Stream2 Only (Depth): {pathway_analysis['accuracy']['stream2_only']*100:.2f}%\n")
    f.write(f"  Integrated Only: {pathway_analysis['accuracy']['integrated_only']*100:.2f}%\n")
    f.write(f"  Stream1 Contribution: {pathway_analysis['accuracy']['stream1_contribution']*100:.2f}%\n")
    f.write(f"  Stream2 Contribution: {pathway_analysis['accuracy']['stream2_contribution']*100:.2f}%\n")
    f.write(f"  Integrated Contribution: {pathway_analysis['accuracy']['integrated_contribution']*100:.2f}%\n")

print(f"✅ Summary report saved: {summary_path}")

print("\n" + "=" * 60)
print(f"All results saved to: {checkpoint_dir}")
print("=" * 60)

# List saved files
print("\nSaved files:")
!ls -lh {checkpoint_dir}

## 17. Summary & Next Steps

### 🎉 Training Complete!

**What we accomplished:**
- ✅ Trained LINet (3-stream Linear Integration ResNet) on SUN RGB-D dataset (15 categories)
- ✅ Used **Keras-style API** with explicit optimizer and scheduler creation
- ✅ Used A100 GPU with AMP (2x speedup)
- ✅ Saved all checkpoints to Google Drive
- ✅ Analyzed Stream1, Stream2, and Integrated pathway contributions
- ✅ Generated training curves and visualizations
- ✅ Comprehensive stream monitoring with overfitting detection

**Results are saved to:** Check the output above for the checkpoint directory path

### 📊 Expected Performance:

For **SUN RGB-D Scene Classification (15 categories, 10,335 images)**:
- **Good:** 65-75% validation accuracy
- **Very Good:** 75-80% validation accuracy
- **Excellent:** 80-85% validation accuracy

**Much better than NYU Depth V2 due to:**
- 6.9x more training samples (8,041 vs 1,159)
- 22.6x better class balance (8.5x vs 192x)
- Higher quality, more diverse dataset

### 🔄 New Keras-Style API Used:

This notebook uses the **refactored Keras-style API**:

```python
# 1. Create optimizer with stream-specific LRs
optimizer = create_stream_optimizer(
    model, stream1_lr=3e-5, stream2_lr=1e-4, shared_lr=7e-5, ...
)

# 2. Create scheduler
scheduler = setup_scheduler(optimizer, 'cosine', epochs=80, ...)

# 3. Compile with objects (not strings!)
model.compile(optimizer=optimizer, scheduler=scheduler, loss='cross_entropy')

# 4. Train (no scheduler_kwargs needed!)
model.fit(train_loader, val_loader, epochs=80, stream_monitoring=True)
```

**Benefits:**
- ✅ Explicit optimizer/scheduler creation
- ✅ Clear separation: compile() = config, fit() = execution
- ✅ Easy to customize and experiment
- ✅ Stream-specific learning rates still fully supported

### 🧠 LINet Architecture Highlights:

**3-Stream Linear Integration:**
- Stream1 processes RGB
- Stream2 processes Depth
- Integrated stream combines both through learned linear weights at every layer

**5 Weight Matrices per LIConv2d:**
- `stream1_weight` (full kernel for RGB)
- `stream2_weight` (full kernel for Depth)
- `integrated_weight` (1×1 channel-wise)
- `integration_from_stream1` (1×1)
- `integration_from_stream2` (1×1)

This allows **neuron-level integration** rather than late fusion!

### 🔍 Next Steps:

1. **Review Results:**
   - Check training curves above
   - Review pathway analysis (3 streams)
   - Compare Stream1/Stream2/Integrated contributions
   - Analyze stream monitoring plots

2. **Download Results:**
   - All files are saved to your Google Drive
   - Download checkpoints for local inference

3. **Experiment:**
   - Try ResNet50 for better accuracy (change `architecture` in Model Config)
   - Adjust stream-specific learning rates if monitoring shows imbalance
   - Train longer if early stopping triggered
   - Analyze integration weights to understand learned strategies

4. **Deploy:**
   - Use the best model for inference
   - Test on new RGB-D images
   - Integrate into your application

---

**Questions or issues?** Check the training summary and pathway analysis above!

In [None]:
# Print final summary
print("=" * 60)
print("FINAL SUMMARY - LINET")
print("=" * 60)
print(f"\n✅ Training Complete!")
print(f"\nFinal Validation Accuracy: {results['accuracy']*100:.2f}%")
print(f"Best Validation Accuracy: {max(history['val_accuracy'])*100:.2f}%")
print(f"\nStream1 Pathway (RGB): {pathway_analysis['accuracy']['stream1_only']*100:.2f}%")
print(f"Stream2 Pathway (Depth): {pathway_analysis['accuracy']['stream2_only']*100:.2f}%")
print(f"Integrated Pathway: {pathway_analysis['accuracy']['integrated_only']*100:.2f}%")
print(f"Combined (Full Model): {pathway_analysis['accuracy']['full_model']*100:.2f}%")
print(f"\nTotal Training Epochs: {len(history['train_loss'])}")
print(f"Total Parameters: {total_params:,}")
print(f"Integration Parameters: {integration_params:,}")
print(f"\nCheckpoints saved to: {checkpoint_dir}")
print("\n" + "=" * 60)
print("🎉 All done! Check Google Drive for saved models and results.")
print("=" * 60)