In [None]:
#Implementation of maxout like ConvNet + FGSM adversarial training for MNIST
#Requirements: torch, torchvision, numpy, matplotlib, pandas

import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import pandas as pd
from datetime import datetime


#Basic config

seed = 1234
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

#Hyperparameters
batch_size = 128
test_batch_size = 1000
epochs = 18
lr = 0.01
momentum = 0.9
weight_decay = 1e-4

#FGSM settings
epsilon = 0.25            #paper commonly used 0.25 for MNIST in [0,1] scale
alpha_mix = 0.5           #mix clean and adv loss like paper

output_dir = "outputs_mnist_maxout"
os.makedirs(output_dir, exist_ok=True)

#pin_memory for GPU
pin_mem = True if torch.cuda.is_available() else False
num_workers = 2 if torch.cuda.is_available() else 0


#Data
transform = transforms.Compose([transforms.ToTensor()])  #values in [0,1]
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_mem)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_mem)


#Maxout helpers
class MaxoutConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, pieces=2):
        super().__init__()
        self.pieces = pieces
        #produce out_channels * pieces channels then do max over pieces
        self.conv = nn.Conv2d(in_channels, out_channels * pieces, kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        out = self.conv(x)  #shape (B, out*pieces, H, W)
        B, C, H, W = out.shape
        #reshape to (B, out, pieces, H, W) then max over pieces
        out = out.view(B, C // self.pieces, self.pieces, H, W)
        out, _ = out.max(dim=2)
        return out

class MaxoutLinear(nn.Module):
    """
    Linear that does maxout in the same way: produce out*pieces dims then max over pieces.
    """
    def __init__(self, in_features, out_features, pieces=2):
        super().__init__()
        self.pieces = pieces
        self.lin = nn.Linear(in_features, out_features * pieces)

    def forward(self, x):
        out = self.lin(x)  #shape (B, out*pieces)
        B, C = out.shape
        out = out.view(B, C // self.pieces, self.pieces)
        out, _ = out.max(dim=2)
        return out


#Model: small Maxout-like conv net (student style)
class MaxoutConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        #conv layers with maxout pieces=2
        self.conv1 = MaxoutConv2d(1, 64, kernel_size=5, padding=2, pieces=2)   #-> 64 channels
        self.pool1 = nn.MaxPool2d(2,2)  #28->14
        self.conv2 = MaxoutConv2d(64, 128, kernel_size=5, padding=2, pieces=2) #-> 128 channels
        self.pool2 = nn.MaxPool2d(2,2)  #14->7
        #flatten -> fc maxout
        self.fc1 = MaxoutLinear(128 * 7 * 7, 1024, pieces=2)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(1024, 10)  #final linear softmax

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)    #paper uses maxout nonlinearity then often uses dropout; keep ReLU after maxout for stability
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

#FGSM attack (keeps gradients only where needed)
def fgsm_attack(model, loss_fn, images, labels, eps):
    """
    Generate FGSM adversarial examples for a batch.
    Caller must ensure gradients are enabled (no torch.no_grad active).
    """
    images_adv = images.clone().detach().to(device)
    images_adv.requires_grad = True
    outputs = model(images_adv)
    loss = loss_fn(outputs, labels)
    model.zero_grad()
    loss.backward()
    data_grad = images_adv.grad.data
    perturbed = images_adv + eps * data_grad.sign()
    perturbed = torch.clamp(perturbed, 0.0, 1.0)
    return perturbed.detach()

#Training + evaluation
def train_epoch(model, loader, optimizer, loss_fn, adv_training=False, eps=0.25, alpha=0.5):
    model.train()
    running_loss = 0.0
    for xb, yb in loader:
        xb = xb.to(device); yb = yb.to(device)
        if not adv_training:
            optimizer.zero_grad()
            out = model(xb)
            loss = loss_fn(out, yb)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * xb.size(0)
        else:
            #create adversarial examples on the fly (must enable grad while generating)
            #generate FGSM using current model params
            #ensure gradients are computed just for adv creation
            with torch.enable_grad():
                xb_adv = fgsm_attack(model, loss_fn, xb, yb, eps)

            optimizer.zero_grad()
            out_clean = model(xb)
            out_adv = model(xb_adv)
            loss = alpha * loss_fn(out_clean, yb) + (1.0 - alpha) * loss_fn(out_adv, yb)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * xb.size(0)

    return running_loss / len(loader.dataset)

def test_model(model, loader, loss_fn, eps=None):
    """
    Evaluate model on clean data (eps=None) or FGSM adversarial data (eps provided).
    FGSM generation is done with torch.enable_grad() (not inside torch.no_grad()).
    Forward passes for final metric use torch.no_grad() to save memory.
    """
    model.eval()
    total = 0
    correct = 0
    total_loss = 0.0
    confs = []

    for xb, yb in loader:
        xb = xb.to(device); yb = yb.to(device)
        if eps is not None:
            #generate adversarial batch with enabled grad
            with torch.enable_grad():
                xb_eval = fgsm_attack(model, loss_fn, xb, yb, eps)
        else:
            xb_eval = xb

        #now evaluate without grad
        with torch.no_grad():
            outputs = model(xb_eval)
            loss = loss_fn(outputs, yb)
            total_loss += loss.item() * xb_eval.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += xb_eval.size(0)

            #confidences on mistaken examples
            probs = F.softmax(outputs, dim=1)
            mistaken = (preds != yb).nonzero(as_tuple=False)
            if mistaken.nelement() != 0:
                #handle both single and multiple indices
                mi = mistaken.view(-1).tolist()
                for i in mi:
                    confs.append(float(probs[i, preds[i]].cpu().item()))

    acc = 100.0 * correct / total
    avg_loss = total_loss / total
    avg_conf_mistake = np.mean(confs) if len(confs) > 0 else None
    return avg_loss, acc, avg_conf_mistake


#Experiment runner
def run_experiment(adv_train=False, epochs=10, prefix="baseline"):
    model = MaxoutConvNet().to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    loss_fn = nn.CrossEntropyLoss()

    history = {"train_loss": [], "test_acc_clean": [], "test_acc_adv": [], "adv_conf": []}

    for ep in range(1, epochs+1):
        t0 = datetime.now()
        train_loss = train_epoch(model, train_loader, optimizer, loss_fn, adv_training=adv_train, eps=epsilon, alpha=alpha_mix)
        loss_clean, acc_clean, _ = test_model(model, test_loader, loss_fn, eps=None)
        loss_adv, acc_adv, adv_conf = test_model(model, test_loader, loss_fn, eps=epsilon)

        history["train_loss"].append(train_loss)
        history["test_acc_clean"].append(acc_clean)
        history["test_acc_adv"].append(acc_adv)
        history["adv_conf"].append(adv_conf)

        t1 = datetime.now()
        print(f"[{prefix}] Epoch {ep}/{epochs} | train_loss={train_loss:.4f} | clean_acc={acc_clean:.2f}% | adv_acc={acc_adv:.2f}% | time={(t1-t0).seconds}s")

    #save model & history
    torch.save(model.state_dict(), os.path.join(output_dir, f"{prefix}_model.pth"))
    pd.DataFrame(history).to_csv(os.path.join(output_dir, f"{prefix}_history.csv"), index=False)

    #save accuracy plot 300 DPI
    plt.figure(figsize=(6,4))
    plt.plot(range(1, epochs+1), history["test_acc_clean"], label="Clean acc")
    plt.plot(range(1, epochs+1), history["test_acc_adv"], label=f"FGSM eps={epsilon} acc")
    plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)"); plt.title(f"{prefix} Accuracies")
    plt.legend(); plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{prefix}_acc.png"), dpi=300); plt.close()

    return model, history


#Main
def main():
    print("Starting baseline training (no adversarial training)...")
    baseline_model, baseline_hist = run_experiment(adv_train=False, epochs=epochs, prefix="baseline")

    print("\nStarting adversarial training (mixing clean + FGSM)...")
    adv_model, adv_hist = run_experiment(adv_train=True, epochs=epochs, prefix="advtrain")

    #final eval summary
    loss_fn = nn.CrossEntropyLoss()
    bl_clean = test_model(baseline_model, test_loader, loss_fn, eps=None)
    bl_adv = test_model(baseline_model, test_loader, loss_fn, eps=epsilon)
    at_clean = test_model(adv_model, test_loader, loss_fn, eps=None)
    at_adv = test_model(adv_model, test_loader, loss_fn, eps=epsilon)

    summary = {
        "Model": ["Baseline", "Baseline", "AdvTrain", "AdvTrain"],
        "Eval": ["Clean", f"FGSM eps={epsilon}", "Clean", f"FGSM eps={epsilon}"],
        "Loss": [bl_clean[0], bl_adv[0], at_clean[0], at_adv[0]],
        "Accuracy": [bl_clean[1], bl_adv[1], at_clean[1], at_adv[1]],
        "Avg_conf_on_mistakes": [bl_clean[2], bl_adv[2], at_clean[2], at_adv[2]]
    }
    df = pd.DataFrame(summary)
    df.to_csv(os.path.join(output_dir, "final_summary.csv"), index=False)
    print("\nFinal summary:")
    print(df)

    #save some sample images (clean vs FGSM from baseline model)
    sample_x, sample_y = next(iter(test_loader))
    sample_x = sample_x[:8].to(device); sample_y = sample_y[:8].to(device)
    with torch.no_grad():
        clean_out = baseline_model(sample_x)
        clean_preds = clean_out.argmax(dim=1)

    with torch.enable_grad():
        adv_sample = fgsm_attack(baseline_model, nn.CrossEntropyLoss(), sample_x, sample_y, epsilon)
    with torch.no_grad():
        adv_out = baseline_model(adv_sample); adv_preds = adv_out.argmax(dim=1)

    fig, axes = plt.subplots(2,8, figsize=(12,3))
    for i in range(8):
        axes[0,i].imshow(sample_x[i].cpu().squeeze(), cmap="gray"); axes[0,i].axis("off")
        axes[0,i].set_title(str(int(clean_preds[i].cpu().numpy())))
        axes[1,i].imshow(adv_sample[i].cpu().squeeze(), cmap="gray"); axes[1,i].axis("off")
        axes[1,i].set_title(str(int(adv_preds[i].cpu().numpy())))
    plt.suptitle("Top: clean preds | Bottom: FGSM preds")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "sample_clean_vs_adv.png"), dpi=300)
    plt.close()

    print("Saved outputs to", output_dir)

if __name__ == "__main__":
    main()


Using device: cuda


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 505kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.63MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.95MB/s]

Starting baseline training (no adversarial training)...





[baseline] Epoch 1/18 | train_loss=0.3308 | clean_acc=97.64% | adv_acc=7.35% | time=17s
[baseline] Epoch 2/18 | train_loss=0.0754 | clean_acc=98.55% | adv_acc=15.79% | time=16s
[baseline] Epoch 3/18 | train_loss=0.0541 | clean_acc=98.88% | adv_acc=14.84% | time=16s
[baseline] Epoch 4/18 | train_loss=0.0429 | clean_acc=98.97% | adv_acc=15.83% | time=16s
[baseline] Epoch 5/18 | train_loss=0.0358 | clean_acc=98.96% | adv_acc=14.51% | time=17s
[baseline] Epoch 6/18 | train_loss=0.0306 | clean_acc=99.12% | adv_acc=16.89% | time=17s
[baseline] Epoch 7/18 | train_loss=0.0271 | clean_acc=99.30% | adv_acc=16.41% | time=16s
[baseline] Epoch 8/18 | train_loss=0.0242 | clean_acc=99.18% | adv_acc=22.66% | time=16s
[baseline] Epoch 9/18 | train_loss=0.0215 | clean_acc=99.32% | adv_acc=13.78% | time=17s
[baseline] Epoch 10/18 | train_loss=0.0199 | clean_acc=99.20% | adv_acc=15.73% | time=16s
[baseline] Epoch 11/18 | train_loss=0.0173 | clean_acc=99.18% | adv_acc=23.40% | time=16s
[baseline] Epoch 12/