# CIFAR-10 Image Classification with ResNet

This notebook demonstrates:
1. Loading and visualizing the CIFAR-10 dataset
2. Data preprocessing and augmentation
3. Model training and evaluation
4. Visualizing model predictions and performance

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

from models.resnet import ResNet18
from utils.data_loader import get_data_loaders, CLASSES
from utils.transforms import get_train_transforms, get_test_transforms, inverse_normalize
from utils.metrics import MetricsTracker

## 1. Data Loading and Visualization

In [None]:
# Load data
train_loader, val_loader, test_loader = get_data_loaders(
    data_dir='../data',
    batch_size=64,
    num_workers=2
)

# Get a batch of training data
images, labels = next(iter(train_loader))
print(f"Batch shape: {images.shape}")
print(f"Labels shape: {labels.shape}")

In [None]:
def show_images(images, labels, num_images=8):
    """Display a grid of images with their labels"""
    plt.figure(figsize=(15, 8))
    for i in range(num_images):
        plt.subplot(2, num_images//2, i+1)
        img = inverse_normalize(images[i])
        img = img.permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        plt.imshow(img)
        plt.title(CLASSES[labels[i]])
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# Display sample images
show_images(images, labels)

## 2. Data Augmentation Visualization

In [None]:
def show_augmentations(image, num_augments=5):
    """Display original image and its augmented versions"""
    transform = get_train_transforms()
    plt.figure(figsize=(15, 3))
    
    # Show original
    plt.subplot(1, num_augments+1, 1)
    img = inverse_normalize(image)
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    plt.title('Original')
    plt.axis('off')
    
    # Show augmented versions
    for i in range(num_augments):
        plt.subplot(1, num_augments+1, i+2)
        augmented = transform(inverse_normalize(image))
        img = inverse_normalize(augmented)
        img = img.permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        plt.imshow(img)
        plt.title(f'Augmented {i+1}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Show augmentations for a single image
show_augmentations(images[0])

## 3. Model Training

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize model and training components
model = ResNet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Initialize metrics trackers
train_metrics = MetricsTracker()
val_metrics = MetricsTracker()

# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

In [None]:
def train_epoch(model, train_loader, optimizer, criterion, device):
    """Train for one epoch"""
    model.train()
    train_metrics.reset()
    
    for inputs, labels in tqdm(train_loader, desc="Training", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_metrics.update(outputs, labels, loss)
    
    return train_metrics.get_metrics()

def validate(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    val_metrics.reset()
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_metrics.update(outputs, labels, loss)
    
    return val_metrics.get_metrics()

In [None]:
# Training loop
num_epochs = 10
best_val_acc = 0.0

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Train
    train_results = train_epoch(model, train_loader, optimizer, criterion, device)
    history['train_loss'].append(train_results['loss'])
    history['train_acc'].append(train_results['accuracy'])
    
    # Validate
    val_results = validate(model, val_loader, criterion, device)
    history['val_loss'].append(val_results['loss'])
    history['val_acc'].append(val_results['accuracy'])
    
    # Update learning rate
    scheduler.step()
    
    print(f"Train Loss: {train_results['loss']:.4f}, Accuracy: {train_results['accuracy']:.4f}")
    print(f"Val Loss: {val_results['loss']:.4f}, Accuracy: {val_results['accuracy']:.4f}")
    
    # Save best model
    if val_results['accuracy'] > best_val_acc:
        best_val_acc = val_results['accuracy']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': best_val_acc,
        }, '../checkpoints/best_model.pth')

## 4. Training Visualization

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

# Loss plot
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train')
plt.plot(history['val_loss'], label='Validation')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Accuracy plot
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train')
plt.plot(history['val_acc'], label='Validation')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

## 5. Model Evaluation

In [None]:
# Load best model
best_model = ResNet18().to(device)
checkpoint = torch.load('../checkpoints/best_model.pth')
best_model.load_state_dict(checkpoint['model_state_dict'])

# Evaluate on test set
test_metrics = MetricsTracker()
best_model.eval()

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Testing"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = best_model(inputs)
        loss = criterion(outputs, labels)
        test_metrics.update(outputs, labels, loss)

test_results = test_metrics.get_metrics()
print(f"\nTest Loss: {test_results['loss']:.4f}, Accuracy: {test_results['accuracy']:.4f}")

# Plot confusion matrix
test_metrics.plot_confusion_matrix()

# Print classification report
test_metrics.print_classification_report()

## 6. Model Predictions Visualization

In [None]:
def show_predictions(model, loader, num_images=8):
    """Display model predictions alongside true labels"""
    model.eval()
    images, labels = next(iter(loader))
    images, labels = images.to(device), labels.to(device)
    
    with torch.no_grad():
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
    
    plt.figure(figsize=(15, 8))
    for i in range(num_images):
        plt.subplot(2, num_images//2, i+1)
        img = inverse_normalize(images[i].cpu())
        img = img.permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        plt.imshow(img)
        
        color = 'green' if preds[i] == labels[i] else 'red'
        plt.title(f'Pred: {CLASSES[preds[i]]}\nTrue: {CLASSES[labels[i]]}',
                  color=color)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# Show predictions on test set
show_predictions(best_model, test_loader)