In [None]:
import os
import sys

ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)

import numpy as np
import pandas as pd
import proplot as plot
from scipy.stats import pearsonr as linear_correlation

from DeepSparseCoding.tf1x.utils.logger import Logger as Logger
import DeepSparseCoding.tf1x.analysis.analysis_picker as ap
import DeepSparseCoding.tf1x.data.data_selector as ds
import DeepSparseCoding.tf1x.utils.data_processing as dp
import DeepSparseCoding.utils.plot_functions as pf

rand_seed = 123
rand_state = np.random.RandomState(rand_seed)

In [None]:
def bin_conf_acc_prop(softmaxes, labels, bin_boundaries):
    # Commented lines are for PyTorch
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    #confidences, predictions = softmaxes.max(axis=1)
    confidences = softmaxes.max(axis=1)
    predictions = dp.dense_to_one_hot(softmaxes.argmax(axis=1), num_classes=10)
    #accuracies = predictions.eq(labels)
    accuracies = np.equal(predictions, labels)
    bin_confidence = []
    bin_accuracy = []
    bin_prop = []
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # Calculated |confidence - accuracy| in each bin
        #in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        in_bin = (confidences > bin_lower.item()) * (confidences <= bin_upper.item())
        #bin_prop.append(in_bin.float().mean())
        bin_prop.append(in_bin.astype(np.float32).mean())
        if bin_prop[-1].item() > 0:
            #bin_accuracy.append(accuracies[in_bin].float().mean())
            bin_accuracy.append(accuracies[in_bin].astype(np.float32).mean())
            bin_confidence.append(confidences[in_bin].mean())
    return bin_confidence, bin_accuracy, bin_prop

### Load DeepSparseCoding analyzer & data

In [None]:
class params(object):
  def __init__(self):
    self.device = "/gpu:0"
    self.analysis_dataset = "test"
    self.save_info = "analysis_" + self.analysis_dataset
    self.overwrite_analysis_log = False
    self.do_class_adversaries = True
    self.do_run_analysis = False
    self.do_evals = False
    self.do_basis_analysis = False
    self.do_inference = False
    self.do_atas = False 
    self.do_recon_adversaries = False
    self.do_neuron_visualization = False
    self.do_full_recon = False
    self.do_orientation_analysis = False 
    self.do_group_recons = False
    
    # Adversarial params
    self.temperature = [0.5, 0.68]
    self.adversarial_attack_method = "kurakin_targeted"
    self.adversarial_step_size = 0.005 # learning rate for optimizer
    self.adversarial_num_steps = 500 # Number of iterations adversarial attacks
    self.confidence_threshold = 0.9
    self.adversarial_max_change = None # maximum size of adversarial perturation (epsilon)
    self.carlini_change_variable = False # whether to use the change of variable trick from carlini et al
    self.adv_optimizer = "sgd" # attack optimizer
    self.adversarial_target_method = "random" # Not used if attack_method is untargeted#TODO support specified
    self.adversarial_clip = True # whether or not to clip the final perturbed image
    self.adversarial_clip_range = [0.0, 1.0] # Maximum range of image values
    self.adversarial_save_int = 1 # Interval at which to save adv examples to the npz file
    self.eval_batch_size = 50 # batch size for computing adv examples
    self.adversarial_input_id = None # Which adv images to use; None to use all
    self.adversarial_target_labels = None # Parameter for "specified" target_method. Only for class attacks. Needs to be a list or numpy array of size [adv_batch_size]
    
    # Data params
    self.data_dir = os.path.join(ROOT_DIR, 'Datasets')
    self.data_type = 'mnist'
    self.vectorize_data = True
    self.rescale_data = True
    self.batch_size = 100
    self.rand_seed = rand_seed

In [None]:
#model_names = ['mlp_cosyne_mnist', 'slp_lca_768_latent_75_steps_mnist']
#model_names = ['mlp_768_mnist', 'slp_lca_768_latent_mnist']
model_names = ['mlp_1568_mnist', 'slp_lca_1568_latent_mnist']
model_types = ['MLP', 'LCA']

analysis_params = params()
analysis_params.projects_dir = os.path.expanduser("~")+"/Work/Projects/"
analyzers = []
for model_type, model_name in zip(model_types, model_names):
    analysis_params.model_name = model_name
    analysis_params.version = '0.0'
    analysis_params.model_dir = analysis_params.projects_dir+analysis_params.model_name
    model_log_file = (analysis_params.model_dir+"/logfiles/"+analysis_params.model_name
      +"_v"+analysis_params.version+".log")
    model_logger = Logger(model_log_file, overwrite=False)
    model_log_text = model_logger.load_file()
    model_params = model_logger.read_params(model_log_text)[-1]
    analysis_params.model_type = model_params.model_type
    analyzer = ap.get_analyzer(analysis_params.model_type)
    analysis_params.save_info = 'analysis_test_' + analysis_params.analysis_dataset
    analysis_params.save_info += (
        '_linf_'+str(analysis_params.adversarial_max_change)
        +'_ss_'+str(analysis_params.adversarial_step_size)
        +'_ns_'+str(analysis_params.adversarial_num_steps)
        +'_ct_'+str(analysis_params.confidence_threshold)
        +'_confidence_attack'
    )
    analyzer.setup(analysis_params)
    analyzer.model_type = model_type
    analyzer.confidence_threshold = analysis_params.confidence_threshold
    analyzers.append(analyzer)

In [None]:
dsc_data = ds.get_data(analysis_params)
dsc_data = analyzers[0].model.preprocess_dataset(dsc_data, analysis_params)
dsc_data = analyzers[0].model.reshape_dataset(dsc_data, analysis_params)

In [None]:
for analyzer_idx, analyzer in enumerate(analyzers):
    analyzer.model_params.data_shape = list(dsc_data['test'].shape[1:])
    analyzer.model_params.temperature = analysis_params.temperature[analyzer_idx]
    analyzer.setup_model(analyzer.model_params)
dsc_image_batch, dsc_label_batch, _ = dsc_data['test'].next_batch(analysis_params.batch_size, shuffle_data=False)
dsc_data['test'].reset_counters()

dsc_all_images = dsc_data['test'].images
dsc_all_images = dsc_all_images.reshape((dsc_all_images.shape[0], 784))
dsc_all_labels = dsc_data['test'].labels

### Compare DeepSparseCoding confidence and accuracy

In [None]:
for analyzer in analyzers:
    analyzer.logits = np.squeeze(analyzer.compute_activations(dsc_all_images, batch_size=50, activation_operation=analyzer.model.get_logits_with_temp))
    analyzer.softmaxes = np.squeeze(analyzer.compute_activations(dsc_all_images, batch_size=50, activation_operation=analyzer.model.get_label_est))

In [None]:
n_bins = 75
bin_boundaries = np.linspace(0, 1, n_bins + 1)

fig, axs = plot.subplots(ncols=2)
for ax, model, codebase in zip(axs, analyzers, ['DSC', 'DSC']):
    confidence, accuracy, props = bin_conf_acc_prop(
        model.softmaxes,
        dsc_all_labels,
        bin_boundaries
    )
    ece = 0
    for avg_confidence_in_bin, accuracy_in_bin, prop_in_bin in zip(confidence, accuracy, props):
        ece += np.abs(avg_confidence_in_bin.item() - accuracy_in_bin.item()) * prop_in_bin.item()
    ece *= 100
    ax.scatter(confidence, accuracy, s=[prop*500 for prop in props if prop > 0], color='k')
    ax.plot([0,1], [0,1], 'k--', linewidth=0.1)
    ax.format(title=f'{codebase}_{model.model_type}\nECE = {ece.round(4)}%\nTemperature={model.model_params.temperature}')
axs.format(
    suptitle='Reliability of classifier confidence on test set',
    xlabel='Confidence',
    ylabel='Accuracy',
    xlim=[0, 1],
    ylim=[0, 1]
)
plot.show()

In [None]:
dsc_logit_forward = [analyzer.logits for analyzer in analyzers]
dsc_logit_forward = np.stack(dsc_logit_forward, axis=0)
dsc_softmax_forward = [analyzer.softmaxes for analyzer in analyzers]
dsc_softmax_forward = np.stack(dsc_softmax_forward, axis=0)

In [None]:
img_idx = np.random.randint(analysis_params.batch_size)
fig, axs = plot.subplots(
    [[1, 2], [1, 3]],
    ref=1, axwidth=1.8, span=False
)
im = axs[0].imshow(dsc_all_images[img_idx, ...].reshape(28, 28), cmap='greys_r')
axs[0].format(title=f'DSC dataset digit class {dp.one_hot_to_dense(dsc_all_labels)[img_idx]}')
axs[0].colorbar(im)
pf.clear_axis(axs[0])
axs[0].set_aspect='equal'
axs[1].bar(np.arange(10), np.squeeze(dsc_logit_forward[0, img_idx, :]))
axs[1].format(title=f'DSC_{analyzers[0].model_type}')
axs[2].bar(np.arange(10), np.squeeze(dsc_logit_forward[1, img_idx, :]))
axs[2].format(title=f'DSC_{analyzers[1].model_type}')
axs.format(
    suptitle='Logit outputs for a single image',
    xtickminor=False,
    xticks=1,
)
plot.show()

In [None]:
fig, axs = plot.subplots(ncols=2, nrows=1)
axs[0].bar(np.squeeze(dsc_softmax_forward[0, img_idx, :]))
axs[0].format(title=f'DSC_{analyzers[0].model_type}')
axs[1].bar(np.squeeze(dsc_softmax_forward[1, img_idx, :]))
axs[1].format(title=f'DSC_{analyzers[1].model_type}')
axs.format(suptitle='Softmax confidence for a single image', xtickminor=False, xticks=1, ylim=[0, 1])
plot.show()

In [None]:
names = ['DSC_MLP', 'DSC_LCA']
data = pd.DataFrame(
    dsc_softmax_forward.reshape(2, -1).T,
    columns=pd.Index(names, name='Model')
)
fig, ax = plot.subplots(ncols=1, axwidth=2.5, share=0)
ax.format(
    grid=False,
    suptitle='Softmax confidence for the test set' 
)
obj1 = ax.boxplot(
    data, linewidth=0.7, marker='.', fillcolor='gray5',
    medianlw=1, mediancolor='k', meancolor='k', meanlw=1
)
ax.format(yscale='log', yformatter='sci')

In [None]:
num_bins = 100
fig, axs = plot.subplots(ncols=2, nrows=1)
for ax, model, atk_type in zip(axs, analyzers, ['DSC', 'DSC']):
    max_confidence = np.max(model.softmaxes, axis=1) # max is across categories, per image
    conf_lim = [0, 1]
    bins = np.linspace(conf_lim[0], conf_lim[1], num_bins)
    count, bin_edges = np.histogram(max_confidence, bins)
    bin_left, bin_right = bin_edges[:-1], bin_edges[1:]
    bin_centers = bin_left + (bin_right - bin_left)/2
    ax.bar(bin_centers, count, color='k')
    mean_confidence = np.mean(max_confidence)
    mean_idx = np.abs(bin_edges - mean_confidence).argmin()
    mean_conf_bin = bin_edges[mean_idx].round(4)
    ax.axvline(mean_conf_bin, lw=1, ls='--', color='r')
    ax.format(
        title=f'{atk_type}_{model.model_type}\nMean confidence = {mean_confidence:.3f}',
        yscale='log',
        xlim=conf_lim
    )
axs.format(
    suptitle='Softmax confidence on the clean test set correct label',
    ylabel='Count',
    xlabel='Confidence'
)

### Run DeepSparseCoding adversarial attack

In [None]:
def get_adv_indices(softmax_conf, all_kept_indices, confidence_threshold, num_images, labels):
    softmax_conf[np.arange(num_images, dtype=np.int32), labels] = 0 # zero confidence at true label
    confidence_indices = np.max(softmax_conf, axis=-1) # highest non-true label confidence
    adversarial_labels = np.argmax(softmax_conf, axis=-1) # index of highest non-true label
    all_above_thresh = np.nonzero(np.squeeze(confidence_indices>confidence_threshold))[0]
    keep_indices = np.array([], dtype=np.int32)
    for adv_index in all_above_thresh:
        if adv_index not in set(all_kept_indices):
            keep_indices = np.append(keep_indices, adv_index)
    return keep_indices, confidence_indices, adversarial_labels

### TODO: get this working on the full dataset

In [None]:
#if run_full_test_set:
#    data = dsc_all_images
#    labels = dsc_all_labels
#else:
data = dsc_image_batch
labels = dsc_label_batch

for analyzer in analyzers:
    analyzer.class_adversary_analysis(
        data,
        labels,
        batch_size=analyzer.analysis_params.eval_batch_size,
        input_id=analyzer.analysis_params.adversarial_input_id,
        target_method = analyzer.analysis_params.adversarial_target_method,
        target_labels = analyzer.analysis_params.adversarial_target_labels,
        save_info=analyzer.analysis_params.save_info)

### Compare DeepSparseCoding & Foolbox adversarial attacks

In [None]:
for analyzer in analyzers:
    analyzer.accuracy = analyzer.adversarial_clean_accuracy.item()
    print(f'DSC {analyzer.model_type} clean accuracy = {analyzer.accuracy} and adv accuracy = {analyzer.adversarial_adv_accuracy}')

In [None]:
def stars(p):
    if p < 0.0001:
        return '****'
    elif (p < 0.001):
        return '***'
    elif (p < 0.01):
        return '**'
    elif (p < 0.05):
        return '*'
    else:
        return 'n.s.'
    
names = ['MLP 2L;768N','LCA 2L;768N']

dsc_all_success_indices = np.intersect1d(*[analyzer.success_indices for analyzer in analyzers])
dsc_adv_results_list = [analyzer.mean_squared_distances[0][dsc_all_success_indices] for analyzer in analyzers]
dsc_all_results = np.stack(dsc_adv_results_list, axis=-1).squeeze()
dsc_dataframe = pd.DataFrame(
    dsc_all_results,
    columns=pd.Index(names, name='Model')
)

dsc_p_value = linear_correlation(dsc_all_results[:,0], dsc_all_results[:,1])[1]

fig, axs = plot.subplots(ncols=1, axwidth=2.5, share=0)
axs.format(grid=False, suptitle='L infinity Attack Mean Squared Distances')

ax = axs[0]
obj2 = ax.boxplot(
    dsc_dataframe, linewidth=0.7, marker='.', fillcolor='gray5',
    medianlw=1, mediancolor='k', meancolor='k', meanlw=1
)
ax_y_max = max(ax.get_ylim())
ax.text(0.5, ax_y_max-0.1*(ax_y_max), stars(dsc_p_value),
       horizontalalignment='center',
       verticalalignment='center',
       fontsize=14)
ax.format(title='Deep Sparse Coding')#, ylim=[0, 0.03])

## attack images

In [None]:
fig, axs = plot.subplots(nrows=3, ncols=len(analyzers))
pf.clear_axes(axs)
top_level = zip(analyzers, dsc_adv_results_list, ['DSC', 'DSC'])
for model_idx, (model, adv_results_list, atk_type) in enumerate(top_level):
    if atk_type == 'DSC':
        adv_imgs = model.conf_adversarial_images[0]
        adv_labels = model.conf_adversarial_labels[0]
    else:
        adv_imgs = model.conf_adversarial_images
        adv_labels = model.conf_adversarial_labels
    adv_results = adv_results_list#[0]
    adv_min_idx = np.abs(adv_results - adv_results.min()).argmin()
    adv_mean_idx = np.abs(adv_results - adv_results.mean()).argmin()
    adv_max_idx = np.abs(adv_results - adv_results.max()).argmin()
    for row_idx, image_idx in enumerate([adv_min_idx, adv_mean_idx, adv_max_idx]):
        img = adv_imgs[image_idx, ...].reshape(28, 28).astype(np.float32)
        h = axs[row_idx, model_idx].imshow(img, cmap='grays')
        axs[row_idx, model_idx].colorbar(h, loc='r', ticks=1)
        axs[row_idx, model_idx].format(title=f'{atk_type}_{model.model_type} adversarial label = {adv_labels[row_idx]}')
    axs[row_idx, 0].format(llabels=['Min MSD', 'Mean MSD', 'Max MSD'])
plot.show()