# Adversarial Attacks with ResNet-34

This notebook performs two tasks:
1. Evaluate clean top-1/top-5 accuracy on a 100-class subset of ImageNet-1K
2. Apply FGSM (ε=0.02) to generate adversarial examples, evaluate again, and visualize misclassifications


## Imports and Preprocessing

In [None]:
import json
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms, datasets, models
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Normalization constants
mean_norms = np.array([0.485, 0.456, 0.406])
std_norms  = np.array([0.229, 0.224, 0.225])

plain_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean_norms, std=std_norms),
])


## Task 1: Clean Accuracy Evaluation

In [None]:
# Load dataset
dataset = datasets.ImageFolder(
    root="./TestDataSet/TestDataSet",
    transform=plain_transforms
)
loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)

# Load mapping
with open("./TestDataSet/TestDataSet/labels_list.json") as f:
    entries = json.load(f)
idx_to_true = { i: int(entries[i].split(":",1)[0]) for i in range(len(entries)) }

# Load model
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
model.eval().to(device)

# Evaluate
top1 = top5 = total = 0
with torch.no_grad():
    for imgs, labels in tqdm(loader, desc="Clean Eval"):
        imgs = imgs.to(device)
        logits = model(imgs)
        _, p5 = logits.topk(5, dim=1)
        true = torch.tensor([idx_to_true[int(l)] for l in labels], device=p5.device)
        top1 += (p5[:,0] == true).sum().item()
        top5 += (p5 == true.unsqueeze(1)).any(dim=1).sum().item()
        total += labels.size(0)

print(f"Top-1 accuracy: {top1/total*100:.2f}%")
print(f"Top-5 accuracy: {top5/total*100:.2f}%")


## Task 2: FGSM Attack & Adversarial Evaluation

In [None]:
# FGSM helper
cn = torch.tensor(mean_norms, device=device)[:,None,None]
cs = torch.tensor(std_norms,  device=device)[:,None,None]
min_val = (0 - cn) / cs
max_val = (1 - cn) / cs

def fgsm(image, eps, grad):
    return torch.max(torch.min(image + eps*grad.sign(), max_val), min_val)

# Single-image loader
si_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

# Run FGSM
epsilon = 0.02
adv_images, adv_labels, orig_preds, adv_preds = [], [], [], []

for img, lab in tqdm(si_loader, desc="FGSM"):
    img = img.to(device).detach().requires_grad_(True)
    lab = lab.to(device)
    out = model(img)
    true_idx = torch.tensor([idx_to_true[int(lab)]], device=device)
    loss = F.cross_entropy(out, true_idx)
    model.zero_grad(); loss.backward()
    adv = fgsm(img, epsilon, img.grad.data).detach().to(torch.float32)
    adv_images.append(adv.squeeze(0).cpu())
    adv_labels.append(int(lab))
    orig_preds.append(out.argmax(1).item())
    adv_preds.append(model(adv).argmax(1).item())

# Build adversarial set
adv_tensor = torch.stack(adv_images)
lab_tensor = torch.tensor(adv_labels)
adv_set = TensorDataset(adv_tensor, lab_tensor)
adv_loader = DataLoader(adv_set, batch_size=32, shuffle=False)

# Evaluate adversarial set
top1 = top5 = total = 0
with torch.no_grad():
    for imgs, labs in adv_loader:
        imgs = imgs.to(device)
        out = model(imgs)
        _, p5 = out.topk(5,1)
        true = torch.tensor([idx_to_true[int(l)] for l in labs], device=p5.device)
        top1 += (p5[:,0]==true).sum().item()
        top5 += (p5==true.unsqueeze(1)).any(1).sum().item()
        total += labs.size(0)
print(f"Adversarial top-1: {top1/total*100:.2f}%")
print(f"Adversarial top-5: {top5/total*100:.2f}%")


### Visualize Misclassified Examples

In [None]:
# Find 3 flips
idx_map = {v:k for k,v in dataset.class_to_idx.items()}
picked=[]
for i,(o,a,l) in enumerate(zip(orig_preds, adv_preds, adv_labels)):
    if o==idx_to_true[l] and a!=idx_to_true[l]: picked.append(i)
    if len(picked)>=3: break

# Un-normalizer
inv_norm = transforms.Normalize(
    mean=(-mean_norms/std_norms).tolist(),
    std=(1/std_norms).tolist()
)

fig, axes = plt.subplots(2,3,figsize=(12,6))
for col, i in enumerate(picked):
    for row, img in enumerate([inv_norm(adv_images[i]), inv_norm(adv_images[i])]):
        ax=axes[row,col]
        ax.imshow(img.permute(1,2,0).clamp(0,1).numpy())
        ax.axis('off')
        ax.set_title(('orig' if row==0 else 'adv')+f"\ntrue={idx_map[adv_labels[i]]}")
plt.tight_layout(); plt.show()
