# CNN Model Training Example

This notebook demonstrates how to train a CNN model using the provided framework.

In [None]:
# Import necessary modules
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Import custom modules
import sys
sys.path.append('..')
from src.models.cnn_model import SimpleCNN, VGGStyleCNN, create_model
from src.training.trainer import Trainer
from src.training.metrics import accuracy, confusion_matrix
from src.utils.helpers import set_seed, get_device, get_optimizer, get_scheduler
from src.utils.visualization import plot_training_history, plot_confusion_matrix

## 1. Setup

In [None]:
# Set random seed for reproducibility
set_seed(42)

# Get device
device = get_device()

# Configuration
BATCH_SIZE = 64
EPOCHS = 50
LEARNING_RATE = 0.001
NUM_CLASSES = 10

## 2. Load Data

In [None]:
# CIFAR-10 normalization values
mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)

# Training transforms with data augmentation
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

# Test transforms (no augmentation)
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(
    root='../data',
    train=True,
    download=True,
    transform=train_transform,
)

test_dataset = datasets.CIFAR10(
    root='../data',
    train=False,
    download=True,
    transform=test_transform,
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 3. Create Model

In [None]:
# Create CNN model
model = create_model(
    model_type='simple',
    num_classes=NUM_CLASSES,
    input_channels=3,
    input_size=(32, 32),
)

print(f"Model: {model.__class__.__name__}")
print(f"Trainable parameters: {model.count_parameters():,}")

## 4. Training Setup

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = get_optimizer(
    model=model,
    optimizer_name='adam',
    learning_rate=LEARNING_RATE,
)

# Learning rate scheduler
scheduler = get_scheduler(
    optimizer=optimizer,
    scheduler_name='cosine',
    epochs=EPOCHS,
)

# Create trainer
trainer = Trainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    scheduler=scheduler,
)

## 5. Train Model

In [None]:
# Train the model
history = trainer.train(
    train_loader=train_loader,
    val_loader=test_loader,
    epochs=EPOCHS,
    early_stopping_patience=10,
)

## 6. Visualize Results

In [None]:
# Plot training history
plot_training_history(history)

## 7. Evaluate Model

In [None]:
# Evaluate on test set
model.eval()
all_predictions = []
all_targets = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        all_predictions.append(predicted.cpu())
        all_targets.append(labels)

all_predictions = torch.cat(all_predictions)
all_targets = torch.cat(all_targets)

# Calculate accuracy
test_acc = accuracy(all_predictions, all_targets)
print(f"Test Accuracy: {test_acc:.2f}%")

In [None]:
# Plot confusion matrix
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']

cm = confusion_matrix(all_predictions, all_targets, num_classes=10)
plot_confusion_matrix(cm, class_names=class_names)