In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms

from robustness.attacks.fast_gradient import FastGradient
from robustness.attacks.projected_gradient_descent import ProjectedGradientDescent

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

## Define CNN

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(3*3*64, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.relu(F.max_pool2d(self.conv3(x), 2))
        x = x.view(-1, 3*3*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define train and evaluation methods

In [None]:
def train(model, train_loader, optimizer, device, epochs):
    # Set model to training mode
    model.train()

    for epoch in range(epochs):

        running_loss = 0.
        correct = 0
        running_confidences = []

        # Loop over each batch from the training set
        for batch in train_loader:
            # Copy data to device if needed
            batch = tuple(t.to(device) for t in batch)

            # Unpack the batch from the loader
            inputs, labels = batch

            # Zero gradient buffers
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)

            # Backward pass
            loss.backward()

            # Update weights
            optimizer.step()

            # Update metrics
            running_loss += loss.item() * len(inputs)
            correct += (outputs.argmax(dim=1) == labels).sum().item()

            confidence, _ = outputs.softmax(dim=1).max(dim=1)
            running_confidences += confidence
        
        metrics = {'loss': running_loss / len(train_loader.dataset),
                   'accuracy': correct / len(train_loader.dataset),
                   'average_confidence': (sum(running_confidences) / len(train_loader.dataset)).item()}

        print(metrics)

    return metrics

def evaluate(model, data_loader, device, attack=None):
    # Set model to evaluation mode
    model.eval()

    correct = 0
    running_loss = 0.
    running_confidences = []
    
    # Loop over each batch from the validation set
    for batch in data_loader:
        
        # Copy data to device if needed
        batch = tuple(t.to(device) for t in batch)

        # Unpack the batch from the loader
        inputs, labels = batch

        if attack:
            inputs, _ = attack.generate(inputs, labels)
        
        # Forward pass
        with torch.no_grad():
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)

            # Update metrics
            running_loss += loss.item() * len(inputs)
            correct += (outputs.argmax(dim=1) == labels).sum().item()

            confidence, _ = outputs.softmax(dim=1).max(dim=1)
            running_confidences += confidence

    metrics = {'loss': running_loss / len(data_loader.dataset),
               'accuracy': correct / len(data_loader.dataset),
               'average_confidence': (sum(running_confidences) / len(data_loader.dataset)).item()}

    return metrics

# Prepare data

In [None]:
def prepare_mnist(batch_size):
    transform = transforms.Compose([transforms.ToTensor()])

    train_set = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

    test_set = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, test_loader

# Train model

In [None]:
train_loader, test_loader = prepare_mnist(batch_size=128)

model = CNN().to(device)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
train(model, train_loader, optimizer, device, epochs=10)

In [None]:
torch.save(model, "../models/adversarial_attacks_mnist.pt")

# Evaluate model

In [None]:
model = torch.load("../models/adversarial_attacks_mnist.pt").to(device)

In [None]:
evaluate(model, test_loader, device)

# Adversarial attacks

## FGSM

In [None]:
epsilons = [0, 0.1, 0.2, 0.3, 0.4, 0.5]

In [None]:
# Metrics
accuracy = []
average_confidence = []

for epsilon in epsilons:
    metrics = evaluate(model, test_loader, device, attack=FastGradient(model, epsilon=epsilon))

    accuracy.append(metrics['accuracy'])
    average_confidence.append(metrics['average_confidence'])

plt.figure(figsize=(3, 3))
plt.plot(epsilons, accuracy, "*-", label='Accuracy')
plt.plot(epsilons, average_confidence, "*-", label='Average confidence')
plt.yticks(np.arange(0, 1, step=0.1))
plt.xticks(np.arange(0, 0.5, step=0.1))
plt.xlabel("Epsilon")
plt.ylabel("Metric")
plt.legend()
plt.show()

### Visualisation

In [None]:
dataiter = iter(train_loader)
inputs, labels = dataiter.next()

num_images = 5
inputs, labels = inputs[:num_images], labels[:num_images]

inputs, labels = inputs.to(device), labels.to(device)

subplot_count = 0

plt.figure(figsize=(5, 6))

for i, epsilon in enumerate(epsilons):

    adversarial_inputs, predicted_labels = FastGradient(model, epsilon=epsilon).generate(inputs, labels)

    for j, (image, predicted_label) in enumerate(zip(adversarial_inputs, predicted_labels)):

        subplot_count += 1

        plt.subplot(len(epsilons), num_images, subplot_count)

        plt.xticks([], [])
        plt.yticks([], [])

        if j == 0:
            plt.ylabel(f"Epsilon: {epsilon}")

        plt.imshow(image.detach().cpu().numpy().transpose(1, 2, 0), cmap='gray')
        plt.title(f"{labels[j]} → {predicted_label}")

plt.tight_layout()
plt.show()

## PGD

In [None]:
epsilons = [0, 0.1, 0.2, 0.3, 0.4, 0.5]

In [None]:
# Metrics
accuracy = []
average_confidence = []

for epsilon in epsilons:
    metrics = evaluate(model, test_loader, device, attack=ProjectedGradientDescent(model, epsilon=epsilon, alpha=2, steps=7))

    accuracy.append(metrics['accuracy'])
    average_confidence.append(metrics['average_confidence'])

plt.figure(figsize=(3, 3))
plt.plot(epsilons, accuracy, "*-", label='Accuracy')
plt.plot(epsilons, average_confidence, "*-", label='Average confidence')
plt.yticks(np.arange(0, 1, step=0.1))
plt.xticks(np.arange(0, 0.5, step=0.1))
plt.xlabel("Epsilon")
plt.ylabel("Metric")
plt.legend()
plt.show()

## Visualisation

In [None]:
dataiter = iter(train_loader)
inputs, labels = dataiter.next()

num_images = 5
inputs, labels = inputs[:num_images], labels[:num_images]

inputs, labels = inputs.to(device), labels.to(device)

subplot_count = 0

plt.figure(figsize=(5, 6))

for i, epsilon in enumerate(epsilons):

    adversarial_inputs, predicted_labels = ProjectedGradientDescent(model, epsilon=epsilon, alpha=2, steps=7).generate(inputs, labels)

    for j, (image, predicted_label) in enumerate(zip(adversarial_inputs, predicted_labels)):

        subplot_count += 1

        plt.subplot(len(epsilons), num_images, subplot_count)

        plt.xticks([], [])
        plt.yticks([], [])

        if j == 0:
            plt.ylabel(f"Epsilon: {epsilon}")

        plt.imshow(image.detach().cpu().numpy().transpose(1, 2, 0), cmap='gray')
        plt.title(f"{labels[j]} → {predicted_label}")

plt.tight_layout()
plt.show()