# Multi-Stream Neural Networks: CIFAR-100 Training

This notebook demonstrates the full pipeline for training multi-stream neural networks on CIFAR-100 data:

🚀 **Features:**
- Automatic GPU detection and optimization
- RGB to RGBL preprocessing with visualizations
- BaseMultiChannelNetwork (Dense) and MultiChannelResNetNetwork (CNN) models
- Dynamic progress bars during training
- Comprehensive evaluation and analysis

**Hardware Requirements:**
- Google Colab with GPU runtime (A100/V100 recommended)
- Sufficient memory for CIFAR-100 dataset processing

## 1. Setup: Mount Drive and Navigate to Project

Mount Google Drive and navigate to the existing Multi-Stream Neural Networks project directory.

In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
# Navigate to Drive and project directory
import os
os.chdir('/content/drive/MyDrive')

# Navigate to the existing project (assuming it's already cloned)
project_path = '/content/drive/MyDrive/Multi-Stream-Neural-Networks'
if os.path.exists(project_path):
    os.chdir(project_path)
    print(f"✅ Found project at: {project_path}")
else:
    print(f"❌ Project not found at: {project_path}")
    print("💡 Please clone the repository first:")
    print("   !git clone https://github.com/clingergab/Multi-Stream-Neural-Networks.git")

Cloning into 'Multi-Stream-Neural-Networks'...
remote: Enumerating objects: 228, done.[K
remote: Counting objects: 100% (228/228), done.[K
remote: Compressing objects: 100% (193/193), done.[K
remote: Total 228 (delta 33), reused 221 (delta 26), pack-reused 0 (from 0)[K
Receiving objects: 100% (228/228), 258.62 KiB | 2.94 MiB/s, done.
Resolving deltas: 100% (33/33), done.


## 2. Update Repository

Pull the latest changes from the repository to ensure you have the most up-to-date code.

In [None]:
# Update repository with latest changes
print("🔄 Pulling latest changes from repository...")

# Make sure we're in the right directory
os.chdir('/content/drive/MyDrive/Multi-Stream-Neural-Networks')
print(f"📁 Current directory: {os.getcwd()}")

# Pull latest changes
!git pull origin main

# Show latest commit info
print("\n📋 Latest commit:")
!git log --oneline -1

# Check status
print("\n📊 Repository status:")
!git status --short

print("\n✅ Repository update complete!")

## 2. Install and Import Required Libraries

Install any missing dependencies and import all necessary libraries for the multi-stream neural network training.

## 3. Install Dependencies and Import Libraries

Install compatible PyTorch/NumPy versions and import all required libraries.

In [None]:
# Install compatible PyTorch, NumPy, and dependencies
print("🔧 Installing compatible PyTorch, NumPy, and dependencies...")

# Check current environment
import sys
print(f"Python version: {sys.version}")

# Install compatible versions (optimized for no restart)
print("📦 Installing packages...")
!pip install -q numpy==1.24.3
!pip install -q torch==2.1.0+cu118 torchvision==0.16.0+cu118 torchaudio==2.1.0+cu118 --index-url https://download.pytorch.org/whl/cu118
!pip install -q tqdm matplotlib seaborn scikit-learn

print("✅ Installation complete!")

# Import all required libraries
print("\n📦 Importing libraries...")

import os
import warnings
warnings.filterwarnings('ignore')

# Add project to Python path
sys.path.append('/content/drive/MyDrive/Multi-Stream-Neural-Networks')

# Core libraries
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as transforms

# Data and visualization
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import pickle
from typing import Tuple, Dict, List
import time
from pathlib import Path

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Verify installations
print("🔍 Verifying installations...")
print(f"   NumPy: {np.__version__}")
print(f"   PyTorch: {torch.__version__}")
print(f"   Torchvision: {torchvision.__version__}")
print(f"   CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Import project modules
print("\n📁 Importing project modules...")
try:
    from src.models.basic_multi_channel.base_multi_channel_network import BaseMultiChannelNetwork
    from src.models.basic_multi_channel.multi_channel_resnet_network import MultiChannelResNetNetwork
    from src.transforms.rgb_to_rgbl import RGBtoRGBL
    from src.utils.colab_utils import load_cifar10
    from src.utils.cifar100_loader import get_cifar100_datasets, CIFAR100_FINE_LABELS, SimpleDataset
    print("✅ All imports successful!")
except ImportError as e:
    print(f"⚠️ Import warning: {e}")
    print("   Make sure you've updated the repository in the previous step")

Looking in indexes: https://download.pytorch.org/whl/cu118
INFO: pip is looking at multiple versions of torch to determine which version is compatible with other requirements. This could take a while.
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.7.1%2Bcu118-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (28 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading https://download.pytorch.org/whl/sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (23.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.2/23.2 MB[0m [31m91.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (875 kB)
[2K     [90m━━━━━━━━━━━━━━━━

RuntimeError: Detected that PyTorch and torchvision were compiled with different CUDA major versions. PyTorch has CUDA Version=11.8 and torchvision has CUDA Version=12.4. Please reinstall the torchvision that matches your PyTorch install.

## 4. Load CIFAR-100 Dataset

Load the CIFAR-100 dataset and verify its structure.

In [None]:
# Import our CIFAR-100 data loading utilities
print("📁 Setting up CIFAR-100 dataset loading...")

# Import the CIFAR-100 loader utilities
try:
    from src.utils.cifar100_loader import get_cifar100_datasets, CIFAR100_FINE_LABELS
    print("✅ CIFAR-100 loader utilities imported successfully")
except ImportError as e:
    print(f"❌ Failed to import CIFAR-100 utilities: {e}")
    print("💡 Make sure you're in the correct directory and have run git pull")
    raise

# Check if data folder exists
data_path = Path("data/cifar-100")
if data_path.exists():
    print(f"✅ Data folder found at: {data_path}")
else:
    print(f"📁 Creating data structure at: {data_path}")
    data_path.mkdir(parents=True, exist_ok=True)

# Load the datasets using our utility
print("📊 Loading CIFAR-100 datasets...")
train_dataset, test_dataset, cifar100_fine_labels = get_cifar100_datasets()

# Get raw data for backward compatibility with existing code
train_data = train_dataset.data
train_labels = train_dataset.labels
test_data = test_dataset.data
test_labels = test_dataset.labels

print(f"\n📊 Dataset Info:")
print(f"   Training samples: {len(train_data)}")
print(f"   Test samples: {len(test_data)}")
print(f"   Image shape: {train_data[0].shape}")
print(f"   Number of classes: {len(cifar100_fine_labels)}")
print(f"   Label format: Single integer (fine labels 0-99)")
print(f"   Data range: [{train_data.min():.3f}, {train_data.max():.3f}]")

print("✅ CIFAR-100 datasets ready for processing!")
print("💡 No torchvision naming conventions needed - loaded directly from pickle files!")

## 5. Process Data: RGB to Multi-Stream Format

Convert RGB images to both RGB and brightness (luminance) channels to create multi-stream data.

In [None]:
# RGB to RGBL Data Processing with caching support
print("🔄 Starting RGB to RGBL data processing...")

# Import the transformation (should work if previous cells succeeded)
try:
    from src.transforms.rgb_to_rgbl import RGBtoRGBL
    print("✅ RGBtoRGBL transform imported successfully")
except ImportError as e:
    print(f"❌ Failed to import RGBtoRGBL: {e}")
    print("💡 Make sure you're in the correct directory and have run git pull")
    raise

def convert_dataset_to_multi_stream(dataset, max_samples=None, cache_file=None, force_reprocess=False):
    """
    Convert a CIFAR-100 dataset to multi-stream format (RGB + Brightness).
    Supports caching to avoid reprocessing if interrupted.

    Args:
        dataset: CIFAR-100 dataset
        max_samples: Maximum number of samples to process (for faster testing)
        cache_file: Path to cache file for resuming processing
        force_reprocess: Force reprocessing even if cache exists

    Returns:
        rgb_data: RGB channel data [N, 3, 32, 32]
        brightness_data: Brightness channel data [N, 1, 32, 32]
        labels: Class labels [N]
    """
    # Check for cached results first
    if cache_file and Path(cache_file).exists() and not force_reprocess:
        print(f"✅ Loading cached data from {cache_file}")
        try:
            with open(cache_file, 'rb') as f:
                cached_data = pickle.load(f)
            print(f"   Loaded {len(cached_data['labels'])} samples from cache")
            return cached_data['rgb_data'], cached_data['brightness_data'], cached_data['labels']
        except Exception as e:
            print(f"⚠️ Failed to load cache: {e}. Reprocessing...")

    print(f"🔄 Converting dataset to multi-stream format...")

    # Initialize RGB to RGBL transform
    rgb_to_rgbl = RGBtoRGBL()
    print("✅ RGB to RGBL transform initialized")

    # Determine number of samples to process
    num_samples = len(dataset) if max_samples is None else min(max_samples, len(dataset))
    print(f"   Processing {num_samples} samples...")

    # Initialize arrays
    rgb_data = []
    brightness_data = []
    labels = []

    # Process samples with progress bar and checkpoint saving
    try:
        with tqdm(range(num_samples), desc="Processing images") as pbar:
            for i in pbar:
                try:
                    # Get image and label - CIFAR-100 returns (image, single_integer_label)
                    image, label = dataset[i]
                    
                    # Convert label to int if it's a tensor
                    if hasattr(label, 'item'):
                        label = label.item()
                    
                    # Convert to RGBL
                    rgbl_image = rgb_to_rgbl(image)

                    # Split RGB and brightness channels
                    rgb_channels = rgbl_image[:3]  # First 3 channels (RGB)
                    brightness_channel = rgbl_image[3:4]  # Last channel (Brightness)

                    rgb_data.append(rgb_channels)
                    brightness_data.append(brightness_channel)
                    labels.append(label)

                    # Update progress bar
                    pbar.set_postfix({'processed': len(labels)})

                    # Save checkpoint every 1000 samples
                    if cache_file and (i + 1) % 1000 == 0:
                        checkpoint_data = {
                            'rgb_data': torch.stack(rgb_data).numpy(),
                            'brightness_data': torch.stack(brightness_data).numpy(),
                            'labels': np.array(labels),
                            'processed': i + 1
                        }
                        with open(f"{cache_file}.checkpoint", 'wb') as f:
                            pickle.dump(checkpoint_data, f)
                        print(f"   Checkpoint saved at sample {i + 1}")

                except Exception as e:
                    print(f"⚠️ Error processing sample {i}: {e}")
                    continue

    except KeyboardInterrupt:
        print(f"\n⚠️ Processing interrupted at sample {len(labels)}")
        if len(labels) > 0:
            print("💾 Saving partial results...")
            # Return what we have so far
            rgb_data = torch.stack(rgb_data).numpy()
            brightness_data = torch.stack(brightness_data).numpy()
            labels = np.array(labels)
            
            if cache_file:
                partial_data = {
                    'rgb_data': rgb_data,
                    'brightness_data': brightness_data,
                    'labels': labels,
                    'processed': len(labels)
                }
                with open(f"{cache_file}.partial", 'wb') as f:
                    pickle.dump(partial_data, f)
                print(f"   Partial results saved to {cache_file}.partial")
            
            return rgb_data, brightness_data, labels
        else:
            raise

    # Convert to numpy arrays
    if len(rgb_data) == 0:
        raise ValueError("No data was processed successfully")
        
    rgb_data = torch.stack(rgb_data).numpy()
    brightness_data = torch.stack(brightness_data).numpy()
    labels = np.array(labels)

    # Save to cache if specified
    if cache_file:
        final_data = {
            'rgb_data': rgb_data,
            'brightness_data': brightness_data,
            'labels': labels
        }
        with open(cache_file, 'wb') as f:
            pickle.dump(final_data, f)
        print(f"💾 Results cached to {cache_file}")

    print(f"✅ Conversion complete!")
    print(f"   RGB data shape: {rgb_data.shape}")
    print(f"   Brightness data shape: {brightness_data.shape}")
    print(f"   Labels shape: {labels.shape}")

    return rgb_data, brightness_data, labels

# Process training data
print("\n🚀 Processing training data...")
train_rgb, train_brightness, train_labels = convert_dataset_to_multi_stream(
    train_dataset, 
    max_samples=5000,  # Reduce for faster demo
    cache_file="train_processed.pkl"
)

# Process test data
print("\n🧪 Processing test data...")
test_rgb, test_brightness, test_labels = convert_dataset_to_multi_stream(
    test_dataset, 
    max_samples=1000,  # Reduce for faster demo
    cache_file="test_processed.pkl"
)

print(f"\n📊 Final Dataset Shapes:")
print(f"   Training RGB: {train_rgb.shape}")
print(f"   Training Brightness: {train_brightness.shape}")
print(f"   Training Labels: {train_labels.shape}")
print(f"   Test RGB: {test_rgb.shape}")
print(f"   Test Brightness: {test_brightness.shape}")
print(f"   Test Labels: {test_labels.shape}")

print("\n💡 Note: Processed data is cached. To reprocess, set force_reprocess=True")

In [None]:
# Verify CIFAR-100 data structure
print("🔍 Verifying CIFAR-100 data structure...")

# Check a few samples
sample = train_dataset[0]
image, target = sample
print(f"✅ Sample structure: (image_tensor, integer_label)")
print(f"   Image shape: {image.shape}")
print(f"   Label: {target} ({cifar100_fine_labels[target]})")
print(f"   Label type: {type(target)}")

print(f"\n📊 Dataset summary:")
print(f"   Training samples: {len(train_dataset)}")
print(f"   Test samples: {len(test_dataset)}")
print(f"   Classes: 100 (fine labels 0-99)")
print("✅ Ready for multi-stream processing!")

### 🔄 Recovery from Interruptions

If the processing was interrupted, you can check for and load partial results:

In [None]:
# Data Validation: Verify processed data structure
print("🔍 Validating processed data structure...")

def validate_processed_data(rgb_data, brightness_data, labels, dataset_name="dataset"):
    """Validate the structure and contents of processed data."""
    
    print(f"\n📊 {dataset_name} Validation:")
    
    # Check shapes
    print(f"   RGB data shape: {rgb_data.shape}")
    print(f"   Brightness data shape: {brightness_data.shape}")
    print(f"   Labels shape: {labels.shape}")
    
    # Check data types
    print(f"   RGB data type: {rgb_data.dtype}")
    print(f"   Brightness data type: {brightness_data.dtype}")
    print(f"   Labels data type: {labels.dtype}")
    
    # Check value ranges
    print(f"   RGB range: [{rgb_data.min():.3f}, {rgb_data.max():.3f}]")
    print(f"   Brightness range: [{brightness_data.min():.3f}, {brightness_data.max():.3f}]")
    print(f"   Label range: [{labels.min()}, {labels.max()}]")
    
    # Check for NaN or infinite values
    rgb_issues = np.isnan(rgb_data).sum() + np.isinf(rgb_data).sum()
    brightness_issues = np.isnan(brightness_data).sum() + np.isinf(brightness_data).sum()
    label_issues = np.isnan(labels).sum() + np.isinf(labels).sum()
    
    print(f"   RGB issues (NaN/Inf): {rgb_issues}")
    print(f"   Brightness issues (NaN/Inf): {brightness_issues}")
    print(f"   Label issues (NaN/Inf): {label_issues}")
    
    # Verify dimensions match
    n_samples = rgb_data.shape[0]
    assert brightness_data.shape[0] == n_samples, f"Brightness samples mismatch: {brightness_data.shape[0]} != {n_samples}"
    assert labels.shape[0] == n_samples, f"Label samples mismatch: {labels.shape[0]} != {n_samples}"
    
    # Verify channel dimensions
    assert rgb_data.shape[1] == 3, f"RGB should have 3 channels, got {rgb_data.shape[1]}"
    assert brightness_data.shape[1] == 1, f"Brightness should have 1 channel, got {brightness_data.shape[1]}"
    
    # Verify image dimensions (32x32 for CIFAR-100)
    assert rgb_data.shape[2:] == (32, 32), f"RGB image size should be 32x32, got {rgb_data.shape[2:]}"
    assert brightness_data.shape[2:] == (32, 32), f"Brightness image size should be 32x32, got {brightness_data.shape[2:]}"
    
    # Verify label range (0-99 for CIFAR-100)
    assert labels.min() >= 0 and labels.max() <= 99, f"Labels should be in range [0, 99], got [{labels.min()}, {labels.max()}]"
    
    print(f"   ✅ {dataset_name} validation passed!")
    
    return True

# Only validate if data has been processed
try:
    if 'train_rgb' in locals() and 'train_brightness' in locals() and 'train_labels' in locals():
        validate_processed_data(train_rgb, train_brightness, train_labels, "Training")
    
    if 'test_rgb' in locals() and 'test_brightness' in locals() and 'test_labels' in locals():
        validate_processed_data(test_rgb, test_brightness, test_labels, "Test")
    
    print("\n✅ All data validation checks passed!")
    
except Exception as e:
    print(f"❌ Data validation failed: {e}")
    print("💡 This might indicate an issue with the data processing step")
    
except NameError:
    print("ℹ️ No processed data found. Run the data processing cells first.")

In [None]:
# Recovery: Check for and load partial processing results
def check_and_load_partial_results():
    """Check for partial processing results and load them."""
    
    partial_files = {
        'train': 'train_processed.pkl.partial',
        'test': 'test_processed.pkl.partial',
        'train_checkpoint': 'train_processed.pkl.checkpoint',
        'test_checkpoint': 'test_processed.pkl.checkpoint'
    }
    
    results = {}
    
    for name, filepath in partial_files.items():
        if Path(filepath).exists():
            try:
                with open(filepath, 'rb') as f:
                    data = pickle.load(f)
                results[name] = data
                print(f"✅ Found {name} partial results: {data['processed']} samples")
            except Exception as e:
                print(f"❌ Failed to load {name}: {e}")
    
    return results

# Check for any partial results
print("🔍 Checking for partial processing results...")
partial_results = check_and_load_partial_results()

if partial_results:
    print("\n💡 Recovery options:")
    print("1. Continue processing from where you left off")
    print("2. Use partial results for faster testing")
    print("3. Clear cache and start fresh")
    
    # Example: Load partial training data if available
    if 'train' in partial_results:
        print(f"\n📊 Partial training data available:")
        data = partial_results['train']
        print(f"   RGB shape: {data['rgb_data'].shape}")
        print(f"   Brightness shape: {data['brightness_data'].shape}")
        print(f"   Labels shape: {data['labels'].shape}")
        
        # Uncomment to use partial data:
        # train_rgb = data['rgb_data']
        # train_brightness = data['brightness_data'] 
        # train_labels = data['labels']
else:
    print("ℹ️ No partial results found. Processing will start from beginning.")

## 6. Visualize Sample Images: RGB and Brightness Side by Side

Display sample images showing the original RGB and extracted brightness channels to understand the multi-stream transformation.

In [None]:
def visualize_rgb_brightness_samples(rgb_data, brightness_data, labels, num_samples=5):
    """
    Visualize RGB and brightness images side by side.

    Args:
        rgb_data: RGB image data [N, 3, H, W]
        brightness_data: Brightness image data [N, 1, H, W]
        labels: Image labels
        num_samples: Number of samples to visualize
    """
    fig, axes = plt.subplots(num_samples, 2, figsize=(8, 2.5 * num_samples))
    fig.suptitle('RGB vs Brightness Channel Comparison', fontsize=16, fontweight='bold')

    for i in range(num_samples):
        # Get RGB image (convert from CHW to HWC for matplotlib)
        rgb_img = np.transpose(rgb_data[i], (1, 2, 0))

        # Get brightness image (squeeze channel dimension)
        brightness_img = brightness_data[i, 0]  # Remove channel dimension

        # Get class name
        class_name = cifar100_fine_labels[labels[i]]

        # Plot RGB image
        axes[i, 0].imshow(rgb_img)
        axes[i, 0].set_title(f'RGB - {class_name}', fontweight='bold')
        axes[i, 0].axis('off')

        # Plot brightness image
        axes[i, 1].imshow(brightness_img, cmap='gray')
        axes[i, 1].set_title(f'Brightness - {class_name}', fontweight='bold')
        axes[i, 1].axis('off')

    plt.tight_layout()
    plt.show()

# Visualize sample images
print("🖼️ Sample RGB vs Brightness Images:")
visualize_rgb_brightness_samples(train_rgb, train_brightness, train_labels, num_samples=5)

# Show data statistics
def show_data_statistics(rgb_data, brightness_data, labels):
    """Show basic statistics about the data."""
    print(f"\n📊 Data Statistics:")
    print(f"   RGB data range: [{rgb_data.min():.3f}, {rgb_data.max():.3f}]")
    print(f"   Brightness data range: [{brightness_data.min():.3f}, {brightness_data.max():.3f}]")
    print(f"   Number of unique classes: {len(np.unique(labels))}")

    # Class distribution
    unique_labels, counts = np.unique(labels, return_counts=True)
    print(f"   Samples per class: {counts.min()} - {counts.max()}")
    print(f"   Average samples per class: {counts.mean():.1f}")

show_data_statistics(train_rgb, train_brightness, train_labels)

## 7. Additional Data Visualizations

Explore the data with helpful visualizations including class distribution and pixel intensity analysis.

In [None]:
# Class distribution visualization
def plot_class_distribution(labels, title="Class Distribution"):
    """Plot the distribution of classes in the dataset."""
    plt.figure(figsize=(12, 6))
    unique_labels, counts = np.unique(labels, return_counts=True)

    plt.bar(unique_labels, counts, alpha=0.7, color='skyblue', edgecolor='navy')
    plt.title(title, fontsize=14, fontweight='bold')
    plt.xlabel('Class ID')
    plt.ylabel('Number of Samples')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Pixel intensity histograms
def plot_intensity_histograms(rgb_data, brightness_data):
    """Plot histograms of pixel intensities for RGB and brightness channels."""
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    fig.suptitle('Pixel Intensity Distributions', fontsize=16, fontweight='bold')

    # RGB histograms
    colors = ['red', 'green', 'blue']
    for i, color in enumerate(colors):
        axes[0, 0].hist(rgb_data[:, i].flatten(), bins=50, alpha=0.6,
                       color=color, label=f'{color.upper()} channel')
    axes[0, 0].set_title('RGB Channel Intensities')
    axes[0, 0].set_xlabel('Pixel Value')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Brightness histogram
    axes[0, 1].hist(brightness_data.flatten(), bins=50, alpha=0.7,
                   color='gray', edgecolor='black')
    axes[0, 1].set_title('Brightness Channel Intensities')
    axes[0, 1].set_xlabel('Pixel Value')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].grid(True, alpha=0.3)

    # Mean pixel values per channel
    rgb_means = np.mean(rgb_data, axis=(0, 2, 3))
    brightness_mean = np.mean(brightness_data)

    channel_names = ['Red', 'Green', 'Blue', 'Brightness']
    channel_means = [rgb_means[0], rgb_means[1], rgb_means[2], brightness_mean]

    axes[1, 0].bar(channel_names, channel_means,
                  color=['red', 'green', 'blue', 'gray'], alpha=0.7)
    axes[1, 0].set_title('Mean Pixel Values by Channel')
    axes[1, 0].set_ylabel('Mean Pixel Value')
    axes[1, 0].grid(True, alpha=0.3)

    # Sample grid
    axes[1, 1].axis('off')

    plt.tight_layout()
    plt.show()

# Sample grid of images
def plot_sample_grid(rgb_data, labels, grid_size=(4, 8)):
    """Plot a grid of sample images."""
    fig, axes = plt.subplots(grid_size[0], grid_size[1], figsize=(16, 8))
    fig.suptitle('Sample Images from CIFAR-100 Dataset', fontsize=16, fontweight='bold')

    for i in range(grid_size[0]):
        for j in range(grid_size[1]):
            idx = i * grid_size[1] + j
            if idx < len(rgb_data):
                img = np.transpose(rgb_data[idx], (1, 2, 0))
                class_name = cifar100_fine_labels[labels[idx]]

                axes[i, j].imshow(img)
                axes[i, j].set_title(class_name, fontsize=8)
                axes[i, j].axis('off')

    plt.tight_layout()
    plt.show()

# Generate visualizations
print("📊 Generating additional visualizations...")

# Class distribution
plot_class_distribution(train_labels, "Training Set Class Distribution")

# Intensity histograms
plot_intensity_histograms(train_rgb[:1000], train_brightness[:1000])  # Sample for speed

# Sample grid
plot_sample_grid(train_rgb, train_labels)

## 8. Create Multi-Stream Neural Network Models

Create both the BaseMultiChannelNetwork (dense) and MultiChannelResNetNetwork (CNN) models for comparison.

---

## 🔧 Troubleshooting Appendix

**Only run these cells if you encounter specific issues:**

In [None]:
# CPU-only fallback (if CUDA issues persist)
print("🔄 Installing CPU-only PyTorch for compatibility...")

!pip uninstall -y torch torchvision torchaudio
!pip install -q numpy==1.24.3 torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip install -q tqdm matplotlib seaborn scikit-learn

# Test installation
import torch
import numpy as np
print(f"✅ NumPy: {np.__version__}")
print(f"✅ PyTorch: {torch.__version__}")
print(f"⚠️ CUDA available: {torch.cuda.is_available()}")
print("ℹ️ Note: Using CPU-only version. Training will be slower but more reliable.")

In [None]:
# Force NumPy version fix (if NumPy 2.x conflicts occur)
print("🔧 Forcing NumPy 1.x for compatibility...")

import numpy as np
if np.__version__.startswith('2.'):
    print(f"⚠️ NumPy 2.x detected: {np.__version__}")
    !pip install -q "numpy<2.0" --force-reinstall
    
    # Re-import
    import importlib
    importlib.reload(np)
    import numpy as np
    print(f"✅ NumPy downgraded to: {np.__version__}")
else:
    print(f"✅ NumPy version OK: {np.__version__}")

# Test PyTorch integration
import torch
test_array = np.array([1, 2, 3], dtype=np.float32)
test_tensor = torch.from_numpy(test_array)
print("✅ NumPy-PyTorch integration working!")