# FGSM Adversarial Attack (Interactive)
This notebook demonstrates the FGSM attack on a simple CNN trained on MNIST.

In [None]:
!pip install -q ipywidgets torchvision matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact, FloatSlider, IntSlider
from PIL import Image
import os

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc = nn.Linear(320, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        return self.fc(x)

model = Net().to(device)
model_path = "../models/mnist_model.pth"
if not os.path.exists(model_path):
    print("[INFO] Training model...")
    transform = transforms.ToTensor()
    train_set = MNIST(root="../data", train=True, download=True, transform=transform)
    train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(1):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = F.cross_entropy(output, labels)
            loss.backward()
            optimizer.step()
    torch.save(model.state_dict(), model_path)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

In [None]:
def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon * sign_data_grad
    return torch.clamp(perturbed_image, 0, 1)

In [None]:
def display_adversarial(epsilon):
    transform = transforms.ToTensor()
    test_set = MNIST(root="../data", train=False, download=True, transform=transform)
    image, label = test_set[0]
    image = image.unsqueeze(0).to(device)
    image.requires_grad = True
    output = model(image)
    init_pred = output.max(1, keepdim=True)[1] # predicted label

    loss = F.cross_entropy(output, torch.tensor([label], device=device))
    model.zero_grad()
    loss.backward()
    data_grad = image.grad.data
    perturbed_image = fgsm_attack(image, epsilon, data_grad)
    output = model(perturbed_image)
    final_pred = output.max(1, keepdim=True)[1] # predicted label

    diff = (perturbed_image - image).squeeze().cpu().numpy()
    plt.figure(figsize=(10, 3))
    plt.subplot(1, 3, 1)
    plt.title("Original
Pred: {}".format(init_pred.item()))
    plt.imshow(image.squeeze().detach().cpu(), cmap="gray")
    plt.subplot(1, 3, 2)
    plt.title("Adversarial
Pred: {}".format(final_pred.item()))
    plt.imshow(perturbed_image.squeeze().detach().cpu(), cmap="gray")
    plt.subplot(1, 3, 3)
    plt.title("Perturbation
Intensity")
    plt.imshow(diff, cmap='hot')
    plt.colorbar()
    plt.tight_layout()
    plt.show()

In [None]:
interact(display_adversarial, epsilon=FloatSlider(min=0.0, max=0.5, step=0.01, value=0.1))