# Resisting Adversarial Attacks by kWTA Activation
*ICLR Reproducibility Challenge*

**CS4803/7643 Spring 2020 Final Project**

By: Baran Usluel and Ilya Golod

## Our Plan

*Will delete this cell later. For now, linking some resources and our guiding plan (copied from proposal spreadsheet):*

**Paper:** https://arxiv.org/abs/1905.10510

**ICLR Submission Review:** https://openreview.net/forum?id=Skgvy64tvr

**Paper's Github:** https://github.com/a554b554/kWTA-Activation

**Project Summary:**

Implement k-WTA activation function as described in Enhancing Adversarial Defense by k-Winners-Take-All @ ICLR 2020. Reproduce empirical results (test accuracy and adversarial robustness) across the different network architectures, training methods and adversarial attacks shown in the paper. Possibly test with unexplored environments as well.

**Follow-up:**

The k-WTA activation function will be implemented in PyTorch on pretrained models. The architecture models tested will include those in the paper (ResNet, DenseNet and Wide ResNet) and possibly additional relevant models (SqueezeNet, AlexNet, VGG and so on).

The white-box attack model will be examined since this is the main focus of the original paper. Specifically the attacks to be considered are: vanilla gradient ascent (as already implemented in PS2), projected gradient descent, Deepfool, Carlini-Wagner, Momentum Iterative Method, and possibly other state-of-the-art gradient-based adversarial attacks by using the Foolbox library.

Since we are using pretrained models, we will only consider attacks on regularly trained models and not explore adversarial training. Note that the authors claim similar improvements with k-WTA across various training methods.

Finally, we will be using the CIFAR10 dataset for the image classification task.

## Setup

In [2]:
# If using colab:
# Mounts google drive folder so we can save/load files.
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [3]:
# Note to team members:
# If you get a 'No such file' error when you run this, go to your Google Drive,
# find the CS4803 Project folder under Shared With Me, right-click and select
# Add Shortcut To Drive. This will make the path accessible.
DATA_DIRECTORY = "gdrive/My Drive/CS4803 Project/"
import os
print(os.listdir(DATA_DIRECTORY))

['models', 'data', 'kWTA Activation.ipynb']


In [0]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets
import torchvision.transforms as T

import matplotlib
import matplotlib.pyplot as plt

#!pip install foolbox
import foolbox
#import eagerpy as ep

from PIL import Image

import copy

In [5]:
# For massive speed-up, ensure GPU is selected from Runtime -> Change runtime type.
# Using hardware acceleration:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

cuda:0


## Helper Functions

In [0]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

## k-WTA Activation Function

In [0]:
class kWTA(nn.Module):
    def __init__(self, sr):
        super(kWTA, self).__init__()
        self.sr = sr

    # Modified version of paper's forward implementation
    def forward(self, x):
        # Custom code to work with any array size:
        tmpx = x.view(x.shape[0], -1)
        size = tmpx.shape[1]
        k = int(self.sr * size)
        # Directly taken from paper's implementation:
        topval = tmpx.topk(k, dim=1)[0][:,-1]
        topval = topval.repeat(tmpx.shape[1], 1).permute(1,0).view_as(x)
        comp = (x>=topval).to(x)
        return comp*x

    # TODO: Is there a more efficient way of computing this?

    # # An alternate implementation:
    # def forward(self, x):
    #     tmpx = x.view(x.shape[0], -1)
    #     size = tmpx.shape[1]
    #     k = int(self.sr * size)
    #     top_inds = tmpx.topk(k, dim=1)[1]
    #     mask = torch.zeros_like(tmpx, dtype=torch.bool)
    #     mask.scatter_(1, top_inds, True)
    #     tmpx[~mask] = 0
    #     return x

In [0]:
# Checking to make sure kWTA forward-pass implementation is correct
kwta = kWTA(0.2)
a = torch.rand(2,5,5)
print(a)
b = kwta.forward(a)
print(b)

In [0]:
# Replaces given activation with specified kWTA activation function
def activation_to_kwta(model, old_activation, sr=0.2):
    for child_name, child in model.named_children():
        if isinstance(child, old_activation):
            setattr(model, child_name, kWTA(sr))
        else:
            activation_to_kwta(child, old_activation, sr)

## Load Dataset

In [10]:
# NOTE on normalization values:
# Paper uses mean=0 var=1, so that's what we used here (for now).
# But pytorch docs suggest mean=var=0.5, see https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
# And some sources claim other specific values:
#   https://github.com/kuangliu/pytorch-cifar/issues/19
#   https://github.com/kuangliu/pytorch-cifar/blob/master/main.py
MEAN = 0
VAR = 1
INPUT_SIZE = 224
# Same transforms as paper, except also resizing to 224x224 because that
# is what torchvision models expect
transform_train = T.Compose(
    [T.RandomCrop(32, padding=4),
     T.RandomHorizontalFlip(),
     T.Resize(INPUT_SIZE),
     T.ToTensor(),
     T.Normalize((MEAN,MEAN,MEAN), (VAR,VAR,VAR))])
transform_test = T.Compose(
    [T.Resize(INPUT_SIZE),
     T.ToTensor(),
     T.Normalize((MEAN,MEAN,MEAN), (VAR,VAR,VAR))])

trainset = datasets.CIFAR10(root=DATA_DIRECTORY+'data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=16, shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root=DATA_DIRECTORY+'data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=16, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


## Load Saved Models

### Pretrained Models

In [0]:
# Change pretrained model download directory, so it doesn't
# download every time the runtime restarts
os.environ['TORCH_HOME'] = DATA_DIRECTORY + 'models/pretrained'

# We have four types of models stored here:
# - pretrained: Default models from pytorch, trained on ImageNet
# - relu: Fine-tuned models for CIFAR10
# - kwta_0_1: Models using kwta activation with sparsity=0.1 for CIFAR10
# - kwta_0_2: Models using kwta activation with sparsity=0.2 for CIFAR10
models = {'pretrained': {}, 'relu': {}, 'kwta_0_1': {}, 'kwta_0_2': {}}

# Download and load pretrained models (trained for ImageNet dataset)
models['pretrained']['resnet'] = torchvision.models.resnet18(pretrained=True)
models['pretrained']['densenet'] = torchvision.models.densenet121(pretrained=True)
models['pretrained']['wide_resnet'] = torchvision.models.wide_resnet50_2(pretrained=True)
models['pretrained']['vgg'] = torchvision.models.vgg11(pretrained=True)
models['pretrained']['alexnet'] = torchvision.models.alexnet(pretrained=True)
models['pretrained']['squeezenet'] = torchvision.models.squeezenet1_1(pretrained=True)

### Our Models

In [0]:
############################################################
# These are the models that we have trained and saved.     #
# Keep this list updated, along with EXPERIMENTAL RESULTS: #
############################################################

# Reference on finetuning models: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html

# TODO: Models could probably use at least another epoch of training.

# TODO: Vanilla attacks should be run again, was using trainloader instead of testloader.

## AlexNet ReLU
# Based off pretrained AlexNet.
# Trained for 2 epochs, test accuracy: 84.5%
# Vanilla attack on 10 minibatches:
#   Attacks Succeeded: 1426 / 1440 = 99.02777777777777 %
#   Robustness Accuracy:  4 / 1440 = 0.2777777777777778 %
# Foolbox attacks on 20 minibatches:
#   PGD, Robustness Accuracy: 0 %
#   Deepfool, Robustness Accuracy: 0 %
models['relu']['alexnet'] = copy.deepcopy(models['pretrained']['alexnet'])
models['relu']['alexnet'].classifier[-1].out_features = 10
models['relu']['alexnet'].load_state_dict(torch.load(
        DATA_DIRECTORY + 'models/relu/AlexNet.pth'))

## AlexNet kWTA 0.2
# Based off ReLU AlexNet trained for 2 epochs.
# Trained for 1 epoch, test accuracy: 85.1%
# Vanilla attack on 10 minibatches:
#   Attacks Succeeded: 1417 / 1440 = 98.40277777777777 %
#   Robustness Accuracy:  13 / 1440 = 0.9027777777777778 %
# Foolbox attacks on 20 minibatches:
#   PGD, Robustness Accuracy: 0 %
#   Deepfool, Robustness Accuracy: 0 %
models['kwta_0_2']['alexnet'] = copy.deepcopy(models['relu']['alexnet'])
activation_to_kwta(models['kwta_0_2']['alexnet'], nn.ReLU, sr=0.2)
models['kwta_0_2']['alexnet'].load_state_dict(torch.load(
        DATA_DIRECTORY + 'models/kwta_0_2/AlexNet.pth'))

## AlexNet kWTA 0.1
# Based off kWTA 0.2 AlexNet trained for 1 epoch.
# Trained for 1 epoch, test accuracy: 87.0%
# Vanilla attack on 10 minibatches:
#   Attacks Succeeded: 696 / 1440 = 48.333333333333336 %
#   Robustness Accuracy:  231 / 1440 = 16.041666666666668 %
# Foolbox attacks on 20 minibatches:
#   PGD, Robustness Accuracy: 0 %
#   Deepfool, Robustness Accuracy: 2.8125 %
models['kwta_0_1']['alexnet'] = copy.deepcopy(models['kwta_0_2']['alexnet'])
activation_to_kwta(models['kwta_0_1']['alexnet'], kWTA, sr=0.1)
models['kwta_0_1']['alexnet'].load_state_dict(torch.load(
        DATA_DIRECTORY + 'models/kwta_0_1/AlexNet.pth'))

## ResNet ReLU
# Based off pretrained ResNet.
# Trained for 1 epoch, test accuracy: 90.8%
# Vanilla attack on 10 minibatches:
#   Attacks Succeeded: 1440 / 1440 = 100.0 %
#   Robustness Accuracy:  0 / 1440 = 0.0 %
# Foolbox attacks on 20 minibatches:
#   PGD, Robustness Accuracy: 0 %
#   Deepfool, Robustness Accuracy: 0 %
models['relu']['resnet'] = copy.deepcopy(models['pretrained']['resnet'])
models['relu']['resnet'].fc.out_features = 10
models['relu']['resnet'].load_state_dict(torch.load(
        DATA_DIRECTORY + 'models/relu/ResNet18.pth'))

## ResNet kWTA 0.2
# Based off ReLU ResNet trained for 1 epoch.
# Trained for 1 epoch, test accuracy: 86.3%
# Vanilla attack on 10 minibatches:
#   Attacks Succeeded: 721 / 1440 = 50.06944444444444 %
#   Robustness Accuracy:  65 / 1440 = 4.513888888888889 %
# Foolbox attacks on 20 minibatches:
#   PGD, Robustness Accuracy: 0 %
#   Deepfool, Robustness Accuracy: 43.75 %
models['kwta_0_2']['resnet'] = copy.deepcopy(models['relu']['resnet'])
activation_to_kwta(models['kwta_0_2']['resnet'], nn.ReLU, sr=0.2)
models['kwta_0_2']['resnet'].load_state_dict(torch.load(
        DATA_DIRECTORY + 'models/kwta_0_2/ResNet18.pth'))

## ResNet kWTA 0.1
# Based off kWTA 0.2 ResNet trained for 1 epoch.
# Trained for 1 epoch, test accuracy: 76.4%
# Vanilla attack on 10 minibatches:
#   Attacks Succeeded: 380 / 1440 = 26.38888888888889 %
#   Robustness Accuracy:  65 / 1440 = 4.513888888888889 %
# Foolbox attacks on 20 minibatches:
#   PGD, Robustness Accuracy: 0.3125 %
#   Deepfool, Robustness Accuracy: 43.75 %
models['kwta_0_1']['resnet'] = copy.deepcopy(models['kwta_0_2']['resnet'])
activation_to_kwta(models['kwta_0_1']['resnet'], kWTA, sr=0.1)
models['kwta_0_1']['resnet'].load_state_dict(torch.load(
        DATA_DIRECTORY + 'models/kwta_0_1/ResNet18.pth'))


# TODO: kWTA seems to be slower than ReLU, training taking longer. Benchmark this.

## Training Models

In [0]:
# Trains the model in-place, and saves after every epoch to save_path.
# Only trains 1 epoch by default
def train(model, save_path, epochs=1):
    model = model.to(device) # use CUDA
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    for epoch in range(epochs):
        running_loss = 0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            outputs = model(inputs)
            # backward
            loss = criterion(outputs, labels)
            loss.backward()
            # had to add clipping to fix exploding gradients:
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

            # print statistics
            running_loss += loss.item() / 200
            if i % 200 == 199:    # print every 200 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss))
                running_loss = 0.0

        # save checkpoint after every epoch
        torch.save(model.state_dict(), save_path)

    print('Finished Training')
    

In [0]:
#################################################
# WARNING:                                      #
# This will overwrite existing saved model!     #
#################################################
#train(models['kwta_0_1']['resnet'], save_path=DATA_DIRECTORY+'models/kwta_0_1/ResNet18.pth')

## Test Model Accuracy

In [0]:
# Test the model
def test(net):
    net = net.to(device)
    net.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %.1f %%' % (
        100 * correct / total))

In [15]:
test(models['kwta_0_1']['resnet'])

Accuracy of the network on the 10000 test images: 76.4 %


## Attacks

### Vanilla Gradient Ascent
Code taken from Fooling Images problem in PS2.

In [0]:
def make_fooling_image(X, target_y, model, max_iter=100, debug=True):
    """
    Generate a fooling image that is close to X, but that the model classifies
    as target_y.
    Inputs:
    - X: Input image; Tensor of shape (1, 3, 224, 224)
    - target_y: An integer in the range [0, 1000)
    - model: A pretrained CNN
    Returns:
    - X_fooling: An image that is close to X, but that is classifed as target_y
    by the model.
    """
    model.eval()

    # Initialize our fooling image to the input image, and wrap it in a Variable.
    X_fooling = X.clone()
    X_fooling_var = Variable(X_fooling, requires_grad=True)

    learning_rate = 10 # fixed learning rate
    
    for it in range(max_iter):
    ##############################################################################
    # Generate a fooling image X_fooling that the model will classify as #
    # the class target_y. You should perform gradient ascent on the score of the #
    # target class, stopping when the model is fooled. #
    # When computing an update step, first normalize the gradient: #
    # dX = learning_rate * g / ||g||_2 #
    ##############################################################################
        scores = model(X_fooling_var) # only one image
        target_score = scores[:, target_y]
        if debug:
            print("Iteration: %d, Target Score: %d" % (it, target_score.data))
        if scores.argmax() == target_y:
            break
        target_score.backward()
        image_grad = X_fooling_var.grad.data
        dX = learning_rate * image_grad / image_grad.norm()
        X_fooling_var.data += dX # gradient *ascent*, so adding not subtracting dX

    X_fooling = X_fooling_var.data

    return X_fooling, it

In [0]:
def vis_fooling_img(X_orig, y_orig, X_fooling, target_y, class_names):
    print(class_names[y_orig])
    imshow(X_orig)

    print(class_names[target_y])
    imshow(X_fooling)
    
    print('Difference')
    imshow(X_fooling - X_orig)
    
    print('Magnified difference (10x)')
    imshow(10*(X_fooling - X_orig))

### Run attacks

In [0]:
# Run adversarial attack on the model.
# Data is loaded in batches of 16 images, argument `num_minibatches`
# specifies how many batches to iterate over.
# Argument `type` takes on values: `vanilla`, `pgd`, `deepfool`
def attack(model_in, attack_type='vanilla', debug=False, num_minibatches=10):
    # copying model so we can modify it
    import copy
    model = copy.deepcopy(model_in)
    # transfer to GPU for CUDA
    model = model.to(device)
    # put into evaluation mode
    model.eval()
    # Not going to train the model, so don't compute gradients w.r.t. parameters.
    # Using this instead of `with torch.no_grad()` because we still want gradients w.r.t. inputs.
    for param in model.parameters():
        param.requires_grad = False
    
    if debug:
        print(model)

    # variables to keep track of vanilla attack's stats
    num_robust = 0 # count how many times model's prediction was still correct after attack
    num_fooled = 0 # count how many time the attack succeeded with the target class
    total = 0

    # variables to keep track of foolbox attack's stats
    robust_acc_sum = 0

    # Adversarial attack loop:
    for j, data in enumerate(testloader, 0):
        if j >= num_minibatches:
            break
        images, labels = data[0].to(device), data[1].to(device)

        if attack_type == 'vanilla':
            # TODO: Try to parallelize batch processing instead of one at a time
            for i in range(len(images)): # for each input image to attack
                for target_idx in range(10): # for each target label
                    # skip if attack target is already correct label
                    if target_idx == labels[i]:
                        continue
                    # attempt to make a fooling image with vanilla gradient ascent attack
                    X_fooling, num_iter = make_fooling_image(images[i].unsqueeze(0),
                                                             target_idx, model, max_iter=20, debug=debug)
                    # evaluate fooling image with the model
                    scores = model(X_fooling)
                    is_fooled = scores.data.max(1)[1][0] == target_idx
                    is_robust = scores.data.max(1)[1][0] == labels[i]
                    if debug:
                        if is_fooled:
                            print('Fooled model, iterations =', num_iter)
                        else:
                            print('Failed to fool model!')
                        X_fooling = X_fooling.cpu()
                        # Visualize fooling image and original image differences
                        #vis_fooling_img(images[i].cpu(), labels[i], X_fooling.cpu().squeeze(), target_idx, classes)
                    num_fooled += is_fooled
                    num_robust += is_robust
                    total += 1

            # Print stats after every minibatch
            print('[Minibatch: %d] Fooled: %d, Robust: %d, Total: %d' % (j+1, num_fooled, num_robust, total))

        else:
            fmodel = foolbox.PyTorchModel(model, bounds=(0,1))
            # Paper uses l_inf metric for all attacks.
            # Using parameters specified in paper's appendix D.1
            if attack_type == 'pgd':
                attack_fn = foolbox.attacks.LinfPGD(steps=40, random_start=True, abs_stepsize=0.003)
            elif attack_type == 'deepfool':
                attack_fn = foolbox.attacks.LinfDeepFoolAttack(steps=20, candidates=10)
            #epsilons = [0.0, 0.001, 0.01, 0.03, 0.1, 0.3, 0.5, 1.0]
            epsilons = [0.031] # value used in the paper
            _, _, success = attack_fn(fmodel, images, labels, epsilons=epsilons)

            robust_accuracy = 1 - success.double().mean(axis=-1)
            robust_acc_sum += robust_accuracy
            # Print stats after every minibatch
            print('[Minibatch: %d] Accuracy: %f %%' % (j+1, 100*robust_accuracy.item()))
    
    if attack_type == 'vanilla':
        print("Attacks Succeeded:", num_fooled.item(), "/", total, "=", 100 * num_fooled.item() / total, "%")
        print("Robustness Accuracy: ", num_robust.item(), "/", total, "=", 100 * num_robust.item() / total, "%")
    else:
        print("Robustness Accuracy: ", 100 * robust_acc_sum.item() / num_minibatches, "%")

In [92]:
attack(models['relu']['alexnet'], attack_type='deepfool', num_minibatches=20, debug=False)

[Minibatch: 1] Accuracy: 0.000000 %
[Minibatch: 2] Accuracy: 0.000000 %
[Minibatch: 3] Accuracy: 0.000000 %
[Minibatch: 4] Accuracy: 0.000000 %
[Minibatch: 5] Accuracy: 0.000000 %
[Minibatch: 6] Accuracy: 0.000000 %
[Minibatch: 7] Accuracy: 0.000000 %
[Minibatch: 8] Accuracy: 0.000000 %
[Minibatch: 9] Accuracy: 0.000000 %
[Minibatch: 10] Accuracy: 0.000000 %
[Minibatch: 11] Accuracy: 0.000000 %
[Minibatch: 12] Accuracy: 0.000000 %
[Minibatch: 13] Accuracy: 0.000000 %
[Minibatch: 14] Accuracy: 0.000000 %
[Minibatch: 15] Accuracy: 0.000000 %
[Minibatch: 16] Accuracy: 0.000000 %
[Minibatch: 17] Accuracy: 0.000000 %
[Minibatch: 18] Accuracy: 0.000000 %
[Minibatch: 19] Accuracy: 0.000000 %
[Minibatch: 20] Accuracy: 0.000000 %
Robustness Accuracy:  0.0 %
