In [None]:
import torch
import os
# Enable cuDNN autotuner for fixed input sizes (can improve throughput)
torch.backends.cudnn.benchmark = True
# Configure DataLoader workers and prefetch
num_workers = max(1, os.cpu_count() - 1)
prefetch_factor = 2
print(f"Using {num_workers} num_workers and prefetch_factor={prefetch_factor}")

In [None]:
import torch
# Reset peak memory stats and grab one mini-batch
torch.cuda.reset_peak_memory_stats()
rgb_batch, bright_batch, _ = next(iter(train_loader))
# Move to device and forward/backward to measure memory
resnet50_mc_streaming = mc_resnet50(num_classes=num_classes, device=str(device), use_amp=True, groups=2)
with torch.cuda.device(device):
    _ = resnet50_mc_streaming(rgb_batch.to(device), bright_batch.to(device))
    # If using amp & need backward, wrap loss/backward here
print(f"Peak GPU usage: {torch.cuda.max_memory_allocated()/1024**3:.1f} GB")

In [None]:
# Measure GPU memory usage including backward pass and optimizer step
import torch.optim as optim
optimizer = optim.AdamW(resnet50_mc_streaming.parameters(), lr=0.1)
loss_fn = torch.nn.CrossEntropyLoss()
# Reset and run forward+backward+step
torch.cuda.reset_peak_memory_stats()
rgb, bright, labels = rgb_batch.to(device), bright_batch.to(device), labels.to(device)
outputs = resnet50_mc_streaming(rgb, bright)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
print(f"Peak GPU usage (forward+backward): {torch.cuda.max_memory_allocated()/1024**3:.1f} GB")

Based on the peak‐memory measurement (~32 GB for a batch size of 128), you don’t have enough headroom to double the batch to 256 without risking an OOM. 

To effectively train with an *effective* batch of 256:
- Keep your DataLoader at `batch_size=128` and use `gradient_accumulation_steps=2` in `fit()`.
- Alternatively, incrementally test intermediate sizes (e.g., 160, 192) and re‐measure before going higher.

In [None]:
# mc_resnet50 with streaming dual-channel data

import traceback

# Test mc_resnet50 with StreamingDualChannelDataset for ImageNet
print("🚀 TESTING MC-RESNET50 WITH STREAMING DUAL-CHANNEL IMAGENET DATA")
print("=" * 70)

print("🧹 Clearing GPU cache...")
torch.cuda.empty_cache()
if torch.cuda.is_available():
    print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

from src.data_utils.streaming_dual_channel_dataset import (
    StreamingDualChannelDataset,
    create_imagenet_dual_channel_train_val_dataloaders,
    create_imagenet_dual_channel_test_dataloader,
    create_default_imagenet_transforms
)
from src.models2.multi_channel.mc_resnet import mc_resnet50

# Set up device
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"🚀 Using CUDA: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("🚀 Using Apple Metal Performance Shaders (MPS)")
else:
    device = torch.device("cpu")
    print("💻 Using CPU")

# Configuration
batch_size = 128  # this is the max possible batch_size currently
image_size = (224, 224)
num_epochs = 1  # Smaller number for demonstration

TRAIN_FOLDERS = [
    "data/ImageNet-1K/train_images_0"
    # "../data/ImageNet/train_images_1",  # Add more if you have split training data
]
VAL_FOLDER = "data/ImageNet-1K/val_images"
TEST_FOLDER = "data/ImageNet-1K/test_images"
TRUTH_FILE = "data/ImageNet-1K/ILSVRC2013_devkit/data/ILSVRC2013_clsloc_validation_ground_truth.txt"

print(f"\n📂 Dataset Configuration:")
print(f"Training folders: {TRAIN_FOLDERS}")
print(f"Validation folder: {VAL_FOLDER}")
print(f"Truth file: {TRUTH_FILE}")
print(f"Batch size: {batch_size}")
print(f"Image size: {image_size}")
print(f"Training epochs: {num_epochs}")

# Create DataLoaders using our streaming dataset
print(f"\n📊 Creating Streaming Dual-Channel DataLoaders...")
try:
    train_loader, val_loader = create_imagenet_dual_channel_train_val_dataloaders(
        train_folders=TRAIN_FOLDERS,
        val_folder=VAL_FOLDER,
        truth_file=TRUTH_FILE,
        # train_transform=train_transform,
        # val_transform=val_transform,
        batch_size=batch_size,
        val_batch_size=batch_size,
        image_size=image_size,
        num_workers=num_workers,  # Reduce for notebook stability
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=prefetch_factor
    )

    print(f"✅ Train loader: {len(train_loader)} batches")
    print(f"✅ Val loader: {len(val_loader)} batches")
    print("✅ DataLoaders created successfully!")

    # Determine number of classes from the dataset
    if hasattr(train_loader.dataset, 'class_to_idx') and train_loader.dataset.class_to_idx:
        num_classes = len(train_loader.dataset.class_to_idx)
        print(f"✅ Number of classes detected: {num_classes}")
    else:
        num_classes = 1000  # Default ImageNet classes
        print(f"⚠️  Using default ImageNet classes: {num_classes}")

    # Create and train MC-ResNet50 model
    print(f"\n🏗️  Creating MC-ResNet50 model...")
    resnet50_mc_streaming = mc_resnet50(num_classes=num_classes, device=str(device), use_amp=True)

    # Compile with optimized settings for ImageNet
    print(f"⚙️  Compiling model with optimized settings...")
    resnet50_mc_streaming.compile(
        optimizer='adamw',
        loss='cross_entropy',
        learning_rate=0.1,
        weight_decay=1e-5,      # Standard ImageNet weight decay
        scheduler='onecycle',
    )

    print(f"\n🎯 Starting training...")
    print(f"Training with {len(train_loader)} train batches and {len(val_loader)} val batches")

    # Clear GPU memory before training
    print("🧹 Clearing GPU cache before training...")
    torch.cuda.empty_cache()

    # Optional: Print memory stats
    if torch.cuda.is_available():
        print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")


    # Train the model
    history_mc_streaming = resnet50_mc_streaming.fit(
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=num_epochs,
        batch_size=batch_size,
        early_stopping=False,
        verbose=True,
        gradient_accumulation_steps=2
        )

    print(f"\n🎉 Training completed!")
    print(f"Best validation accuracy: {max(history_mc_streaming['val_accuracy']):.4f}")
    print(f"Final train accuracy: {history_mc_streaming['train_accuracy'][-1]:.4f}")
    print(f"Final validation accuracy: {history_mc_streaming['val_accuracy'][-1]:.4f}")

    # Evaluate on validation set (since we don't have test set in this example)
    print(f"\n📊 Final evaluation...")
    evaluate_mc_streaming = resnet50_mc_streaming.evaluate(val_loader)
    print(f"Validation loss: {evaluate_mc_streaming['loss']:.4f}")
    print(f"Validation accuracy: {evaluate_mc_streaming['accuracy']:.4f}")


    print(f"\n✅ StreamingDualChannelDataset test completed successfully!")
    print(f"🎊 The model trained on ImageNet data using on-demand loading!")

except FileNotFoundError as e:
    print(f"❌ Dataset not found: {e}")
    print(f"\n💡 To run this test, you need to:")
    print(f"1. Download ImageNet dataset")
    print(f"2. Update the paths above to point to your ImageNet data:")
    print(f"   - TRAIN_FOLDERS: path(s) to training images")
    print(f"   - VAL_FOLDER: path to validation images")
    print(f"   - TRUTH_FILE: path to validation ground truth file")
    print(f"3. Ensure the data is in the expected ImageNet format")

except Exception as e:
    print(f"❌ Error during training: {e}")
    print(f"This might be due to missing data or configuration issues.")
    print(f"Please check the dataset paths and ensure ImageNet data is available.")
    traceback.print_exc()

print(f"\n" + "=" * 70)
print(f"🏁 StreamingDualChannelDataset Demo Complete!")
#/content/drive/MyDrive/Multi-Stream-Neural-Networks/data/ImageNet-1K


In [None]:
# Grab a single batch from our ImageNet train_loader
rgb_batch, bright_batch, labels = next(iter(train_loader))
print(f"RGB batch shape: {rgb_batch.shape}\nBrightness batch shape: {bright_batch.shape}\nLabels shape: {labels.shape}")

In [None]:
# Grab a single batch from our ImageNet train_loader
rgb_batch, bright_batch, labels = next(iter(train_loader))
print(f"RGB batch shape: {rgb_batch.shape}\nBrightness batch shape: {bright_batch.shape}\nLabels shape: {labels.shape}")