In [None]:
# train and evaluate the cnn model
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from typing import Tuple, Dict
import matplotlib.pyplot as plt

def get_cifar100_data() -> Tuple[DataLoader, DataLoader]:
    """
    Load and prepare CIFAR-100 dataset
    
    Returns:
        Tuple containing train and test dataloaders
    """
    # Define data transforms
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])

    # Load datasets
    trainset = torchvision.datasets.CIFAR100(
        root="./data", 
        train=True,
        download=True, 
        transform=transform
    )
    
    testset = torchvision.datasets.CIFAR100(
        root="./data", 
        train=False,
        download=True, 
        transform=test_transform
    )

    # Create data loaders
    trainloader = DataLoader(
        trainset, 
        batch_size=128,
        shuffle=True, 
        num_workers=2
    )
    
    testloader = DataLoader(
        testset, 
        batch_size=128,
        shuffle=False, 
        num_workers=2
    )

    return trainloader, testloader

def train_epoch(
    model: nn.Module,
    trainloader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device
) -> float:
    """
    Train model for one epoch
    
    Args:
        model: CNN model to train
        trainloader: Training data loader
        criterion: Loss function
        optimizer: Optimizer
        device: Device to train on
        
    Returns:
        Average training loss for the epoch
    """
    model.train()
    running_loss = 0.0
    
    for inputs, targets in trainloader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    return running_loss / len(trainloader)



# Set training parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 200
learning_rate = 0.1
momentum = 0.9
weight_decay = 5e-4

# Initialize model, criterion, optimizer
model = CNN(num_classes=100).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    model.parameters(),
    lr=learning_rate,
    momentum=momentum,
    weight_decay=weight_decay
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

# Get data
trainloader, testloader = get_cifar100_data()

# Training loop
train_losses = []
test_metrics = []

print("Starting training...")
for epoch in range(num_epochs):
    # Train
    train_loss = train_epoch(model, trainloader, criterion, optimizer, device)
    train_losses.append(train_loss)
    
    # Evaluate
    metrics = evaluate(model, testloader, device)
    test_metrics.append(metrics)
    
    # Update learning rate
    scheduler.step()
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Test Loss: {metrics['loss']:.4f}")
        print(f"Test Accuracy: {metrics['accuracy']:.2f}%")
        print("-" * 50)

# Plot training curves
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label="Train Loss")
plt.plot([m["loss"] for m in test_metrics], label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot([m["accuracy"] for m in test_metrics], label="Test Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.show()



