In [6]:
# Environment Setup
import os
import sys
from pathlib import Path

# Set up project root path
project_root = Path.cwd()
while not (project_root / "src").exists() and project_root != project_root.parent:
    project_root = project_root.parent

if not (project_root / "src").exists():
    # Fallback: assume we're in notebooks directory
    project_root = Path.cwd().parent

print(f"Project root: {project_root}")
print(f"Current working directory: {os.getcwd()}")

# Add project root to path for imports
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
    print(f"✅ Added {project_root} to sys.path")

# Set environment variables for better error reporting
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
print("✅ Environment setup complete!")

Project root: /Users/gclinger/Documents/projects/Multi-Stream-Neural-Networks
Current working directory: /Users/gclinger/Documents/projects/Multi-Stream-Neural-Networks/notebooks
✅ Environment setup complete!


In [None]:
# Import Libraries
print("📦 Importing libraries...")

# Core PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

# Machine learning utilities
from sklearn.model_selection import train_test_split

# Visualization and analysis
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

# Progress tracking
from tqdm import tqdm
import time
import json
from pathlib import Path

# Project imports
try:
    from src.data_utils.dataset_utils import load_cifar100_data, CIFAR100_FINE_LABELS
    from src.data_utils.rgb_to_rgbl import RGBtoRGBL
    from src.models.basic_multi_channel.base_multi_channel_network import base_multi_channel_large
    from src.models.basic_multi_channel.multi_channel_resnet_network import multi_channel_resnet50
    print("✅ All project modules imported successfully")
except ImportError as e:
    print(f"❌ Error importing project modules: {e}")
    print("⚠️  Please ensure you're running from the correct directory")

# Check device availability
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"🚀 Using CUDA: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = torch.device("mps") 
    print("🚀 Using Apple Metal Performance Shaders (MPS)")
else:
    device = torch.device("cpu")
    print("💻 Using CPU")

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")
print("✅ Library imports complete!")

📦 Importing libraries...
✅ All project modules imported successfully
🚀 Using Apple Metal Performance Shaders (MPS)
PyTorch version: 2.7.1
Device: mps
✅ Library imports complete!


In [9]:
from torch.utils.data import Dataset, DataLoader, TensorDataset
from src.data_utils import load_cifar100_data
from src.models2.common.model_helpers import create_dataloader_from_tensors
from src.data_utils import RGBtoRGBL
from sklearn.model_selection import train_test_split
from src.models2.core.resnet import resnet50

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"🚀 Using CUDA: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = torch.device("mps") 
    print("🚀 Using Apple Metal Performance Shaders (MPS)")
else:
    device = torch.device("cpu")
    print("💻 Using CPU")

batch_size = 32

train_data, train_labels, test_data, test_labels = load_cifar100_data(
    data_dir="../data/cifar-100",
    normalize=True  # Apply normalization to [0, 1] range
)

# Split the data
train_color, val_color, train_labels, val_labels = train_test_split(
    train_data, train_labels, test_size=0.1, random_state=42
)

rgb_to_rgbl = RGBtoRGBL()
# Convert to brightness channels
train_brightness = rgb_to_rgbl.get_brightness(train_color)
val_brightness = rgb_to_rgbl.get_brightness(val_color)  
test_brightness = rgb_to_rgbl.get_brightness(test_data)

print(f"Training samples: {len(train_color)}")
print(f"Validation samples: {len(val_color)}")
print(f"Test samples: {len(test_data)}")
print(f"Number of classes: {len(torch.unique(train_labels))}")
print(f"Data shape - RGB: {train_color.shape}, Brightness: {train_brightness.shape}")
print(f"Labels shape: {train_labels.shape}")
print(f"Data range - RGB: [{train_color.min():.3f}, {train_color.max():.3f}], Brightness: [{train_brightness.min():.3f}, {train_brightness.max():.3f}]")


# Create DataLoaders for ResNet50 training (RGB only)
print("Creating DataLoaders for ResNet50...")

# Use only color data for standard ResNet training - create DataLoaders directly
train_dataset = TensorDataset(train_color, train_labels)
val_dataset = TensorDataset(val_color, val_labels)

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    num_workers=2,
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size*2, 
    shuffle=False,
    num_workers=2,
    pin_memory=False
)

print(f"Train loader: {len(train_loader)} batches")
print(f"Val loader: {len(val_loader)} batches")
print("DataLoaders created successfully!")


# Create and train ResNet50 model with proper settings
print("Creating ResNet50 model...")
resnet50_baseline = resnet50(num_classes=100, device=str(device))

# Compile with proper learning rate and stable scheduler
print("Compiling model with optimized settings...")
resnet50_baseline.compile(
    optimizer='adamw',
    loss='cross_entropy',
    learning_rate=0.001,  # Much lower learning rate
    weight_decay=1e-4,   # Standard weight decay
    scheduler='cosine',    # Stable step scheduler instead of onecycle
)

print("Starting training...")
# Train with step scheduler parameters
history = resnet50_baseline.fit(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=20,
    early_stopping=True,
    patience=5,
    verbose=True,
)

print("Training completed!")
print(f"Best validation accuracy: {max(history['val_accuracy']):.4f}")
print(f"Final train accuracy: {history['train_accuracy'][-1]:.4f}")
print(f"Final validation accuracy: {history['val_accuracy'][-1]:.4f}")

🚀 Using Apple Metal Performance Shaders (MPS)
📁 Loading CIFAR-100 from: ../data/cifar-100
✅ Loaded CIFAR-100 (torch format):
   Training: torch.Size([50000, 3, 32, 32]), labels: 50000
   Test: torch.Size([10000, 3, 32, 32]), labels: 10000
Training samples: 45000
Validation samples: 5000
Test samples: 10000
Number of classes: 100
Data shape - RGB: torch.Size([45000, 3, 32, 32]), Brightness: torch.Size([45000, 1, 32, 32])
Labels shape: torch.Size([45000])
Data range - RGB: [0.000, 1.000], Brightness: [0.000, 1.000]
Creating DataLoaders for ResNet50...
Train loader: 1407 batches
Val loader: 79 batches
DataLoaders created successfully!
Creating ResNet50 model...
✅ Loaded CIFAR-100 (torch format):
   Training: torch.Size([50000, 3, 32, 32]), labels: 50000
   Test: torch.Size([10000, 3, 32, 32]), labels: 10000
Training samples: 45000
Validation samples: 5000
Test samples: 10000
Number of classes: 100
Data shape - RGB: torch.Size([45000, 3, 32, 32]), Brightness: torch.Size([45000, 1, 32, 32])

Epoch 1/20: 100%|██████████| 1486/1486 [01:28<00:00, 16.79it/s, train_loss=4.7008, train_acc=0.0620, val_loss=5.8967, val_acc=0.0444, best=5.8967, lr=0.000994]
Epoch 1/20: 100%|██████████| 1486/1486 [01:28<00:00, 16.79it/s, train_loss=4.7008, train_acc=0.0620, val_loss=5.8967, val_acc=0.0444, best=5.8967, lr=0.000994]
Epoch 2/20: 100%|██████████| 1486/1486 [01:26<00:00, 17.10it/s, train_loss=4.1332, train_acc=0.0898, val_loss=6.6487, val_acc=0.1128, patience=1/5, lr=0.000976]
Epoch 2/20: 100%|██████████| 1486/1486 [01:26<00:00, 17.10it/s, train_loss=4.1332, train_acc=0.0898, val_loss=6.6487, val_acc=0.1128, patience=1/5, lr=0.000976]
Epoch 3/20: 100%|██████████| 1486/1486 [01:25<00:00, 17.31it/s, train_loss=3.6703, train_acc=0.1419, val_loss=3.8438, val_acc=0.1728, best=3.8438, lr=0.000946]
Epoch 3/20: 100%|██████████| 1486/1486 [01:25<00:00, 17.31it/s, train_loss=3.6703, train_acc=0.1419, val_loss=3.8438, val_acc=0.1728, best=3.8438, lr=0.000946]
Epoch 4/20: 100%|██████████| 1486/1486

🔄 Restored best model weights
Training completed!
Best validation accuracy: 0.3934
Final train accuracy: 0.6406
Final validation accuracy: 0.3934



