# CIFAR-10 Image Classification - Modern Implementation

This notebook demonstrates the improved, modular implementation of CIFAR-10 image classification using modern PyTorch practices.

## Key Improvements
- **Modular Design**: Separated concerns into different modules
- **Better Architecture**: Improved CNN with modern techniques
- **Configuration Management**: Flexible experiment configuration
- **Advanced Training**: Professional training pipeline with monitoring
- **Comprehensive Evaluation**: Detailed analysis and visualization

In [None]:
# Import necessary libraries
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

# Add project root to path
project_root = Path().resolve().parent
sys.path.append(str(project_root))

# Import our custom modules
from src.data.dataset import CIFAR10DataModule
from src.models.cifar10_cnn import get_model
from src.utils.trainer import ModelTrainer
from src.utils.visualization import (
    show_sample_images, plot_training_history, 
    plot_confusion_matrix, visualize_predictions
)
from configs.config import ExperimentConfig, get_quick_config

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else False}")

## 1. Configuration Setup

We'll use the modern configuration system to set up our experiment.

In [None]:
# Create experiment configuration
config = get_quick_config()  # Quick config for notebook demonstration
config.name = "notebook_demo"
config.description = "Demonstration of improved CIFAR-10 implementation"
config.training.epochs = 10  # Reduced for demonstration
config.data.batch_size = 64
config.model.model_name = "improved"  # Use the improved architecture

print(f"Experiment: {config.name}")
print(f"Description: {config.description}")
print(f"Device: {config.system.device}")
print(f"Model: {config.model.model_name}")
print(f"Epochs: {config.training.epochs}")
print(f"Batch size: {config.data.batch_size}")

## 2. Data Loading and Visualization

Our improved data module handles all preprocessing and augmentation automatically.

In [None]:
# Create data module
data_module = CIFAR10DataModule(
    data_dir=config.data.data_dir,
    batch_size=config.data.batch_size,
    val_split=config.data.val_split,
    num_workers=2,  # Reduced for notebook
    pin_memory=False  # Disabled for notebook compatibility
)

# Get data loaders
train_loader, val_loader, test_loader = data_module.get_dataloaders()

print(f"Training samples: {len(data_module.train_dataset)}")
print(f"Validation samples: {len(data_module.val_dataset)}")
print(f"Test samples: {len(data_module.test_dataset)}")
print(f"\nCIFAR-10 classes: {data_module.classes}")

In [None]:
# Visualize sample images from training set
show_sample_images(
    train_loader, 
    num_images=8, 
    title="Sample Training Images (with augmentation)"
)

In [None]:
# Visualize sample images from test set (no augmentation)
show_sample_images(
    test_loader, 
    num_images=8, 
    title="Sample Test Images (no augmentation)"
)

## 3. Model Architecture

Let's examine our improved CNN architecture.

In [None]:
# Create the improved model
device = torch.device(config.system.device)
model = get_model(
    model_name=config.model.model_name,
    num_classes=config.model.num_classes,
    dropout_rate=config.model.dropout_rate
)

print(f"Model architecture: {config.model.model_name}")
print(f"\nModel summary:")
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# Test model with a sample batch
model.eval()
sample_batch, sample_labels = next(iter(test_loader))
print(f"Input shape: {sample_batch.shape}")

with torch.no_grad():
    output = model(sample_batch)
    print(f"Output shape: {output.shape}")
    print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")
    
    # Apply softmax to get probabilities
    probabilities = torch.softmax(output, dim=1)
    print(f"Probability range: [{probabilities.min():.3f}, {probabilities.max():.3f}]")
    print(f"Probability sum (should be ~1.0): {probabilities[0].sum():.3f}")

## 4. Training Setup

Our training system includes modern optimizers, schedulers, and monitoring.

In [None]:
# Create optimizer and scheduler
optimizer = optim.Adam(
    model.parameters(),
    lr=config.training.learning_rate,
    weight_decay=config.training.weight_decay
)

scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config.training.epochs
)

# Create trainer
trainer = ModelTrainer(
    model=model,
    device=device,
    criterion=nn.CrossEntropyLoss(),
    optimizer=optimizer,
    scheduler=scheduler
)

print(f"Optimizer: {type(optimizer).__name__}")
print(f"Learning rate: {config.training.learning_rate}")
print(f"Scheduler: {type(scheduler).__name__}")
print(f"Device: {device}")

## 5. Model Training

Let's train our model with the professional training pipeline.

In [None]:
# Train the model
print("Starting training...")
print("=" * 50)

history = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=config.training.epochs,
    save_best=True,
    early_stopping_patience=5
)

print("\nTraining completed!")
print(f"Best validation accuracy: {trainer.best_val_acc:.2f}%")

## 6. Training Visualization

Visualize the training progress with our improved plotting utilities.

In [None]:
# Plot training history
plot_training_history(
    train_losses=history['train_losses'],
    val_losses=history['val_losses'],
    val_accuracies=history['val_accuracies'],
    learning_rates=history.get('learning_rates')
)

# Print training summary
print(f"\nTraining Summary:")
print(f"Final training loss: {history['train_losses'][-1]:.4f}")
print(f"Final validation loss: {history['val_losses'][-1]:.4f}")
print(f"Final validation accuracy: {history['val_accuracies'][-1]:.2f}%")
print(f"Best validation accuracy: {max(history['val_accuracies']):.2f}%")

## 7. Model Evaluation

Comprehensive evaluation with detailed metrics and visualizations.

In [None]:
# Evaluate on test set
print("Evaluating on test set...")
test_results = trainer.evaluate(test_loader)

print(f"\nTest Results:")
print(f"Test Loss: {test_results['test_loss']:.4f}")
print(f"Test Accuracy: {test_results['test_accuracy']:.2f}%")

# Print per-class accuracies
print(f"\nPer-class accuracies:")
for i, class_name in enumerate(data_module.classes):
    acc = test_results['class_accuracies'][i]
    print(f"  {class_name}: {acc:.2f}%")

In [None]:
# Generate predictions for confusion matrix
model.eval()
all_predictions = []
all_targets = []

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        
        all_predictions.extend(predicted.cpu().numpy())
        all_targets.extend(targets.cpu().numpy())

all_predictions = np.array(all_predictions)
all_targets = np.array(all_targets)

print(f"Generated predictions for {len(all_predictions)} samples")

In [None]:
# Plot confusion matrix
plot_confusion_matrix(
    y_true=all_targets,
    y_pred=all_predictions,
    class_names=data_module.classes,
    title="Confusion Matrix - Improved CIFAR-10 CNN"
)

In [None]:
# Visualize model predictions
visualize_predictions(
    model=model,
    dataloader=test_loader,
    device=device,
    num_samples=8,
    class_names=data_module.classes,
    title="Model Predictions on Test Images"
)

## 8. Performance Analysis

Let's analyze the model's performance and compare it with the original implementation.

In [None]:
# Performance comparison
print("Performance Comparison:")
print("=" * 40)
print(f"Original Implementation:")
print(f"  - Test Accuracy: 72.2%")
print(f"  - Architecture: Simple CNN")
print(f"  - Training: Basic setup")
print(f"")
print(f"Improved Implementation:")
print(f"  - Test Accuracy: {test_results['test_accuracy']:.2f}%")
print(f"  - Architecture: {config.model.model_name.title()} CNN")
print(f"  - Training: Professional pipeline")
print(f"  - Data Augmentation: Advanced")
print(f"  - Regularization: Dropout + BatchNorm")

# Calculate improvement
original_acc = 72.2
current_acc = test_results['test_accuracy']
improvement = current_acc - original_acc

print(f"\nImprovement: {improvement:+.2f} percentage points")

if improvement > 0:
    print("✅ The improved implementation performs better!")
elif improvement > -2:
    print("📊 Similar performance with better code quality")
else:
    print("⚠️ Performance may need tuning (try more epochs)")

## 9. Model Insights

Analyze what the model learned and identify areas for improvement.

In [None]:
# Find best and worst performing classes
class_accs = [test_results['class_accuracies'][i] for i in range(10)]
best_class_idx = np.argmax(class_accs)
worst_class_idx = np.argmin(class_accs)

print(f"Model Performance Analysis:")
print(f"=" * 30)
print(f"Best performing class: {data_module.classes[best_class_idx]} ({class_accs[best_class_idx]:.2f}%)")
print(f"Worst performing class: {data_module.classes[worst_class_idx]} ({class_accs[worst_class_idx]:.2f}%)")
print(f"Performance gap: {class_accs[best_class_idx] - class_accs[worst_class_idx]:.2f} percentage points")

# Calculate confusion between similar classes
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(all_targets, all_predictions)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Find most confused pairs
print(f"\nMost confused class pairs:")
for i in range(10):
    for j in range(10):
        if i != j and cm_normalized[i, j] > 0.15:  # More than 15% confusion
            print(f"  {data_module.classes[i]} → {data_module.classes[j]}: {cm_normalized[i, j]*100:.1f}%")

## 10. Saving and Loading Models

Demonstrate how to save and load the trained model for future use.

In [None]:
# Save the trained model
checkpoint_path = project_root / "notebooks" / "demo_model.pth"

trainer.save_checkpoint(
    filepath=str(checkpoint_path),
    epoch=config.training.epochs,
    config=config.to_dict(),
    test_results=test_results
)

print(f"Model saved to: {checkpoint_path}")

# Demonstrate loading
print("\nDemonstrating model loading...")
new_model = get_model(
    model_name=config.model.model_name,
    num_classes=config.model.num_classes,
    dropout_rate=config.model.dropout_rate
)

new_trainer = ModelTrainer(model=new_model, device=device)
loaded_checkpoint = new_trainer.load_checkpoint(str(checkpoint_path))

print(f"Model loaded successfully from epoch {loaded_checkpoint.get('epoch', 'unknown')}")
print(f"Loaded best validation accuracy: {new_trainer.best_val_acc:.2f}%")

## Summary

This notebook demonstrates the significantly improved CIFAR-10 implementation with:

### ✅ **Code Quality Improvements**
- **Modular Design**: Clean separation of concerns
- **Configuration Management**: Flexible experiment setup
- **Professional Training Pipeline**: Modern PyTorch practices
- **Comprehensive Evaluation**: Detailed analysis tools

### 🚀 **Technical Improvements**
- **Better Architecture**: Improved CNN with modern techniques
- **Advanced Data Augmentation**: Robust preprocessing
- **Smart Training**: Learning rate scheduling, early stopping
- **Proper Normalization**: CIFAR-10 specific statistics

### 📊 **Portfolio Readiness**
- **Documentation**: Comprehensive README and comments
- **Reproducibility**: Seed management and configuration
- **Extensibility**: Easy to add new models and experiments
- **Professional Structure**: Industry-standard organization

### 🎯 **Next Steps for Portfolio**
1. **Experiment with different architectures** (ResNet, EfficientNet)
2. **Add transfer learning** capabilities
3. **Implement hyperparameter optimization**
4. **Create a web demo** with Streamlit or Gradio
5. **Add model deployment** scripts

This implementation showcases modern machine learning engineering practices and demonstrates your growth from the original Udacity project!