# Multi-Stream Neural Networks: CIFAR-100 Training

This notebook demonstrates the full pipeline for training multi-stream neural networks on CIFAR-100 data:

🚀 **Features:**
- Automatic GPU detection and optimization
- RGB to RGBL preprocessing with visualizations
- BaseMultiChannelNetwork (Dense) and MultiChannelResNetNetwork (CNN) models
- Dynamic progress bars during training
- Comprehensive evaluation and analysis

**Hardware Requirements:**
- Google Colab with GPU runtime (A100/V100 recommended)
- Sufficient memory for CIFAR-100 dataset processing

## 1. Clone Repository and Set Up Working Directory

First, we'll clone the Multi-Stream Neural Networks repository to our Google Drive and set up the working directory.

In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [4]:
# Navigate to Drive and clone repository
import os
os.chdir('/content/drive/MyDrive')

# Clone the repository (replace with your actual repository URL)
!git clone https://github.com/clingergab/Multi-Stream-Neural-Networks.git

Cloning into 'Multi-Stream-Neural-Networks'...
remote: Enumerating objects: 228, done.[K
remote: Counting objects: 100% (228/228), done.[K
remote: Compressing objects: 100% (193/193), done.[K
remote: Total 228 (delta 33), reused 221 (delta 26), pack-reused 0 (from 0)[K
Receiving objects: 100% (228/228), 258.62 KiB | 2.94 MiB/s, done.
Resolving deltas: 100% (33/33), done.


In [5]:


# Change to the project directory
os.chdir('/content/drive/MyDrive/Multi-Stream-Neural-Networks')

# Verify we're in the right directory
print("Current working directory:", os.getcwd())
print("\nDirectory contents:")
!ls -la

Current working directory: /content/drive/MyDrive/Multi-Stream-Neural-Networks

Directory contents:
total 169
drwx------ 5 root root  4096 Jun 26 02:28 archive
-rw------- 1 root root  3785 Jun 26 02:28 cleanup_comprehensive.py
-rw------- 1 root root  5595 Jun 26 02:28 cleanup_empty_files.py
drwx------ 5 root root  4096 Jun 26 02:28 configs
-rw------- 1 root root 94768 Jun 26 02:28 DESIGN.md
drwx------ 2 root root  4096 Jun 26 02:28 docs
drwx------ 2 root root  4096 Jun 26 02:28 examples
drwx------ 3 root root  4096 Jun 26 02:28 experiments
drwx------ 8 root root  4096 Jun 26 02:28 .git
-rw------- 1 root root   587 Jun 26 02:28 .gitignore
-rw------- 1 root root  1084 Jun 26 02:28 LICENSE
drwx------ 2 root root  4096 Jun 26 02:28 notebooks
-rw------- 1 root root  4477 Jun 26 02:28 README.md
-rw------- 1 root root   106 Jun 26 02:28 requirements.txt
-rw------- 1 root root  4892 Jun 26 02:28 safe_cleanup_empty_files.py
drwx------ 3 root root  4096 Jun 26 02:28 scripts
-rw------- 1 root roo

## 2. Install and Import Required Libraries

Install any missing dependencies and import all necessary libraries for the multi-stream neural network training.

In [6]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install tqdm matplotlib seaborn scikit-learn

# System and utility imports
import sys
import os
import warnings
warnings.filterwarnings('ignore')

# Add project to Python path
sys.path.append('/content/drive/MyDrive/Multi-Stream-Neural-Networks')

# Core libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as transforms

# Data and visualization
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import pickle
from typing import Tuple, Dict, List

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("📦 All libraries imported successfully!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🚀 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

Looking in indexes: https://download.pytorch.org/whl/cu118
INFO: pip is looking at multiple versions of torch to determine which version is compatible with other requirements. This could take a while.
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.7.1%2Bcu118-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (28 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading https://download.pytorch.org/whl/sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (23.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.2/23.2 MB[0m [31m91.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (875 kB)
[2K     [90m━━━━━━━━━━━━━━━━

RuntimeError: Detected that PyTorch and torchvision were compiled with different CUDA major versions. PyTorch has CUDA Version=11.8 and torchvision has CUDA Version=12.4. Please reinstall the torchvision that matches your PyTorch install.

## 3. Load CIFAR-100 Dataset

Load the CIFAR-100 dataset from the data folder. We assume the data folder structure matches the repository structure.

In [None]:
# Import our data loading utilities
from src.utils.colab_utils import load_cifar10  # We'll adapt this for CIFAR-100

# Check if data folder exists
data_path = "data/cifar-100"
if os.path.exists(data_path):
    print(f"✅ Data folder found at: {data_path}")
else:
    print(f"❌ Data folder not found. Creating data structure...")
    os.makedirs(data_path, exist_ok=True)
    print("📁 Please manually upload CIFAR-100 data to the data folder")

# Define CIFAR-100 classes for reference
cifar100_fine_labels = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
    'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
    'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
    'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
    'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
    'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
    'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
    'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
    'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
    'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
    'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
    'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
    'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
    'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
    'worm'
]

# Load CIFAR-100 data using torchvision (fallback if data folder is empty)
def load_cifar100_data():
    # Transform to convert PIL images to tensors
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    try:
        # Try loading from local data folder first
        train_dataset = torchvision.datasets.CIFAR100(
            root='./data', train=True, download=False, transform=transform
        )
        test_dataset = torchvision.datasets.CIFAR100(
            root='./data', train=False, download=False, transform=transform
        )
        print("✅ Loaded CIFAR-100 from local data folder")
    except:
        # Download if not available locally
        print("⬇️ Downloading CIFAR-100 dataset...")
        train_dataset = torchvision.datasets.CIFAR100(
            root='./data', train=True, download=True, transform=transform
        )
        test_dataset = torchvision.datasets.CIFAR100(
            root='./data', train=False, download=True, transform=transform
        )
        print("✅ CIFAR-100 dataset downloaded and loaded")

    return train_dataset, test_dataset

# Load the datasets
train_dataset, test_dataset = load_cifar100_data()

print(f"📊 Dataset Info:")
print(f"   Training samples: {len(train_dataset)}")
print(f"   Test samples: {len(test_dataset)}")
print(f"   Number of classes: 100")
print(f"   Image size: 32x32x3")

## 4. Preprocess Data: RGB to RGBL Transformation

Apply preprocessing to convert RGB images to both RGB and brightness (luminance) channels. This creates our multi-stream data.

In [None]:
# Import our RGB to RGBL transformation
from src.transforms.rgb_to_rgbl import RGBtoRGBL

def convert_dataset_to_multi_stream(dataset, max_samples=None):
    """
    Convert a CIFAR-100 dataset to multi-stream format (RGB + Brightness).

    Args:
        dataset: CIFAR-100 dataset
        max_samples: Maximum number of samples to process (for faster testing)

    Returns:
        rgb_data: RGB channel data [N, 3, 32, 32]
        brightness_data: Brightness channel data [N, 1, 32, 32]
        labels: Class labels [N]
    """
    print(f"🔄 Converting dataset to multi-stream format...")

    # Initialize RGB to RGBL transform
    rgb_to_rgbl = RGBtoRGBL()

    # Determine number of samples to process
    num_samples = len(dataset) if max_samples is None else min(max_samples, len(dataset))

    # Initialize arrays
    rgb_data = []
    brightness_data = []
    labels = []

    # Process samples with progress bar
    for i in tqdm(range(num_samples), desc="Processing images"):
        image, label = dataset[i]

        # Convert to RGBL
        rgbl_image = rgb_to_rgbl(image)

        # Split RGB and brightness channels
        rgb_channels = rgbl_image[:3]  # First 3 channels (RGB)
        brightness_channel = rgbl_image[3:4]  # Last channel (Brightness)

        rgb_data.append(rgb_channels)
        brightness_data.append(brightness_channel)
        labels.append(label)

    # Convert to numpy arrays
    rgb_data = torch.stack(rgb_data).numpy()
    brightness_data = torch.stack(brightness_data).numpy()
    labels = np.array(labels)

    print(f"✅ Conversion complete!")
    print(f"   RGB data shape: {rgb_data.shape}")
    print(f"   Brightness data shape: {brightness_data.shape}")
    print(f"   Labels shape: {labels.shape}")

    return rgb_data, brightness_data, labels

# Convert training data (use subset for faster processing in demo)
print("🚀 Processing training data...")
train_rgb, train_brightness, train_labels = convert_dataset_to_multi_stream(
    train_dataset, max_samples=5000  # Reduce for faster demo
)

# Convert test data (use subset for faster processing in demo)
print("\n🧪 Processing test data...")
test_rgb, test_brightness, test_labels = convert_dataset_to_multi_stream(
    test_dataset, max_samples=1000  # Reduce for faster demo
)

print(f"\n📊 Final Dataset Shapes:")
print(f"   Training RGB: {train_rgb.shape}")
print(f"   Training Brightness: {train_brightness.shape}")
print(f"   Training Labels: {train_labels.shape}")
print(f"   Test RGB: {test_rgb.shape}")
print(f"   Test Brightness: {test_brightness.shape}")
print(f"   Test Labels: {test_labels.shape}")

## 5. Visualize Sample Images: RGB and Brightness Side by Side

Display sample images showing the original RGB and extracted brightness channels to understand the multi-stream transformation.

In [None]:
def visualize_rgb_brightness_samples(rgb_data, brightness_data, labels, num_samples=5):
    """
    Visualize RGB and brightness images side by side.

    Args:
        rgb_data: RGB image data [N, 3, H, W]
        brightness_data: Brightness image data [N, 1, H, W]
        labels: Image labels
        num_samples: Number of samples to visualize
    """
    fig, axes = plt.subplots(num_samples, 2, figsize=(8, 2.5 * num_samples))
    fig.suptitle('RGB vs Brightness Channel Comparison', fontsize=16, fontweight='bold')

    for i in range(num_samples):
        # Get RGB image (convert from CHW to HWC for matplotlib)
        rgb_img = np.transpose(rgb_data[i], (1, 2, 0))

        # Get brightness image (squeeze channel dimension)
        brightness_img = brightness_data[i, 0]  # Remove channel dimension

        # Get class name
        class_name = cifar100_fine_labels[labels[i]]

        # Plot RGB image
        axes[i, 0].imshow(rgb_img)
        axes[i, 0].set_title(f'RGB - {class_name}', fontweight='bold')
        axes[i, 0].axis('off')

        # Plot brightness image
        axes[i, 1].imshow(brightness_img, cmap='gray')
        axes[i, 1].set_title(f'Brightness - {class_name}', fontweight='bold')
        axes[i, 1].axis('off')

    plt.tight_layout()
    plt.show()

# Visualize sample images
print("🖼️ Sample RGB vs Brightness Images:")
visualize_rgb_brightness_samples(train_rgb, train_brightness, train_labels, num_samples=5)

# Show data statistics
def show_data_statistics(rgb_data, brightness_data, labels):
    """Show basic statistics about the data."""
    print(f"\n📊 Data Statistics:")
    print(f"   RGB data range: [{rgb_data.min():.3f}, {rgb_data.max():.3f}]")
    print(f"   Brightness data range: [{brightness_data.min():.3f}, {brightness_data.max():.3f}]")
    print(f"   Number of unique classes: {len(np.unique(labels))}")

    # Class distribution
    unique_labels, counts = np.unique(labels, return_counts=True)
    print(f"   Samples per class: {counts.min()} - {counts.max()}")
    print(f"   Average samples per class: {counts.mean():.1f}")

show_data_statistics(train_rgb, train_brightness, train_labels)

## 6. Additional Data Visualizations

Let's explore the data with helpful visualizations including class distribution and pixel intensity analysis.

In [None]:
# Class distribution visualization
def plot_class_distribution(labels, title="Class Distribution"):
    """Plot the distribution of classes in the dataset."""
    plt.figure(figsize=(12, 6))
    unique_labels, counts = np.unique(labels, return_counts=True)

    plt.bar(unique_labels, counts, alpha=0.7, color='skyblue', edgecolor='navy')
    plt.title(title, fontsize=14, fontweight='bold')
    plt.xlabel('Class ID')
    plt.ylabel('Number of Samples')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Pixel intensity histograms
def plot_intensity_histograms(rgb_data, brightness_data):
    """Plot histograms of pixel intensities for RGB and brightness channels."""
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    fig.suptitle('Pixel Intensity Distributions', fontsize=16, fontweight='bold')

    # RGB histograms
    colors = ['red', 'green', 'blue']
    for i, color in enumerate(colors):
        axes[0, 0].hist(rgb_data[:, i].flatten(), bins=50, alpha=0.6,
                       color=color, label=f'{color.upper()} channel')
    axes[0, 0].set_title('RGB Channel Intensities')
    axes[0, 0].set_xlabel('Pixel Value')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Brightness histogram
    axes[0, 1].hist(brightness_data.flatten(), bins=50, alpha=0.7,
                   color='gray', edgecolor='black')
    axes[0, 1].set_title('Brightness Channel Intensities')
    axes[0, 1].set_xlabel('Pixel Value')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].grid(True, alpha=0.3)

    # Mean pixel values per channel
    rgb_means = np.mean(rgb_data, axis=(0, 2, 3))
    brightness_mean = np.mean(brightness_data)

    channel_names = ['Red', 'Green', 'Blue', 'Brightness']
    channel_means = [rgb_means[0], rgb_means[1], rgb_means[2], brightness_mean]

    axes[1, 0].bar(channel_names, channel_means,
                  color=['red', 'green', 'blue', 'gray'], alpha=0.7)
    axes[1, 0].set_title('Mean Pixel Values by Channel')
    axes[1, 0].set_ylabel('Mean Pixel Value')
    axes[1, 0].grid(True, alpha=0.3)

    # Sample grid
    axes[1, 1].axis('off')

    plt.tight_layout()
    plt.show()

# Sample grid of images
def plot_sample_grid(rgb_data, labels, grid_size=(4, 8)):
    """Plot a grid of sample images."""
    fig, axes = plt.subplots(grid_size[0], grid_size[1], figsize=(16, 8))
    fig.suptitle('Sample Images from CIFAR-100 Dataset', fontsize=16, fontweight='bold')

    for i in range(grid_size[0]):
        for j in range(grid_size[1]):
            idx = i * grid_size[1] + j
            if idx < len(rgb_data):
                img = np.transpose(rgb_data[idx], (1, 2, 0))
                class_name = cifar100_fine_labels[labels[idx]]

                axes[i, j].imshow(img)
                axes[i, j].set_title(class_name, fontsize=8)
                axes[i, j].axis('off')

    plt.tight_layout()
    plt.show()

# Generate visualizations
print("📊 Generating additional visualizations...")

# Class distribution
plot_class_distribution(train_labels, "Training Set Class Distribution")

# Intensity histograms
plot_intensity_histograms(train_rgb[:1000], train_brightness[:1000])  # Sample for speed

# Sample grid
plot_sample_grid(train_rgb, train_labels)

## 7. Create Multi-Stream Neural Network Models

Now we'll create both the BaseMultiChannelNetwork (dense) and MultiChannelResNetNetwork (CNN) models for comparison.