# TA2 - CNNs on FashionMNIST (Resumable + Traceable Version)

This notebook is the **resumable training** run:
- same model architecture
- same Optuna tuning setup
- same training loop
- with checkpointing, trial traceability, and resumable Optuna storage

Use this version when long runs may be interrupted (cloud shutdown, runtime limits, etc.).


## Why a Single Validation Split (not K-fold CV)

FashionMNIST already provides:
- 60,000 training images
- 10,000 test images

For CNN tuning with Optuna, K-fold cross-validation multiplies training cost by *K* and is usually unnecessary on a dataset this large. We use a **single stratified validation split** (90% train / 10% val), which is rigorous and much more compute-efficient.


In [None]:
import os
import json
import time
import random
import hashlib
import copy
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

import optuna


In [None]:
SEED = 42
DATA_ROOT = Path('session_2/data')
CLASS_NAMES = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
NUM_CLASSES = len(CLASS_NAMES)
VAL_RATIO = 0.10

MAX_EPOCHS = 15
FINAL_MAX_EPOCHS = 20
EARLY_STOPPING_PATIENCE = 4
N_TRIALS = 20
OPTUNA_TIMEOUT = None  # seconds (None = no timeout)

NUM_WORKERS = min(4, os.cpu_count() or 1)
PIN_MEMORY = torch.cuda.is_available()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
AMP_ENABLED = DEVICE.type == 'cuda'

print(f'Device: {DEVICE}')
print(f'AMP enabled: {AMP_ENABLED}')


In [None]:
RUN_MODE = 'resumable_with_checkpoints'
ENABLE_CHECKPOINTING = True

ARTIFACT_DIR = Path('session_2/artifacts/ta2_cnns_resumable')
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
SPLIT_FILE = ARTIFACT_DIR / 'train_val_split.npz'

STUDY_NAME = 'ta2_cnns_resumable'
STUDY_DB = ARTIFACT_DIR / 'optuna_fashion_mnist.db'
STUDY_STORAGE = f'sqlite:///{STUDY_DB}'

print('Run mode:', RUN_MODE)
print('Checkpointing:', ENABLE_CHECKPOINTING)
print('Artifacts directory:', ARTIFACT_DIR.resolve())
print('Optuna DB:', STUDY_DB.resolve())


## Resume Notes (for this notebook)

If execution stops unexpectedly:
1. Re-open this notebook.
2. Re-run from the top with the same `ARTIFACT_DIR` and `STUDY_NAME`.
3. Optuna will continue from the SQLite study.
4. Trial/final training checkpoints in `session_2/artifacts/ta2_cnns_resumable/` let training resume with minimal loss of progress.


In [None]:
def seed_everything(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if torch.backends.cudnn.is_available():
        torch.backends.cudnn.benchmark = DEVICE.type == 'cuda'

seed_everything(SEED)


## Data Loading and Raw Data Exploration

We first inspect the raw grayscale images, then build dataloaders with:
- augmentation on the train split only
- normalization on train/val/test
- stratified split to keep class balance


In [None]:
raw_train = datasets.FashionMNIST(root=DATA_ROOT, train=True, download=True, transform=None)
raw_test = datasets.FashionMNIST(root=DATA_ROOT, train=False, download=True, transform=None)

print(f'Train size: {len(raw_train):,}')
print(f'Test size: {len(raw_test):,}')
print(f'Image shape: {raw_train[0][0].size} (grayscale)')


In [None]:
def plot_random_raw_samples(dataset, class_names, n: int = 16, seed: int = 42) -> None:
    rng = np.random.default_rng(seed)
    indices = rng.choice(len(dataset), size=n, replace=False)

    side = int(np.ceil(np.sqrt(n)))
    fig, axes = plt.subplots(side, side, figsize=(10, 10))
    axes = axes.flatten()

    for ax in axes:
        ax.axis('off')

    for i, idx in enumerate(indices):
        image, label = dataset[idx]
        axes[i].imshow(np.array(image), cmap='gray')
        axes[i].set_title(class_names[label], fontsize=9)
        axes[i].axis('off')

    plt.suptitle('Random raw samples from FashionMNIST', fontsize=14)
    plt.tight_layout()
    plt.show()

plot_random_raw_samples(raw_train, CLASS_NAMES, n=16, seed=SEED)


In [None]:
def plot_one_per_class(dataset, class_names) -> None:
    first_indices = {}
    for idx, (_, label) in enumerate(dataset):
        if label not in first_indices:
            first_indices[label] = idx
        if len(first_indices) == len(class_names):
            break

    fig, axes = plt.subplots(2, 5, figsize=(14, 6))
    axes = axes.flatten()

    for class_id, class_name in enumerate(class_names):
        image, _ = dataset[first_indices[class_id]]
        axes[class_id].imshow(np.array(image), cmap='gray')
        axes[class_id].set_title(f'{class_id}: {class_name}')
        axes[class_id].axis('off')

    plt.suptitle('At least one raw example from each class', fontsize=14)
    plt.tight_layout()
    plt.show()

plot_one_per_class(raw_train, CLASS_NAMES)


In [None]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(28, padding=2),
    transforms.ToTensor(),
    transforms.Normalize((0.2860,), (0.3530,)),
])

eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.2860,), (0.3530,)),
])


def get_or_create_split_indices(
    targets: np.ndarray,
    val_ratio: float,
    seed: int,
    split_file: Path | None = None,
) -> tuple[np.ndarray, np.ndarray]:
    if split_file is not None and split_file.exists():
        data = np.load(split_file)
        return data['train_idx'], data['val_idx']

    indices = np.arange(len(targets))
    train_idx, val_idx = train_test_split(
        indices,
        test_size=val_ratio,
        random_state=seed,
        stratify=targets,
        shuffle=True,
    )
    train_idx = np.sort(train_idx)
    val_idx = np.sort(val_idx)

    if split_file is not None:
        split_file.parent.mkdir(parents=True, exist_ok=True)
        np.savez(split_file, train_idx=train_idx, val_idx=val_idx)

    return train_idx, val_idx


train_targets = np.array(raw_train.targets)
train_idx, val_idx = get_or_create_split_indices(
    targets=train_targets,
    val_ratio=VAL_RATIO,
    seed=SEED,
    split_file=SPLIT_FILE,
)

train_dataset = datasets.FashionMNIST(root=DATA_ROOT, train=True, download=False, transform=train_transform)
val_dataset = datasets.FashionMNIST(root=DATA_ROOT, train=True, download=False, transform=eval_transform)
test_dataset = datasets.FashionMNIST(root=DATA_ROOT, train=False, download=False, transform=eval_transform)

train_subset = Subset(train_dataset, train_idx)
val_subset = Subset(val_dataset, val_idx)

print(f'Train subset size: {len(train_subset):,}')
print(f'Validation subset size: {len(val_subset):,}')
print(f'Test size: {len(test_dataset):,}')


In [None]:
loader_cache: dict[int, tuple[DataLoader, DataLoader]] = {}

def make_train_val_loaders(batch_size: int) -> tuple[DataLoader, DataLoader]:
    if batch_size in loader_cache:
        return loader_cache[batch_size]

    train_loader = DataLoader(
        train_subset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        persistent_workers=NUM_WORKERS > 0,
    )
    val_loader = DataLoader(
        val_subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        persistent_workers=NUM_WORKERS > 0,
    )
    loader_cache[batch_size] = (train_loader, val_loader)
    return train_loader, val_loader


def make_test_loader(batch_size: int) -> DataLoader:
    return DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        persistent_workers=NUM_WORKERS > 0,
    )

example_loader, _ = make_train_val_loaders(batch_size=128)
xb, yb = next(iter(example_loader))
print('Batch X shape:', tuple(xb.shape))
print('Batch y shape:', tuple(yb.shape))
print('Batch dtype:', xb.dtype)


## CNN Architecture (Modern but Easy to Read)

The model uses a compact **MBConv-style** design inspired by EfficientNet/MobileNet ideas:
- depthwise separable convolutions (efficient)
- squeeze-and-excitation blocks (channel attention)
- residual connections when shapes match

This is stronger than a basic CNN while still being understandable for first CNN practice.


In [None]:
class SqueezeExcite(nn.Module):
    def __init__(self, channels: int, reduction: int = 4):
        super().__init__()
        reduced = max(8, channels // reduction)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.net = nn.Sequential(
            nn.Conv2d(channels, reduced, kernel_size=1),
            nn.SiLU(inplace=True),
            nn.Conv2d(reduced, channels, kernel_size=1),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        scale = self.net(self.pool(x))
        return x * scale


class MBConvBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        expand_ratio: int = 2,
        stride: int = 1,
        drop_rate: float = 0.0,
    ):
        super().__init__()
        hidden_channels = int(in_channels * expand_ratio)
        self.use_residual = stride == 1 and in_channels == out_channels

        if expand_ratio == 1:
            self.expand = nn.Identity()
            depthwise_in = in_channels
        else:
            self.expand = nn.Sequential(
                nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=False),
                nn.BatchNorm2d(hidden_channels),
                nn.SiLU(inplace=True),
            )
            depthwise_in = hidden_channels

        self.depthwise = nn.Sequential(
            nn.Conv2d(
                depthwise_in,
                depthwise_in,
                kernel_size=3,
                stride=stride,
                padding=1,
                groups=depthwise_in,
                bias=False,
            ),
            nn.BatchNorm2d(depthwise_in),
            nn.SiLU(inplace=True),
        )
        self.se = SqueezeExcite(depthwise_in, reduction=4)
        self.project = nn.Sequential(
            nn.Conv2d(depthwise_in, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
        )
        self.dropout = nn.Dropout2d(drop_rate) if drop_rate > 0 else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        x = self.expand(x)
        x = self.depthwise(x)
        x = self.se(x)
        x = self.project(x)
        x = self.dropout(x)
        if self.use_residual:
            x = x + identity
        return x


class FashionEfficientCNN(nn.Module):
    def __init__(self, num_classes: int = 10, dropout: float = 0.25):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.SiLU(inplace=True),
        )
        self.features = nn.Sequential(
            MBConvBlock(32, 32, expand_ratio=2, stride=1),
            MBConvBlock(32, 48, expand_ratio=2, stride=2),
            MBConvBlock(48, 64, expand_ratio=2, stride=1),
            MBConvBlock(64, 96, expand_ratio=2, stride=2),
            MBConvBlock(96, 128, expand_ratio=2, stride=1),
        )
        self.head = nn.Sequential(
            nn.Conv2d(128, 192, kernel_size=1, bias=False),
            nn.BatchNorm2d(192),
            nn.SiLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(p=dropout),
            nn.Linear(192, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.stem(x)
        x = self.features(x)
        x = self.head(x)
        return x


def count_trainable_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


tmp_model = FashionEfficientCNN(num_classes=NUM_CLASSES, dropout=0.25).to(DEVICE)
print(f'Trainable parameters: {count_trainable_parameters(tmp_model):,}')
del tmp_model


In [None]:
def compute_epoch_metrics(y_true: list[int], y_pred: list[int]) -> dict[str, float]:
    return {
        'accuracy': accuracy_score(y_true, y_pred),
        'macro_f1': f1_score(y_true, y_pred, average='macro'),
    }


def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    scaler: torch.cuda.amp.GradScaler,
) -> dict[str, float]:
    model.train()
    losses = []
    all_true, all_pred = [], []

    for images, labels in loader:
        images = images.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        if AMP_ENABLED:
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                logits = model(images)
                loss = criterion(logits, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

        losses.append(loss.item())
        all_true.extend(labels.detach().cpu().tolist())
        all_pred.extend(torch.argmax(logits.detach(), dim=1).cpu().tolist())

    metrics = compute_epoch_metrics(all_true, all_pred)
    metrics['loss'] = float(np.mean(losses))
    return metrics


@torch.no_grad()
def evaluate_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
) -> tuple[dict[str, float], list[int], list[int]]:
    model.eval()
    losses = []
    all_true, all_pred = [], []

    for images, labels in loader:
        images = images.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True)

        if AMP_ENABLED:
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                logits = model(images)
                loss = criterion(logits, labels)
        else:
            logits = model(images)
            loss = criterion(logits, labels)

        losses.append(loss.item())
        all_true.extend(labels.detach().cpu().tolist())
        all_pred.extend(torch.argmax(logits.detach(), dim=1).cpu().tolist())

    metrics = compute_epoch_metrics(all_true, all_pred)
    metrics['loss'] = float(np.mean(losses))
    return metrics, all_true, all_pred


def params_to_trial_id(params: dict[str, Any]) -> str:
    payload = json.dumps(params, sort_keys=True)
    return hashlib.sha1(payload.encode('utf-8')).hexdigest()[:12]


def save_json(path: Path, payload: dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(payload, indent=2))


def run_training(
    params: dict[str, Any],
    train_loader: DataLoader,
    val_loader: DataLoader,
    max_epochs: int,
    checkpoint_dir: Path | None = None,
    trial: optuna.trial.Trial | None = None,
    run_name: str = 'run',
) -> dict[str, Any]:
    model = FashionEfficientCNN(num_classes=NUM_CLASSES, dropout=params['dropout']).to(DEVICE)
    criterion = nn.CrossEntropyLoss(label_smoothing=params['label_smoothing'])
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=params['lr'],
        weight_decay=params['weight_decay'],
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=max_epochs,
        eta_min=1e-6,
    )
    scaler = torch.cuda.amp.GradScaler(enabled=AMP_ENABLED)

    start_epoch = 0
    best_val_macro_f1 = -1.0
    best_state_dict = None
    epochs_without_improvement = 0
    history_rows: list[dict[str, float]] = []

    last_ckpt_path = None
    best_ckpt_path = None
    history_csv_path = None
    status_json_path = None

    if checkpoint_dir is not None:
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        save_json(checkpoint_dir / 'params.json', params)

        last_ckpt_path = checkpoint_dir / 'last.pt'
        best_ckpt_path = checkpoint_dir / 'best.pt'
        history_csv_path = checkpoint_dir / 'history.csv'
        status_json_path = checkpoint_dir / 'status.json'

        if status_json_path.exists():
            status_data = json.loads(status_json_path.read_text())
            if status_data.get('finished') and best_ckpt_path.exists():
                best_payload = torch.load(best_ckpt_path, map_location=DEVICE)
                model.load_state_dict(best_payload['model_state_dict'])
                history_df = pd.read_csv(history_csv_path) if history_csv_path.exists() else pd.DataFrame()
                return {
                    'model': model,
                    'history_df': history_df,
                    'best_val_macro_f1': status_data['best_val_macro_f1'],
                    'epochs_ran': int(status_data['epochs_ran']),
                }

        if last_ckpt_path.exists():
            payload = torch.load(last_ckpt_path, map_location=DEVICE)
            model.load_state_dict(payload['model_state_dict'])
            optimizer.load_state_dict(payload['optimizer_state_dict'])
            scheduler.load_state_dict(payload['scheduler_state_dict'])
            if payload.get('scaler_state_dict') is not None:
                scaler.load_state_dict(payload['scaler_state_dict'])
            start_epoch = int(payload['epoch']) + 1
            best_val_macro_f1 = float(payload['best_val_macro_f1'])
            epochs_without_improvement = int(payload['epochs_without_improvement'])
            history_rows = payload.get('history_rows', [])

    if start_epoch >= max_epochs:
        if best_state_dict is None and best_ckpt_path is not None and best_ckpt_path.exists():
            best_payload = torch.load(best_ckpt_path, map_location=DEVICE)
            model.load_state_dict(best_payload['model_state_dict'])
        history_df = pd.DataFrame(history_rows)
        return {
            'model': model,
            'history_df': history_df,
            'best_val_macro_f1': best_val_macro_f1,
            'epochs_ran': len(history_rows),
        }

    for epoch in range(start_epoch, max_epochs):
        train_metrics = train_one_epoch(model, train_loader, optimizer, criterion, scaler)
        val_metrics, _, _ = evaluate_one_epoch(model, val_loader, criterion)
        scheduler.step()

        row = {
            'epoch': epoch,
            'train_loss': train_metrics['loss'],
            'train_accuracy': train_metrics['accuracy'],
            'train_macro_f1': train_metrics['macro_f1'],
            'val_loss': val_metrics['loss'],
            'val_accuracy': val_metrics['accuracy'],
            'val_macro_f1': val_metrics['macro_f1'],
            'lr': float(optimizer.param_groups[0]['lr']),
        }
        history_rows.append(row)

        val_score = val_metrics['macro_f1']
        improved = val_score > (best_val_macro_f1 + 1e-4)
        if improved:
            best_val_macro_f1 = val_score
            best_state_dict = copy.deepcopy(model.state_dict())
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if checkpoint_dir is not None:
            payload = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict() if AMP_ENABLED else None,
                'best_val_macro_f1': best_val_macro_f1,
                'epochs_without_improvement': epochs_without_improvement,
                'history_rows': history_rows,
            }
            torch.save(payload, last_ckpt_path)

            if improved:
                torch.save({'model_state_dict': model.state_dict()}, best_ckpt_path)

            pd.DataFrame(history_rows).to_csv(history_csv_path, index=False)
            save_json(
                status_json_path,
                {
                    'finished': False,
                    'best_val_macro_f1': best_val_macro_f1,
                    'epochs_ran': len(history_rows),
                },
            )

        if trial is not None:
            trial.report(val_score, step=epoch)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned(f'Pruned at epoch {epoch}')

        if epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
            break

    if best_state_dict is not None:
        model.load_state_dict(best_state_dict)
    elif best_ckpt_path is not None and best_ckpt_path.exists():
        best_payload = torch.load(best_ckpt_path, map_location=DEVICE)
        model.load_state_dict(best_payload['model_state_dict'])

    history_df = pd.DataFrame(history_rows)

    if checkpoint_dir is not None and status_json_path is not None:
        save_json(
            status_json_path,
            {
                'finished': True,
                'best_val_macro_f1': float(best_val_macro_f1),
                'epochs_ran': int(len(history_rows)),
            },
        )

    return {
        'model': model,
        'history_df': history_df,
        'best_val_macro_f1': float(best_val_macro_f1),
        'epochs_ran': int(len(history_rows)),
    }


## Hyperparameter Tuning with Optuna

We tune learning and regularization hyperparameters:
- learning rate
- weight decay
- dropout
- label smoothing
- batch size

Optimization target: **validation Macro-F1** (good class-balanced metric).


In [None]:
def suggest_params(trial: optuna.trial.Trial) -> dict[str, Any]:
    return {
        'lr': trial.suggest_float('lr', 1e-4, 5e-3, log=True),
        'weight_decay': trial.suggest_float('weight_decay', 1e-6, 5e-3, log=True),
        'dropout': trial.suggest_float('dropout', 0.10, 0.45),
        'label_smoothing': trial.suggest_float('label_smoothing', 0.0, 0.12),
        'batch_size': trial.suggest_categorical('batch_size', [64, 128, 256]),
    }


def objective(trial: optuna.trial.Trial) -> float:
    params = suggest_params(trial)
    train_loader, val_loader = make_train_val_loaders(batch_size=params['batch_size'])

    checkpoint_dir = None
    trial_id = params_to_trial_id(params)
    if ENABLE_CHECKPOINTING:
        checkpoint_dir = ARTIFACT_DIR / 'trials' / trial_id

    result = run_training(
        params=params,
        train_loader=train_loader,
        val_loader=val_loader,
        max_epochs=MAX_EPOCHS,
        checkpoint_dir=checkpoint_dir,
        trial=trial,
        run_name=f'trial_{trial.number:03d}',
    )

    trial.set_user_attr('trial_id', trial_id)
    trial.set_user_attr('epochs_ran', result['epochs_ran'])
    if checkpoint_dir is not None:
        trial.set_user_attr('checkpoint_dir', str(checkpoint_dir))

    return result['best_val_macro_f1']


In [None]:
sampler = optuna.samplers.TPESampler(seed=SEED)
pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=3)
storage = optuna.storages.RDBStorage(
    url=STUDY_STORAGE,
    heartbeat_interval=60,
    grace_period=180,
    failed_trial_callback=optuna.storages.RetryFailedTrialCallback(max_retry=2),
)

study = optuna.create_study(
    study_name=STUDY_NAME,
    storage=storage,
    load_if_exists=True,
    direction='maximize',
    sampler=sampler,
    pruner=pruner,
)

start_time = time.perf_counter()
study.optimize(objective, n_trials=N_TRIALS, timeout=OPTUNA_TIMEOUT, show_progress_bar=True)
search_minutes = (time.perf_counter() - start_time) / 60

print(f'Finished/loaded {len(study.trials)} trials in {search_minutes:.2f} minutes this session.')
print('Best validation Macro-F1:', round(study.best_value, 4))
print('Best params:', study.best_params)


In [None]:
trials_df = study.trials_dataframe()
display(trials_df.sort_values('value', ascending=False).head(10))

plt.figure(figsize=(8, 4))
completed = trials_df[trials_df['state'] == 'COMPLETE']
plt.plot(completed['number'], completed['value'], marker='o')
plt.title('Validation Macro-F1 by completed trial')
plt.xlabel('Trial number')
plt.ylabel('Macro-F1')
plt.grid(True, alpha=0.3)
plt.show()


## Final Training and Out-of-Sample Test Metrics

After tuning, we train one final model using the best hyperparameters and evaluate on the held-out test set.

Reported metrics:
- test loss
- test accuracy
- test Macro-F1
- per-class classification report
- confusion matrix


In [None]:
best_params = study.best_trial.params.copy()
train_loader, val_loader = make_train_val_loaders(batch_size=best_params['batch_size'])

final_checkpoint_dir = None
if ENABLE_CHECKPOINTING:
    final_checkpoint_dir = ARTIFACT_DIR / 'final_model'

start_time = time.perf_counter()
final_result = run_training(
    params=best_params,
    train_loader=train_loader,
    val_loader=val_loader,
    max_epochs=FINAL_MAX_EPOCHS,
    checkpoint_dir=final_checkpoint_dir,
    trial=None,
    run_name='final_model',
)
final_minutes = (time.perf_counter() - start_time) / 60

best_model = final_result['model']
print(f"Final training finished in {final_minutes:.2f} minutes")
print('Best validation Macro-F1:', round(final_result['best_val_macro_f1'], 4))
print('Epochs used:', final_result['epochs_ran'])


In [None]:
criterion_eval = nn.CrossEntropyLoss()
test_loader = make_test_loader(batch_size=best_params['batch_size'])
test_metrics, y_true_test, y_pred_test = evaluate_one_epoch(best_model, test_loader, criterion_eval)

print('Test metrics (out-of-sample):')
print(f"  Loss:      {test_metrics['loss']:.4f}")
print(f"  Accuracy:  {test_metrics['accuracy']:.4f}")
print(f"  Macro-F1:  {test_metrics['macro_f1']:.4f}")

print('
Per-class report:')
print(classification_report(y_true_test, y_pred_test, target_names=CLASS_NAMES, digits=4))


In [None]:
cm = confusion_matrix(y_true_test, y_pred_test)

plt.figure(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=CLASS_NAMES,
    yticklabels=CLASS_NAMES,
)
plt.title('Test confusion matrix')
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.tight_layout()
plt.show()


In [None]:
if ENABLE_CHECKPOINTING:
    summary = {
        'study_name': STUDY_NAME,
        'best_trial_number': int(study.best_trial.number),
        'best_validation_macro_f1': float(study.best_value),
        'best_params': study.best_params,
        'test_metrics': {
            'loss': float(test_metrics['loss']),
            'accuracy': float(test_metrics['accuracy']),
            'macro_f1': float(test_metrics['macro_f1']),
        },
    }
    summary_path = ARTIFACT_DIR / 'final_summary.json'
    summary_path.write_text(json.dumps(summary, indent=2))
    print('Saved summary to', summary_path.resolve())
