In [1]:
import pandas as pd
import numpy as np
import random

import matplotlib.pyplot as plt

import torch
torch.manual_seed(17)
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.transforms as transforms
import torchvision.datasets as datasets

ModuleNotFoundError: No module named 'matplotlib'

In [16]:
# Download dataset and prepare dataloaders
BATCH_SIZE = 256
LEARNING_RATE = 0.001
NUM_EPOCHS = 10
DEVICE = 'cpu'

train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
test_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

full_train_dataset = datasets.CIFAR10(root='data', train=True, transform=train_transforms, download=True)
train_dataset, validation_dataset = torch.utils.data.random_split(full_train_dataset, [40000, 10000])

test_dataset = datasets.CIFAR10(root='data', train=False, transform=test_transforms)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=2, shuffle=True)
validation_loader = DataLoader(dataset=validation_dataset, batch_size=BATCH_SIZE, num_workers=2, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, num_workers=2, shuffle=False)

labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

Files already downloaded and verified


In [None]:
# Train model
def train(model, weight_decay=0):
    # Define loss function
    criterion = nn.CrossEntropyLoss()
    # Define optimizer
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999), weight_decay=weight_decay)

    # Store losses to plot after training finishes
    train_losses = []
    validation_losses = []

    for epoch in range(1, NUM_EPOCHS+1):
        # Track training + validation loss
        train_loss = 0.0
        validation_loss = 0.0
        
        # Train the model
        model.train()
        for data, target in train_loader:
            # Clear gradients
            optimizer.zero_grad()
            # Forward pass - compute predictions by passing input through model
            output = model(data)
            # Calculate loss
            loss = criterion(output, target)
            # Backpropogation: compute gradient of loss w/ respect to model parameters
            loss.backward()
            # Backpropogation: Update parameters using loss gradient
            optimizer.step()
            # Update train loss
            train_loss += loss.item()*data.size(0)
            
        # Check accuracy on validation set to make sure we don't overfit
        model.eval()
        for data, target in validation_loader:
            # Forward pass - compute predictions by passing input through model
            output = model(data)
            # Calculate loss
            loss = criterion(output, target)
            # Update validation loss
            validation_loss += loss.item()*data.size(0)
        
        # Calculate average train and validation losses
        train_loss = train_loss/len(train_loader.dataset)
        validation_loss = validation_loss/len(validation_loader.dataset)
        train_losses.append(train_loss)
        validation_losses.append(validation_loss)
            
        # Display training and validation loss and accuracy every epoch 
        train_accuracy = get_accuracy(model, train_loader, DEVICE)
        validation_accuracy = get_accuracy(model, validation_loader, DEVICE)
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f} \tTraining Accuracy: {:.6f} \tValidation Accuracy: {:.6f}'.format(
            epoch, train_loss, validation_loss, train_accuracy, validation_accuracy))
    return train_losses, validation_losses

In [None]:

# Function to get accuracy of model
def get_accuracy(model, data_loader, device):
    num_correct = 0
    model.eval()
    with torch.no_grad():
        for features, targets in data_loader:
            features = features.to(device)
            targets = targets.to(device)

            predictions = model(features)
            _, predicted_labels = torch.max(predictions, 1)

            # Increment correct count by number of correct predictions in batch
            num_correct += (predicted_labels == targets).sum()
        return num_correct.float()/len(data_loader.dataset) * 100

In [None]:
# Train model and save weights to a file

model_vanilla = LeNetVanilla()
train_losses, validation_losses = train(model_vanilla)

MODEL_PATH = "models/"
torch.save({'state_dict': model_vanilla.state_dict()}, MODEL_PATH + "LeNet_vanilla.pth")

plt.plot([*range(NUM_EPOCHS)], train_losses, color='blue', label='Train Loss')
plt.plot([*range(NUM_EPOCHS)], validation_losses, color='green', label='Validation Loss')
leg = plt.legend(loc='upper center')

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Performance of Vanilla LeNet model during training")
plt.show()

In [None]:
# Modify image using FGSM attack 
def fgsm_perturb(image, epsilon, gradient):
    # Modify image by adjusting all of the pixels
    perturbed_image = image + epsilon*gradient.sign()
    # Clip to 0,1 raange
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

In [None]:
def igsm_perturb(model, image, target, epsilon, alpha, iters): 
    # Forward pass, get least likely class
    output = model(image)
    ll_label = torch.min(output, 1)[1] # get index of the min log-prob  
    
    if iters == 0 :
        # Chose epochs based on epsilon value
        iters = int(min(epsilon + 4, 1.25*epsilon))
    
    # Data is in range [0,1] scale epsilon down by 255.
    epsilon = epsilon/255
    
    for i in range(iters) : 
        # Prepare to use input gradient
        image.requires_grad = True
        
        # Forward pass 
        output = model(image)
        init_pred = output.max(1, keepdim=True)[1]
        
        # Don't attack if prediction is already wrong
        if init_pred.item() != target.item():
            return image

        loss = F.nll_loss(output, ll_label) 
        model.zero_grad()

        # Backward pass to calculate gradients
        loss.backward()

        # Collect datagrad
        data_grad = image.grad.data

        # Collect the element-wise sign of the data gradient
        sign_data_grad = data_grad.sign()
        # Create the perturbed image by adjusting each pixel of the input image
        perturbed_image = image - alpha*sign_data_grad
        
                
        # Clip image for next iteration
        first = torch.clamp(image - epsilon, min=0)  
        second = (perturbed_image>=first).float() * perturbed_image + (first>perturbed_image).float() * first
        third = (second > image+epsilon).float() * (image+epsilon) + (image+epsilon >= second).float() * second
        image = torch.clamp(third, max=1).detach_()
    
    return image

In [None]:
def attack_test(model, device, test_loader, epsilon, alpha=0.006, iters=20, attack="fgsm"):
    # Keep track of correctly classified examples
    correct = 0
    sample_images = []

    # Go through all test images
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)

        # Preparing to modify gradient w/ respect to image
        data.requires_grad = True

        # Forward pass
        output = model(data)
        # Get highest probability as prediction
        initial_pred = output.max(1, keepdim=True)[1] 

        # If the initial prediction was wrong, we dont have to attack this sample
        if initial_pred.item() != target.item():
            continue

        loss = F.nll_loss(output, target)
        model.zero_grad()

        # Calculate gradients w/ respect to loss
        loss.backward()

        # Get data gradient
        data_grad = data.grad.data

        if attack=="fgsm":
            # Call FGSM Attack
            perturbed_data = fgsm_perturb(data, epsilon, data_grad)
        else:
            # Call IGSM attack
            perturbed_data = igsm_perturb(model, data, target, epsilon, alpha, iters)

        # Run perturbed image through model, get prediction
        output = model(perturbed_data)
        final_pred = output.max(1, keepdim=True)[1]

        if final_pred.item() == target.item():
            correct += 1
        else:
            # Save samples to display after predictions
            if len(sample_images) < 5:
                img_ptg = perturbed_data.squeeze().detach().cpu().numpy()
                sample_images.append((initial_pred.item(), final_pred.item(), img_ptg))

    # Calculate final accuracy for this value of epsilon
    accuracy = correct/float(len(test_loader))
    print(f"Epsilon: {epsilon}\tTest Accuracy = {accuracy}")

    # Return accuracy and some images to display
    return accuracy, sample_images

In [None]:
with torch.set_grad_enabled(False):
    print('Test accuracy (benign images): %.2f%%' % (get_accuracy(model_vanilla, test_loader, DEVICE)))a

In [None]:
epsilons_fgsm = [.05, .12, .3]
epsilons_igsm = [2, 4, 8, 16]

accuracies_fgsm = []
accuracies_igsm = []
examples_fgsm = []
examples_igsm = []

# Run test for each epsilon
print("fgsm testing")
for eps in epsilons_fgsm:
    acc, ex = attack_test(model_vanilla, DEVICE, test_loader, eps, attack="fgsm")
    accuracies_fgsm.append(acc)
    examples_fgsm.append(ex)

print("igsm testing")
for eps in epsilons_igsm:
    acc, ex = attack_test(model_vanilla, DEVICE, test_loader, eps, 0.006, 20, attack="igsm")
    accuracies_igsm.append(acc)
    examples_igsm.append(ex)



In [None]:
# Display some FGSM modified images at different epsilons
def showImages(examples, epsilons, title):
    count = 0
    plt.figure(figsize=(8,10))
    for i in range(len(epsilons_fgsm)):
        for j in range(len(examples[i])):
            count += 1
            plt.subplot(len(epsilons),len(examples[0]),count)
            plt.xticks([], [])
            plt.yticks([], [])
            # Add epsilon label to the beginning of each row
            if j == 0:
                plt.ylabel("Eps: {}".format(epsilons[i]), fontsize=14)
            original,adv,example = examples[i][j]
            plt.title("{} -> {}".format(labels[original], labels[adv]))
            example = example.swapaxes(0,1)
            example = example.swapaxes(1,2)
            plt.imshow(example, cmap="gray")
    plt.tight_layout()
    plt.suptitle(title)
    plt.show()

In [None]:
showImages(examples_fgsm, epsilons_fgsm, "FGSM attacked images on no-defense model")
showImages(examples_igsm, epsilons_igsm, "IGSM attacked images on no-defense model")

Fine tune model using adversarial training

In [None]:
# Train model with 75% perturbed images
def train_adversarial(model, weight_decay=0):
    # Define loss function
    criterion = nn.CrossEntropyLoss()
    # Define optimizer
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999), weight_decay=weight_decay)

    # Store losses to plot after training finishes
    train_losses = []
    validation_losses = []

    for epoch in range(1, NUM_EPOCHS+1):
        # Track training + validation loss
        train_loss = 0.0
        validation_loss = 0.0
        correct = 0
        
        # Train the model
        model.train()
        for data, target in train_loader:
            # Perturb image 75% of the time
            if random.random() < 0.75:
                # Preparing to modify gradient w/ respect to image
                data.requires_grad = True

                # Forward pass
                output = model(data)
                # Get highest probability as prediction
                initial_pred = output.max(1, keepdim=True)[1] 
                loss = F.nll_loss(output, target)
                model.zero_grad()

                # Calculate gradients w/ respect to loss
                loss.backward()

                # Get data gradient
                data_grad = data.grad.data

                # Call FGSM Attack
                data = fgsm_perturb(data, 0.12, data_grad)
            
            # Clear gradients
            optimizer.zero_grad()
            # Forward pass
            output = model(data)
            # Calculate loss
            loss = criterion(output, target)
            # Backwards pass
            loss.backward()
            # Update parameters
            optimizer.step()
            # Update train loss
            train_loss += loss.item()*data.size(0)
            
        # Check accuracy on validation set to prevent overfitting
        model.eval()
        for data, target in validation_loader:
            # Forward pass
            output = model(data)
            # Calculate loss
            loss = criterion(output, target)
            # Update validation loss
            validation_loss += loss.item()*data.size(0)
        
        # Calculate average train and validation losses
        train_loss = train_loss/len(train_loader.dataset)
        validation_loss = validation_loss/len(validation_loader.dataset)
        train_losses.append(train_loss)
        validation_losses.append(validation_loss)
            
        # Display training and validation loss and accuracy every epoch 
        train_accuracy = get_accuracy(model, train_loader, DEVICE)
        validation_accuracy = get_accuracy(model, validation_loader, DEVICE)
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f} \tTraining Accuracy: {:.6f} \tValidation Accuracy: {:.6f}'.format(
            epoch, train_loss, validation_loss, train_accuracy, validation_accuracy))
    return train_losses, validation_losses