# 01 - Baseline CNN on FashionMNIST

**Goal**: Train a baseline CNN and evaluate clean accuracy before adversarial attacks.

This notebook establishes our baseline model performance.

In [None]:
# Colab setup: mount drive and install dependencies if needed
import sys
import os

# Add src to path for imports
if 'google.colab' in sys.modules:
    # Running in Colab
    %cd /content
    !git clone https://github.com/cdm34/adversarial-robustness.git 2>/dev/null || true
    %cd adversarial-robustness
    sys.path.insert(0, '/content/adversarial-robustness')
else:
    # Running locally
    sys.path.insert(0, os.path.abspath('..'))

In [None]:
import torch
import matplotlib.pyplot as plt

from src import (
    FashionMNISTNet,
    DataConfig, get_fashion_mnist_datasets, split_train_val, make_loaders,
    TrainConfig, fit,
    accuracy, per_class_accuracy, confidence_stats,
    get_device, set_seed,
    plot_training_curves, save_figure,
    FASHION_MNIST_CLASSES,
)

print(f"PyTorch version: {torch.__version__}")

# Optional Google Drive mounting

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Use Drive for checkpoints
import os
CHECKPOINT_DIR = '/content/drive/MyDrive/adversarial-robustness-checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

## 1. Setup

In [None]:
# Reproducibility
set_seed(42)

# Device
device = get_device()
print(f"Using device: {device}")

## 2. Load Data

In [None]:
# Load FashionMNIST
train_ds, test_ds = get_fashion_mnist_datasets()
print(f"Training samples: {len(train_ds)}")
print(f"Test samples: {len(test_ds)}")

# Split train/val
data_cfg = DataConfig(batch_size=128, val_ratio=0.1, num_workers=2)
train_subset, val_subset = split_train_val(train_ds, data_cfg.val_ratio, data_cfg.seed)

# Create loaders
train_loader, val_loader, test_loader = make_loaders(
    train_subset, val_subset, test_ds, data_cfg, device
)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}, Test batches: {len(test_loader)}")

In [None]:
# Visualize some samples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
x_sample, y_sample = next(iter(test_loader))

for i, ax in enumerate(axes.flatten()):
    ax.imshow(x_sample[i].squeeze(), cmap='gray')
    ax.set_title(FASHION_MNIST_CLASSES[y_sample[i]])
    ax.axis('off')

plt.suptitle('FashionMNIST Samples', fontsize=14)
plt.tight_layout()
plt.show()

## 3. Train Baseline CNN

In [None]:
# Instantiate model
model = FashionMNISTNet().to(device)
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")

In [None]:
# Training config
train_cfg = TrainConfig(epochs=10, lr=1e-3, optimizer='adam')

# Train
print("Training baseline CNN...")
result = fit(model, train_loader, val_loader, device, train_cfg)

print(f"\nBest validation accuracy: {result['best_val_acc']:.2f}%")

In [None]:
# Plot training curves
fig = plot_training_curves(result['history'], title='Baseline CNN Training')
save_figure(fig, 'baseline_training_curves')
plt.show()

## 4. Evaluate on Test Set

In [None]:
# Clean test accuracy
test_acc = accuracy(model, test_loader, device)
print(f"Clean Test Accuracy: {test_acc:.2f}%")

In [None]:
# Per-class accuracy (identify failure modes)
class_results = per_class_accuracy(model, test_loader, device)

print("\nPer-Class Accuracy:")
for i, acc in enumerate(class_results['per_class_accuracy']):
    print(f"  {FASHION_MNIST_CLASSES[i]:12s}: {acc:.1f}%")

print(f"\nWorst class: {FASHION_MNIST_CLASSES[class_results['worst_class']]}")
print(f"Best class:  {FASHION_MNIST_CLASSES[class_results['best_class']]}")

In [None]:
# Confidence statistics (AI safety metric)
conf_stats = confidence_stats(model, test_loader, device)

print("\nConfidence Statistics:")
print(f"  Mean confidence: {conf_stats['mean_confidence']:.3f}")
print(f"  Std confidence:  {conf_stats['std_confidence']:.3f}")
print(f"  Min confidence:  {conf_stats['min_confidence']:.3f}")
print(f"  Max confidence:  {conf_stats['max_confidence']:.3f}")

In [None]:
# Per-class accuracy bar chart
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(FASHION_MNIST_CLASSES, class_results['per_class_accuracy'], color='steelblue')
ax.axhline(y=test_acc, color='red', linestyle='--', label=f'Overall: {test_acc:.1f}%')
ax.set_xlabel('Class')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Per-Class Accuracy (Baseline CNN)')
ax.set_xticklabels(FASHION_MNIST_CLASSES, rotation=45, ha='right')
ax.legend()
ax.set_ylim(0, 100)
plt.tight_layout()
save_figure(fig, 'baseline_per_class_accuracy')
plt.show()

## 5. Save Model Checkpoint

In [None]:
# Save model for use in subsequent notebooks
os.makedirs('checkpoints', exist_ok=True)
torch.save({
    'model_state_dict': model.state_dict(),
    'test_accuracy': test_acc,
    'train_config': train_cfg,
    'history': result['history'],
}, 'checkpoints/baseline_cnn.pt')

print("Model saved to checkpoints/baseline_cnn.pt")

## Summary

- Trained baseline CNN on FashionMNIST
- Clean test accuracy: TBD
- Identified difficult classes (e.g., Shirt vs. T-shirt/top)
- Model shows high confidence on test samples

**Next**: Evaluate vulnerability to adversarial attacks (FGSM, PGD)