In [None]:
import os
import sys

ROOT_DIR = os.getcwd()
while 'DeepSparseCoding' in ROOT_DIR:
    ROOT_DIR = os.path.dirname(ROOT_DIR)
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
    
import pickle
import numpy as np
import pandas as pd
import proplot as plot
import torch
import torch.nn as nn
import torchvision.models as models

from DeepSparseCoding.utils.file_utils import Logger
import DeepSparseCoding.utils.loaders as loaders
import DeepSparseCoding.utils.run_utils as run_utils
import DeepSparseCoding.utils.dataset_utils as dataset_utils
import DeepSparseCoding.utils.run_utils as ru
import DeepSparseCoding.utils.plot_functions as pf
import DeepSparseCoding.utils.data_processing as dp

import eagerpy as ep
from foolbox import PyTorchModel, accuracy, samples
import foolbox.attacks as fa

In [None]:
def create_mnist_fb() -> PyTorchModel:
    model = nn.Sequential(
        nn.Conv2d(1, 32, 3),
        nn.ReLU(),
        nn.Conv2d(32, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Dropout2d(0.25),
        nn.Flatten(),  # type: ignore
        nn.Linear(9216, 128),
        nn.ReLU(),
        nn.Dropout2d(0.5),
        nn.Linear(128, 10),
        #nn.LogSoftmax(dim=1)
    )
    #path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mnist_cnn.pth")
    path = os.path.join(*[ROOT_DIR, 'DeepSparseCoding', 'mnist_cnn.pth'])
    model.load_state_dict(torch.load(path))  # type: ignore
    model.eval()
    #preprocessing = dict(mean=0.1307, std=0.3081)
    fmodel = PyTorchModel(model, bounds=(0, 1))#, preprocessing=preprocessing)
    return fmodel

In [None]:
def create_mnist_dsc(log_file, cp_file):
    logger = Logger(log_file, overwrite=False)
    log_text = logger.load_file()
    params = logger.read_params(log_text)[-1]
    params.cp_latest_filename = cp_file
    params.standardize_data = False
    params.rescale_data_to_one = True
    train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)
    for key, value in data_params.items():
        setattr(params, key, value)
    model = loaders.load_model(params.model_type)
    model.setup(params, logger)
    model.params.analysis_out_dir = os.path.join(
        *[model.params.model_out_dir, 'analysis', model.params.version])
    model.params.analysis_save_dir = os.path.join(model.params.analysis_out_dir, 'savefiles')
    if not os.path.exists(model.params.analysis_save_dir):
        os.makedirs(model.params.analysis_save_dir)
    model.to(params.device)
    model.load_checkpoint()
    fmodel = PyTorchModel(model.eval(), bounds=(0, 1))
    return fmodel, model, test_loader, model.params.batch_size, model.params.device

In [None]:
log_files = [
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'mlp_768_mnist', 'logfiles', 'mlp_768_mnist_v0.log']),
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'logfiles', 'lca_768_mlp_mnist_v0.log'])
    ]

cp_latest_filenames = [
    os.path.join(*[ROOT_DIR,'Torch_projects', 'mlp_768_mnist', 'checkpoints', 'mlp_768_mnist_latest_checkpoint_v0.pt']),
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'checkpoints', 'lca_768_mlp_mnist_latest_checkpoint_v0.pt'])
    ]

In [None]:
fmodel_mlp, dmodel_mlp, test_loader, batch_size, device = create_mnist_dsc(log_files[0], cp_latest_filenames[0])
fmodel_mlp.type = 'MLP'
fmodel_cnn = create_mnist_fb()
fmodel_cnn.type = 'CNN'
fmodel_lca, dmodel_lca = create_mnist_dsc(log_files[1], cp_latest_filenames[1])[:2]
fmodel_lca.type = 'LCA'
model_list = {fmodel_mlp.type:fmodel_mlp, fmodel_cnn.type:fmodel_cnn, fmodel_lca.type:fmodel_lca}

In [None]:
def plot_weights(weights, title="", figsize=None):
    """
        weights: [np.ndarray] of shape [num_outputs, num_input_y, num_input_x]
        The matrices are renormalized before plotting.
    """
    #weights = dp.norm_weights(weights)
    vmin = np.min(weights)
    vmax = np.max(weights)
    num_plots = weights.shape[0]
    num_plots_y = int(np.ceil(np.sqrt(num_plots))+1)
    num_plots_x = int(np.floor(np.sqrt(num_plots)))
    fig, sub_ax = plt.subplots(num_plots_y, num_plots_x, figsize=figsize)
    filter_total = 0
    for plot_id in  np.ndindex((num_plots_y, num_plots_x)):
        if filter_total < num_plots:
            sub_ax[plot_id].imshow(np.squeeze(weights[filter_total, ...]), vmin=vmin, vmax=vmax, cmap="Greys_r")
            filter_total += 1
        clear_axis(sub_ax[plot_id])
        sub_ax[plot_id].set_aspect("equal")
    fig.suptitle(title, y=0.95, x=0.5, fontsize=20)
    plt.show()
    return fig

In [None]:
def pad_data(data, pad_values=1):
    """
    Pad data with ones for visualization
    Outputs:
        padded version of input
    Inputs:
        data: np.ndarray
        pad_values: [int] specifying what value will be used for padding
    """
    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = (((0, n ** 2 - data.shape[0]),
        (1, 1), (1, 1)) # add some space between filters
        + ((0, 0),) * (data.ndim - 3)) # don't pad last dimension (if there is one)
    padded_data = np.pad(data, padding, mode="constant",
        constant_values=pad_values)
    # tile the filters into an image
    padded_data = padded_data.reshape((
        (n, n) + padded_data.shape[1:])).transpose((
        (0, 2, 1, 3) + tuple(range(4, padded_data.ndim + 1))))
    padded_data = padded_data.reshape((n * padded_data.shape[1],
        n * padded_data.shape[3]) + padded_data.shape[4:])
    return padded_data

def normalize_data_with_max(data):
    """
    Normalize data by dividing by abs(max(data))
    Inputs:
        data: [np.ndarray] data to be normalized
    Outputs:
        norm_data: [np.ndarray] normalized data
        data_max: [float] max that was divided out
    """
    data_max = np.max(np.abs(data))
    if data_max > 0:
        norm_data = data / data_max
    else:
        norm_data = data
    return norm_data, data_max

In [None]:
weights = dmodel_mlp.fc0_w.cpu().detach().numpy()
num_neurons, num_pixels = weights.shape
weights = weights.reshape((num_neurons, int(np.sqrt(num_pixels)), int(np.sqrt(num_pixels)), 1))
weights = normalize_data_with_max(weights)[0]
weights = pad_data(weights)
fig, ax = plot.subplots(figsize=(10, 10))
ax = pf.clear_axis(ax)
axis_image = ax.imshow(np.squeeze(weights), cmap='greys_r', interpolation="nearest")
plot.show()

In [None]:
weights = dmodel_lca.lca.w.cpu().detach().numpy().T
num_plots, num_pixels = weights.shape
weights = weights.reshape((num_plots, int(np.sqrt(num_pixels)), int(np.sqrt(num_pixels)), 1))
weights = normalize_data_with_max(weights)[0]
weights = pad_data(weights)
fig, ax = plot.subplots(figsize=(10, 10))
ax = pf.clear_axis(ax)
axis_image = ax.imshow(np.squeeze(weights), cmap='greys_r', interpolation="nearest")
plot.show()

In [None]:
save_results = True
load_results = True

num_batches =  10#len(test_loader.dataset) // batch_size

attack_params = {
    'LinfPGD': {
        'random_start':False,
        'abs_stepsize':0.002,
        'steps':500
    },
    'L2PGD': {
        'random_start':False,
        'abs_stepsize':0.45,
        'steps':2000
    }
}

linf_epsilons = [ # allowed perturbation size
    0.0,
    0.03,
    0.06,
    0.09,
    0.1,
    0.13,
    0.16,
    0.19,
    0.2,
    0.23,
    0.26,
    0.29,
    0.3,
    0.33,
    0.36,
    0.39,
    0.4
]

l2_epsilons = [10 * eps for eps in linf_epsilons]

attacks = [
    (fa.LinfPGD(**attack_params['LinfPGD']), linf_epsilons),
    (fa.L2PGD(**attack_params['L2PGD']), l2_epsilons),
    #fa.FGSM(),
    #fa.LinfBasicIterativeAttack(),
    #fa.LinfAdditiveUniformNoiseAttack(),
    #fa.LinfDeepFoolAttack(),
]

In [None]:
if load_results:
    attack_results = []
    for model_index, (model_type, fmodel) in enumerate(model_list.items()):
        attack_results.append(np.load(f'adv_attack_results_{model_type}.npz', allow_pickle=True)['data'].item())
else:
    attack_results = []
    for model_index, (model_type, fmodel) in enumerate(model_list.items()):
        attack_success = np.zeros(
                (len(attacks), len(linf_epsilons), num_batches, batch_size), dtype=np.bool)
        for batch_index, (data, target) in enumerate(test_loader):
            if batch_index < num_batches:
                data = data.to(device)
                target = target.to(device)
                images, labels = ep.astensors(*(data, target))
                if model_type == 'CNN':
                    images = images.squeeze().expand_dims(axis=1)
                else:
                    images = images.reshape((batch_size, 784))
                print('\n', '~' * 79)
                print(f'Model type: {model_type} [{model_index+1} out of {len(model_list)}]')
                print(f'Batch {batch_index+1} out of {num_batches}')
                print(f'accuracy {accuracy(fmodel, images, labels)}')
                for attack_index, (attack, epsilons) in enumerate(attacks):
                    advs, _, success = attack(fmodel, images, labels, epsilons=epsilons)
                    assert success.shape == (len(epsilons), len(images))
                    success_ = success.numpy()
                    assert success_.dtype == np.bool
                    attack_success[attack_index, :, batch_index, :] = success_
                    print('\n', attack)
                    print('  ', 1.0 - success_.mean(axis=-1).round(2))
                robust_accuracy = 1.0 - attack_success[:, :, batch_index, :].max(axis=0).mean(axis=-1)
                #print('\n', '-' * 79, '\n')
                #print('worst case (best attack per-sample)')
                #print('  ', robust_accuracy.round(2))
                #print('-' * 79)
        attack_success = attack_success.reshape(
            (len(attacks), len(epsilons), num_batches * batch_size))
        attack_types = []
        epsilon_list = []
        for attack, epsilons in attacks:
            attack_types.append(str(type(attack)).split('.')[-1][:-2])
            epsilon_list.append(epsilons)
        out_dict = {
            'num_batches':num_batches,
            'batch_size':batch_size,
            'adversarial_analysis':attack_success,
            'attack_types':attack_types,
            'epsilons':epsilon_list,
            'attack_params':attack_params}
        attack_results.append(out_dict)
        if save_results:
            np.savez(f'adv_attack_results_{model_type}.npz', data=out_dict)

In [None]:
plot_abs = False

if(plot_abs):
    abs_filename = os.path.join(
        *[ROOT_DIR, 'analysis-by-synthesis', 'figures', 'Linf_accuracy_distortion_curves.pickle'])
    with open(abs_filename, 'rb') as file:
        abs_linf_pgd_accuracies = pickle.load(file)

fig, axes = plot.subplots(ncols=len(attacks), nrows=1)#, share=0)
handles = []
for model_idx, (results_dict, model_type) in enumerate(zip(attack_results, model_list.keys())):
    for attack_idx in range(len(attacks)):
        score = results_dict['adversarial_analysis'][attack_idx, ...]
        attack_accuracy = 1.0 - score.mean(axis=-1)
        y_vals = 100*attack_accuracy
        x_vals = results_dict['epsilons'][attack_idx]
        handle = axes[attack_idx].plot(x_vals, y_vals, label=model_type)
        if(attack_idx == 0):
            handles.extend(handle)
        if(plot_abs):
            if(model_idx == 0 and attack_idx == 0):
                for abs_model_type, abs_model_accuracy in abs_linf_pgd_accuracies.items():
                    if(abs_model_type not in ['Binary CNN', 'Nearest Neighbor', 'Binary ABS', 'CNN']):
                        handle = axes[attack_idx].plot(
                            abs_model_accuracy['x'], abs_model_accuracy['y'], label=abs_model_type)
                        handles.extend(handle)
        axes[attack_idx].format(title=results_dict['attack_types'][attack_idx])
        axes[attack_idx].format(
            xlabel='Maximum perturbation size',
            xlim=[0.0, np.max(x_vals)],
            ylim=[0, 100])
axes.format(ylabel='Model accuracy')
fig.legend(handles, ncols=1, frame=False, label='Model type', loc='right')
plot.show()

In [None]:
from typing import Union, Any, Optional, Callable, Tuple

from foolbox.types import Bounds
from foolbox.attacks.base import T
from foolbox.models.base import Model
from foolbox.criteria import Misclassification
from foolbox.attacks.base import raise_if_kwargs
from foolbox.attacks.base import get_criterion
from foolbox.attacks.projected_gradient_descent import LinfProjectedGradientDescentAttack

class LinfProjectedGradientDescentAttackWithStopping(LinfProjectedGradientDescentAttack):
    def __init__(
        self,
        *,
        rel_stepsize: float = 0.025,
        abs_stepsize: Optional[float] = None,
        steps: int = 50,
        random_start: bool = True,
    ):
        super().__init__(
            rel_stepsize=rel_stepsize,
            abs_stepsize=abs_stepsize,
            steps=steps,
            random_start=random_start,
        )
        
    def normalize(
        self, gradients: ep.Tensor, *, x: ep.Tensor, bounds: Bounds
    ) -> ep.Tensor:
        return gradients.sign()
        
    def run(
        self,
        model: Model,
        inputs: T,
        criterion: Union[Misclassification, T],
        *,
        epsilon: float,
        **kwargs: Any,
    ) -> T:
        raise_if_kwargs(kwargs)
        x0, restore_type = ep.astensor_(inputs)
        criterion_ = get_criterion(criterion)
        del inputs, criterion, kwargs

        if not isinstance(criterion_, Misclassification):
            raise ValueError("unsupported criterion")

        labels = criterion_.labels
        loss_fn = self.get_loss_fn(model, labels)

        if self.abs_stepsize is None:
            stepsize = self.rel_stepsize * epsilon
        else:
            stepsize = self.abs_stepsize

        x = x0

        #np.savez('tmp_x0.npz', data=x0.numpy())
        if self.random_start:
            x = self.get_random_start(x0, epsilon)
            x = ep.clip(x, *model.bounds)
        else:
            x = x0

        confidence_threshold = 0.9
        store_x = np.zeros_like(x)
        store_time_step = -1*np.ones(x.shape[0], dtype=np.int32)
        store_confidence = np.zeros(x.shape[0], dtype=np.float32)
        all_kept_indices = []
        time_step = 0
        num_failed = 0
        while len(set(all_kept_indices)) < x.shape[0]:
            loss, gradients = self.value_and_grad(loss_fn, x)
            gradients = self.normalize(gradients=gradients, x=x, bounds=model.bounds)
            x = x + stepsize * gradients
            x = self.project(x, x0, epsilon)
            x = ep.clip(x, *model.bounds)
            
            # for targeted attacks
            #adversarial_outputs = ep.softmax(model(x))
            #adversarial_confidence = ep.take_along_axis(adversarial_outputs, labels[:,None], axis=1).numpy()
            
            # for untargeted attacks
            adversarial_outputs = ep.softmax(model(x)).numpy().copy()
            adversarial_outputs[np.arange(x.shape[0]), labels.numpy()] = 0 # zero confidence at true label
            adversarial_confidence = ep.max(adversarial_outputs, axis=1) # highest non-true label confidence
            
            all_above_thresh = np.nonzero(np.squeeze(adversarial_confidence>confidence_threshold))[0]
            keep_indices = np.array([], dtype=np.int32)
            for adv_index in all_above_thresh:
                if adv_index not in set(all_kept_indices):
                    keep_indices = np.append(keep_indices, adv_index)
            all_kept_indices.extend(keep_indices)
            store_x[keep_indices, ...] = x.numpy()[keep_indices, ...]
            store_time_step[keep_indices] = time_step
            store_confidence[keep_indices] = adversarial_confidence[keep_indices]
            time_step += 1
            if time_step == self.steps-1:
                num_failed = x.shape[0] - len(set(all_kept_indices))
                print(f'Max steps = {self.steps} reached for model {model.type}, {num_failed} images did not achieve adversarial confidence threshold of {confidence_threshold}')
                #import IPython; IPython.embed(); raise SystemExit
                break
        remaining_indices = np.array([val for val in np.arange(x.shape[0], dtype=np.int32) if val not in all_kept_indices])
        if len(remaining_indices) > 0:
            store_confidence[remaining_indices] = adversarial_confidence[remaining_indices]
            store_x[remaining_indices, ...] = x[remaining_indices, ...]
        reduc_dim = tuple(range(1, len(x.shape)))
        msd = np.mean(np.square(store_x - x0.numpy()), axis=reduc_dim)
        output_dict = {
            'adversarial_images':store_x,
            'adversarial_time_step':store_time_step,
            'adversarial_confidence':store_confidence,
            'confidence_threshold':confidence_threshold,
            'failed_indices':remaining_indices,
            'mean_squared_distances':msd,
            'epsilon':epsilon,
            'image_bounds':model.bounds,
            'max_steps':self.steps,
            'num_failed':num_failed
        }
        np.savez(f'confidence_attack_{model.type}.npz', data=output_dict)
        return restore_type(x)

In [None]:
attack_params = {
    'LinfPGD': {
        'random_start':False,
        'abs_stepsize':0.005,
        'steps':500 # maximum number of steps
    }
}
epsilons = [1.0]

attack = LinfProjectedGradientDescentAttackWithStopping(**attack_params['LinfPGD'])
test_images, test_labels = next(iter(test_loader))
test_images = test_images.to(device)
test_labels = test_labels.to(device)
test_images, test_labels = ep.astensors(*(test_images, test_labels))

test_images = test_images.squeeze().expand_dims(axis=1)
cnn_advs, _, cnn_success = attack(fmodel_cnn, test_images, test_labels, epsilons=epsilons)

test_images = test_images.squeeze().reshape((batch_size, 784))
mlp_advs, _, mlp_success = attack(fmodel_mlp, test_images, test_labels, epsilons=epsilons)
lca_advs, _, lca_success = attack(fmodel_lca, test_images, test_labels, epsilons=epsilons)

In [None]:
adv_results = {}
for model_type in model_list.keys():
    adv_results[model_type] = np.load(f'confidence_attack_{model_type}.npz', allow_pickle=True)['data'].item()
    
lca_success_indices = np.argwhere(adv_results['LCA']['adversarial_time_step']>=0).squeeze()
mlp_success_indices = np.argwhere(adv_results['MLP']['adversarial_time_step']>=0).squeeze()
all_success_indices = np.union1d(lca_success_indices, mlp_success_indices)
adv_results_list = [adv_results['LCA']['mean_squared_distances'][all_success_indices],
    adv_results['MLP']['mean_squared_distances'][all_success_indices]]
all_results = np.stack(adv_results_list, axis=-1).squeeze()

In [None]:
names = ['LCA 2L;768N', 'MLP 2L;768N']

data = pd.DataFrame(
    all_results,
    columns=pd.Index(names, name='Model')
)
fig, axs = plot.subplots(ncols=1, axwidth=2.5)
axs.format(grid=False, suptitle='Mean Squared Distances')
ax = axs[0]
obj1 = ax.boxplot(
    data, lw=0.7, marker='.', fillcolor='gray5',
    medianlw=1, mediancolor='k', meancolor='k', meanlw=1
)
ax.format(title='L inf', titleloc='uc')