# Boilerplate

Package installation, loading, and dataloaders. There's also a simple model defined. You can change it your favourite architecture if you want.

In [1]:
# !pip install tensorboardX

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt
import random

from torchvision import datasets, transforms
# from tensorboardX import SummaryWriter

use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 64

train_fresh = True

np.random.seed(42)
torch.manual_seed(42)


## Dataloaders
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Simple NN. You can change this if you want. If you change it, mention the architectural details in your report.
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(28*28, 200)
        self.fc2 = nn.Linear(200,10)

    def forward(self, x):
        x = x.view((-1, 28*28))
        x = F.relu(self.fc(x))
        x = self.fc2(x)
        return x

class Normalize(nn.Module):
    def forward(self, x):
        return (x - 0.1307)/0.3081

# Add the data normalization as a first "layer" to the network
# this allows us to search for adverserial examples to the real image, rather than
# to the normalized image
model = nn.Sequential(Normalize(), Net())

model = model.to(device)
model.train()

Sequential(
  (0): Normalize()
  (1): Net(
    (fc): Linear(in_features=784, out_features=200, bias=True)
    (fc2): Linear(in_features=200, out_features=10, bias=True)
  )
)

# Implement the Attacks

Functions are given a simple useful signature that you can start with. Feel free to extend the signature as you see fit.

You may find it useful to create a 'batched' version of PGD that you can use to create the adversarial attack.

In [2]:
def get_model_prediction(model, x):
    model.eval()
    with torch.no_grad():
        x = x.to(device)
        output = model(x)
        _, pred = output.max(1)
    return pred.item()

# The last argument 'targeted' can be used to toggle between a targeted and untargeted attack.
def fgsm(model, x, y, eps):
    #TODO: implement this as an intermediate step of PGD
    # Notes: put the model in eval() mode for this function
    # x: input image
    # y: ground truth label for x
    # eps: size of an individual FGSM step
    model.eval()

    loss_fn = nn.CrossEntropyLoss()
    x = x.to(device)
    y = torch.tensor([y])
    prediciton = model(x)
    loss = loss_fn(prediciton, y)
    # retain_graph needed since we are calling backward multiple times
    loss.backward(retain_graph=True)
    
    # print(f'Gradient of loss with respect to x: {x.grad}')
    step = eps * torch.sign(x.grad)
    x_new = x + step
    # Retain the gradient for non-leaf variables
    x_new.retain_grad()
    
    return x_new

def pgd_untargeted(model, x, y, k, eps, eps_step):
    #TODO: implement this 
    # Notes: put the model in eval() mode for this function
    # x: input image
    # y: ground truth label for x
    # k: steps of FGSM
    # eps: projection region for PGD (note the need for normalization before projection, as eps values are for inputs in [0,1])
    # eps_step: step for one iteration of FGSM
    model.eval()

    # print(f'Original Label: {y}')
    x_orig = torch.clone(x)
    x.requires_grad = True
    y_pred = y

    for i in range(k):
        x_new = fgsm(model, x, y, eps_step)
        x_new = torch.clamp(x_new, min = x_orig - eps, max = x_orig + eps)
        x_new.retain_grad()
        y_pred = get_model_prediction(model, x_new)
        # print(f'Prediction after step {i}: {y_pred}')
        if y_pred != y:
            return x_new, y_pred
        x = x_new
    return x_new, y_pred
    

# Implement Adversarial Training

In [3]:
def train_model(model, num_epochs, enable_defense=True, attack='pgd', eps=0.1):
    # TODO: implement this function that trains a given model on the MNIST dataset.
    # this is a general-purpose function for both standard training and adversarial training.
    # (toggle enable_defense parameter to switch between training schemes)
    model.train()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}')
        running_loss = 0.0
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            if not enable_defense:
                outputs = model(inputs)
                optimizer.zero_grad()
                loss = loss_fn(outputs, labels)
                loss.backward()
                optimizer.step()
            
                running_loss += loss.item()
                # Print every 100th batch in an epoch
                if batch_idx % 100 == 99:
                    print(f'Average loss per batch: {running_loss/100}')
                    running_loss = 0.0
            else:
                adv_inputs = torch.stack([pgd_untargeted(model, inputs[i], labels[i], 10, eps, 0.01)[0] for i in range(len(inputs))])
                model.train()
                outputs = model(adv_inputs)
                optimizer.zero_grad()
                loss = loss_fn(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                # Print every 100th batch in an epoch
                if batch_idx % 100 == 99:
                    print(f'Average loss per batch: {running_loss/100}')
                    running_loss = 0.0

In [4]:
def print_images(x, y, x_adv, y_adv, label=""):
    x = x.squeeze().detach().numpy()
    x_adv = x_adv.squeeze().detach().numpy()
    _, axes = plt.subplots(1, 2, figsize=(10, 5))


    axes[0].imshow(x, cmap='gray')
    axes[0].set_title(f'Original Image', fontsize=20)
    axes[0].axis('off')

    axes[1].imshow(x_adv, cmap='gray')
    axes[1].set_title(f'Adversarial Image: {y} → {y_adv}', fontsize=20)
    axes[1].axis('off')

    plt.tight_layout()
    plt.savefig(f'adversarial_examples/{label}.pdf', dpi=300)
    plt.clf()

# Study Accuracy, Quality, etc.

Compare the various results and report your observations on the submission.

In [5]:
def standard_accuracy(model):
    with torch.no_grad():
        model.eval()
        total = 0
        correct = 0
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    standard_accuracy = 100 * correct/total
    print(f'Standard Accuracy: {standard_accuracy}%')
    return standard_accuracy

In [6]:
## train the original model
model = nn.Sequential(Normalize(), Net())
model = model.to(device)
model.train()

if train_fresh:
    train_model(model, 20, False)
    torch.save(model.state_dict(), 'weights.pt')
else:
    model.load_state_dict(torch.load('weights.pt'))

Epoch 0
Average loss per batch: 1.6452730214595794
Average loss per batch: 0.7862982395291328
Average loss per batch: 0.5622169169783592
Average loss per batch: 0.47251234769821165
Average loss per batch: 0.4298209285736084
Average loss per batch: 0.39704652965068815
Average loss per batch: 0.3663578736782074
Average loss per batch: 0.3556960505247116
Average loss per batch: 0.3541528633236885
Epoch 1
Average loss per batch: 0.310719183832407
Average loss per batch: 0.3064771881699562
Average loss per batch: 0.308118626922369
Average loss per batch: 0.31862631112337114
Average loss per batch: 0.297772326990962
Average loss per batch: 0.3073667399585247
Average loss per batch: 0.3072488883137703
Average loss per batch: 0.2783796367049217
Average loss per batch: 0.2853197442740202
Epoch 2
Average loss per batch: 0.2736907389014959
Average loss per batch: 0.2740565609931946
Average loss per batch: 0.26279278345406054
Average loss per batch: 0.26735957220196727
Average loss per batch: 0.26