In [1]:
import os
import time
import copy
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import scipy.ndimage # Added for mask resizing

In [9]:
# =========================
# Config
# =========================
class Config:
    data_root = "data/Sports"  # expects data_root/train and data_root/val
    image_size = 64            # 32 or 64
    batch_size = 64
    num_workers = 4

    num_classes = 10
    num_epochs = 10
    lr = 1e-3
    weight_decay = 1e-4

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    out_dir = "outputs_problem_a"
    model_a_name = "simple_cnn"
    model_b_name = "small_resnet"

    run_name_a = "run_model_A_simple_cnn"
    run_name_b = "run_model_B_small_resnet"


cfg = Config()
os.makedirs(cfg.out_dir, exist_ok=True)

In [3]:
# =========================
# Data loaders
# =========================
def get_dataloaders(cfg: Config):
    train_transform = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    val_transform = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    train_dir = os.path.join(cfg.data_root, "train")
    val_dir = os.path.join(cfg.data_root, "valid") # Changed 'val' to 'valid'

    train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
    val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)

    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True
    )

    class_names = train_dataset.classes
    return train_loader, val_loader, class_names

In [4]:
# =========================
# Model A - Simple CNN
# =========================
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.MaxPool2d(2),  # 64 -> 32

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.MaxPool2d(2),  # 32 -> 16

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.MaxPool2d(2),  # 16 -> 8
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [5]:
# =========================
# Model B - Small ResNet style
# =========================
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3,
            stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels,
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        if self.downsample is not None:
            identity = self.downsample(identity)

        out = out + identity
        out = self.relu(out)
        return out


class SmallResNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )

        self.layer1 = BasicBlock(32, 64, stride=2)   # 64 -> 32
        self.layer2 = BasicBlock(64, 128, stride=2)  # 32 -> 16
        self.layer3 = BasicBlock(128, 256, stride=2) # 16 -> 8

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


In [6]:
# =========================
# Utility
# =========================
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# =========================
# Train and eval loops
# =========================
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    start_time = time.time()

    for inputs, targets in dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(targets).sum().item()
        total += targets.size(0)

    epoch_time = time.time() - start_time
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc, epoch_time


def evaluate(model, dataloader, criterion, device, num_classes, class_names=None):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    per_class_correct = np.zeros(num_classes, dtype=np.int64)
    per_class_total = np.zeros(num_classes, dtype=np.int64)

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            running_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)

            correct += preds.eq(targets).sum().item()
            total += targets.size(0)

            for t, p in zip(targets, preds):
                per_class_total[t.item()] += 1
                if t.item() == p.item():
                    per_class_correct[t.item()] += 1

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    per_class_acc = per_class_correct / np.maximum(per_class_total, 1)

    if class_names is not None:
        print("Per class accuracy:")
        for idx, name in enumerate(class_names):
            print(f"{name}: {per_class_acc[idx] * 100:.2f}% "
                  f"({per_class_correct[idx]}/{per_class_total[idx]})")

    return epoch_loss, epoch_acc, per_class_acc


# =========================
# Saliency maps
# =========================
def compute_saliency_map(model, image_tensor, label, device):
    model.eval()
    image = image_tensor.unsqueeze(0).to(device)
    image.requires_grad_()

    output = model(image)
    target_score = output[0, label]
    model.zero_grad()
    target_score.backward()

    saliency = image.grad.data.abs().max(dim=1)[0]  # shape (1, H, W)
    saliency = saliency.squeeze(0).cpu().numpy() # Now shape (H, W)
    saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)
    return saliency


# =========================
# Grad CAM
# =========================
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer

        self.activations = None
        self.gradients = None

        self.hook_a = target_layer.register_forward_hook(self.save_activation)
        self.hook_g = target_layer.register_full_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def __call__(self, input_tensor, target_class=None):
        self.model.eval()
        input_tensor = input_tensor.unsqueeze(0)
        input_tensor = input_tensor.to(next(self.model.parameters()).device)
        input_tensor.requires_grad_()

        output = self.model(input_tensor)
        if target_class is None:
            target_class = output.argmax(dim=1).item()

        loss = output[0, target_class]
        self.model.zero_grad()
        loss.backward()

        gradients = self.gradients  # shape (1, C, H, W)
        activations = self.activations  # shape (1, C, H, W)

        # Corrected: average gradients over spatial dimensions (H, W) for each channel and batch item
        weights = gradients.mean(dim=(2, 3), keepdim=True)  # shape (1, C, 1, 1)

        # Corrected: multiply weights with activations and sum over channel dimension, then squeeze batch dim
        cam = (weights * activations).sum(dim=1)  # shape (1, H, W)
        cam = cam.squeeze(0)  # shape (H, W)

        cam = cam.cpu().numpy()
        cam = np.maximum(cam, 0)
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        return cam

    def close(self):
        self.hook_a.remove()
        self.hook_g.remove()


def show_and_save_heatmap(image_tensor, mask, out_path, alpha=0.4):
    # image_tensor is normalized. Need to unnormalize for visualization.
    mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
    std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)

    img = image_tensor.cpu().numpy()
    img = std * img + mean
    img = np.clip(img, 0, 1)
    img = np.transpose(img, (1, 2, 0)) # Shape (H, W, C)

    # Resize mask to the same dimensions as the image if needed
    img_h, img_w, _ = img.shape
    mask_h, mask_w = mask.shape

    if mask_h != img_h or mask_w != img_w:
        zoom_factor_h = img_h / mask_h
        zoom_factor_w = img_w / mask_w
        # Use order=1 for bilinear interpolation
        mask = scipy.ndimage.zoom(mask, (zoom_factor_h, zoom_factor_w), order=1)

    heatmap = plt.cm.jet(mask)[..., :3]
    overlay = alpha * heatmap + (1 - alpha) * img
    overlay = np.clip(overlay, 0, 1)

    plt.figure(figsize=(4, 4))
    plt.axis("off")
    plt.imshow(overlay)
    plt.tight_layout()
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, bbox_inches="tight", pad_inches=0)
    plt.close()


# =========================
# Training wrapper
# =========================
def train_and_eval_model(model, model_name, run_name, cfg, train_loader, val_loader, class_names):
    device = cfg.device
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, run_name))

    best_val_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())

    num_params = count_parameters(model)
    print(f"Model {model_name} param count: {num_params}")

    for epoch in range(cfg.num_epochs):
        print(f"Epoch {epoch + 1}/{cfg.num_epochs}")

        train_loss, train_acc, train_time = train_one_epoch(
            model, train_loader, criterion, optimizer, device
        )
        val_loss, val_acc, _ = evaluate(
            model, val_loader, criterion, device,
            cfg.num_classes, class_names=None
        )

        print(f"Train loss {train_loss:.4f} acc {train_acc:.4f} "
              f"time {train_time:.2f}s")
        print(f"Val   loss {val_loss:.4f} acc {val_acc:.4f}")

        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Loss/val", val_loss, epoch)
        writer.add_scalar("Acc/train", train_acc, epoch)
        writer.add_scalar("Acc/val", val_acc, epoch)
        writer.add_scalar("Time/train_epoch_sec", train_time, epoch)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            checkpoint_path = os.path.join(
                cfg.out_dir, f"{model_name}-best-original.pt"
            )
            torch.save(best_model_wts, checkpoint_path)
            print(f"Saved best checkpoint to {checkpoint_path}")

    writer.close()

    model.load_state_dict(best_model_wts)
    final_ckpt_path = os.path.join(
        cfg.out_dir, f"{model_name}-final-original.pt"
    )
    torch.save(model.state_dict(), final_ckpt_path)
    print(f"Saved final model to {final_ckpt_path}")

    criterion = nn.CrossEntropyLoss()
    val_loss, val_acc, per_class_acc = evaluate(
        model, val_loader, criterion, device,
        cfg.num_classes, class_names=class_names
    )

    print(f"Best val accuracy for {model_name}: {best_val_acc:.4f}")
    print(f"Final val accuracy for {model_name}: {val_acc:.4f}")

    return model, per_class_acc


# =========================
# Run interpretability
# =========================
def run_interpretability(model, cfg, class_names, val_loader, model_name):
    device = cfg.device
    model.eval()
    model.to(device)

    # choose last conv layer as target for Grad CAM
    if isinstance(model, SimpleCNN):
        target_layer = model.features[-3]  # last Conv in features
    elif isinstance(model, SmallResNet):
        target_layer = model.layer3  # last residual block
    else:
        raise ValueError("Unknown model type for Grad CAM")

    grad_cam = GradCAM(model, target_layer)

    out_dir = Path(cfg.out_dir) / f"{model_name}_interpretability"
    out_dir.mkdir(parents=True, exist_ok=True)

    images_done = 0
    max_images = 10  # you can change

    # Removed with torch.no_grad(): block here
    for inputs, targets in val_loader:
        for i in range(inputs.size(0)):
            img = inputs[i]
            label = targets[i].item()
            label_name = class_names[label]

            # Saliency uses grad, so need to re run with grad
            sal_map = compute_saliency_map(model, img, label, device)

            # Grad CAM
            cam = grad_cam(img, target_class=label)

            # save saliency
            sal_out_path = out_dir / f"saliency_{images_done}_{label_name}.png"
            plt.figure(figsize=(4, 4))
            plt.axis("off")
            plt.imshow(sal_map, cmap="hot")
            plt.tight_layout()
            plt.savefig(sal_out_path, bbox_inches="tight", pad_inches=0)
            plt.close()

            # save Grad CAM overlay
            cam_out_path = out_dir / f"gradcam_{images_done}_{label_name}.png"
            show_and_save_heatmap(img, cam, str(cam_out_path))

            images_done += 1
            if images_done >= max_images:
                grad_cam.close()
                print(f"Saved {images_done} interpretability images for {model_name}")
                return

    grad_cam.close()
    print(f"Saved {images_done} interpretability images for {model_name}")

In [10]:

# =========================
# Main
# =========================
def main():
    train_loader, val_loader, class_names = get_dataloaders(cfg)

    print("Training Model A - SimpleCNN")
    model_a = SimpleCNN(num_classes=cfg.num_classes)
    model_a, per_class_acc_a = train_and_eval_model(
        model_a,
        cfg.model_a_name,
        cfg.run_name_a,
        cfg,
        train_loader,
        val_loader,
        class_names
    )
    run_interpretability(model_a, cfg, class_names, val_loader, cfg.model_a_name)

    print("\nTraining Model B - SmallResNet")
    model_b = SmallResNet(num_classes=cfg.num_classes)
    model_b, per_class_acc_b = train_and_eval_model(
        model_b,
        cfg.model_b_name,
        cfg.run_name_b,
        cfg,
        train_loader,
        val_loader,
        class_names
    )
    run_interpretability(model_b, cfg, class_names, val_loader, cfg.model_b_name)

    print("Done")


if __name__ == "__main__":
    main()

Training Model A - SimpleCNN
Model simple_cnn param count: 4586506
Epoch 1/10




Train loss 3.1326 acc 0.2524 time 9.52s
Val   loss 1.6791 acc 0.3600
Saved best checkpoint to outputs_problem_a\simple_cnn-best-original.pt
Epoch 2/10
Train loss 1.8161 acc 0.3333 time 8.41s
Val   loss 1.4393 acc 0.4600
Saved best checkpoint to outputs_problem_a\simple_cnn-best-original.pt
Epoch 3/10
Train loss 1.6625 acc 0.4018 time 7.98s
Val   loss 1.3412 acc 0.4600
Epoch 4/10
Train loss 1.5934 acc 0.4275 time 8.02s
Val   loss 1.4398 acc 0.5200
Saved best checkpoint to outputs_problem_a\simple_cnn-best-original.pt
Epoch 5/10
Train loss 1.4985 acc 0.4614 time 8.14s
Val   loss 1.1254 acc 0.6200
Saved best checkpoint to outputs_problem_a\simple_cnn-best-original.pt
Epoch 6/10
Train loss 1.4944 acc 0.4520 time 8.13s
Val   loss 1.2493 acc 0.5800
Epoch 7/10
Train loss 1.4060 acc 0.5028 time 8.13s
Val   loss 1.1639 acc 0.5400
Epoch 8/10
Train loss 1.4019 acc 0.4947 time 7.80s
Val   loss 1.1382 acc 0.5800
Epoch 9/10
Train loss 1.3245 acc 0.5185 time 8.36s
Val   loss 1.0635 acc 0.6200
Epoch 1

In [11]:
# part B

import os
import time
import copy
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import scipy.ndimage # Added for mask resizing


# =========================
# Config
# =========================
class ConfigB:
    data_root = "data/Sports"  # same as Part A, changed to capital 'S'
    image_size = 64
    batch_size = 32
    num_workers = 4
    num_classes = 10

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # model checkpoint paths from Part A
    ckpt_dir = "outputs_problem_a"
    simple_ckpt = "simple_cnn-best-original.pt"
    resnet_ckpt = "small_resnet-best-original.pt"

    out_dir = "outputs_problem_b"

    # attack settings
    eps_fgsm = 0.03          # epsilon in normalized space
    eps_pgd = 0.03
    pgd_alpha = 0.007
    pgd_steps = 20

    # target label index for "basketball"
    basketball_class_name = "basketball"


cfg = ConfigB()
os.makedirs(cfg.out_dir, exist_ok=True)


# =========================
# Dataset and loader
# =========================
def get_val_loader(cfg: ConfigB):
    val_transform = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    val_dir = os.path.join(cfg.data_root, "valid") # Changed 'val' to 'valid'
    val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)
    val_loader = DataLoader(
        val_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True
    )
    return val_loader, val_dataset.classes


# =========================
# Models copied from Part A
# =========================
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.MaxPool2d(2),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3,
            stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels,
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        if self.downsample is not None:
            identity = self.downsample(identity)

        out = out + identity
        out = self.relu(out)
        return out


class SmallResNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )

        self.layer1 = BasicBlock(32, 64, stride=2)
        self.layer2 = BasicBlock(64, 128, stride=2)
        self.layer3 = BasicBlock(128, 256, stride=2)

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def load_models(cfg: ConfigB, class_names):
    device = cfg.device

    model_a = SimpleCNN(num_classes=cfg.num_classes).to(device)
    model_b = SmallResNet(num_classes=cfg.num_classes).to(device)

    ckpt_a = os.path.join(cfg.ckpt_dir, cfg.simple_ckpt)
    ckpt_b = os.path.join(cfg.ckpt_dir, cfg.resnet_ckpt)

    print(f"Loading {ckpt_a}")
    model_a.load_state_dict(torch.load(ckpt_a, map_location=device))
    print(f"Loading {ckpt_b}")
    model_b.load_state_dict(torch.load(ckpt_b, map_location=device))

    # find index for basketball
    if cfg.basketball_class_name in class_names:
        basketball_idx = class_names.index(cfg.basketball_class_name)
    else:
        raise ValueError(
            f"basketball class not found. classes are {class_names}"
        )

    return model_a, model_b, basketball_idx


# =========================
# Interpretability helpers
# =========================
def compute_saliency_map(model, image_tensor, label, device):
    model.eval()
    image = image_tensor.unsqueeze(0).to(device)
    image.requires_grad_()

    output = model(image)
    target_score = output[0, label]
    model.zero_grad()
    target_score.backward()

    saliency = image.grad.data.abs().max(dim=1)[0]
    saliency = saliency.squeeze(0).cpu().numpy()
    saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)
    return saliency


class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None
        self.hook_a = target_layer.register_forward_hook(self.save_activation)
        self.hook_g = target_layer.register_full_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def __call__(self, input_tensor, target_class=None):
        self.model.eval()
        input_tensor = input_tensor.unsqueeze(0)
        input_tensor = input_tensor.to(next(self.model.parameters()).device)
        input_tensor.requires_grad_()

        output = self.model(input_tensor)
        if target_class is None:
            target_class = output.argmax(dim=1).item()

        loss = output[0, target_class]
        self.model.zero_grad()
        loss.backward()

        gradients = self.gradients
        activations = self.activations

        weights = gradients.mean(dim=(2, 3), keepdim=True)  # Corrected: average over spatial dimensions (H, W)
        cam = (weights * activations).sum(dim=1)
        cam = cam.squeeze(0)  # Add squeeze to remove batch dimension

        cam = cam.cpu().numpy()
        cam = np.maximum(cam, 0)
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        return cam

    def close(self):
        self.hook_a.remove()
        self.hook_g.remove()


def show_and_save_overlay(image_tensor, heatmap, out_path, alpha=0.4):
    mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
    std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)

    img = image_tensor.cpu().numpy()
    img = std * img + mean
    img = np.clip(img, 0, 1)
    img = np.transpose(img, (1, 2, 0)) # Shape (H, W, C)

    # Resize heatmap to the same dimensions as the image
    img_h, img_w, _ = img.shape
    mask_h, mask_w = heatmap.shape

    if mask_h != img_h or mask_w != img_w:
        zoom_factor_h = img_h / mask_h
        zoom_factor_w = img_w / mask_w
        # Use order=1 for bilinear interpolation
        heatmap = scipy.ndimage.zoom(heatmap, (zoom_factor_h, zoom_factor_w), order=1)

    heat = plt.cm.jet(heatmap)[..., :3]
    overlay = alpha * heat + (1 - alpha) * img
    overlay = np.clip(overlay, 0, 1)

    plt.figure(figsize=(4, 4))
    plt.axis("off")
    plt.imshow(overlay)
    plt.tight_layout()
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, bbox_inches="tight", pad_inches=0)
    plt.close()


def save_clean_and_adv_maps(model, img_clean, img_adv, label_true, label_adv,
                            class_names, out_dir, prefix):
    device = cfg.device
    model.to(device)
    label_true_name = class_names[label_true]
    label_adv_name = class_names[label_adv]

    # choose last conv layer
    if isinstance(model, SimpleCNN):
        target_layer = model.features[-3]
    elif isinstance(model, SmallResNet):
        target_layer = model.layer3
    else:
        raise ValueError("Unknown model type for GradCAM")

    grad_cam = GradCAM(model, target_layer)

    # saliency clean
    sal_clean = compute_saliency_map(model, img_clean, label_true, device)
    sal_adv = compute_saliency_map(model, img_adv, label_adv, device)

    # grad cam clean and adv
    cam_clean = grad_cam(img_clean, target_class=label_true)
    cam_adv = grad_cam(img_adv, target_class=label_adv)

    # save
    base = Path(out_dir)
    base.mkdir(parents=True, exist_ok=True)

    # raw saliency
    for name, sal in [("clean", sal_clean), ("adv", sal_adv)]:
        p = base / f"{prefix}_saliency_{name}_{label_true_name}_to_{label_adv_name}.png"
        plt.figure(figsize=(4, 4))
        plt.axis("off")
        plt.imshow(sal, cmap="hot")
        plt.tight_layout()
        plt.savefig(p, bbox_inches="tight", pad_inches=0)
        plt.close()

    # overlays
    show_and_save_overlay(
        img_clean, cam_clean,
        str(base / f"{prefix}_gradcam_clean_{label_true_name}.png")
    )
    show_and_save_overlay(
        img_adv, cam_adv,
        str(base / f"{prefix}_gradcam_adv_{label_adv_name}.png")
    )

    grad_cam.close()


# =========================
# Attack functions
# =========================
def clamp_tensor(x, mean, std):
    # x is normalized. we clamp in image space then renormalize
    mean = torch.tensor(mean, device=x.device).view(1, 3, 1, 1)
    std = torch.tensor(std, device=x.device).view(1, 3, 1, 1)
    img = x * std + mean
    img = torch.clamp(img, 0.0, 1.0)
    x_norm = (img - mean) / std
    return x_norm


def fgsm_attack(model, x, y, eps, targeted=False, target_labels=None):
    model.eval()
    x_adv = x.clone().detach().requires_grad_(True)
    criterion = nn.CrossEntropyLoss()

    outputs = model(x_adv)
    if targeted:
        assert target_labels is not None
        loss = criterion(outputs, target_labels)
        grad_sign = torch.sign(torch.autograd.grad(loss, x_adv)[0])
        x_adv = x_adv - eps * grad_sign
    else:
        loss = criterion(outputs, y)
        grad_sign = torch.sign(torch.autograd.grad(loss, x_adv)[0])
        x_adv = x_adv + eps * grad_sign

    x_adv = x_adv.detach()
    x_adv = clamp_tensor(
        x_adv,
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    return x_adv


def pgd_attack(model, x, y, eps, alpha, steps, targeted=False, target_labels=None):
    model.eval()
    x_orig = x.clone().detach()
    x_adv = x_orig + 0.001 * torch.randn_like(x_orig)
    x_adv.requires_grad_(True)
    criterion = nn.CrossEntropyLoss()

    for _ in range(steps):
        outputs = model(x_adv)
        if targeted:
            assert target_labels is not None
            loss = criterion(outputs, target_labels)
            grad = torch.autograd.grad(loss, x_adv)[0]
            x_adv = x_adv - alpha * torch.sign(grad)
        else:
            loss = criterion(outputs, y)
            grad = torch.autograd.grad(loss, x_adv)[0]
            x_adv = x_adv + alpha * torch.sign(grad)

        # project back to epsilon ball
        eta = torch.clamp(x_adv - x_orig, min=-eps, max=eps)
        x_adv = x_orig + eta
        x_adv = clamp_tensor(
            x_adv,
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        x_adv = x_adv.detach()
        x_adv.requires_grad_(True)

    return x_adv.detach()


# =========================
# Generate and evaluate adversarials
# =========================
def generate_adversarial_set(
    model,
    val_loader,
    attack_type,
    eps,
    targeted,
    target_class_idx,
    num_samples,
    alpha=None,
    steps=None
):
    device = cfg.device
    model.to(device)
    model.eval()

    collected = []

    for inputs, targets in val_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        # only keep correctly classified to start
        with torch.no_grad():
            logits_clean = model(inputs)
            preds_clean = logits_clean.argmax(dim=1)
        mask_correct = preds_clean.eq(targets)
        if mask_correct.sum().item() == 0:
            continue

        inputs_c = inputs[mask_correct]
        targets_c = targets[mask_correct]
        logits_clean = logits_clean[mask_correct]
        preds_clean = preds_clean[mask_correct]

        if targeted:
            target_labels = torch.full_like(
                targets_c, target_class_idx, device=device
            )
        else:
            target_labels = None

        if attack_type == "fgsm":
            adv = fgsm_attack(
                model, inputs_c, targets_c, eps,
                targeted=targeted, target_labels=target_labels
            )
        elif attack_type == "pgd":
            adv = pgd_attack(
                model, inputs_c, targets_c, eps,
                alpha=alpha, steps=steps,
                targeted=targeted, target_labels=target_labels
            )
        else:
            raise ValueError("Unknown attack type")

        with torch.no_grad():
            logits_adv = model(adv)
            preds_adv = logits_adv.argmax(dim=1)

        for i in range(inputs_c.size(0)):
            x0 = inputs_c[i].detach().cpu()
            xa = adv[i].detach().cpu()
            y_true = int(targets_c[i].item())
            y_clean = int(preds_clean[i].item())
            y_adv = int(preds_adv[i].item())

            lc = logits_clean[i].detach().cpu().numpy()
            la = logits_adv[i].detach().cpu().numpy()

            diff = (xa - x0).view(-1)
            l2 = torch.norm(diff, p=2).item()
            linf = torch.norm(diff, p=float("inf")).item()

            if targeted:
                success = (y_adv == target_class_idx)
            else:
                success = (y_adv != y_true)

            collected.append({
                "x_clean": x0,
                "x_adv": xa,
                "y_true": y_true,
                "y_pred_clean": y_clean,
                "y_pred_adv": y_adv,
                "logits_clean": lc,
                "logits_adv": la,
                "l2": l2,
                "linf": linf,
                "success": success,
            })

            if len(collected) >= num_samples:
                return collected

    return collected


def save_example_images(examples, class_names, out_dir, prefix):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
    std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)

    for idx, ex in enumerate(examples):
        x0 = ex["x_clean"].numpy()
        xa = ex["x_adv"].numpy()

        def denorm(x):
            img = std * x + mean
            img = np.clip(img, 0, 1)
            img = np.transpose(img, (1, 2, 0))
            return img

        img0 = denorm(x0)
        imga = denorm(xa)

        fig, axs = plt.subplots(1, 2, figsize=(6, 3))
        axs[0].imshow(img0)
        axs[0].axis("off")
        axs[0].set_title(
            f"Clean {class_names[ex['y_true']]} / pred {class_names[ex['y_pred_clean']]}"
        )
        axs[1].imshow(imga)
        axs[1].axis("off")
        axs[1].set_title(
            f"Adv pred {class_names[ex['y_pred_adv']]}"
        )
        plt.tight_layout()
        fig.savefig(out_dir / f"{prefix}_pair_{idx}.png", bbox_inches="tight", pad_inches=0)
        plt.close()


def transferability_test(examples, model_other, class_names, attack_name, targeted, target_idx):
    device = cfg.device
    model_other.to(device)
    model_other.eval()

    total = 0
    success = 0

    for ex in examples:
        xa = ex["x_adv"].unsqueeze(0).to(device)
        y_true = ex["y_true"]
        with torch.no_grad():
            logits = model_other(xa)
            pred = logits.argmax(dim=1).item()

        if targeted:
            suc = (pred == target_idx)
        else:
            suc = (pred != y_true)

        success += int(suc)
        total += 1

    rate = success / max(total, 1)
    t_type = "targeted" if targeted else "untargeted"
    print(f"Transfer {attack_name} {t_type} success rate on other model {rate:.3f}")
    return rate


# =========================
# Main experiment
# =========================
def main():
    val_loader, class_names = get_val_loader(cfg)
    model_a, model_b, basketball_idx = load_models(cfg, class_names)

    # attacks to run
    attacks = [
        ("fgsm", False),
        ("fgsm", True),
        ("pgd", False),
        ("pgd", True),
    ]

    for attack_name, targeted in attacks:
        print("\n====================================")
        print(f"Attack {attack_name}, targeted={targeted}")
        print("Generating on Model A (SimpleCNN)")

        adv_a = generate_adversarial_set(
            model_a,
            val_loader,
            attack_type=attack_name,
            eps=cfg.eps_fgsm if attack_name == "fgsm" else cfg.eps_pgd,
            targeted=targeted,
            target_class_idx=basketball_idx,
            num_samples=10,
            alpha=cfg.pgd_alpha if attack_name == "pgd" else None,
            steps=cfg.pgd_steps if attack_name == "pgd" else None,
        )

        success_rate_a = sum(int(e["success"]) for e in adv_a) / max(len(adv_a), 1)
        atype = "targeted" if targeted else "untargeted"
        print(f"Model A {attack_name} {atype} success {success_rate_a:.3f} "
              f"on {len(adv_a)} examples")

        # save images
        prefix_a = f"modelA_{attack_name}_{'targeted' if targeted else 'untargeted'}"
        save_example_images(
            adv_a,
            class_names,
            os.path.join(cfg.out_dir, "examples"),
            prefix_a
        )

        # interpretability for first few
        inter_dir = os.path.join(cfg.out_dir, "interpretability")
        for i, ex in enumerate(adv_a[:3]):
            save_clean_and_adv_maps(
                model_a,
                ex["x_clean"],
                ex["x_adv"],
                ex["y_true"],
                ex["y_pred_adv"],
                class_names,
                inter_dir,
                prefix=f"{prefix_a}_ex{i}"
            )

        # transferability from model A to B
        transferability_test(
            adv_a, model_b, class_names, attack_name, targeted, basketball_idx
        )

        # now generate on Model B
        print("Generating on Model B (SmallResNet)")

        adv_b = generate_adversarial_set(
            model_b,
            val_loader,
            attack_type=attack_name,
            eps=cfg.eps_fgsm if attack_name == "fgsm" else cfg.eps_pgd,
            targeted=targeted,
            target_class_idx=basketball_idx,
            num_samples=10,
            alpha=cfg.pgd_alpha if attack_name == "pgd" else None,
            steps=cfg.pgd_steps if attack_name == "pgd" else None,
        )

        success_rate_b = sum(int(e["success"]) for e in adv_b) / max(len(adv_b), 1)
        atype = "targeted" if targeted else "untargeted"
        print(f"Model B {attack_name} {atype} success {success_rate_b:.3f} "
              f"on {len(adv_b)} examples")

        prefix_b = f"modelB_{attack_name}_{'targeted' if targeted else 'untargeted'}"
        save_example_images(
            adv_b,
            class_names,
            os.path.join(cfg.out_dir, "examples"),
            prefix_b
        )

        for i, ex in enumerate(adv_b[:3]):
            save_clean_and_adv_maps(
                model_b,
                ex["x_clean"],
                ex["x_adv"],
                ex["y_true"],
                ex["y_pred_adv"],
                class_names,
                inter_dir,
                prefix=f"{prefix_b}_ex{i}"
            )

        # transferability from model B to A
        transferability_test(
            adv_b, model_a, class_names, attack_name, targeted, basketball_idx
        )

    print("\nDone Part B")


if __name__ == "__main__":
    main()

Loading outputs_problem_a\simple_cnn-best-original.pt
Loading outputs_problem_a\small_resnet-best-original.pt

Attack fgsm, targeted=False
Generating on Model A (SimpleCNN)
Model A fgsm untargeted success 0.300 on 10 examples
Transfer fgsm untargeted success rate on other model 0.300
Generating on Model B (SmallResNet)
Model B fgsm untargeted success 0.500 on 10 examples
Transfer fgsm untargeted success rate on other model 0.200

Attack fgsm, targeted=True
Generating on Model A (SimpleCNN)
Model A fgsm targeted success 0.300 on 10 examples
Transfer fgsm targeted success rate on other model 0.200
Generating on Model B (SmallResNet)
Model B fgsm targeted success 0.300 on 10 examples
Transfer fgsm targeted success rate on other model 0.300

Attack pgd, targeted=False
Generating on Model A (SimpleCNN)
Model A pgd untargeted success 0.500 on 10 examples
Transfer pgd untargeted success rate on other model 0.300
Generating on Model B (SmallResNet)
Model B pgd untargeted success 0.500 on 10 ex

In [12]:
#part C 
import os
import time
from pathlib import Path
import copy
import numpy as np

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt


# =========================
# Config
# =========================
class ConfigC:
    data_root = "data/Sports"
    image_size = 64
    batch_size = 64
    num_workers = 4
    num_classes = 10

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # choose which model to prune: "simple" or "resnet"
    model_type = "simple"  # change to "resnet" if you prefer

    ckpt_dir = "outputs_problem_a"
    simple_ckpt = "simple_cnn-best-original.pt"
    resnet_ckpt = "small_resnet-best-original.pt"

    out_dir = "outputs_problem_c"

    # pruning levels
    sparsity_levels = [0.2, 0.5, 0.8]

    # fine tune
    finetune_epochs = 5
    finetune_lr = 1e-4
    finetune_weight_decay = 1e-5

    # adversarial settings (reuse style from Part B)
    eps_fgsm = 0.03
    basketball_class_name = "basketball"
    num_adv_eval = 20  # number of samples for robustness check


cfg = ConfigC()
os.makedirs(cfg.out_dir, exist_ok=True)


# =========================
# Dataset
# =========================
def get_val_loader(cfg: ConfigC):
    transform = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ])
    val_dir = os.path.join(cfg.data_root, "valid")
    val_ds = datasets.ImageFolder(val_dir, transform=transform)
    val_loader = DataLoader(
        val_ds,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True
    )
    return val_loader, val_ds.classes


# =========================
# Models (same as Part A)
# =========================
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3,
            stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels,
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(identity)
        out = out + identity
        out = self.relu(out)
        return out


class SmallResNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )
        self.layer1 = BasicBlock(32, 64, stride=2)
        self.layer2 = BasicBlock(64, 128, stride=2)
        self.layer3 = BasicBlock(128, 256, stride=2)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def load_baseline_model(cfg: ConfigC, class_names):
    device = cfg.device
    if cfg.model_type == "simple":
        model = SimpleCNN(num_classes=cfg.num_classes)
        ckpt_path = os.path.join(cfg.ckpt_dir, cfg.simple_ckpt)
    elif cfg.model_type == "resnet":
        model = SmallResNet(num_classes=cfg.num_classes)
        ckpt_path = os.path.join(cfg.ckpt_dir, cfg.resnet_ckpt)
    else:
        raise ValueError("model_type must be 'simple' or 'resnet'")

    print(f"Loading baseline from {ckpt_path}")
    state_dict = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state_dict)
    model = model.to(device)

    if cfg.basketball_class_name in class_names:
        basketball_idx = class_names.index(cfg.basketball_class_name)
    else:
        raise ValueError(
            f"basketball class not found in {class_names}"
        )
    return model, basketball_idx


# =========================
# Evaluation and metrics
# =========================
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def evaluate(model, loader, device):
    model.eval()
    crit = nn.CrossEntropyLoss()
    total_loss = 0.0
    total_correct = 0
    total = 0

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = crit(logits, y)
            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(dim=1)
            total_correct += preds.eq(y).sum().item()
            total += y.size(0)

    avg_loss = total_loss / total
    acc = total_correct / total
    return avg_loss, acc


def model_sparsity(model):
    total = 0
    zeros = 0
    for module in model.modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor):
                total += module.weight.numel()
                zeros += (module.weight == 0).sum().item()
            if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor):
                total += module.bias.numel()
                zeros += (module.bias == 0).sum().item()
    if total == 0: # Avoid division by zero if no prunable layers are found
        return 0.0
    return zeros / total


def save_model_and_get_size(model, path):
    torch.save(model.state_dict(), path)
    size_mb = os.path.getsize(path) / (1024 * 1024)
    return size_mb


def measure_latency(model, device, batch_size, num_runs=100, warmup=10):
    model.eval()
    model.to(device)
    x = torch.randn(batch_size, 3, cfg.image_size, cfg.image_size, device=device)

    # warm up
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(x)

    times = []
    with torch.no_grad():
        for _ in range(num_runs):
            start = time.perf_counter()
            _ = model(x)
            if device.type == "cuda":
                torch.cuda.synchronize()
            end = time.perf_counter()
            times.append((end - start) * 1000.0)  # ms

    times = np.array(times)
    return times.mean(), times.std()


# =========================
# Pruning
# =========================
def prune_model_unstructured(model, amount):
    # prune conv and linear layers
    for module in model.modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name="weight", amount=amount)
            prune.remove(module, "weight")
    return model


def finetune(model, train_loader, val_loader, cfg: ConfigC):
    device = cfg.device
    model.train() # Set model to training mode
    model.to(device)
    crit = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(
        model.parameters(), lr=cfg.finetune_lr,
        weight_decay=cfg.finetune_weight_decay
    )

    best_state = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(cfg.finetune_epochs):
        model.train()
        total_loss = 0.0
        total = 0
        total_correct = 0

        for x, y in train_loader:
            x = x.to(device)
            y = y.to(device)

            opt.zero_grad()
            logits = model(x)
            loss = crit(logits, y)
            loss.backward()
            opt.step()

            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(dim=1)
            total_correct += preds.eq(y).sum().item()
            total += y.size(0)

        train_loss = total_loss / total
        train_acc = total_correct / total

        val_loss, val_acc = evaluate(model, val_loader, device)
        print(f"Finetune epoch {epoch+1} "
              f"train_loss {train_loss:.4f} train_acc {train_acc:.4f} "
              f"val_loss {val_loss:.4f} val_acc {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            best_state = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_state)
    return model


# =========================
# Adversarial attacks (simple FGSM)
# =========================
def clamp_tensor(x):
    mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1, 3, 1, 1)
    img = x * std + mean
    img = torch.clamp(img, 0.0, 1.0)
    x_norm = (img - mean) / std
    return x_norm


def fgsm_attack(model, x, y, eps, targeted=False, target_labels=None):
    device = cfg.device
    model.eval()
    x_adv = x.clone().detach().requires_grad_(True)
    crit = nn.CrossEntropyLoss()

    logits = model(x_adv)
    if targeted:
        assert target_labels is not None
        loss = crit(logits, target_labels)
        grad_sign = torch.sign(torch.autograd.grad(loss, x_adv)[0])
        x_adv = x_adv - eps * grad_sign
    else:
        loss = crit(logits, y)
        grad_sign = torch.sign(torch.autograd.grad(loss, x_adv)[0])
        x_adv = x_adv + eps * grad_sign

    x_adv = x_adv.detach()
    x_adv = clamp_tensor(x_adv)
    return x_adv


def build_adv_eval_set(model, val_loader, basketball_idx, num_samples):
    device = cfg.device
    model.eval()
    model.to(device)

    collected = []

    for x, y in val_loader:
        x = x.to(device)
        y = y.to(device)

        with torch.no_grad():
            logits = model(x)
            preds = logits.argmax(dim=1)

        mask = preds.eq(y)
        if mask.sum().item() == 0:
            continue

        x_good = x[mask]
        y_good = y[mask]

        for i in range(x_good.size(0)):
            collected.append((x_good[i].detach().cpu(), int(y_good[i].item())))
            if len(collected) >= num_samples:
                return collected

    return collected


def eval_adv_success(model, examples, basketball_idx, eps):
    device = cfg.device
    model.eval()
    model.to(device)

    untargeted_total = 0
    untargeted_success = 0

    targeted_total = 0
    targeted_success = 0

    for img_cpu, y_true in examples:
        x = img_cpu.unsqueeze(0).to(device)
        y = torch.tensor([y_true], device=device)

        with torch.no_grad():
            logits = model(x)
            pred_clean = logits.argmax(dim=1).item()
        if pred_clean != y_true:
            continue  # ensure clean correct

        # untargeted
        x_unt = fgsm_attack(model, x, y, eps, targeted=False)
        with torch.no_grad():
            logits_unt = model(x_unt)
            pred_unt = logits_unt.argmax(dim=1).item()
        untargeted_total += 1
        if pred_unt != y_true:
            untargeted_success += 1

        # targeted to basketball
        target_label = torch.tensor([basketball_idx], device=device)
        x_tar = fgsm_attack(model, x, y, eps, targeted=True, target_labels=target_label)
        with torch.no_grad():
            logits_tar = model(x_tar)
            pred_tar = logits_tar.argmax(dim=1).item()
        targeted_total += 1
        if pred_tar == basketball_idx:
            targeted_success += 1

    unt_rate = untargeted_success / max(untargeted_total, 1)
    tar_rate = targeted_success / max(targeted_total, 1)
    return unt_rate, tar_rate, untargeted_total, targeted_total


# =========================
# Main
# =========================
def main():
    # train loader for finetune uses same transform as val
    val_loader, class_names = get_val_loader(cfg)

    # reuse val loader as both train and val for finetuning
    train_loader = val_loader

    baseline_model, basketball_idx = load_baseline_model(cfg, class_names)

    device = cfg.device
    baseline_model.to(device)

    # baseline metrics
    base_loss, base_acc = evaluate(baseline_model, val_loader, device)
    base_params = count_parameters(baseline_model)
    base_path = os.path.join(cfg.out_dir, f"{cfg.model_type}_baseline.pt")
    base_size = save_model_and_get_size(baseline_model, base_path)
    base_sparsity = model_sparsity(baseline_model)
    lat1_mean, lat1_std = measure_latency(baseline_model, device, batch_size=1)
    lat16_mean, lat16_std = measure_latency(baseline_model, device, batch_size=16)

    print("\nBaseline model summary")
    print(f"Val loss {base_loss:.4f} acc {base_acc:.4f}")
    print(f"Params {base_params}")
    print(f"File size {base_size:.2f} MB")
    print(f"Sparsity {base_sparsity:.4f}")
    print(f"Latency bs=1 {lat1_mean:.3f} +- {lat1_std:.3f} ms")
    print(f"Latency bs=16 {lat16_mean:.3f} +- {lat16_std:.3f} ms")

    # build fixed adversarial eval set off baseline
    adv_examples_base = build_adv_eval_set(
        baseline_model, val_loader, basketball_idx, cfg.num_adv_eval
    )
    print(f"\nCollected {len(adv_examples_base)} examples for adversarial eval")

    base_unt_rate, base_tar_rate, b_u_n, b_t_n = eval_adv_success(
        baseline_model, adv_examples_base, basketball_idx, cfg.eps_fgsm
    )
    print(f"Baseline adv untargeted success {base_unt_rate:.3f} "
          f"on {b_u_n} examples")
    print(f"Baseline adv targeted success {base_tar_rate:.3f} "
          f"on {b_t_n} examples")

    # record results
    results = []
    results.append({
        "sparsity": 0.0,
        "acc_pre": base_acc,
        "acc_post": base_acc,
        "params": base_params,
        "size_mb": base_size,
        "lat1_mean": lat1_mean,
        "lat1_std": lat1_std,
        "lat16_mean": lat16_mean,
        "lat16_std": lat16_std,
        "adv_unt": base_unt_rate,
        "adv_tar": base_tar_rate,
    })

    # pruning experiments
    for s in cfg.sparsity_levels:
        print("\n==============================")
        print(f"Pruning sparsity {s}")
        model_s = copy.deepcopy(baseline_model)

        # pre finetune eval
        loss_before, acc_before = evaluate(model_s, val_loader, device)

        # apply pruning
        model_s = prune_model_unstructured(model_s, amount=s)
        sparsity_actual = model_sparsity(model_s)
        print(f"Actual sparsity {sparsity_actual:.4f}")

        # eval after prune but before finetune
        loss_pruned, acc_pruned = evaluate(model_s, val_loader, device)
        print(f"After prune before finetune val_loss {loss_pruned:.4f} "
              f"val_acc {acc_pruned:.4f}")

        # finetune
        model_s = finetune(model_s, train_loader, val_loader, cfg)
        loss_after, acc_after = evaluate(model_s, val_loader, device)
        print(f"After finetune val_loss {loss_after:.4f} "
              f"val_acc {acc_after:.4f}")

        # param count and size
        params_s = count_parameters(model_s)
        path_s = os.path.join(
            cfg.out_dir,
            f"{cfg.model_type}_pruned_{int(s*100)}.pt"
        )
        size_s = save_model_and_get_size(model_s, path_s)

        # latency
        lat1_m, lat1_s = measure_latency(model_s, device, batch_size=1)
        lat16_m, lat16_s = measure_latency(model_s, device, batch_size=16)

        # adv robustness
        adv_unt, adv_tar, n_u, n_t = eval_adv_success(
            model_s, adv_examples_base, basketball_idx, cfg.eps_fgsm
        )
        print(f"Pruned {int(s*100)}% adv untargeted success {adv_unt:.3f} "
              f"on {n_u} examples")
        print(f"Pruned {int(s*100)}% adv targeted success {adv_tar:.3f} "
              f"on {n_t} examples")

        results.append({
            "sparsity": sparsity_actual,
            "acc_pre": acc_pruned,
            "acc_post": acc_after,
            "params": params_s,
            "size_mb": size_s,
            "lat1_mean": lat1_m,
            "lat1_std": lat1_s,
            "lat16_mean": lat16_m,
            "lat16_std": lat16_s,
            "adv_unt": adv_unt,
            "adv_tar": adv_tar,
        })

    # summary table to csv
    import csv
    csv_path = os.path.join(cfg.out_dir, f"{cfg.model_type}_pruning_summary.csv")
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "sparsity", "acc_pre", "acc_post", "params", "size_mb",
            "lat1_mean", "lat1_std", "lat16_mean", "lat16_std",
            "adv_unt", "adv_tar"
        ])
        for r in results:
            writer.writerow([
                r["sparsity"], r["acc_pre"], r["acc_post"], r["params"],
                r["size_mb"], r["lat1_mean"], r["lat1_std"],
                r["lat16_mean"], r["lat16_std"], r["adv_unt"], r["adv_tar"]
            ])

    # accuracy vs sparsity plot
    sparsities = [r["sparsity"] for r in results]
    acc_post = [r["acc_post"] for r in results]
    plt.figure()
    plt.plot(sparsities, acc_post, marker="o")
    plt.xlabel("Sparsity")
    plt.ylabel("Val accuracy (post finetune)")
    plt.title("Accuracy vs Sparsity")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(cfg.out_dir, f"{cfg.model_type}_acc_vs_sparsity.png"))
    plt.close()

    # latency vs sparsity (bs=1)
    lat1 = [r["lat1_mean"] for r in results]
    plt.figure()
    plt.plot(sparsities, lat1, marker="o")
    plt.xlabel("Sparsity")
    plt.ylabel("Latency (ms) batch size 1")
    plt.title("Latency vs Sparsity")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(cfg.out_dir, f"{cfg.model_type}_lat_vs_sparsity_bs1.png"))
    plt.close()

    print(f"\nSaved summary CSV and plots to {cfg.out_dir}")


if __name__ == "__main__":
    main()

Loading baseline from outputs_problem_a\simple_cnn-best-original.pt

Baseline model summary
Val loss 1.0283 acc 0.6600
Params 4586506
File size 17.51 MB
Sparsity 0.0000
Latency bs=1 2.055 +- 0.470 ms
Latency bs=16 18.815 +- 1.519 ms

Collected 20 examples for adversarial eval
Baseline adv untargeted success 0.300 on 20 examples
Baseline adv targeted success 0.200 on 20 examples

Pruning sparsity 0.2
Actual sparsity 0.2000
After prune before finetune val_loss 1.1536 val_acc 0.6200
Finetune epoch 1 train_loss 1.4474 train_acc 0.4600 val_loss 1.0889 val_acc 0.6800
Finetune epoch 2 train_loss 1.3486 train_acc 0.4800 val_loss 1.0446 val_acc 0.6800
Finetune epoch 3 train_loss 1.3980 train_acc 0.4400 val_loss 1.0050 val_acc 0.7000
Finetune epoch 4 train_loss 1.2131 train_acc 0.6000 val_loss 0.9725 val_acc 0.7200
Finetune epoch 5 train_loss 1.2208 train_acc 0.5000 val_loss 0.9466 val_acc 0.7400
After finetune val_loss 0.9466 val_acc 0.7400
Pruned 20% adv untargeted success 0.300 on 20 examples