# Boilerplate

Package installation, loading, and dataloaders. There's also a simple model defined. You can change it your favourite architecture if you want.

In [None]:
# !pip install tensorboardX

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt
import random

from torchvision import datasets, transforms
# from tensorboardX import SummaryWriter

use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 64

train_fresh = False

np.random.seed(42)
torch.manual_seed(42)


## Dataloaders
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Simple NN. You can change this if you want. If you change it, mention the architectural details in your report.
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 50)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 50)
        self.fc4 = nn.Linear(50, 10)

    def forward(self, x):
        x = x.view((-1, 28*28))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

class Normalize(nn.Module):
    def forward(self, x):
        return (x - 0.1307)/0.3081

# Add the data normalization as a first "layer" to the network
# this allows us to search for adverserial examples to the real image, rather than
# to the normalized image
model = nn.Sequential(Normalize(), Net())

model = model.to(device)
model.train()

# Implement the Attacks

Functions are given a simple useful signature that you can start with. Feel free to extend the signature as you see fit.

You may find it useful to create a 'batched' version of PGD that you can use to create the adversarial attack.

In [21]:
def get_model_prediction(model, x):
    model.eval()
    with torch.no_grad():
        x = x.to(device)
        output = model(x)
        _, pred = output.max(1)
    return pred.item()

# The last argument 'targeted' can be used to toggle between a targeted and untargeted attack.
def fgsm(model, x, y, eps):
    #TODO: implement this as an intermediate step of PGD
    # Notes: put the model in eval() mode for this function
    # x: input image
    # y: ground truth label for x
    # eps: size of an individual FGSM step
    model.eval()

    loss_fn = nn.CrossEntropyLoss()
    x = x.to(device)
    y = torch.tensor([y])
    prediciton = model(x)
    loss = loss_fn(prediciton, y)
    # retain_graph needed since we are calling backward multiple times
    loss.backward(retain_graph=True)
    
    # print(f'Gradient of loss with respect to x: {x.grad}')
    step = eps * torch.sign(x.grad)
    x_new = x + step
    # Retain the gradient for non-leaf variables
    x_new.retain_grad()
    
    return x_new

def pgd_untargeted(model, x, y, k, eps, eps_step):
    #TODO: implement this 
    # Notes: put the model in eval() mode for this function
    # x: input image
    # y: ground truth label for x
    # k: steps of FGSM
    # eps: projection region for PGD (note the need for normalization before projection, as eps values are for inputs in [0,1])
    # eps_step: step for one iteration of FGSM
    model.eval()

    # print(f'Original Label: {y}')
    x_orig = torch.clone(x)
    x.requires_grad = True
    y_pred = y

    for i in range(k):
        x_new = fgsm(model, x, y, eps_step)
        x_new = torch.clamp(x_new, min = x_orig - eps, max = x_orig + eps)
        x_new.retain_grad()
        y_pred = get_model_prediction(model, x_new)
        # print(f'Prediction after step {i}: {y_pred}')
        if y_pred != y:
            return x_new, y_pred
        x = x_new
    return x_new, y_pred
    

# Implement Adversarial Training

In [22]:
def train_model(model, num_epochs, enable_defense=True, attack='pgd', eps=0.1):
    # TODO: implement this function that trains a given model on the MNIST dataset.
    # this is a general-purpose function for both standard training and adversarial training.
    # (toggle enable_defense parameter to switch between training schemes)
    model.train()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}')
        running_loss = 0.0
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            if not enable_defense:
                outputs = model(inputs)
                optimizer.zero_grad()
                loss = loss_fn(outputs, labels)
                loss.backward()
                optimizer.step()
            
                running_loss += loss.item()
                # Print every 100th batch in an epoch
                if batch_idx % 100 == 99:
                    print(f'Average loss per batch: {running_loss/100}')
                    running_loss = 0.0
            else:
                adv_inputs = torch.stack([pgd_untargeted(model, inputs[i], labels[i], 10, eps, 0.01)[0] for i in range(len(inputs))])
                model.train()
                outputs = model(adv_inputs)
                optimizer.zero_grad()
                loss = loss_fn(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                # Print every 100th batch in an epoch
                if batch_idx % 100 == 99:
                    print(f'Average loss per batch: {running_loss/100}')
                    running_loss = 0.0

In [23]:
def print_images(x, y, x_adv, y_adv, label=""):
    x = x.squeeze().detach().numpy()
    x_adv = x_adv.squeeze().detach().numpy()
    _, axes = plt.subplots(1, 2, figsize=(10, 5))


    axes[0].imshow(x, cmap='gray')
    axes[0].set_title(f'Original Image', fontsize=20)
    axes[0].axis('off')

    axes[1].imshow(x_adv, cmap='gray')
    axes[1].set_title(f'Adversarial Image: {y} → {y_adv}', fontsize=20)
    axes[1].axis('off')

    plt.tight_layout()
    plt.savefig(f'adversarial_examples/{label}.pdf', dpi=300)
    plt.clf()

# Study Accuracy, Quality, etc.

Compare the various results and report your observations on the submission.

In [24]:
def standard_accuracy(model):
    with torch.no_grad():
        model.eval()
        total = 0
        correct = 0
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    standard_accuracy = 100 * correct/total
    print(f'Standard Accuracy: {standard_accuracy}%')
    return standard_accuracy

In [None]:
## train the original model
model = nn.Sequential(Normalize(), Net())
model = model.to(device)
model.train()

if train_fresh:
    train_model(model, 20, False)
    torch.save(model.state_dict(), 'weights.pt')
else:
    model.load_state_dict(torch.load('weights.pt'))

In [26]:
class Interval:
    def __init__(self, lower, upper):
        self.l = lower
        self.u = upper
        if isinstance(lower, torch.Tensor):
            assert isinstance(upper, torch.Tensor)
            assert torch.all(torch.le(lower, upper)), f'Invalid Interval: lower bound should be less than or equal to the upper bound'
        else:
            assert lower <= upper, f'Invalid Interval: lower bound should be less than or equal to the upper bound'

    def __add__(self, other):
        if isinstance(other, Interval):
            return Interval(self.l + other.l, self.u + other.u)
        else:
            return None
    
    def __sub__(self, other):
        if isinstance(other, Interval):
            return Interval(self.l - other.u, self.u - other.l)
        else:
            return None

    def constMul(self, c):
        if c >= 0:
            return Interval(self.l * c, self.u * c)
        else:
            return Interval(self.u * c, self.l * c)
    
    def __neg__(self):
        return Interval(-self.u, -self.l)

    def relu(self):
        return Interval(max(0, self.l), max(0, self.u))

    def __iadd__(self, other):
        if isinstance(other, Interval):
            self.l += other.l
            self.u += other.u
            return self
        else:
            return None
    
    def __gt__(self, other):
        if isinstance(other, Interval):
            return self.l > other.u
        else:
            return None
    
    def __str__(self):
        return f'[{self.l}, {self.u}]'
    
    def __repr__(self):
        return self.__str__()

In [27]:
# Returns the model weights and biases in a dictionary
def get_model_parameters(model):
    layers = dict()

    for name, param in model.named_parameters():
        layer_name, param_type = name.rsplit('.', 1)
        if layer_name not in layers:
            layers[layer_name] = dict()
        layers[layer_name][param_type] = param.data    
    return layers

def get_next_layer_interval(layers, k, current_layer):
    layer_name = f'1.fc{k}'
    assert layer_name in layers
    weights = layers[layer_name]['weight']
    bias = layers[layer_name]['bias']
    assert len(current_layer.l) == weights.shape[1]

    # Decompose the weight matrix into positive and negative weights
    weights_plus = weights.clamp(min=0)
    weights_minus = weights.clamp(max=0)

    # Calculating the lower bound:
    # - Positive weights multiplied by the lower bound of the input interval
    # - Negative weights multiplied by the upper bound of the input interval
    y_l = torch.matmul(weights_plus, current_layer.l) + torch.matmul(weights_minus, current_layer.u) + bias

    # Calculating the upper bound:
    # - Positive weights multiplied by the upper bound of the input interval
    # - Negative weights multiplied by the lower bound of the input interval
    y_u = torch.matmul(weights_plus, current_layer.u) + torch.matmul(weights_minus, current_layer.l) + bias
    if k != 4:
        # Apply ReLU if this is not the last layer
        y_l = y_l.clamp(min=0)
        y_u = y_u.clamp(min=0)
    return Interval(y_l, y_u)
                

def test_interval_analysis(model, eps):
    model.eval()
    layers = get_model_parameters(model)
    robust_inputs = 0
    non_robust_inputs = 0
    print(f'Testing for epsilon: {eps}')
    for input, label in test_dataset:
        x = input.view((-1, 28*28))
        x = x.squeeze(0)
        # Normalize the input, exactly same as the normalization layer in the model
        x = (x - 0.1307) / 0.3081
        # Maintain an interval of tensors, instead of a list of intervals
        input_layer = Interval(x - eps, x + eps)
        current_layer = input_layer
        for k in [1, 2, 3, 4]:
            # Propagate the interval through each layer
            current_layer = get_next_layer_interval(layers, k, current_layer)
        low_tensor, high_tensor = current_layer.l, current_layer.u
        low_actual = low_tensor[label]
        upper_others = high_tensor.tolist()[:label] + high_tensor.tolist()[label+1:]
        robust = all([low_actual > x for x in upper_others])
        if robust:
            robust_inputs += 1
            assert get_model_prediction(model, input) == label, f'Robustness should imply correct prediction'
        else:
            x_adv, y_adv = pgd_untargeted(model, input, label, 10, eps, 0.005)
            if y_adv != label:
                non_robust_inputs += 1
    robust_accuracy = 100 * robust_inputs/len(test_dataset)
    non_robust_percentage = 100 * non_robust_inputs/len(test_dataset)
    print(f'\tNetwork is robust for {robust_inputs} inputs out of {len(test_dataset)}, {non_robust_inputs} images are not robust')
    print(f'\tRobust Accuracy: {robust_accuracy}%, Non-Robust Percentage: {non_robust_percentage}%')
    return robust_accuracy, non_robust_percentage

def plot_accuracy(robust_accuracies_std, robust_accuraceies_AT, x_points, x_label, filename='default', y_label='Provable Robustness (%)'):
    plt.plot(x_points, robust_accuracies_std, label='Standard Training', marker='o')
    plt.plot(x_points, robust_accuraceies_AT, label='Adversarial Training', marker='x')
    
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'{filename}.pdf', dpi = 300)
    plt.clf()

In [None]:
# Compute the standard accuracy
standard_accuracy(model)

# Robust accuracy for different epsilons
robust_accuracies_std = []
non_robust_std = []
eps_values = [i * 0.001 for i in range(1, 11)]
for eps in eps_values:
    r, nr = test_interval_analysis(model, eps)
    robust_accuracies_std.append(r)
    non_robust_std.append(nr)

robust_accuracies_std_ = []
non_robust_std_ = []
eps_values = [i * 0.01 for i in range(1, 11)]
for eps in eps_values:
    r, nr = test_interval_analysis(model, eps)
    robust_accuracies_std_.append(r)
    non_robust_std_.append(nr)

In [None]:
## PGD based adversarial training
model = nn.Sequential(Normalize(), Net())
eps = 0.1

if train_fresh:
    train_model(model, 20, True, 'pgd', eps)
    torch.save(model.state_dict(), f'weights_AT.pt')
else:
    model.load_state_dict(torch.load(f'weights_AT.pt'))

# Compute the standard accuracy
standard_accuracy(model)

# Robust accuracy for different epsilons
robust_accuracies_AT = []
non_robust_AT = []
eps_values = [i * 0.001 for i in range(1, 11)]
for eps in eps_values:
    r, nr = test_interval_analysis(model, eps)
    robust_accuracies_AT.append(r)
    non_robust_AT.append(nr)
plot_accuracy(robust_accuracies_std, robust_accuracies_AT, eps_values, 'Epsilon', 'eps_trend_0.001')
plot_accuracy(non_robust_std, non_robust_AT, eps_values, 'Epsilon', 'eps_trend_nr_0.001', y_label='Non Robust Images (%)')

robust_accuracies_AT_ = []
non_robust_AT_ = []
eps_values = [i * 0.01 for i in range(1, 11)]
for eps in eps_values:
    r, nr = test_interval_analysis(model, eps)
    robust_accuracies_AT_.append(r)
    non_robust_AT_.append(nr)
plot_accuracy(robust_accuracies_std_, robust_accuracies_AT_, eps_values, 'Epsilon', 'eps_trend_0.01')
plot_accuracy(non_robust_std_, non_robust_AT_, eps_values, 'Epsilon', 'eps_trend_nr_0.01', y_label='Non Robust Images (%)')
