# Optimized+augmented PyTorch RGB CNN

In this notebook, we train the winning CNN architecture from the Optuna run in notebook 04 on the CIFAR-10 dataset with image augmentation for improved generalization.

## Notebook set-up

### Imports

In [None]:
# Standard library imports
from pathlib import Path

# Third party imports
import matplotlib.pyplot as plt
import numpy as np
import optuna
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# Package imports
from cifar10_tools.pytorch.evaluation import evaluate_model
from cifar10_tools.pytorch.plotting import (
    plot_sample_images, plot_learning_curves, 
    plot_confusion_matrix, plot_class_probability_distributions,
    plot_evaluation_curves
)
from cifar10_tools.pytorch.training import train_model

# Suppress Optuna info messages
optuna.logging.set_verbosity(optuna.logging.WARNING)

# Set random seeds for reproducibility
torch.manual_seed(315)
np.random.seed(315)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

### Hyperparameters

In [None]:
batch_size = 10000 # Smaller batch size for augmented training
epochs = 100       # Extended training with augmentation
print_every = 20   # Print training progress every n epochs

# CIFAR-10 class names in class order
class_names = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

num_classes = len(class_names)

In [None]:
# Load best hyperparameters from Optuna study
storage_path = Path('../data/pytorch/cnn_optimization.db')
storage_url = f'sqlite:///{storage_path}'

study = optuna.load_study(
    study_name='cnn_optimization',
    storage=storage_url
)

best_params = study.best_trial.params

print('Loaded best hyperparameters from Optuna study:')

for key, value in best_params.items():
    print(f'  {key}: {value}')

print(f'\nBest validation accuracy from optimization: {study.best_trial.value:.2f}%')


## 1. Load and preprocess CIFAR-10 data with augmentation

CIFAR-10 contains 32x32 color images (3 channels) across 10 classes. We use RGB images with data augmentation to improve model generalization.

### 1.1. Define transforms

In [None]:
# Training transform with augmentation
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Validation/test transform (no augmentation)
eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

print('Training augmentations:')
print('  - Random horizontal flip (p=0.5)')
print('  - Random rotation (±15°)')
print('  - Random translation (±10%)')
print('  - Color jitter (brightness, contrast, saturation)')

### 1.2. Load datasets

In [None]:
# Make sure data directory exists
data_dir = Path('../data/pytorch/cifar10')
data_dir.mkdir(parents=True, exist_ok=True)

# Load full training dataset (will split into train/val later)
train_dataset_full = datasets.CIFAR10(
    root=data_dir,
    train=True,
    download=True,
    transform=train_transform
)

# Load validation set with eval transform (no augmentation)
val_dataset_base = datasets.CIFAR10(
    root=data_dir,
    train=True,
    download=False,
    transform=eval_transform
)

# Load test dataset with eval transform
test_dataset = datasets.CIFAR10(
    root=data_dir,
    train=False,
    download=True,
    transform=eval_transform
)

# Split into train/val (80/20)
n_train = int(0.8 * len(train_dataset_full))
n_val = len(train_dataset_full) - n_train

# Create index-based subsets
generator = torch.Generator().manual_seed(315)
train_indices, val_indices = random_split(
    range(len(train_dataset_full)), 
    [n_train, n_val],
    generator=generator
)

# Create subset datasets
train_dataset = torch.utils.data.Subset(train_dataset_full, train_indices.indices)
val_dataset = torch.utils.data.Subset(val_dataset_base, val_indices.indices)

print(f'Training samples: {len(train_dataset)} (with augmentation)')
print(f'Validation samples: {len(val_dataset)} (no augmentation)')
print(f'Test samples: {len(test_dataset)} (no augmentation)')
print(f'Image shape: {train_dataset_full[0][0].shape}')
print(f'Number of classes: {len(class_names)}')

### 1.3. Visualize sample images (with augmentation)

In [None]:
# Plot first 10 images from the training dataset (augmented)
fig, axes = plot_sample_images(train_dataset, class_names, nrows=2, ncols=5)
fig.suptitle('Augmented Training Images', y=1.02)
plt.show()

### 1.4. Create `DataLoader()` objects

In [None]:
# Create DataLoaders (data stays on CPU, moved to GPU per-batch)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True  # Faster CPU->GPU transfer
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f'Training batches: {len(train_loader)}')
print(f'Validation batches: {len(val_loader)}')
print(f'Test batches: {len(test_loader)}')

## 2. Build optimized CNN using best hyperparameters

We create a CNN using the best hyperparameters found during Optuna optimization, then train it with data augmentation.

### 2.1. Define model builder

In [None]:
def create_cnn(
    n_conv_blocks: int,
    initial_filters: int,
    fc_units_1: int,
    fc_units_2: int,
    dropout_rate: float,
    use_batch_norm: bool
) -> nn.Sequential:
    '''Create a CNN with configurable architecture for RGB images.
    
    Args:
        n_conv_blocks: Number of convolutional blocks (1-4)
        initial_filters: Number of filters in first conv layer (doubles each block)
        fc_units_1: Number of units in first fully connected layer
        fc_units_2: Number of units in second fully connected layer
        dropout_rate: Dropout probability
        use_batch_norm: Whether to use batch normalization
    
    Returns:
        nn.Sequential model
    '''

    layers = []
    in_channels = 3  # RGB input
    current_size = 32  # Input image size
    
    for block_idx in range(n_conv_blocks):
        out_channels = initial_filters * (2 ** block_idx)
        
        # First conv in block
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))

        if use_batch_norm:
            layers.append(nn.BatchNorm2d(out_channels))

        layers.append(nn.ReLU())
        
        # Second conv in block
        layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))

        if use_batch_norm:
            layers.append(nn.BatchNorm2d(out_channels))

        layers.append(nn.ReLU())
        
        # Pooling and dropout
        layers.append(nn.MaxPool2d(2, 2))
        layers.append(nn.Dropout(dropout_rate))
        
        in_channels = out_channels
        current_size //= 2
    
    # Calculate flattened size
    final_channels = initial_filters * (2 ** (n_conv_blocks - 1))
    flattened_size = final_channels * current_size * current_size
    
    # Classifier (3 fully connected layers)
    layers.append(nn.Flatten())
    layers.append(nn.Linear(flattened_size, fc_units_1))
    layers.append(nn.ReLU())
    layers.append(nn.Dropout(dropout_rate))
    layers.append(nn.Linear(fc_units_1, fc_units_2))
    layers.append(nn.ReLU())
    layers.append(nn.Dropout(dropout_rate))
    layers.append(nn.Linear(fc_units_2, num_classes))
    
    return nn.Sequential(*layers)

### 2.2. Create model with best hyperparameters

In [None]:
# Create model with best hyperparameters from Optuna
model = create_cnn(
    n_conv_blocks=best_params['n_conv_blocks'],
    initial_filters=best_params['initial_filters'],
    fc_units_1=best_params['fc_units_1'],
    fc_units_2=best_params['fc_units_2'],
    dropout_rate=best_params['dropout_rate'],
    use_batch_norm=best_params['use_batch_norm']
).to(device)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(model)
print(f'\nTotal parameters: {trainable_params:,}')

### 2.3. Define loss function and optimizer

In [None]:
criterion = nn.CrossEntropyLoss()

# Create optimizer with best hyperparameters
if best_params['optimizer'] == 'Adam':
    optimizer = optim.Adam(model.parameters(), lr=best_params['learning_rate'])

elif best_params['optimizer'] == 'SGD':
    optimizer = optim.SGD(
        model.parameters(), 
        lr=best_params['learning_rate'],
        momentum=best_params.get('sgd_momentum', 0.9)
    )

else:  # RMSprop
    optimizer = optim.RMSprop(model.parameters(), lr=best_params['learning_rate'])

print(f"Optimizer: {best_params['optimizer']}")

### 2.4. Train model

In [None]:
%%time

history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    epochs=epochs,
    print_every=print_every,
    device=device  # Enable per-batch GPU transfer
)

### 2.5. Learning curves

In [None]:
fig, axes = plot_learning_curves(history)
plt.show()

## 3. Evaluate model on test set

### 3.1. Calculate test accuracy

In [None]:
# Custom evaluation with per-batch GPU transfer
model.eval()
predictions = []
true_labels = []
test_correct = 0
test_total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()
        
        predictions.extend(predicted.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

test_accuracy = 100 * test_correct / test_total
print(f'Test accuracy: {test_accuracy:.2f}%')

### 3.2. Per-class accuracy

In [None]:
# Calculate per-class accuracy
class_correct = {name: 0 for name in class_names}
class_total = {name: 0 for name in class_names}

for pred, true in zip(predictions, true_labels):

    class_name = class_names[true]
    class_total[class_name] += 1

    if pred == true:
        class_correct[class_name] += 1

print('Per-class accuracy:')
print('-' * 30)

for name in class_names:
    acc = 100 * class_correct[name] / class_total[name]
    print(f'{name:12s}: {acc:.2f}%')

### 3.3. Confusion matrix

In [None]:
fig, ax = plot_confusion_matrix(true_labels, predictions, class_names)
plt.show()

### 3.4. Predicted class probability distributions

In [None]:
# Get predicted probabilities for all test samples
model.eval()
all_probs = []

with torch.no_grad():
    for images, _ in test_loader:
        images = images.to(device, non_blocking=True)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        all_probs.append(probs.cpu().numpy())

all_probs = np.concatenate(all_probs, axis=0)

# Plot probability distributions
fig, axes = plot_class_probability_distributions(all_probs, class_names)
plt.show()

### 3.5. Evaluation curves

In [None]:
fig, (ax1, ax2) = plot_evaluation_curves(true_labels, all_probs, class_names)
plt.show()

## 4. Save model

In [None]:
# Create models directory if it doesn't exist
models_dir = Path('../models/pytorch')
models_dir.mkdir(parents=True, exist_ok=True)

# Save model state dict
model_path = models_dir / 'augmented_cnn.pth'

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'test_accuracy': test_accuracy,
    'history': history
}, model_path)

print(f'Model saved to: {model_path}')
print(f'Test accuracy: {test_accuracy:.2f}%')