# FGSM Adversarial Attack Interactive Notebook
This notebook demonstrates FGSM attacks with visualizations, downloads, and a GIF generator.

In [None]:
# Cell 1: Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import save_image
import imageio
import os
from ipywidgets import interact, FloatSlider
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Cell 2: Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True)
classes = list(range(10))

In [None]:
# Cell 3: Define a small CNN
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.fc(x)

model = Net().to(device)
model_path = "./models/mnist_model.pth"
os.makedirs("./models", exist_ok=True)

if not os.path.exists(model_path):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    for epoch in range(1):
        for data, target in trainloader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
    torch.save(model.state_dict(), model_path)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

In [None]:
# Cell 4: FGSM Attack Function
def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

In [None]:
# Cell 5: Interactive FGSM Attack
dataiter = iter(testloader)
image, label = next(dataiter)
image, label = image.to(device), label.to(device)
image.requires_grad = True

output = model(image)
init_pred = output.max(1, keepdim=True)[1]

if init_pred.item() != label.item():
    print("Model initially misclassified this example.")
else:
    loss = nn.CrossEntropyLoss()(output, label)
    model.zero_grad()
    loss.backward()
    data_grad = image.grad.data

    @interact(epsilon=FloatSlider(min=0, max=0.5, step=0.01, value=0.1))
    def update(epsilon):
        perturbed_data = fgsm_attack(image, epsilon, data_grad)
        output = model(perturbed_data)
        final_pred = output.max(1, keepdim=True)[1]
        
        fig, axs = plt.subplots(1, 3, figsize=(10,3))
        axs[0].imshow(image.squeeze().cpu().detach(), cmap="gray")
        axs[0].set_title('Original')
        axs[1].imshow(perturbed_data.squeeze().cpu().detach(), cmap="gray")
        axs[1].set_title('Adversarial')
        perturbation = perturbed_data - image
        axs[2].imshow(perturbation.squeeze().cpu().detach().abs(), cmap='hot')
        axs[2].set_title('Perturbation')
        for ax in axs:
            ax.axis('off')
        plt.tight_layout()
        plt.show()
        save_image(perturbed_data, "fgsm_output.png")

In [None]:
# Cell 6: Create a GIF over increasing epsilon values
images = []
epsilons = np.linspace(0, 0.5, num=20)
for eps in epsilons:
    adv = fgsm_attack(image, eps, data_grad)
    adv_img = adv.squeeze().cpu().detach().numpy() * 255
    adv_img = adv_img.astype(np.uint8)
    images.append(adv_img)
imageio.mimsave('fgsm_progression.gif', images, duration=0.2)
print("[INFO] GIF saved as fgsm_progression.gif")