In [None]:
import matplotlib.pylab as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import glob
import numpy as np
import imageio
import itertools
import foolbox as fb
import foolbox.ext.native as fbn
import time
import torchvision
from cifar10_models import return_model

## Model Selection

In [None]:
model = 'resnet'

## Setup Adversarial Attack

In [None]:
# initialisation attack from existing validation images
import eagerpy as ep
import numpy as np

class PrecomputedSamplesAttack:
    """This is a helper attack that makes it straight-forward to choose initialisation points
       for boundary-type attacks from a given data set. All it does is to store samples and
       the predicted responses of a given model in order to select suitable adversarial images
       from the given data set.
    """

    def __init__(self, model):
        self.model = model
        self.samples = []
        self.labels = []
        
    def feed(self, inputs):
        response = self.model.forward(inputs).argmax(1)
        
        for k in range(len(inputs)):
            self.labels.append(int(response[k]))
            self.samples.append(inputs[k])

    def __call__(self, inputs, labels):
        inputs = ep.astensor(inputs)
        labels = ep.astensor(labels)
        x = ep.zeros_like(inputs)
        
        for k in range(len(labels)):
            while True:
                idx = np.random.randint(len(self.labels))
                if int(labels[k].numpy()) != self.labels[idx]:
                    x.tensor[k] = self.samples[idx]
                    break
        
        return x.tensor


In [None]:
data_dir = '../data'
device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

In [None]:
data = {}

metrics = ['Linf']
repetitions = 1
step_scale = 1
num_batches = 16

BB_static = {'init_attack' : 'init_attack', 'steps' : int(step_scale * 10000)}
BB_lr = [{'lr' : 1e-2}, {'lr' : 1e-1}]

L2_attacks = [(fbn.attacks.PGD,{'num_steps' : int(1000 * step_scale)}, [{'epsilon' : 8.0/255.0}])]

attacks = {'Linf' : L2_attacks}


# load model
test_batch_size = 100

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

testset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=test_batch_size, shuffle=False, num_workers=2)

## Generate adversarial attacks for 100 examples in the training set. Through the traning up to the highest performing epoch. If an examples does not have an adversarial example it saves a black image and adversarial category is: None

In [None]:
resnet = return_model(model)
resnet = torch.nn.DataParallel(resnet)
resnet = resnet.to(device)
resnet.eval();
check = torch.load("./cifar_resnet/ckpt_{}.pth".format(model))
epoch_max = check['epoch']
for i in range(0,epoch_max+1):
    for j in range(1,392,30): 
        names = './cifar_resnet/ckpt_{}_{}_{}.pth'.format(model, i, j)
        saved = torch.load(names)
        resnet.load_state_dict(saved)


        # init foolbox models
        fbn_model = fbn.models.PyTorchModel(resnet.module, bounds=[0.0, 1.0], device=device)
        fb_model = fb.models.PyTorchModel(resnet.module, bounds=[0.0, 1.0], device=device, num_classes=10)   

        for metric in metrics:
            for attack in attacks[metric]:
                iteration = 0
                if len(attack) == 3:
                    Attack, static_kwargs, dynamic_kwargs = attack
                    native = 'foolbox.ext.native' in Attack.__module__
                    attack = Attack(fbn_model) if native else Attack(fb_model)
                else:
                    Attack, static_kwargs, dynamic_kwargs, init_kwargs = attack
                    native = 'foolbox.ext.native' in Attack.__module__
                    attack = Attack(fbn_model) if native else Attack(fb_model, **init_kwargs)

                name = str(attack.__class__).split('.')[-1].split("'")[0]
                if metric == 'L2':
                    bbattack = fbn.attacks.L2BrendelBethgeAttack(fbn_model)

                if native:
                    model = fbn_model
                    print("Native")
                else:
                    model = fb_model
                # create init attack if necessary
                if 'init_attack' in static_kwargs.keys():
                    init_attack = PrecomputedSamplesAttack(fbn_model)

                    for batch in testloader:
                        inputs, labels = batch
                        inputs = inputs.to(device)
                        labels = labels.to(device)

                        out = init_attack.feed(inputs)

                    static_kwargs['init_attack'] = init_attack         


            # perform attack with different arguments
                img = []
                for kwarg in dynamic_kwargs:
                    kwargs = {**kwarg, **static_kwargs}
                    print(kwarg)
                    images = []
                    for b, batch in enumerate(testloader):
                        if b == 1:
                            break
                        check = time.time()
                        inputs, labels = batch
                        inputs = inputs.to(device)
                        labels = labels.to(device)
                        if not native:
                            inputs = inputs.data.cpu().numpy()
                            labels = labels.data.cpu().numpy()
                        adversarials = attack(
                            inputs,
                            labels,
                            **kwargs
                        )

                        out = model.forward(adversarials)
                        is_adv = out.argmax(1) != labels
                        out_x = model.forward(inputs)
                        is_cor = out_x.argmax(1) == labels

                        # check if adversarial
                        if native:
                            out = out.data.cpu().numpy()
                            adversarials = adversarials.data.cpu().numpy()
                            inputs = inputs.data.cpu().numpy()
                            labels = labels.data.cpu().numpy()
                            output = out.argmax(1)
                        for k in range(len(inputs)):
                            if is_adv[k] and is_cor[k]:
                                x0 = inputs[k]
                                x = adversarials[k]

                                images.append([x, x0,labels[k],output[k]])
                            else:
                                x0 = inputs[k]
                                x = np.zeros((3,32,32))
                                images.append([x,x0,labels[k],None])
                        if b % 10 == 0:
                            print(i, j, b, 'Time: ', time.time() - check)
                    img.append(images)


                np.save("./cifar_resnet/{}_linf_adversarial_100.npy".format(names.split('/')[-1].split('.')[0]),img)

                # delete model to free memory
                del fb_model
                del fbn_model