In [None]:
import os
import sys
from typing import Union, Any, Optional, Callable, Tuple

ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)

import numpy as np
import pandas as pd
import proplot as plot
import eagerpy as ep
import torch
from scipy.stats import pearsonr as linear_correlation
from torch import nn, optim
from torch.nn import functional as F

from DeepSparseCoding.tf1x.utils.logger import Logger as tfLogger
import DeepSparseCoding.tf1x.analysis.analysis_picker as ap
from DeepSparseCoding.tf1x.data.dataset import Dataset
import DeepSparseCoding.tf1x.utils.data_processing as tfdp

from DeepSparseCoding.utils.file_utils import Logger
import DeepSparseCoding.utils.dataset_utils as dataset_utils
import DeepSparseCoding.utils.loaders as loaders
import DeepSparseCoding.utils.plot_functions as pf

import foolbox
from foolbox import PyTorchModel
from foolbox.attacks.projected_gradient_descent import LinfProjectedGradientDescentAttack
from foolbox.types import Bounds
from foolbox.models.base import Model
from foolbox.attacks.base import T
from foolbox.criteria import Misclassification
from foolbox.attacks.base import raise_if_kwargs
from foolbox.attacks.base import get_criterion

rand_state = np.random.RandomState(123)

### Load PyTorch Foolbox models & data

In [None]:
class ModelWithTemperature(nn.Module):
    """
    A thin decorator, which wraps a model with temperature scaling
    model (nn.Module):
        A classification neural network
        NB: Output of the neural network should be the classification logits,
            NOT the softmax (or log softmax)!
    """
    def __init__(self, model, init_temp):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.params = model.params
        self.temperature = nn.Parameter(torch.ones(1) * init_temp)

    def forward(self, input_tensor):
        logits = self.model.forward(input_tensor)
        return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        # Expand temperature to match the size of logits
        temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
        return logits / temperature

def bin_conf_acc_prop(softmaxes, labels, bin_boundaries):
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    confidences, predictions = softmaxes.max(axis=1)
    #accuracies = predictions == labels
    accuracies = predictions.eq(labels)
    bin_confidence = []
    bin_accuracy = []
    bin_prop = []
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # Calculated |confidence - accuracy| in each bin
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        bin_prop.append(in_bin.float().mean())
        if bin_prop[-1].item() > 0:
            bin_accuracy.append(accuracies[in_bin].float().mean())
            bin_confidence.append(confidences[in_bin].mean())
    return bin_confidence, bin_accuracy, bin_prop

class _ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).
    The input to this loss is the logits of a model, NOT the softmax scores.
    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:
    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
    We then return a weighted average of the gaps, based on the number
    of samples in each bin
    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        self.bin_boundaries = torch.linspace(0, 1, n_bins + 1)

    def forward(self, logits, labels):
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
        ece = torch.zeros(1, device=logits.device)
        bin_stats = bin_conf_acc_prop(softmaxes, labels, self.bin_boundaries)
        for avg_confidence_in_bin, accuracy_in_bin, prop_in_bin in zip(*bin_stats):
            if prop_in_bin.item() > 0:
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        return ece

def set_temperature(model, valid_loader, device):
        """
        Tune the tempearature of the model (using the validation set).
        We're going to set it to optimize NLL.
        valid_loader (DataLoader): validation set loader
        """
        nll_criterion = nn.CrossEntropyLoss().to(device)
        ece_criterion = _ECELoss().to(device)
        # First: collect all the logits and labels for the validation set
        logits_list = []
        labels_list = []
        with torch.no_grad():
            for input_tensor, label_tensor in valid_loader:
                input_tensor = input_tensor.reshape((input_tensor.shape[0], 784))
                input_tensor = input_tensor.to(device)
                logits = model(input_tensor)
                logits_list.append(logits)
                labels_list.append(label_tensor)
            logits = torch.cat(logits_list).to(device)
            labels = torch.cat(labels_list).to(device)
        # Calculate NLL and ECE before temperature scaling
        before_temperature_nll = nll_criterion(logits, labels).item()
        before_temperature_ece = ece_criterion(logits, labels).item()
        print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))
        # Next: optimize the temperature w.r.t. NLL
        optimizer = optim.LBFGS([model.temperature], lr=0.01, max_iter=50)
        def eval():
            loss = nll_criterion(model.temperature_scale(logits), labels)
            loss.backward()
            return loss
        optimizer.step(eval)
        # Calculate NLL and ECE after temperature scaling
        after_temperature_nll = nll_criterion(model.temperature_scale(logits), labels).item()
        after_temperature_ece = ece_criterion(model.temperature_scale(logits), labels).item()
        print('Optimal temperature: %.3f' % model.temperature.item())
        print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))
        return model

def create_mnist_dsc(log_file, cp_file, calibrate=False, init_temp=1.0):
    logger = Logger(log_file, overwrite=False)
    log_text = logger.load_file()
    params = logger.read_params(log_text)[-1]
    params.cp_latest_filename = cp_file
    params.standardize_data = False
    params.rescale_data_to_one = True
    params.shuffle_data = False
    params.batch_size = 50
    train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)
    for key, value in data_params.items():
        setattr(params, key, value)
    model = loaders.load_model(params.model_type)
    model.setup(params, logger)
    model.params.analysis_out_dir = os.path.join(
        *[model.params.model_out_dir, 'analysis', model.params.version])
    model.params.analysis_save_dir = os.path.join(model.params.analysis_out_dir, 'savefiles')
    if not os.path.exists(model.params.analysis_save_dir):
        os.makedirs(model.params.analysis_save_dir)
    model.load_checkpoint()
    if calibrate:
        model = ModelWithTemperature(model, init_temp)
        model.to(params.device)
        #model = set_temperature(model, test_loader, params.device)
    else:
        model.to(params.device)
    fmodel = PyTorchModel(model.eval(), bounds=(0, 1))
    fmodel.params = params
    return fmodel, model, test_loader, model.params.batch_size, model.params.device

In [None]:
run_full_test_set = True
calibrate = True
fb_mlp_temp = 1.69
fb_lca_temp = 1.50

log_files = [
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'mlp_768_mnist', 'logfiles', 'mlp_768_mnist_v0.log']),
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'logfiles', 'lca_768_mlp_mnist_v0.log'])
]

cp_latest_filenames = [
    os.path.join(*[ROOT_DIR,'Torch_projects', 'mlp_768_mnist', 'checkpoints', 'mlp_768_mnist_latest_checkpoint_v0.pt']),
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'checkpoints', 'lca_768_mlp_mnist_latest_checkpoint_v0.pt'])
]

fmodel_mlp, dsc_model_mlp, test_loader, batch_size, device = create_mnist_dsc(log_files[0], cp_latest_filenames[0], calibrate=calibrate, init_temp=fb_mlp_temp)
if calibrate:
    fmodel_mlp.model_type = 'MLP_calibrated'
else:
    fmodel_mlp.model_type = 'MLP'
print(fmodel_mlp.model_type)

fmodel_lca, dsc_model_lca = create_mnist_dsc(log_files[1], cp_latest_filenames[1], calibrate=calibrate, init_temp=fb_lca_temp)[:2]
if calibrate:
    fmodel_lca.model_type = 'LCA_calibrated'
else:
    fmodel_lca.model_type = 'LCA'
print(fmodel_lca.model_type)

fmodels = [fmodel_mlp, fmodel_lca]

fb_image_batch, fb_label_batch = next(iter(test_loader))
fb_image_batch = fb_image_batch.reshape((batch_size, 784))

### Load DeepSparseCoding analyzer & data

In [None]:
class params(object):
  def __init__(self):
    self.device = "/gpu:0"
    self.analysis_dataset = "test"
    self.save_info = "analysis_temp_" + self.analysis_dataset
    self.overwrite_analysis_log = False
    self.do_class_adversaries = True
    self.do_run_analysis = False
    self.do_evals = False
    self.do_basis_analysis = False
    self.do_inference = False
    self.do_atas = False 
    self.do_recon_adversaries = False
    self.do_neuron_visualization = False
    self.do_full_recon = False
    self.do_orientation_analysis = False 
    self.do_group_recons = False
    
    #Adversarial params
    self.adversarial_attack_method = "kurakin_targeted"
    self.adversarial_step_size = 0.005 # learning rate for optimizer
    self.adversarial_num_steps = 500 # Number of iterations adversarial attacks
    self.confidence_threshold = 0.9
    self.adversarial_max_change = 1.0 # maximum size of adversarial perturation (epsilon)
    self.carlini_change_variable = False # whether to use the change of variable trick from carlini et al
    self.adv_optimizer = "sgd" # attack optimizer
    self.adversarial_target_method = "random" # Not used if attack_method is untargeted#TODO support specified
    self.adversarial_clip = True # whether or not to clip the final perturbed image
    self.adversarial_clip_range = [0.0, 1.0] # Maximum range of image values
    #self.carlini_recon_mult = 0.1#list(np.arange(.5, 1, .1))
    self.adversarial_save_int = 1 # Interval at which to save adv examples to the npz file
    self.eval_batch_size = 50 # batch size for computing adv examples
    self.adversarial_input_id = None # Which adv images to use; None to use all
    self.adversarial_target_labels = None # Parameter for "specified" target_method. Only for class attacks. Needs to be a list or numpy array of size [adv_batch_size]

In [None]:
analysis_params = params()
analysis_params.projects_dir = os.path.expanduser("~")+"/Work/Projects/"

#model_names = ['mlp_cosyne_mnist', 'slp_lca_768_latent_75_steps_mnist']
#model_names = ['mlp_768_mnist', 'slp_lca_768_latent_mnist']
model_names = ['mlp_1568_mnist', 'slp_lca_1568_latent_mnist']
model_types = ['MLP', 'LCA']
analyzers = []
for model_type, model_name in zip(model_types, model_names):
    analysis_params.model_name = model_name
    analysis_params.version = '0.0'
    analysis_params.model_dir = analysis_params.projects_dir+analysis_params.model_name
    model_log_file = (analysis_params.model_dir+"/logfiles/"+analysis_params.model_name
      +"_v"+analysis_params.version+".log")
    model_logger = tfLogger(model_log_file, overwrite=False)
    model_log_text = model_logger.load_file()
    model_params = model_logger.read_params(model_log_text)[-1]
    analysis_params.model_type = model_params.model_type
    analyzer = ap.get_analyzer(analysis_params.model_type)
    analysis_params.save_info = "analysis_tmp_test_" + analysis_params.analysis_dataset
    analysis_params.save_info += (
        "_linf_"+str(analysis_params.adversarial_max_change)
        +"_ss_"+str(analysis_params.adversarial_step_size)
        +"_ns_"+str(analysis_params.adversarial_num_steps)
        +"_pgd_targeted"
    )
    analyzer.setup(analysis_params)
    analyzer.model_type = model_type
    analyzer.confidence_threshold = analysis_params.confidence_threshold
    analyzers.append(analyzer)

mnist_data = test_loader.dataset.data.numpy().astype(np.float32)
mnist_data /= 255
dsc_data = {
    'test':Dataset(
        np.expand_dims(mnist_data, axis=-1),
        tfdp.dense_to_one_hot(test_loader.dataset.targets.numpy(), 10),
        None,
        rand_state
    )
}
dsc_data = analyzers[0].model.reshape_dataset(dsc_data, analyzer.model_params)
for analyzer in analyzers:
    analyzer.model_params.data_shape = list(dsc_data["test"].shape[1:])
    analyzer.setup_model(analyzer.model_params)
dsc_image_batch, dsc_label_batch, _ = dsc_data['test'].next_batch(batch_size, shuffle_data=False)
dsc_data['test'].reset_counters()

dsc_all_images = dsc_data['test'].images
dsc_all_images = dsc_all_images.reshape((dsc_all_images.shape[0], 784))
dsc_all_labels = dsc_data['test'].labels

### Compare DeepSparseCoding & Foolbox data

In [None]:
img_idx = np.random.randint(batch_size)
fig, axs = plot.subplots(ncols=3)
im = axs[0].imshow(fb_image_batch.numpy()[img_idx,...].reshape(28, 28), cmap='greys_r')
axs[0].format(title=f'PyTorch loader digit class {fb_label_batch[img_idx]}')
axs[0].colorbar(im)
im = axs[1].imshow(dsc_image_batch[img_idx,...].reshape(28, 28), cmap='greys_r')
axs[1].format(title=f'DSC dataset digit class {tfdp.one_hot_to_dense(dsc_label_batch)[img_idx]}')
axs[1].colorbar(im)
diff_img = np.abs(dsc_image_batch[img_idx,...].reshape(28, 28) - fb_image_batch.numpy()[img_idx,...].reshape(28, 28))
im = axs[2].imshow(diff_img, cmap='greys_r')
axs[2].format(title=f'Difference image')
axs[2].colorbar(im)
pf.clear_axes(axs)
plot.show()

### Compare DeepSparseCoding & Foolbox confidence and accuracy

In [None]:
for fmodel in fmodels:
    fmodel.softmaxes = []
    fmodel.logits = []
    for input_tensor, label_tensor in test_loader:
        input_tensor = input_tensor.reshape((input_tensor.shape[0], 784))
        input_tensor = input_tensor.to(fmodel.params.device)
        label_tensor = label_tensor.to(fmodel.params.device)
        input_tensor, label_tensor = ep.astensors(input_tensor, label_tensor)
        fmodel.logits.append(fmodel(input_tensor))
        fmodel.softmaxes.append(torch.nn.functional.softmax(fmodel.logits[-1], dim=-1))
    fmodel.softmaxes = ep.stack(fmodel.softmaxes, axis=0)
    fmodel.num_batches, fmodel.batch_size, fmodel.num_classes = fmodel.softmaxes.shape
    fmodel.softmaxes = fmodel.softmaxes.reshape((fmodel.num_batches*fmodel.batch_size, fmodel.num_classes)).numpy()
    fmodel.logits = ep.stack(fmodel.logits, axis=0)
    fmodel.logits = fmodel.logits.reshape((fmodel.num_batches*fmodel.batch_size, fmodel.num_classes)).numpy()

for analyzer in analyzers:
    print(analyzer.analysis_params.model_name)
    analyzer.logits = np.squeeze(analyzer.compute_activations(dsc_all_images, batch_size=50, activation_operation=analyzer.model.get_logits))
    dsc_data['test'].reset_counters()
    print('bleh')
    analyzer.softmaxes = np.squeeze(analyzer.compute_activations(dsc_all_images, batch_size=50, activation_operation=analyzer.model.get_label_est))
    dsc_data['test'].reset_counters()

In [None]:
n_bins = 50
bin_boundaries = np.linspace(0, 1, n_bins + 1)

fig, axs = plot.subplots(ncols=4)
for ax, model, codebase in zip(axs, fmodels+analyzers, ['FB', 'FB', 'DSC', 'DSC']):

#fig, axs = plot.subplots(ncols=2)
#for ax, model, codebase in zip(axs, analyzers, ['DSC', 'DSC']):

#fig, axs = plot.subplots(ncols=2)
#for ax, model, codebase in zip(axs, fmodels, ['FB', 'FB']):

    confidence, accuracy, props = bin_conf_acc_prop(
        torch.from_numpy(model.softmaxes).to(device),
        test_loader.dataset.targets.to(device),
        bin_boundaries
    )
    confidence = [conf.cpu().numpy() for conf in confidence]
    accuracy = [acc.cpu().numpy() for acc in accuracy]
    props = [prop.cpu().numpy() for prop in props]
    ece = 0
    for avg_confidence_in_bin, accuracy_in_bin, prop_in_bin in zip(confidence, accuracy, props):
        ece += np.abs(avg_confidence_in_bin.item() - accuracy_in_bin.item()) * prop_in_bin.item()
    ece *= 100
    ax.scatter(confidence, accuracy, s=[prop*500 for prop in props if prop > 0], color='k')
    ax.plot([0,1], [0,1], 'k--', linewidth=0.1)
    ax.format(title=f'{codebase}_{model.model_type}\nECE = {ece.round(3)}%')
axs.format(
    suptitle='Reliability of classifier confidence on test set',
    xlabel='Confidence',
    ylabel='Accuracy',
    xlim=[0, 1],
    ylim=[0, 1]
)
plot.show()

In [None]:
fb_logit_forward = [fmodel.logits for fmodel in fmodels]
fb_logit_forward = np.stack(fb_logit_forward, axis=0)
fb_softmax_forward = [fmodel.softmaxes for fmodel in fmodels]
fb_softmax_forward = np.stack(fb_softmax_forward, axis=0)

dsc_logit_forward = [analyzer.logits for analyzer in analyzers]
dsc_logit_forward = np.stack(dsc_logit_forward, axis=0)
dsc_softmax_forward = [analyzer.softmaxes for analyzer in analyzers]
dsc_softmax_forward = np.stack(dsc_softmax_forward, axis=0)

all_softmax_results = np.concatenate((fb_softmax_forward.reshape(2, -1), dsc_softmax_forward.reshape(2, -1)), axis=0).T

In [None]:
fig, axs = plot.subplots(ncols=2, nrows=2)
axs[0,0].bar(np.arange(10), np.squeeze(fb_logit_forward[0, img_idx, :]))
axs[0,0].format(title=f'FB_{fmodels[0].model_type}')
axs[0,1].bar(np.arange(10), np.squeeze(fb_logit_forward[1, img_idx, :]))
axs[0,1].format(title=f'FB_{fmodels[1].model_type}')
axs[1,0].bar(np.arange(10), np.squeeze(dsc_logit_forward[0, img_idx, :]))
axs[1,0].format(title=f'DSC_{analyzers[0].model_type}')
axs[1,1].bar(np.arange(10), np.squeeze(dsc_logit_forward[1, img_idx, :]))
axs[1,1].format(title=f'DSC_{analyzers[1].model_type}')
axs.format(suptitle='Logit outputs for a single image', xtickminor=False, xticks=1)
plot.show()

In [None]:
fig, axs = plot.subplots(ncols=2, nrows=2)
axs[0,0].bar(np.squeeze(fb_softmax_forward[0, img_idx, :]))
axs[0,0].format(title=f'FB_{fmodels[0].model_type}')
axs[0,1].bar(np.squeeze(fb_softmax_forward[1, img_idx, :]))
axs[0,1].format(title=f'FB_{fmodels[1].model_type}')
axs[1,0].bar(np.squeeze(dsc_softmax_forward[0, img_idx, :]))
axs[1,0].format(title=f'DSC_{analyzers[0].model_type}')
axs[1,1].bar(np.squeeze(dsc_softmax_forward[1, img_idx, :]))
axs[1,1].format(title=f'DSC_{analyzers[1].model_type}')
axs.format(suptitle='Softmax confidence for a single image', xtickminor=False, xticks=1, ylim=[0, 1])
plot.show()

In [None]:
names = ['FB_MLP', 'FB_LCA', 'DSC_MLP', 'DSC_LCA']
data = pd.DataFrame(
    all_softmax_results,
    columns=pd.Index(names, name='Model')
)
fig, ax = plot.subplots(ncols=1, axwidth=2.5, share=0)
ax.format(
    grid=False,
    suptitle='Softmax confidence for the test set' 
)
obj1 = ax.boxplot(
    data, linewidth=0.7, marker='.', fillcolor='gray5',
    medianlw=1, mediancolor='k', meancolor='k', meanlw=1
)
ax.format(yscale='log', yformatter='sci')

In [None]:
num_bins = 100
fig, axs = plot.subplots(ncols=2, nrows=2)
for ax, model, atk_type in zip(axs, fmodels+analyzers, ['FB', 'FB', 'DSC', 'DSC']):
    max_confidence = np.max(model.softmaxes, axis=1) # max is across categories, per image
    conf_lim = [0, 1]
    bins = np.linspace(conf_lim[0], conf_lim[1], num_bins)
    count, bin_edges = np.histogram(max_confidence, bins)
    bin_left, bin_right = bin_edges[:-1], bin_edges[1:]
    bin_centers = bin_left + (bin_right - bin_left)/2
    ax.bar(bin_centers, count, color='k')
    mean_confidence = np.mean(max_confidence)
    mean_idx = np.abs(bin_edges - mean_confidence).argmin()
    mean_conf_bin = bin_edges[mean_idx].round(4)
    ax.axvline(mean_conf_bin, lw=1, ls='--', color='r')
    ax.format(
        title=f'{atk_type}_{model.model_type}\nMean confidence = {mean_conf_bin}',
        yscale='log',
        xlim=conf_lim
    )
axs.format(
    suptitle='Softmax confidence on the clean test set correct label',
    ylabel='Count',
    xlabel='Confidence'
)

### Run DeepSparseCoding adversarial attack

In [None]:
def get_adv_indices(softmax_conf, all_kept_indices, confidence_threshold, num_images, labels):
    softmax_conf[np.arange(num_images, dtype=np.int32), labels] = 0 # zero confidence at true label
    confidence_indices = np.max(softmax_conf, axis=-1) # highest non-true label confidence
    adversarial_labels = np.argmax(softmax_conf, axis=-1) # index of highest non-true label
    all_above_thresh = np.nonzero(np.squeeze(confidence_indices>confidence_threshold))[0]
    keep_indices = np.array([], dtype=np.int32)
    for adv_index in all_above_thresh:
        if adv_index not in set(all_kept_indices):
            keep_indices = np.append(keep_indices, adv_index)
    return keep_indices, confidence_indices, adversarial_labels

In [None]:
#if run_full_test_set:
#    data = dsc_all_images
#    labels = dsc_all_labels
#else:
data = dsc_image_batch
labels = dsc_label_batch


for analyzer in analyzers:
    analyzer.class_adversary_analysis(
        data,
        labels,
        batch_size=analyzer.analysis_params.eval_batch_size,
        input_id=analyzer.analysis_params.adversarial_input_id,
        target_method = analyzer.analysis_params.adversarial_target_method,
        target_labels = analyzer.analysis_params.adversarial_target_labels,
        save_info=analyzer.analysis_params.save_info)

In [None]:
labels = tfdp.one_hot_to_dense(labels.astype(np.int32))
for analyzer in analyzers:
    store_data = np.zeros_like(data)
    store_time_step = -1*np.ones(data.shape[0], dtype=np.int32)
    store_labels = np.zeros(data.shape[0], dtype=np.int32)
    store_confidence = np.zeros(data.shape[0], dtype=np.float32)
    store_mses = np.zeros(data.shape[0], dtype=np.float32)
    all_kept_indices = []
    for adv_step in range(1, analyzer.analysis_params.adversarial_num_steps+1): # first one is original
        keep_indices, confidence_indices, adversarial_labels = get_adv_indices(
            analyzer.adversarial_outputs[0, adv_step, ...],
            all_kept_indices,
            analyzer.confidence_threshold,
            data.shape[0],
            labels)
        if keep_indices.size > 0:
            all_kept_indices.extend(keep_indices)
            store_data[keep_indices, ...] = analyzer.adversarial_images[0, adv_step, keep_indices, ...]
            store_time_step[keep_indices] = adv_step
            store_confidence[keep_indices] = confidence_indices[keep_indices]
            store_mses[keep_indices] = analyzer.adversarial_input_adv_mses[0, adv_step, keep_indices]
            store_labels[keep_indices] = adversarial_labels[keep_indices]
    batch_indices = np.arange(data.shape[0], dtype=np.int32)[:,None]
    failed_indices = np.array([val for val in batch_indices if val not in all_kept_indices])
    if len(failed_indices) > 0:
        store_confidence[failed_indices] = confidence_indices[failed_indices]
        store_labels[failed_indices] = adversarial_labels[failed_indices]
        store_data[failed_indices, ...] = data[failed_indices, ...]
        store_mses[failed_indices] = analyzer.adversarial_input_adv_mses[0, -1, failed_indices]
    analyzer.adversarial_images = [store_data]
    analyzer.adversarial_time_step = [store_time_step]
    analyzer.adversarial_confidence = [store_confidence]
    analyzer.failed_indices = [failed_indices]
    analyzer.success_indices = [list(set(all_kept_indices))]
    analyzer.adversarial_labels = [store_labels]
    analyzer.mean_squared_distances = [store_mses]
    analyzer.num_failed = [data.shape[0] - len(set(all_kept_indices))]
    print(f'model {analyzer.model_type} had {analyzer.num_failed} failed indices')

### Run Foolbox adversarial attack

In [None]:
class LinfProjectedGradientDescentAttackWithStopping(LinfProjectedGradientDescentAttack):
    def __init__(
        self,
        *,
        rel_stepsize: float = 0.025,
        abs_stepsize: Optional[float] = None,
        steps: int = 50,
        random_start: bool = True,
    ):
        super().__init__(
            rel_stepsize=rel_stepsize,
            abs_stepsize=abs_stepsize,
            steps=steps,
            random_start=random_start,
        )
        
    #def project(self, x: ep.Tensor, x0: ep.Tensor, epsilon: float) -> ep.Tensor:
    #    return x0 + ep.clip(x - x0, -epsilon, epsilon)
    
    def normalize(
        self, gradients: ep.Tensor, *, x: ep.Tensor, bounds: Bounds
    ) -> ep.Tensor:
        return gradients.sign()
        
    def run(
        self,
        model: Model,
        inputs: T,
        criterion: Union[Misclassification, T],
        *,
        epsilon: float,
        **kwargs: Any,
    ) -> T:
        raise_if_kwargs(kwargs)
        x0, restore_type = ep.astensor_(inputs)
        criterion_ = get_criterion(criterion)
        del inputs, criterion, kwargs

        if not isinstance(criterion_, Misclassification):
            raise ValueError("unsupported criterion")

        labels = criterion_.labels
        loss_fn = self.get_loss_fn(model, labels)

        if self.abs_stepsize is None:
            stepsize = self.rel_stepsize * epsilon
        else:
            stepsize = self.abs_stepsize

        orig_x = x0.numpy().copy()
        x = x0

        if self.random_start:
            x = self.get_random_start(x0, epsilon)
            x = ep.clip(x, *model.bounds)
        else:
            x = x0
        store_x = np.zeros_like(x, dtype=np.float32)
        store_time_step = -1*np.ones(x.shape[0], dtype=np.int32)
        store_labels = np.zeros(x.shape[0], dtype=np.int32)
        store_confidence = np.zeros(x.shape[0], dtype=np.float32)
        all_kept_indices = []
        time_step = 0
        num_failed = 0
        while len(set(all_kept_indices)) < x.shape[0]:
            loss, gradients = self.value_and_grad(loss_fn, x)
            gradients = self.normalize(gradients=gradients, x=x, bounds=model.bounds)
            x = x + stepsize * gradients
            x = self.project(x, x0, epsilon)
            x = ep.clip(x, *model.bounds)
            keep_indices, confidence_indices, adversarial_labels = get_adv_indices(
                ep.softmax(model(x)).numpy().copy(),
                all_kept_indices,
                model.confidence_threshold,
                x.shape[0],
                labels.numpy())
            if keep_indices.size > 0:
                all_kept_indices.extend(keep_indices)
                store_x[keep_indices, ...] = x.numpy()[keep_indices, ...]
                store_labels[keep_indices] = adversarial_labels[keep_indices]
                store_time_step[keep_indices] = time_step
                store_confidence[keep_indices] = confidence_indices[keep_indices]
            time_step += 1
            if time_step == self.steps-1:
                num_failed = x.shape[0] - len(set(all_kept_indices))
                print(f'Max steps = {self.steps} reached for model {model.model_type}, {num_failed} images did not achieve adversarial confidence threshold of {model.confidence_threshold}')
                break
        batch_indices = np.arange(x.shape[0], dtype=np.int32)[:,None]
        failed_indices = np.array([val for val in batch_indices if val not in all_kept_indices])
        if len(failed_indices) > 0:
            store_confidence[failed_indices] = confidence_indices[failed_indices]
            store_x[failed_indices, ...] = x[failed_indices, ...]
        reduc_dim = tuple(range(1, len(orig_x.shape)))
        msd = np.mean((store_x - orig_x)**2, axis=reduc_dim)
        model.adversarial_images.append(store_x)
        model.adversarial_time_step.append(store_time_step)
        model.adversarial_labels.append(store_labels)
        model.adversarial_confidence.append(store_confidence)
        model.success_indices.append(np.array(all_kept_indices, dtype=np.int32))
        model.failed_indices.append(failed_indices)
        model.mean_squared_distances.append(msd)
        model.num_failed.append(len(failed_indices))
        return restore_type(x)

In [None]:
attack_params = {
    'LinfPGD': {
        'random_start':False,
        'abs_stepsize':analysis_params.adversarial_step_size,
        'steps':analysis_params.adversarial_num_steps # maximum number of steps
    }
}
epsilons = [analysis_params.adversarial_max_change]
attack = LinfProjectedGradientDescentAttackWithStopping(**attack_params['LinfPGD'])

for fmodel in fmodels:
    fmodel.confidence_threshold = analysis_params.confidence_threshold
    fmodel.adversarial_images = []
    fmodel.adversarial_labels = []
    fmodel.adversarial_time_step = []
    fmodel.adversarial_confidence = []
    fmodel.failed_indices = []
    fmodel.mean_squared_distances = []
    fmodel.num_failed = []
    fmodel.success_indices = []
    fmodel.success = []
    for batch_idx, (input_tensor, label_tensor) in enumerate(test_loader):
        if not run_full_test_set and batch_idx >= 1:
            pass
        input_tensor = input_tensor.reshape((input_tensor.shape[0], 784))
        input_tensor = input_tensor.to(fmodel.params.device)
        label_tensor = label_tensor.to(fmodel.params.device)
        input_tensor, label_tensor = ep.astensors(input_tensor, label_tensor)
        advs, _, success = attack(
            fmodel,
            input_tensor,
            label_tensor,
            epsilons=epsilons
        )
        fmodel.success.append(success.numpy())
     
    #fmodel.num_batches, fmodel.batch_size, fmodel.num_classes = fmodel.softmaxes.shape
    fmodel.num_batches = batch_idx+1
    
    fmodel.adversarial_images = np.stack(fmodel.adversarial_images, axis=0)
    fmodel.adversarial_images = fmodel.adversarial_images.reshape(
        (fmodel.num_batches*fmodel.batch_size,
         *list(fmodel.adversarial_images.shape[2:]))
    )
    
    fmodel.adversarial_labels = np.stack(fmodel.adversarial_labels, axis=0)
    fmodel.adversarial_labels = fmodel.adversarial_labels.reshape(
        (fmodel.num_batches*fmodel.batch_size,
         *list(fmodel.adversarial_labels.shape[2:]))
    )
    
    fmodel.adversarial_time_step = np.stack(fmodel.adversarial_time_step, axis=0)
    fmodel.adversarial_time_step = fmodel.adversarial_time_step.reshape(
        (fmodel.num_batches*fmodel.batch_size,
         *list(fmodel.adversarial_time_step.shape[2:]))
    )
    
    fmodel.adversarial_confidence = np.stack(fmodel.adversarial_confidence, axis=0)
    fmodel.adversarial_confidence = fmodel.adversarial_confidence.reshape(
        (fmodel.num_batches*fmodel.batch_size,
         *list(fmodel.adversarial_confidence.shape[2:]))
    )
    
    fmodel.failed_indices = np.stack(fmodel.failed_indices, axis=0)
    
    fmodel.mean_squared_distances = np.stack(fmodel.mean_squared_distances, axis=0)
    fmodel.mean_squared_distances = fmodel.mean_squared_distances.reshape(
        (fmodel.num_batches*fmodel.batch_size,
         *list(fmodel.mean_squared_distances.shape[2:]))
    )
    
    fmodel.num_failed = np.stack(fmodel.num_failed, axis=0)
    
    fmodel.success_indices = np.stack(fmodel.success_indices, axis=0)
    fmodel.success_indices = fmodel.success_indices.reshape(
        (fmodel.num_batches*fmodel.batch_size,
         *list(fmodel.success_indices.shape[2:]))
    )
    
    fmodel.success = np.stack(fmodel.success, axis=0)
    print(f'model {fmodel.model_type} had {fmodel.num_failed.sum()} failed indices')

### Compare DeepSparseCoding & Foolbox adversarial attacks

In [None]:
for analyzer in analyzers:
    analyzer.accuracy = analyzer.adversarial_clean_accuracy.item()
    print(f'DSC {analyzer.model_type} clean accuracy = {analyzer.accuracy} and adv accuracy = {analyzer.adversarial_adv_accuracy}')
    
for fmodel in fmodels:
    fmodel.accuracy = foolbox.accuracy(fmodel, fb_image_batch.to(device), fb_label_batch.to(device))
    print(f'FB {fmodel.model_type} clean accuracy = {fmodel.accuracy} and adv accuracy = {1.0 - fmodel.success[0].mean(axis=-1).round(2)}')

In [None]:
def stars(p):
    if p < 0.0001:
        return '****'
    elif (p < 0.001):
        return '***'
    elif (p < 0.01):
        return '**'
    elif (p < 0.05):
        return '*'
    else:
        return 'n.s.'

names = ['MLP 2L;768N','LCA 2L;768N']

fb_all_success_indices = np.intersect1d(*[fmodel.success_indices for fmodel in fmodels]).astype(np.int32)
fb_adv_results_list = [np.array(fmodel.mean_squared_distances)[fb_all_success_indices] for fmodel in fmodels]
fb_all_results = np.stack(fb_adv_results_list, axis=-1).squeeze()
fb_dataframe = pd.DataFrame(
    fb_all_results,
    columns=pd.Index(names, name='Model')
)

dsc_all_success_indices = np.intersect1d(*[analyzer.success_indices for analyzer in analyzers]).astype(np.int32)
dsc_adv_results_list = [analyzer.mean_squared_distances[0][dsc_all_success_indices] for analyzer in analyzers]
dsc_all_results = np.stack(dsc_adv_results_list, axis=-1).squeeze()
dsc_dataframe = pd.DataFrame(
    dsc_all_results,
    columns=pd.Index(names, name='Model')
)

dsc_p_value = linear_correlation(dsc_all_results[:,0], dsc_all_results[:,1])[1]
fb_p_value = linear_correlation(fb_all_results[:,0], fb_all_results[:,1])[1]

fig, axs = plot.subplots(ncols=2, axwidth=2.5, share=0)
axs.format(grid=False, suptitle='L infinity Attack Mean Squared Distances')
ax = axs[0]
obj1 = ax.boxplot(
    fb_dataframe, linewidth=0.7, marker='.', fillcolor='gray5',
    medianlw=1, mediancolor='k', meancolor='k', meanlw=1
)
ax_y_max = max(ax.get_ylim())
ax.text(0.5, ax_y_max-0.1*(ax_y_max), stars(fb_p_value),
       horizontalalignment='center',
       verticalalignment='center',
       fontsize=14)
ax.format(title='Foolbox')

ax = axs[1]
obj2 = ax.boxplot(
    dsc_dataframe, linewidth=0.7, marker='.', fillcolor='gray5',
    medianlw=1, mediancolor='k', meancolor='k', meanlw=1
)
ax_y_max = max(ax.get_ylim())
ax.text(0.5, ax_y_max-0.1*(ax_y_max), stars(dsc_p_value),
       horizontalalignment='center',
       verticalalignment='center',
       fontsize=24)
ax.format(title='Deep Sparse Coding')

## attack images

In [None]:
fig, axs = plot.subplots(nrows=3, ncols=len(fmodels+analyzers))
pf.clear_axes(axs)
top_level = zip(fmodels+analyzers, fb_adv_results_list+dsc_adv_results_list, ['FB', 'FB', 'DSC', 'DSC'])
for model_idx, (model, adv_results_list, atk_type) in enumerate(top_level):
    if atk_type == 'DSC':
        adv_imgs = model.adversarial_images[0]
        adv_labels = model.adversarial_labels[0]
    else:
        adv_imgs = model.adversarial_images
        adv_labels = model.adversarial_labels
    adv_results = adv_results_list#[0]
    adv_min_idx = np.abs(adv_results - adv_results.min()).argmin()
    adv_mean_idx = np.abs(adv_results - adv_results.mean()).argmin()
    adv_max_idx = np.abs(adv_results - adv_results.max()).argmin()
    for row_idx, image_idx in enumerate([adv_min_idx, adv_mean_idx, adv_max_idx]):
        img = adv_imgs[image_idx, ...].reshape(28, 28).astype(np.float32)
        h = axs[row_idx, model_idx].imshow(img, cmap='grays')
        axs[row_idx, model_idx].colorbar(h, loc='r', ticks=1)
        axs[row_idx, model_idx].format(title=f'{atk_type}_{model.model_type} adversarial label = {adv_labels[row_idx]}')
    axs[row_idx, 0].format(llabels=['Min MSD', 'Mean MSD', 'Max MSD'])
plot.show()