In [None]:
import os
import sys

parent_path = os.path.dirname(os.path.dirname(os.getcwd()))
if parent_path not in sys.path: sys.path.append(parent_path) 
    
import numpy as np
import proplot as plot
import torch

from DeepSparseCoding.utils.file_utils import Logger
import DeepSparseCoding.utils.loaders as loaders
import DeepSparseCoding.utils.run_utils as run_utils
import DeepSparseCoding.utils.dataset_utils as dataset_utils
import DeepSparseCoding.utils.run_utils as ru
import DeepSparseCoding.utils.plot_functions as pf

import eagerpy as ep
from foolbox import PyTorchModel, accuracy, samples
import foolbox.attacks as fa

In [None]:
workspace_dir = os.path.expanduser("~")+"/Work/"

linfPGD_params = {"rel_stepsize":1e-3, "steps":int(1/1e-3)}

attacks = [
    fa.FGSM(),
    fa.LinfPGD(**linfPGD_params),
    fa.LinfBasicIterativeAttack(),
    fa.LinfAdditiveUniformNoiseAttack(),
    fa.LinfDeepFoolAttack(),
]

epsilons = [ # allowed perturbation size
    0.0,
    0.0005,
    0.001,
    0.0015,
    0.002,
    0.003,
    0.005,
    0.01,
    0.02,
    0.03,
    0.1,
    0.3,
    0.5,
    1.0,
]

In [None]:
log_files = [workspace_dir+"/Torch_projects/short/mlp_768_mnist/logfiles/mlp_768_mnist_v0.log",
    workspace_dir+"/Torch_projects/short/lca_768_mlp_mnist/logfiles/lca_768_mlp_mnist_v0.log"]
             
cp_latest_filenames = ['/gpfs01/bethge/home/dpaiton/Work/Torch_projects/short/mlp_768_mnist/checkpoints/mlp_768_mnist_latest_checkpoint_v0.pt',
    '/gpfs01/bethge/home/dpaiton/Work/Torch_projects/short/lca_768_mlp_mnist/checkpoints/lca_768_mlp_mnist_latest_checkpoint_v0.pt']

num_models = len(log_files)

train_loader, val_loader, test_loader, params = dataset_utils.load_dataset(params)
data, target = next(iter(train_loader))

attack_success_dict = {}
for model_index in range(num_models):
    logger = Logger(log_files[model_index], overwrite=False)
    log_text = logger.load_file()
    params = logger.read_params(log_text)[-1]
    params.cp_latest_filename = cp_latest_filenames[model_index]
    
    model = loaders.load_model(params.model_type, params.lib_root_dir)
    model.setup(params, logger)
    model.to(params.device)
    model.load_checkpoint()
    fmodel = PyTorchModel(model.eval(), bounds=(0, 1))
    
    data = model.preprocess_data(data.to(model.params.device))
    target = target.to(model.params.device)
    
    images, labels = ep.astensors(*(data, target))
    
    print("\n")
    print("~" * 79)
    print("Model type:", model.params.model_type)
    print("Model index: ", model_index)
    print("accuracy")
    print(accuracy(fmodel, images, labels))
    print("")
    
    attack_success = np.zeros((len(attacks), len(epsilons), len(images)), dtype=np.bool)
    for i, attack in enumerate(attacks):
        _, _, success = attack(fmodel, images, labels, epsilons=epsilons)
        assert success.shape == (len(epsilons), len(images))
        success_ = success.numpy()
        assert success_.dtype == np.bool
        attack_success[i] = success_
        print("\n", attack)
        print("  ", 1.0 - success_.mean(axis=-1).round(2))
    
    attack_success_dict[model.params.model_type] = attack_success
    robust_accuracy = 1.0 - attack_success.max(axis=0).mean(axis=-1)
    print("")
    print("-" * 79)
    print("")
    print("worst case (best attack per-sample)")
    print("  ", robust_accuracy.round(2))

In [None]:
num_attacks, num_epsilons, num_images = attack_success.shape
fig, axes = plot.subplots(ncols=num_attacks, nrows=1, share=0)
handles = []
for attack_index in range(num_attacks):
    for model_type, model_attack_success in attack_success_dict.items():
        attack_accuracy = 1.0 - model_attack_success[attack_index, ...].mean(axis=-1)
        handle = axes[attack_index].plot(epsilons, attack_accuracy, label=model_type)
        if attack_index == 0:
            handles.extend(handle)
    attack_type = str(type(attacks[attack_index])).split('.')[-1][:-2]
    axes[attack_index].format(title=attack_type)
axes[0].format(ylabel='Model accuracy')
axes.format(
    xlabel='Maximum perturbation size',
    ylim=[0, 1.0])
fig.legend(handles, ncols=1, frame=False, label='Model type', loc='right')
plot.show()