In [None]:
import os
import sys

ROOT_DIR = os.getcwd()
while 'DeepSparseCoding' in ROOT_DIR:
    ROOT_DIR = os.path.dirname(ROOT_DIR)
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
    
import pickle
import numpy as np
import proplot as plot
import torch
import torch.nn as nn
import torchvision.models as models

from DeepSparseCoding.utils.file_utils import Logger
import DeepSparseCoding.utils.loaders as loaders
import DeepSparseCoding.utils.run_utils as run_utils
import DeepSparseCoding.utils.dataset_utils as dataset_utils
import DeepSparseCoding.utils.run_utils as ru
import DeepSparseCoding.utils.plot_functions as pf
import DeepSparseCoding.utils.data_processing as dp

import eagerpy as ep
from foolbox import PyTorchModel, accuracy, samples
import foolbox.attacks as fa

In [None]:
def create_mnist_fb() -> PyTorchModel:
    model = nn.Sequential(
        nn.Conv2d(1, 32, 3),
        nn.ReLU(),
        nn.Conv2d(32, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Dropout2d(0.25),
        nn.Flatten(),  # type: ignore
        nn.Linear(9216, 128),
        nn.ReLU(),
        nn.Dropout2d(0.5),
        nn.Linear(128, 10),
    )
    #path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mnist_cnn.pth")
    path = os.path.join(*[ROOT_DIR, 'DeepSparseCoding', 'mnist_cnn.pth'])
    model.load_state_dict(torch.load(path))  # type: ignore
    model.eval()
    preprocessing = dict(mean=0.1307, std=0.3081)
    fmodel = PyTorchModel(model, bounds=(0, 1), preprocessing=preprocessing)
    return fmodel

In [None]:
def create_mnist_dsc(log_file, cp_file):
    logger = Logger(log_file, overwrite=False)
    log_text = logger.load_file()
    params = logger.read_params(log_text)[-1]
    params.cp_latest_filename = cp_file
    params.standardize_data = False
    params.rescale_data_to_one = True
    train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)
    for key, value in data_params.items():
        setattr(params, key, value)
    model = loaders.load_model(params.model_type)
    model.setup(params, logger)
    model.params.analysis_out_dir = os.path.join(
        *[model.params.model_out_dir, 'analysis', model.params.version])
    model.params.analysis_save_dir = os.path.join(model.params.analysis_out_dir, 'savefiles')
    if not os.path.exists(model.params.analysis_save_dir):
        os.makedirs(model.params.analysis_save_dir)
    model.to(params.device)
    model.load_checkpoint()
    fmodel = PyTorchModel(model.eval(), bounds=(0, 1))
    return fmodel, test_loader, model.params.batch_size, model.params.device

In [None]:
log_files = [
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'mlp_768_mnist', 'logfiles', 'mlp_768_mnist_v0.log']),
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'logfiles', 'lca_768_mlp_mnist_v0.log'])
    ]

cp_latest_filenames = [
    os.path.join(*[ROOT_DIR,'Torch_projects', 'mlp_768_mnist', 'checkpoints', 'mlp_768_mnist_latest_checkpoint_v0.pt']),
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'checkpoints', 'lca_768_mlp_mnist_latest_checkpoint_v0.pt'])
    ]

In [None]:
save_results = True
load_results = True


attack_params = {
    'LinfPGD': {
        #'rel_stepsize':1e-3,
        #'steps':int(1/(2*1e-3)) # max perturbation it can reach is 0.5
        #'rel_stepsize':1/50,
        'abs_stepsize':1/500,
        'steps':400,
    },
    'L2PGD': {
        #'rel_stepsize':0.01,
        'abs_stepsize':0.01,
        'steps':600,
    }
    
}

linf_epsilons = [ # allowed perturbation size
    0.0,
    0.03,
    0.06,
    0.09,
    0.1,
    0.13,
    0.16,
    0.19,
    0.2,
    0.23,
    0.26,
    0.3
]

l2_epsilons = [10 * eps for eps in linf_epsilons]

attacks = [
    (fa.LinfPGD(**attack_params['LinfPGD']), linf_epsilons),
    (fa.L2PGD(**attack_params['L2PGD']), l2_epsilons),
    #fa.FGSM(),
    #fa.LinfBasicIterativeAttack(),
    #fa.LinfAdditiveUniformNoiseAttack(),
    #fa.LinfDeepFoolAttack(),
]

if load_results:
    attack_results = np.loadz('adv_attack_results.npz')['data']
else:
    fmodel_cnn = create_mnist_fb()
    fmodel_mlp, test_loader, batch_size, device = create_mnist_dsc(log_files[0], cp_latest_filenames[0])
    fmodel_lca = create_mnist_dsc(log_files[1], cp_latest_filenames[1])[0]

    model_types = ['CNN', 'MLP', 'LCA+SLP']
    fmodel_list = [fmodel_cnn, fmodel_mlp, fmodel_lca]

    num_batches =  5#len(test_loader.dataset) // batch_size

    attack_results = []
    for model_index, (fmodel, model_type) in enumerate(zip(fmodel_list, model_types)):
        attack_success = np.zeros(
                (len(attacks), len(linf_epsilons), num_batches, batch_size), dtype=np.bool)
        for batch_index, (data, target) in enumerate(test_loader):
            if batch_index < num_batches:
                data = data.to(device)
                target = target.to(device)
                images, labels = ep.astensors(*(data, target))
                if model_type == 'CNN':
                    images = images.squeeze().expand_dims(axis=1)
                else:
                    images = images.reshape((batch_size, 784))
                print('\n', '~' * 79)
                print(f'Model type: {model_type} [{model_index+1} out of {len(model_types)}]')
                print(f'Batch {batch_index+1} out of {num_batches}')
                print(f'accuracy {accuracy(fmodel, images, labels)}')
                for attack_index, (attack, epsilons) in enumerate(attacks):
                    advs, inputs, success = attack(fmodel, images, labels, epsilons=epsilons)
                    assert success.shape == (len(epsilons), len(images))
                    success_ = success.numpy()
                    assert success_.dtype == np.bool
                    attack_success[attack_index, :, batch_index, :] = success_
                    print('\n', attack)
                    print('  ', 1.0 - success_.mean(axis=-1).round(2))
                robust_accuracy = 1.0 - attack_success[:, :, batch_index, :].max(axis=0).mean(axis=-1)
                #print('\n', '-' * 79, '\n')
                #print('worst case (best attack per-sample)')
                #print('  ', robust_accuracy.round(2))
                #print('-' * 79)
        attack_success = attack_success.reshape(
            (len(attacks), len(epsilons), num_batches * batch_size))
        attack_types = [str(type(attack)).split('.')[-1][:-2] for attack, _ in attacks]
        epsilons = [epsilons for attack, epsilons in attacks]
        out_dict = {
            'adversarial_analysis':attack_success,
            'attack_types':attack_types,
            'epsilons':epsilons,
            'attack_params':attack_params}
        attack_results.append(out_dict)
        if save_results:
            np.savez('adv_attack_results.npz', data=attack_results)

In [None]:
plot_abs = False

if(plot_abs):
    abs_filename = os.path.join(
        *[ROOT_DIR, 'analysis-by-synthesis', 'figures', 'Linf_accuracy_distortion_curves.pickle'])
    with open(abs_filename, 'rb') as file:
        abs_linf_pgd_accuracies = pickle.load(file)

fig, axes = plot.subplots(ncols=len(attacks), nrows=1)#, share=0)
handles = []
for model_idx, (results_dict, model_type) in enumerate(zip(attack_results, model_types)):
    for attack_idx in range(len(attacks)):
        score = results_dict['adversarial_analysis'][attack_idx, ...]
        attack_accuracy = 1.0 - score.mean(axis=-1)
        y_vals = 100*attack_accuracy
        x_vals = results_dict['epsilons'][attack_idx]
        handle = axes[attack_idx].plot(x_vals, y_vals, label=model_type)
        if(attack_idx == 0):
            handles.extend(handle)
        if(plot_abs):
            if(model_idx == 0 and attack_idx == 0):
                for abs_model_type, abs_model_accuracy in abs_linf_pgd_accuracies.items():
                    if(abs_model_type not in ['Binary CNN', 'Nearest Neighbor', 'Binary ABS', 'CNN']):
                        handle = axes[attack_idx].plot(
                            abs_model_accuracy['x'], abs_model_accuracy['y'], label=abs_model_type)
                        handles.extend(handle)
        axes[attack_idx].format(title=results_dict['attack_types'][attack_idx])
        axes[attack_idx].format(
            xlabel='Maximum perturbation size',
            xlim=[0.0, np.max(x_vals)],
            ylim=[0, 100])
axes.format(ylabel='Model accuracy')
fig.legend(handles, ncols=1, frame=False, label='Model type', loc='right')
plot.show()