# 04 - Adversarial Training

**Goal**: Train models on adversarial examples to improve robustness.

**Research Questions**:
- How much robustness does adversarial training provide?
- What is the clean accuracy trade-off?
- Is PGD adversarial training stronger than FGSM?

In [None]:
# Colab setup
import sys
import os

if 'google.colab' in sys.modules:
    %cd /content
    !git clone https://github.com/cdm34/adversarial-robustness.git 2>/dev/null || true
    %cd adversarial-robustness
    sys.path.insert(0, '/content/adversarial-robustness')
else:
    sys.path.insert(0, os.path.abspath('..'))

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from src import (
    FashionMNISTNet,
    DataConfig, get_fashion_mnist_datasets, split_train_val, make_loaders,
    AttackConfig, fgsm, pgd_linf,
    AdvTrainConfig, adversarial_train_fgsm, adversarial_train_pgd, mixed_adversarial_training,
    accuracy, confidence_stats,
    get_device, set_seed,
    plot_training_curves, plot_robustness_comparison, plot_epsilon_vs_accuracy, save_figure,
    FASHION_MNIST_CLASSES,
)

print(f"PyTorch version: {torch.__version__}")

## 1. Setup

In [None]:
set_seed(42)
device = get_device()
print(f"Using device: {device}")

# Load data
train_ds, test_ds = get_fashion_mnist_datasets()
data_cfg = DataConfig(batch_size=128, val_ratio=0.1)
train_subset, val_subset = split_train_val(train_ds, data_cfg.val_ratio, data_cfg.seed)
train_loader, val_loader, test_loader = make_loaders(
    train_subset, val_subset, test_ds, data_cfg, device
)

## 2. Load Baseline Model

In [None]:
from src import TrainConfig, fit

baseline_model = FashionMNISTNet().to(device)

if os.path.exists('checkpoints/baseline_cnn.pt'):
    checkpoint = torch.load('checkpoints/baseline_cnn.pt', map_location=device)
    baseline_model.load_state_dict(checkpoint['model_state_dict'])
    print("Loaded baseline model")
else:
    print("Training baseline model...")
    fit(baseline_model, train_loader, val_loader, device, TrainConfig(epochs=10))

baseline_clean = accuracy(baseline_model, test_loader, device)
print(f"Baseline clean accuracy: {baseline_clean:.2f}%")

## 3. FGSM Adversarial Training

In [None]:
# Train with FGSM adversarial examples
fgsm_adv_model = FashionMNISTNet().to(device)

adv_cfg = AdvTrainConfig(epochs=10, lr=1e-3, attack_eps=0.1)

if os.path.exists('checkpoints/fgsm_adv_trained.pt'):
    checkpoint = torch.load('checkpoints/fgsm_adv_trained.pt', map_location=device)
    fgsm_adv_model.load_state_dict(checkpoint['model_state_dict'])
    print("Loaded FGSM adversarially trained model")
else:
    print("Training with FGSM adversarial examples (ε=0.1)...")
    print("This may take a few minutes...")
    fgsm_result = adversarial_train_fgsm(fgsm_adv_model, train_loader, val_loader, device, adv_cfg)
    
    os.makedirs('checkpoints', exist_ok=True)
    torch.save({'model_state_dict': fgsm_adv_model.state_dict()}, 'checkpoints/fgsm_adv_trained.pt')
    print(f"Best val acc: {fgsm_result['best_val_acc']:.2f}%")
    
    # Plot training curves
    fig = plot_training_curves(fgsm_result['history'], title='FGSM Adversarial Training')
    save_figure(fig, 'fgsm_adv_training_curves')
    plt.show()

## 4. PGD Adversarial Training

In [None]:
# Train with PGD adversarial examples (stronger)
pgd_adv_model = FashionMNISTNet().to(device)

if os.path.exists('checkpoints/pgd_adv_trained.pt'):
    checkpoint = torch.load('checkpoints/pgd_adv_trained.pt', map_location=device)
    pgd_adv_model.load_state_dict(checkpoint['model_state_dict'])
    print("Loaded PGD adversarially trained model")
else:
    print("Training with PGD adversarial examples (ε=0.1)...")
    print("This will take longer than FGSM training...")
    pgd_result = adversarial_train_pgd(pgd_adv_model, train_loader, val_loader, device, adv_cfg)
    
    torch.save({'model_state_dict': pgd_adv_model.state_dict()}, 'checkpoints/pgd_adv_trained.pt')
    print(f"Best val acc: {pgd_result['best_val_acc']:.2f}%")
    
    fig = plot_training_curves(pgd_result['history'], title='PGD Adversarial Training')
    save_figure(fig, 'pgd_adv_training_curves')
    plt.show()

## 5. Mixed Adversarial Training

In [None]:
# Train with 50% clean, 50% adversarial (better clean accuracy)
mixed_model = FashionMNISTNet().to(device)

if os.path.exists('checkpoints/mixed_adv_trained.pt'):
    checkpoint = torch.load('checkpoints/mixed_adv_trained.pt', map_location=device)
    mixed_model.load_state_dict(checkpoint['model_state_dict'])
    print("Loaded mixed adversarially trained model")
else:
    print("Training with mixed adversarial examples (50/50)...")
    mixed_result = mixed_adversarial_training(
        mixed_model, train_loader, val_loader, device, adv_cfg, mix_ratio=0.5
    )
    
    torch.save({'model_state_dict': mixed_model.state_dict()}, 'checkpoints/mixed_adv_trained.pt')
    print(f"Best val acc: {mixed_result['best_val_acc']:.2f}%")

## 6. Evaluate All Models

In [None]:
def evaluate_robust(model, loader, eps, device):
    """Evaluate robustness against FGSM and PGD."""
    fgsm_cfg = AttackConfig(eps=eps)
    pgd_cfg = AttackConfig(eps=eps, steps=10, step_size=eps/4)
    
    fgsm_correct = 0
    pgd_correct = 0
    total = 0
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        
        x_fgsm = fgsm(model, x, y, fgsm_cfg)
        x_pgd = pgd_linf(model, x, y, pgd_cfg)
        
        with torch.no_grad():
            fgsm_correct += (model(x_fgsm).argmax(1) == y).sum().item()
            pgd_correct += (model(x_pgd).argmax(1) == y).sum().item()
            total += y.size(0)
    
    return 100.0 * fgsm_correct / total, 100.0 * pgd_correct / total

In [None]:
# Evaluate all models
eps = 0.1

models = {
    'Baseline': baseline_model,
    'FGSM Adv Train': fgsm_adv_model,
    'PGD Adv Train': pgd_adv_model,
    'Mixed (50/50)': mixed_model,
}

results = {}
print(f"\nEvaluation (ε={eps}):")
print("="*70)
print(f"{'Model':<18} {'Clean Acc':>12} {'FGSM Robust':>12} {'PGD Robust':>12}")
print("-"*70)

for name, model in models.items():
    clean_acc = accuracy(model, test_loader, device)
    fgsm_acc, pgd_acc = evaluate_robust(model, test_loader, eps, device)
    
    results[name] = {'clean': clean_acc, 'fgsm': fgsm_acc, 'pgd': pgd_acc}
    print(f"{name:<18} {clean_acc:>11.2f}% {fgsm_acc:>11.2f}% {pgd_acc:>11.2f}%")

print("="*70)

In [None]:
# Bar chart comparison
model_names = list(results.keys())
clean_accs = [results[n]['clean'] for n in model_names]
fgsm_accs = [results[n]['fgsm'] for n in model_names]
pgd_accs = [results[n]['pgd'] for n in model_names]

fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(model_names))
width = 0.25

bars1 = ax.bar(x - width, clean_accs, width, label='Clean', color='steelblue')
bars2 = ax.bar(x, fgsm_accs, width, label='FGSM', color='coral')
bars3 = ax.bar(x + width, pgd_accs, width, label='PGD', color='seagreen')

ax.set_xlabel('Model', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title(f'Adversarial Training Comparison (ε={eps})', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(model_names, rotation=15, ha='right')
ax.legend(loc='best')
ax.set_ylim(0, 100)
ax.grid(True, alpha=0.3, axis='y')

# Add value labels
for bars in [bars1, bars2, bars3]:
    for bar in bars:
        height = bar.get_height()
        ax.annotate(f'{height:.1f}%', xy=(bar.get_x() + bar.get_width()/2, height),
                    xytext=(0, 3), textcoords='offset points', ha='center', va='bottom', fontsize=8)

plt.tight_layout()
save_figure(fig, 'adversarial_training_comparison')
plt.show()

## 7. Robustness vs. Attack Strength

In [None]:
# Test at various epsilon values
epsilons = [0.0, 0.01, 0.03, 0.05, 0.1, 0.15, 0.2]

pgd_trained_fgsm = []
pgd_trained_pgd = []
baseline_fgsm = []
baseline_pgd = []

print("Computing robustness curves...")
for eps in epsilons:
    if eps == 0:
        base_clean = accuracy(baseline_model, test_loader, device)
        pgd_clean = accuracy(pgd_adv_model, test_loader, device)
        baseline_fgsm.append(base_clean)
        baseline_pgd.append(base_clean)
        pgd_trained_fgsm.append(pgd_clean)
        pgd_trained_pgd.append(pgd_clean)
    else:
        bf, bp = evaluate_robust(baseline_model, test_loader, eps, device)
        pf, pp = evaluate_robust(pgd_adv_model, test_loader, eps, device)
        baseline_fgsm.append(bf)
        baseline_pgd.append(bp)
        pgd_trained_fgsm.append(pf)
        pgd_trained_pgd.append(pp)
    print(f"  ε={eps:.2f} done")

print("Done!")

In [None]:
# Plot robustness curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# FGSM comparison
ax1.plot(epsilons, baseline_fgsm, 'b-o', label='Baseline', markersize=6)
ax1.plot(epsilons, pgd_trained_fgsm, 'g-s', label='PGD Adv Trained', markersize=6)
ax1.set_xlabel('Perturbation (ε)', fontsize=12)
ax1.set_ylabel('Accuracy (%)', fontsize=12)
ax1.set_title('FGSM Attack', fontsize=14)
ax1.legend(loc='best')
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0, 100)

# PGD comparison
ax2.plot(epsilons, baseline_pgd, 'b-o', label='Baseline', markersize=6)
ax2.plot(epsilons, pgd_trained_pgd, 'g-s', label='PGD Adv Trained', markersize=6)
ax2.set_xlabel('Perturbation (ε)', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('PGD Attack', fontsize=14)
ax2.legend(loc='best')
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0, 100)

fig.suptitle('Robustness: Baseline vs. PGD Adversarial Training', fontsize=14)
plt.tight_layout()
save_figure(fig, 'robustness_curves_comparison')
plt.show()

## 8. Trade-off Analysis

In [None]:
# Clean vs. Robust accuracy trade-off
fig, ax = plt.subplots(figsize=(8, 6))

for name in results:
    ax.scatter(results[name]['clean'], results[name]['pgd'], s=150, label=name)
    ax.annotate(name, (results[name]['clean'], results[name]['pgd']), 
                textcoords='offset points', xytext=(5, 5), fontsize=9)

ax.set_xlabel('Clean Accuracy (%)', fontsize=12)
ax.set_ylabel('PGD Robust Accuracy (%)', fontsize=12)
ax.set_title('Clean vs. Robust Accuracy Trade-off', fontsize=14)
ax.grid(True, alpha=0.3)

# Ideal point would be top-right
ax.plot([100], [100], 'r*', markersize=20, label='Ideal')

ax.set_xlim(0, 100)
ax.set_ylim(0, 100)

plt.tight_layout()
save_figure(fig, 'clean_vs_robust_tradeoff')
plt.show()

## Summary

**Key Findings**:

1. **Adversarial training significantly improves robustness**
   - Baseline: ~20% robust accuracy vs. PGD trained: ~50-60%+

2. **PGD training is stronger than FGSM training**
   - PGD-trained models are more robust to both FGSM and PGD attacks

3. **Clean accuracy trade-off exists**
   - Adversarially trained models lose ~5-10% clean accuracy
   - Mixed training helps reduce this gap

4. **No free lunch in adversarial robustness**
   - More robust models require training-time computation
   - Robustness at ε often doesn't generalize to larger ε

**AI Safety Implications**:
- Standard models are highly vulnerable
- Adversarial training is essential for deployable robust models
- Trade-offs must be carefully considered for safety-critical applications