In [None]:
import os
import sys
from typing import Union, Any, Optional, Callable, Tuple

ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
    
import pickle
import numpy as np
import pandas as pd
import proplot as plot
import torch
import torch.nn as nn
import torchvision.models as models

from DeepSparseCoding.utils.file_utils import Logger
import DeepSparseCoding.utils.loaders as loaders
import DeepSparseCoding.utils.run_utils as run_utils
import DeepSparseCoding.utils.dataset_utils as dataset_utils
import DeepSparseCoding.utils.run_utils as ru
import DeepSparseCoding.utils.plot_functions as pf
import DeepSparseCoding.utils.data_processing as dp

import eagerpy as ep
from foolbox import PyTorchModel, accuracy, samples
import foolbox.attacks as fa
from foolbox.types import Bounds
from foolbox.attacks.base import T
from foolbox.models.base import Model
from foolbox.criteria import Misclassification
from foolbox.attacks.base import raise_if_kwargs
from foolbox.attacks.base import get_criterion
from foolbox.attacks.projected_gradient_descent import LinfProjectedGradientDescentAttack

In [None]:
def create_mnist_fb() -> PyTorchModel:
    model = nn.Sequential(
        nn.Conv2d(1, 32, 3),
        nn.ReLU(),
        nn.Conv2d(32, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Dropout2d(0.25),
        nn.Flatten(),  # type: ignore
        nn.Linear(9216, 128),
        nn.ReLU(),
        nn.Dropout2d(0.5),
        nn.Linear(128, 10),
        #nn.LogSoftmax(dim=1)
    )
    #path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mnist_cnn.pth")
    path = os.path.join(*[ROOT_DIR, 'DeepSparseCoding', 'mnist_cnn.pth'])
    model.load_state_dict(torch.load(path))  # type: ignore
    model.eval()
    #preprocessing = dict(mean=0.1307, std=0.3081)
    fmodel = PyTorchModel(model, bounds=(0, 1))#, preprocessing=preprocessing)
    return fmodel

In [None]:
def create_mnist_dsc(log_file, cp_file):
    logger = Logger(log_file, overwrite=False)
    log_text = logger.load_file()
    params = logger.read_params(log_text)[-1]
    params.cp_latest_filename = cp_file
    params.standardize_data = False
    params.rescale_data_to_one = True
    train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)
    for key, value in data_params.items():
        setattr(params, key, value)
    model = loaders.load_model(params.model_type)
    model.setup(params, logger)
    model.params.analysis_out_dir = os.path.join(
        *[model.params.model_out_dir, 'analysis', model.params.version])
    model.params.analysis_save_dir = os.path.join(model.params.analysis_out_dir, 'savefiles')
    if not os.path.exists(model.params.analysis_save_dir):
        os.makedirs(model.params.analysis_save_dir)
    model.to(params.device)
    model.load_checkpoint()
    fmodel = PyTorchModel(model.eval(), bounds=(0, 1))
    fmodel.orig_model = model
    return fmodel, model, test_loader, model.params.batch_size, model.params.device

In [None]:
log_files = [
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'mlp_768_mnist', 'logfiles', 'mlp_768_mnist_v0.log']),
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'logfiles', 'lca_768_mlp_mnist_v0.log'])
    ]

cp_latest_filenames = [
    os.path.join(*[ROOT_DIR,'Torch_projects', 'mlp_768_mnist', 'checkpoints', 'mlp_768_mnist_latest_checkpoint_v0.pt']),
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'checkpoints', 'lca_768_mlp_mnist_latest_checkpoint_v0.pt'])
    ]

In [None]:
fmodel_mlp, dmodel_mlp, test_loader, batch_size, device = create_mnist_dsc(log_files[0], cp_latest_filenames[0])
fmodel_mlp.type = 'MLP'
fmodel_cnn = create_mnist_fb()
fmodel_cnn.type = 'CNN'
fmodel_lca, dmodel_lca = create_mnist_dsc(log_files[1], cp_latest_filenames[1])[:2]
fmodel_lca.type = 'LCA'
model_list = {fmodel_mlp.type:fmodel_mlp, fmodel_cnn.type:fmodel_cnn, fmodel_lca.type:fmodel_lca}

In [None]:
model_list['MLP'](next(iter(test_loader))[0].reshape(batch_size, 784).to(device)[0,...][None,:]).shape

In [None]:
load_results = False#True

num_batches =  10#len(test_loader.dataset) // batch_size

attack_params = {
    'LinfPGD': {
        'random_start':False,
        'abs_stepsize':0.002,
        'steps':500
    },
    'L2PGD': {
        'random_start':False,
        'abs_stepsize':0.45,
        'steps':2000
    }
    #'DDN' - best attack in terms of simplicity & effectiveness - Lukas
}

linf_epsilons = [ # allowed perturbation size
    0.0,
    0.03,
    0.06,
    0.09,
    0.1,
    0.13,
    0.16,
    0.19,
    0.2,
    0.23,
    0.26,
    0.29,
    0.3,
    0.33,
    0.36,
    0.39,
    0.4
]

l2_epsilons = [10 * eps for eps in linf_epsilons]

attacks = [
    (fa.LinfPGD(**attack_params['LinfPGD']), linf_epsilons),
    (fa.L2PGD(**attack_params['L2PGD']), l2_epsilons),
    #fa.FGSM(),
    #fa.LinfBasicIterativeAttack(),
    #fa.LinfAdditiveUniformNoiseAttack(),
    #fa.LinfDeepFoolAttack(),
]

In [None]:
if load_results:
    attack_results = []
    for model_index, (model_type, fmodel) in enumerate(model_list.items()):
        attack_results.append(np.load(f'adv_attack_results_{model_type}.npz', allow_pickle=True)['data'].item())
else:
    attack_results = []
    for model_index, (model_type, fmodel) in enumerate(model_list.items()):
        attack_success = np.zeros(
                (len(attacks), len(linf_epsilons), num_batches, batch_size), dtype=np.bool)
        for batch_index, (data, target) in enumerate(test_loader):
            if batch_index < num_batches:
                data = data.to(device)
                target = target.to(device)
                images, labels = ep.astensors(*(data, target))
                if model_type == 'CNN':
                    images = images.squeeze().expand_dims(axis=1)
                else:
                    images = images.reshape((batch_size, 784))
                print('\n', '~' * 79)
                print(f'Model type: {model_type} [{model_index+1} out of {len(model_list)}]')
                print(f'Batch {batch_index+1} out of {num_batches}')
                print(f'accuracy {accuracy(fmodel, images, labels)}')
                for attack_index, (attack, epsilons) in enumerate(attacks):
                    advs, _, success = attack(fmodel, images, labels, epsilons=epsilons)
                    assert success.shape == (len(epsilons), len(images))
                    success_ = success.numpy()
                    assert success_.dtype == np.bool
                    attack_success[attack_index, :, batch_index, :] = success_
                    print('\n', attack)
                    print('  ', 1.0 - success_.mean(axis=-1).round(2))
                robust_accuracy = 1.0 - attack_success[:, :, batch_index, :].max(axis=0).mean(axis=-1)
                #print('\n', '-' * 79, '\n')
                #print('worst case (best attack per-sample)')
                #print('  ', robust_accuracy.round(2))
                #print('-' * 79)
        attack_success = attack_success.reshape(
            (len(attacks), len(epsilons), num_batches * batch_size))
        attack_types = []
        epsilon_list = []
        for attack, epsilons in attacks:
            attack_types.append(str(type(attack)).split('.')[-1][:-2])
            epsilon_list.append(epsilons)
        out_dict = {
            'num_batches':num_batches,
            'batch_size':batch_size,
            'adversarial_analysis':attack_success,
            'attack_types':attack_types,
            'epsilons':epsilon_list,
            'attack_params':attack_params}
        attack_results.append(out_dict)
        np.savez(f'adv_attack_results_{model_type}.npz', data=out_dict)

In [None]:
plot_abs = True

if(plot_abs):
    abs_filename = os.path.join(
        *[ROOT_DIR, 'analysis-by-synthesis', 'figures', 'Linf_accuracy_distortion_curves.pickle'])
    with open(abs_filename, 'rb') as file:
        abs_linf_pgd_accuracies = pickle.load(file)

fig, axes = plot.subplots(ncols=len(attacks), nrows=1)#, share=0)
handles = []
for model_idx, (results_dict, model_type) in enumerate(zip(attack_results, model_list.keys())):
    for attack_idx in range(len(attacks)):
        score = results_dict['adversarial_analysis'][attack_idx, ...]
        attack_accuracy = 1.0 - score.mean(axis=-1)
        y_vals = 100*attack_accuracy
        x_vals = results_dict['epsilons'][attack_idx]
        handle = axes[attack_idx].plot(x_vals, y_vals, label=model_type)
        if(attack_idx == 0):
            handles.extend(handle)
        if(plot_abs):
            if(model_idx == 0 and attack_idx == 0):
                for abs_model_type, abs_model_accuracy in abs_linf_pgd_accuracies.items():
                    if(abs_model_type not in ['Binary CNN', 'Nearest Neighbor', 'Binary ABS', 'CNN']):
                        handle = axes[attack_idx].plot(
                            abs_model_accuracy['x'], abs_model_accuracy['y'], '--', label=abs_model_type)
                        handles.extend(handle)
        axes[attack_idx].format(title=results_dict['attack_types'][attack_idx])
        axes[attack_idx].format(
            xlabel='Maximum perturbation size',
            xlim=[0.0, np.max(x_vals)],
            ylim=[0, 100])
axes.format(suptitle='Standard Robustness Evaluation', ylabel='Model accuracy')
fig.legend(handles, ncols=1, frame=False, label='Model type', loc='right')
plot.show()

In [None]:
test_accuracies = []
test_confidences = []
for model_type, fmodel in model_list.items():
    batch_confidences = []
    batch_accuracies = []
    for batch_index, (data, target) in enumerate(test_loader):
        data = data.to(device)
        target = target.to(device)
        images, labels = ep.astensors(*(data, target))
        if model_type == 'CNN':
            images = images.squeeze().expand_dims(axis=1)
        else:
            images = images.reshape((batch_size, 784))
        logits = fmodel(images)
        confidence = torch.nn.functional.softmax(logits, dim=-1)
        batch_confidences.append(confidence.numpy())
        batch_accuracies.append(accuracy(fmodel, images, labels))
    all_confidences = np.stack(batch_confidences, axis=0)
    all_accuracies = np.stack(batch_accuracies, axis=0)
    num_batches, batch_size, num_classes = all_confidences.shape
    test_confidences.append(all_confidences.reshape(num_batches*batch_size, num_classes))
    test_accuracies.append(np.mean(all_accuracies))