In [None]:
import matplotlib.pylab as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import glob
import numpy as np
import imageio
import itertools
import foolbox as fb
import foolbox.ext.native as fbn
import time

In [None]:
from mnist_models import return_model

In [None]:
# initialisation attack from existing validation images
import eagerpy as ep
import numpy as np

class PrecomputedSamplesAttack:
    """This is a helper attack that makes it straight-forward to choose initialisation points
       for boundary-type attacks from a given data set. All it does is to store samples and
       the predicted responses of a given model in order to select suitable adversarial images
       from the given data set.
    """

    def __init__(self, model):
        self.model = model
        self.samples = []
        self.labels = []
        
    def feed(self, inputs):
        response = self.model.forward(inputs).argmax(1)
        
        for k in range(len(inputs)):
            self.labels.append(int(response[k]))
            self.samples.append(inputs[k])

    def __call__(self, inputs, labels):
        inputs = ep.astensor(inputs)
        labels = ep.astensor(labels)
        x = ep.zeros_like(inputs)
        
        for k in range(len(labels)):
            while True:
                idx = np.random.randint(len(self.labels))
                if int(labels[k].numpy()) != self.labels[idx]:
                    x.tensor[k] = self.samples[idx]
                    break
        
        return x.tensor


In [None]:
data_dir = '../data'
device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

In [None]:
data = {}

metrics = ['Linf']
repetitions = 1
step_scale = 1
num_batches = 16

BB_static = {'init_attack' : 'init_attack', 'steps' : int(step_scale * 10000)}
BB_lr = [{'lr' : 1e-3}, {'lr' : 1e-2}]

L2_attacks = [(fbn.attacks.PGD,{'num_steps' : int(1000 * step_scale)}, [{'epsilon' : (72/255) * 2.82148653035}])]
attacks = {'Linf' : L2_attacks}


# load model

test_batch_size = 100
# init data loader
data_testing = datasets.MNIST('../data',train=True, transform=transforms.Compose([
               transforms.ToTensor(),
               transforms.Normalize((0.1307,), (0.3081,))]), download=True)



data_loader = torch.utils.data.DataLoader(data_testing,
    batch_size=test_batch_size, shuffle=False)


## Generate adversarial attacks for 100 examples in the training set. Through the traning up to epoch 1 because there isn't a big difference, but we can do more epochs. If an examples does not have an adversarial example it saves a black image and adversarial category is: None

In [None]:
model_val = 'full'
resnet = return_model(model_val)
resnet = resnet.to(device)
resnet.eval();

for i in range(1):
    for j in range(1,938):
        names = './mnist_cnn/mnist_{}_10_{}_{}.pt'.format(model_val,j, i+1)
        saved = torch.load(names)
        resnet.load_state_dict(saved)
        #resnet.to(device)
        
        # init foolbox models
        fbn_model = fbn.models.PyTorchModel(resnet, bounds=[-0.42421291788, 2.82148653035], device=device)
        fb_model = fb.models.PyTorchModel(resnet, bounds=[-0.42421291788, 2.82148653035], device=device, num_classes=10)   

        for metric in metrics:
            for attack in attacks[metric]:
                iteration = 0
                # initiate attack
                if len(attack) == 3:
                    Attack, static_kwargs, dynamic_kwargs = attack
                    native = 'foolbox.ext.native' in Attack.__module__
                    attack = Attack(fbn_model) if native else Attack(fb_model)
                else:
                    Attack, static_kwargs, dynamic_kwargs, init_kwargs = attack
                    native = 'foolbox.ext.native' in Attack.__module__
                    attack = Attack(fbn_model) if native else Attack(fb_model, **init_kwargs)

                name = str(attack.__class__).split('.')[-1].split("'")[0]
                if metric == 'L2':
                    bbattack = fbn.attacks.L2BrendelBethgeAttack(fbn_model)

                if native:
                    model = fbn_model
                    print("Native")
                else:
                    model = fb_model
                # create init attack if necessary
                if 'init_attack' in static_kwargs.keys():
                    init_attack = PrecomputedSamplesAttack(fbn_model)

                    for batch in data_loader:
                        inputs, labels = batch
                        inputs = inputs.to(device)
                        labels = labels.to(device)

                        out = init_attack.feed(inputs)

                    static_kwargs['init_attack'] = init_attack         


            # perform attack with different arguments
                img = []
                for kwarg in dynamic_kwargs:
                    kwargs = {**kwarg, **static_kwargs}
                    print(kwarg)
                    images = []
                    for b, batch in enumerate(data_loader):
                        if b == 1:
                            break
                        else:
                            check = time.time()
                            inputs, labels = batch
                            inputs = inputs.to(device)
                            labels = labels.to(device)
                            if not native:
                                inputs = inputs.data.cpu().numpy()
                                labels = labels.data.cpu().numpy()
                            adversarials = attack(
                                inputs,
                                labels,
                                **kwargs
                            )

                            out = model.forward(adversarials)
                            is_adv = out.argmax(1) != labels
                            out_x = model.forward(inputs)
                            is_cor = out_x.argmax(1) == labels

                            # check if adversarial
                            if native:
                                out = out.data.cpu().numpy()
                                adversarials = adversarials.data.cpu().numpy()
                                inputs = inputs.data.cpu().numpy()
                                labels = labels.data.cpu().numpy()
                                output = out.argmax(1)
                            for k in range(len(inputs)):
                                if is_adv[k] and is_cor[k]:
                                    x0 = inputs[k]
                                    x = adversarials[k]

                                    images.append([x, x0,labels[k],output[k]])
                                else:
                                    x0 = np.zeros((1,28,28))
                                    x = np.zeros((1,28,28))
                                    images.append([x,x0,labels[k],None])
                            if b % 10 == 0:
                                print(i, j, b, 'Time: ', time.time() - check)
                        img.append(images)


                np.save("./mnist_cnn/{}_linf_adversarial.npy".format(names.split('/')[-1].split('.')[0]),img)

                # delete model to free memory
                del fb_model
                del fbn_model
                #del resnet