# Neural Computation Paper Figures

### Imports

In [None]:
import os
os.chdir("../")
%env CUDA_VISIBLE_DEVICES=0
%matplotlib inline

In [None]:
import re
import numpy as np
from functools import reduce as reduce
from skimage.io import imread
from skimage.util import crop
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.ticker as plticker
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.patches import FancyArrowPatch
from matplotlib.lines import Line2D
from mpl_toolkits.mplot3d import proj3d
import matplotlib.font_manager
import matplotlib.cm as cm
import pickle
import tensorflow as tf
from utils.logger import Logger
from data.dataset import Dataset
import data.data_selector as ds
import analysis.analysis_picker as ap
import utils.data_processing as dp
import utils.plot_functions as pf
import utils.neural_comp_funcs as nc

### Parameters

In [None]:
text_width = 416.83269 #pt = 14.65cm
fontsize = 12
dpi = 1200#800

In [None]:
font_settings = {
        "text.usetex": True,
        "font.family": "serif",
        "axes.labelsize": fontsize,
        "axes.titlesize": fontsize,
        "figure.titlesize": fontsize,
        "font.size": fontsize,
        "legend.fontsize": fontsize,
        "xtick.labelsize": fontsize-2,
        "ytick.labelsize": fontsize-2,
}
mpl.rcParams.update(font_settings)

In [None]:
class lca_512_vh_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_512_vh"
    self.display_name = "Sparse Coding 512"
    self.version = "0.0"
    self.save_info = "analysis_train_carlini_targeted"
    self.overwrite_analysis_log = False

class lca_768_vh_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_768_vh"
    self.display_name = "Sparse Coding 768"
    self.version = "0.0"
    self.save_info = "analysis_train_carlini_targeted"
    self.overwrite_analysis_log = False

class lca_1024_vh_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_1024_vh"
    self.display_name = "Sparse Coding 1024"
    self.version = "0.0"
    self.save_info = "analysis_train_carlini_targeted"
    self.overwrite_analysis_log = False

class lca_2560_vh_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_2560_vh"
    self.display_name = "Sparse Coding"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class sae_768_vh_params(object):
  def __init__(self):
    self.model_type = "sae"
    self.model_name = "sae_768_vh"
    self.display_name = "Sparse Autoencoder"
    self.version = "1.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class rica_768_vh_params(object):
  def __init__(self):
    self.model_type = "rica"
    self.model_name = "rica_768_vh"
    self.display_name = "Linear Autoencoder"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class ae_768_vh_params(object):
  def __init__(self):
    self.model_type = "ae"
    self.model_name = "ae_768_vh"
    self.display_name = "ReLU"
    self.version = "1.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class lca_768_mnist_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_768_mnist"
    self.display_name = "Sparse Coding 768"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class lca_1536_mnist_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_1536_mnist"
    self.display_name = "Sparse Coding 1536"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False

class ae_768_mnist_params(object):
  def __init__(self):
    self.model_type = "ae"
    self.model_name = "ae_768_mnist"
    self.display_name = "Leaky ReLU"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False

class sae_768_mnist_params(object):
  def __init__(self):
    self.model_type = "sae"
    self.model_name = "sae_768_mnist"
    self.display_name = "Sparse Autoencoder"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False

class rica_768_mnist_params(object):
  def __init__(self):
    self.model_type = "rica"
    self.model_name = "rica_768_mnist"
    self.display_name = "Linear Autoencoder"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class ae_deep_mnist_params(object):
  def __init__(self):
    self.model_type = "ae"
    self.model_name = "ae_deep_mnist"
    self.display_name = "Leaky ReLU"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False

In [None]:
num_levels = 10
color_vals = dict(zip(["blk", "lt_green", "md_green", "dk_green", "lt_blue", "md_blue", "dk_blue", "lt_red", "md_red", "dk_red"],
  ["#000000", "#A9DFBF", "#196F3D", "#27AE60", "#AED6F1", "#3498DB", "#21618C", "#F5B7B1", "#E74C3C", "#943126"]))

### Iso-contour activations comparison

In [None]:
params_list = [rica_768_vh_params(), ae_768_vh_params(), sae_768_vh_params(), lca_2560_vh_params()]
for params in params_list:
  params.model_dir = (os.path.expanduser("~")+"/Work/Projects/"+params.model_name)
analyzer_list = [ap.get_analyzer(params.model_type) for params in params_list]
for analyzer, params in zip(analyzer_list, params_list):
  analyzer.setup(params)
  analyzer.model.setup(analyzer.model_params)
  analyzer.load_analysis(save_info=params.save_info)
  analyzer.model_name = params.model_name

In [None]:
save_name = ""
for analyzer in analyzer_list:
  run_params = np.load(analyzer.analysis_out_dir+"savefiles/iso_params_"+save_name
    +analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
  min_angle = run_params["min_angle"]
  max_angle = run_params["max_angle"]
  num_neurons = run_params["num_neurons"]
  use_bf_stats = run_params["use_bf_stats"]
  analyzer.num_comparison_vectors = run_params["num_comparison_vects"]
  x_range = run_params["x_range"]
  y_range = run_params["y_range"]
  num_images = run_params["num_images"]

  iso_vectors = np.load(analyzer.analysis_out_dir+"savefiles/iso_vectors_"+save_name
    +analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
  analyzer.target_neuron_ids = iso_vectors["target_neuron_ids"]
  analyzer.comparison_neuron_ids = iso_vectors["comparison_neuron_ids"]
  analyzer.target_vectors = iso_vectors["target_vectors"]
  analyzer.rand_orth_vectors = iso_vectors["rand_orth_vectors"]
  analyzer.comparison_vectors = iso_vectors["comparison_vectors"]

  analyzer.comp_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_activations_"+save_name
    +analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"]
  analyzer.comp_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_contour_dataset_"+save_name
    +analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
  analyzer.rand_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_activations_"+save_name
    +analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"]
  analyzer.rand_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_contour_dataset_"+save_name
    +analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()

In [None]:
neuron_indices = [0, 0, 0, 0]
orth_indices = [0, 0, 0, 1]
num_plots_y = 2
num_plots_x = 2
width_fraction = 1.0
show_contours = True

contour_fig, contour_handles = nc.plot_group_iso_contours(analyzer_list, neuron_indices, orth_indices,
  num_levels, x_range, y_range, show_contours, text_width, width_fraction, dpi)
for analyzer, neuron_index, orth_index in zip(analyzer_list, neuron_indices, orth_indices):
  for ext in [".eps"]:#[".png", ".eps"]:
    neuron_str = str(analyzer.target_neuron_ids[neuron_index])
    orth_str = str(analyzer.comparison_neuron_ids[neuron_index][orth_index])
    save_name = analyzer.analysis_out_dir+"/vis/iso_contour_comparison_"
    if not show_contours:
      save_name += "continuous_"
    save_name += "bf0id"+neuron_str+"_bf1id"+orth_str+"_"+analyzer.analysis_params.save_info+ext
    contour_fig.savefig(save_name, dpi=dpi, transparent=False, bbox_inches="tight", pad_inches=0.01)

### Curvature histogram

In [None]:
params_list = [lca_512_vh_params(), lca_768_vh_params(), lca_2560_vh_params()]
iso_save_name = "iso_curvature_xrange1.3_yrange-2.2_"
attn_save_name = "1d_"

for params in params_list:
  params.model_dir = (os.path.expanduser("~")+"/Work/Projects/"+params.model_name)

analyzer_list = [ap.get_analyzer(params.model_type) for params in params_list]
for analyzer, params in zip(analyzer_list, params_list):
  analyzer.setup(params)
  analyzer.model.setup(analyzer.model_params)
  analyzer.load_analysis(save_info=params.save_info)
  analyzer.model_name = params.model_name

  analyzer.iso_params = np.load(analyzer.analysis_out_dir+"savefiles/iso_params_"+iso_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
  analyzer.iso_comp_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_activations_"+iso_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"]
  analyzer.iso_comp_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_contour_dataset_"+iso_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
  analyzer.iso_rand_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_activations_"+iso_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"]
  analyzer.iso_rand_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_contour_dataset_"+iso_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()

  analyzer.iso_num_target_neurons = analyzer.iso_params["num_neurons"]
  analyzer.iso_num_comparison_vectors = analyzer.iso_params["num_comparison_vects"]
  
  analyzer.attn_params = np.load(analyzer.analysis_out_dir+"savefiles/iso_params_"+attn_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
  analyzer.attn_comp_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_activations_"+attn_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"]
  analyzer.attn_comp_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_contour_dataset_"+attn_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
  analyzer.attn_rand_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_activations_"+attn_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"]
  analyzer.attn_rand_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_contour_dataset_"+attn_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
  
  analyzer.attn_num_target_neurons = analyzer.attn_params["num_neurons"]
  analyzer.attn_num_comparison_vectors = analyzer.attn_params["num_comparison_vects"]
  
mesh_save_name = "iso_curvature_ryan_"
contour_activity = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_activations_"+mesh_save_name
  +analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"][0, 1, ...]
comp_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_contour_dataset_"+mesh_save_name
  +analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
x = range(comp_contour_dataset["x_pts"].size)
y = range(comp_contour_dataset["y_pts"].size)
contour_pts = (x,y)

In [None]:
target_act = 0.3 # target activity spot between min & max value of normalized activity (btwn 0 and 1)
num_bins = 50
nc.compute_curvature_fits(analyzer_list, target_act)
nc.compute_curvature_hists(analyzer_list, num_bins)

In [None]:
label_list = [["2x", "4x", "10x"]]*2
color_list = [[color_vals["lt_red"], color_vals["md_red"], color_vals["dk_red"]],
  [color_vals["lt_blue"], color_vals["md_blue"], color_vals["dk_blue"]]]
curve_lims = { 
    "x":[min(comp_contour_dataset["x_pts"]), max(comp_contour_dataset["x_pts"])],
    "y":[min(comp_contour_dataset["y_pts"]), max(comp_contour_dataset["y_pts"])]
}

iso_title = "Iso-Response"
iso_hist_list = [[analyzer.iso_comp_hist for analyzer in analyzer_list],
  [analyzer.iso_rand_hist for analyzer in analyzer_list]]
plot_bin_lefts, plot_bin_rights = analyzer_list[0].iso_bin_edges[:-1], analyzer_list[0].iso_bin_edges[1:]
iso_plot_bin_centers = plot_bin_lefts + (plot_bin_rights - plot_bin_lefts)

attn_title = "Response Attenuation"
attn_hist_list = [[analyzer.attn_comp_hist for analyzer in analyzer_list],
  [analyzer.attn_rand_hist for analyzer in analyzer_list]]
plot_bin_lefts, plot_bin_rights = analyzer_list[0].attn_bin_edges[:-1], analyzer_list[0].attn_bin_edges[1:]
attn_plot_bin_centers = plot_bin_lefts + (plot_bin_rights - plot_bin_lefts)

full_hist_list = [iso_hist_list, attn_hist_list]
full_label_list = [label_list,]*2
full_color_list = [color_list,]*2
full_bin_centers = [iso_plot_bin_centers, attn_plot_bin_centers]
full_title = [iso_title, attn_title]
full_xlabel = ["Curvature (Comparison)", "Curvature (Random)"]

In [None]:
mesh_color = "#A9A9A9"
contour_angle = 195
# The following two variables set the curvature line label locations.
# The numbering is sort of [z x y], from a 2-D perspective
resp_att_loc = [105, 38, 0.80]#[105, 268, 0.38] 
iso_resp_loc = [-18, 20, 0.11]#[-18, 215, 0.41]
contour_text_loc = [iso_resp_loc, resp_att_loc]

curvature_fig = nc.plot_curvature_histograms(contour_activity, contour_pts, contour_angle, contour_text_loc,
  full_hist_list, full_label_list, full_color_list, mesh_color, full_bin_centers, full_title, full_xlabel,
  curve_lims, text_width=text_width, width_ratio=1.0, dpi=dpi)

for analyzer in analyzer_list:
  for ext in [".pdf"]:
    save_name = (analyzer.analysis_out_dir+"/vis/curvatures_and_histograms"
      +"_"+analyzer.analysis_params.save_info+ext)
    curvature_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

### Orientation Selectivity

In [None]:
params_list = [rica_768_vh_params(), sae_768_vh_params(), lca_768_vh_params()]
params_list[-1].display_name = "Sparse Coding"
for params in params_list:
  params.model_dir = (os.path.expanduser("~")+"/Work/Projects/"+params.model_name)
analyzer_list = [ap.get_analyzer(params.model_type) for params in params_list]
for analyzer, params in zip(analyzer_list, params_list):
  analyzer.setup(params)
  analyzer.model.setup(analyzer.model_params)
  analyzer.load_analysis(save_info=params.save_info)
  analyzer.model_name = params.model_name

In [None]:
circ_var_list = []
for analyzer in analyzer_list:
  analyzer.bf_indices = np.random.choice(analyzer.ot_grating_responses["neuron_indices"], 12)
  num_orientation_samples = len(analyzer.ot_grating_responses['orientations'])
  corresponding_angles_deg = (180 * np.arange(num_orientation_samples) / num_orientation_samples) - 90
  corresponding_angles_rad = (np.pi * np.arange(num_orientation_samples) / num_orientation_samples) - (np.pi/2)
  analyzer.metrics_list = {"fwhm":[], "circ_var":[], "osi":[], "skipped_indices":[]}
  contrast_idx = -1
  for bf_idx in range(analyzer.bf_stats["num_outputs"]):
    ot_curve = nc.center_curve(analyzer.ot_grating_responses["mean_responses"][bf_idx, contrast_idx, :])
    if np.max(ot_curve) - np.min(ot_curve) == 0:
      analyzer.metrics_list["skipped_indices"].append(bf_idx)
    else:
      fwhm = nc.compute_fwhm(ot_curve, corresponding_angles_deg)
      analyzer.metrics_list["fwhm"].append(fwhm)
      circ_var = nc.compute_circ_var(ot_curve, corresponding_angles_rad)
      analyzer.metrics_list["circ_var"].append(circ_var)
      osi = nc.compute_osi(ot_curve)
      analyzer.metrics_list["osi"].append(osi)
  circ_var_list.append(np.array([val[2] for val in analyzer.metrics_list["circ_var"]]))

In [None]:
color_list = [color_vals["md_green"], color_vals["md_blue"], color_vals["md_red"]]
label_list = ["Linear Autoencoder", "Sparse Autoencoder", "Sparse Coding"]
num_bins = 30
width_ratios = [0.5, 0.25, 0.25]
height_ratios = [0.13, 0.25, 0.25, 0.25]
density = False

circ_var_fig = nc.plot_circ_variance_histogram(analyzer_list, circ_var_list, color_list, label_list, num_bins,
  density, width_ratios, height_ratios, text_width=text_width, width_ratio=1.0, dpi=dpi)

for analyzer in analyzer_list:
  for ext in [".png", ".eps"]:
    save_name = (analyzer.analysis_out_dir+"/vis/circular_variance_combo"
      +"_"+analyzer.analysis_params.save_info+ext)
    circ_var_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

In [None]:
spatial_frequencies = np.stack([np.array(analyzer.bf_stats["spatial_frequencies"]) for analyzer in analyzer_list], axis=0)
circular_variances = np.stack([variance for variance in circ_var_list], axis=0)

cv_vs_sf_fig = plt.figure(figsize=nc.set_size(text_width), dpi=dpi)
ax = cv_vs_sf_fig.add_subplot()
for analyzer_idx in range(len(analyzer_list)):
  ax.scatter(spatial_frequencies[analyzer_idx, :], circular_variances[analyzer_idx, :],
    s=12, color=color_list[analyzer_idx], label=label_list[analyzer_idx])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("Spatial Frequency (Cycles/Patch)")
ax.set_ylabel("Circular Variance")
#ax.set_title("Weight spatial frequency alone does\nnot account for improved selectivity")
legend = ax.legend(loc="upper center", framealpha=1.0, ncol=1, borderaxespad=0., borderpad=0.,
  handlelength=0., labelspacing=0.1, bbox_to_anchor=(0.38, 0.98))
legend.get_frame().set_linewidth(0.0)
for text, color in zip(legend.get_texts(), color_list):
  text.set_color(color)
for item in legend.legendHandles:
  item.set_visible(False)
plt.show()

for analyzer in analyzer_list:
  for ext in [".eps"]:
    save_name = (analyzer.analysis_out_dir+"/vis/spatial_freq_vs_circular_variance"
      +"_"+analyzer.analysis_params.save_info+ext)
    cv_vs_sf_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

In [None]:
params_list = [lca_512_vh_params(), lca_768_vh_params(), lca_1024_vh_params()]#, lca_2560_vh_params()]
display_names = ["512 Neurons", "768 Neurons", "1024 Neurons"]#, "2560 Neurons"]
for params, display_name in zip(params_list, display_names):
  params.display_name = display_name
  params.model_dir = (os.path.expanduser("~")+"/Work/Projects/"+params.model_name)
analyzer_list = [ap.get_analyzer(params.model_type) for params in params_list]
for analyzer, params in zip(analyzer_list, params_list):
  analyzer.setup(params)
  analyzer.model.setup(analyzer.model_params)
  analyzer.load_analysis(save_info=params.save_info)
  analyzer.model_name = params.model_name

In [None]:
circ_var_list = []
for analyzer in analyzer_list:
  analyzer.bf_indices = np.random.choice(analyzer.ot_grating_responses["neuron_indices"], 12)
  num_orientation_samples = len(analyzer.ot_grating_responses['orientations'])
  corresponding_angles_deg = (180 * np.arange(num_orientation_samples) / num_orientation_samples) - 90
  corresponding_angles_rad = (np.pi * np.arange(num_orientation_samples) / num_orientation_samples) - (np.pi/2)
  analyzer.metrics_list = {"fwhm":[], "circ_var":[], "osi":[], "skipped_indices":[]}
  contrast_idx = -1
  for bf_idx in range(analyzer.bf_stats["num_outputs"]):
    ot_curve = nc.center_curve(analyzer.ot_grating_responses["mean_responses"][bf_idx, contrast_idx, :])
    if np.max(ot_curve) - np.min(ot_curve) == 0:
      analyzer.metrics_list["skipped_indices"].append(bf_idx)
    else:
      fwhm = nc.compute_fwhm(ot_curve, corresponding_angles_deg)
      analyzer.metrics_list["fwhm"].append(fwhm)
      circ_var = nc.compute_circ_var(ot_curve, corresponding_angles_rad)
      analyzer.metrics_list["circ_var"].append(circ_var)
      osi = nc.compute_osi(ot_curve)
      analyzer.metrics_list["osi"].append(osi)
  circ_var_list.append(np.array([val[2] for val in analyzer.metrics_list["circ_var"]]))

In [None]:
color_list = [color_vals["md_green"], color_vals["md_blue"], color_vals["md_red"]]#, color_vals["blk"]]
label_list = display_names
num_bins = 30
width_ratios = [0.5, 0.25, 0.25]
height_ratios = [0.13, 0.25, 0.25, 0.25]
density = True

oc_vs_cv_fig = nc.plot_circ_variance_histogram(analyzer_list, circ_var_list, color_list, label_list, num_bins,
  density, width_ratios, height_ratios, text_width=text_width, width_ratio=1.0, dpi=dpi)

for analyzer in analyzer_list:
  for ext in [".png", ".eps"]:
    save_name = (analyzer.analysis_out_dir+"/vis/overcompleteness_vs_circular_variance"
      +"_"+analyzer.analysis_params.save_info+ext)
    oc_vs_cv_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

## Targeted attacks on MLP & LCA network

In [None]:
def std_conf(outputs, labels, index):
  return np.std(np.sum(outputs[index] * labels, axis=1))

### Functions for finding MSE data
def find_conf_index(analysis, stop_conf):
  outputs = analysis["adversarial_outputs"][0]
  target_labels = analysis["target_labels"]
  confs = np.sum(outputs * target_labels, axis=-1)
  stop_indices = []
  for i in range(len(confs.T)):
    gt_stop_conf = np.where(confs.T[i] >= stop_conf)[0]
    if len(gt_stop_conf) > 0:
      stop_indices.append(gt_stop_conf[0])
    else:
      stop_indices.append(-1)
  return stop_indices

### Functions for the adv MSE bar plots
def get_rects(filename_list, metric, stop_conf):
  data = []; means = []; stds = [];
  for file in filename_list:
    metrics = np.load(file, allow_pickle=True)["data"].item()
    stop_indices = find_conf_index(metrics, stop_conf)
    MSE = metrics[metric][0, stop_indices, np.arange(metrics[metric][0].shape[-1])]
    data.append(MSE)
    means.append(np.mean(MSE))
    stds.append(np.std(MSE)) 
  return data, means, stds

def get_flat_mnist_mse_data(file_list, metric, stop_conf):
  return get_rects(file_list, metric, stop_conf)

def get_results(metric_files, img_files, stop_conf):
  imgs = [np.load(file, allow_pickle=True)["data"].item() for file in img_files]
  metrics = [np.load(file, allow_pickle=True)["data"].item() for file in metric_files]
  indices = [np.array(find_conf_index(r, stop_conf)) for r in metrics]
  input_images = [r["input_images"] for r in metrics]
  input_clf = [np.argmax(r["input_labels"], axis=1) for r in metrics]
  adv_images = [r["adversarial_images"][0, i, np.arange(len(i)), :] for i, r in zip(indices,imgs)]
  adv_clf = [np.argmax(r["adversarial_outputs"][0, i, np.arange(len(i))], axis=1)
    for i, r in zip(indices,metrics)]
  return input_images, input_clf, adv_images, adv_clf

def get_mnist_data(metric_files, img_files, stop_conf):
  imgs = [np.load(file, allow_pickle=True)["data"].item() for file in img_files]
  metrics = [np.load(file, allow_pickle=True)["data"].item() for file in metric_files]
  indices = [np.array(find_conf_index(r, stop_conf)) for r in metrics]
  input_images = [r["input_images"].reshape(-1, 28, 28) for r in metrics]
  input_clf = [np.argmax(r["input_labels"], axis=1) for r in metrics]
  adv_images = [r["adversarial_images"][0, i, np.arange(len(i)), :].reshape(-1, 28, 28) for i, r in zip(indices, imgs)]
  adv_clf = [np.argmax(r["adversarial_outputs"][0, i, np.arange(len(i))], axis=1)
    for i, r in zip(indices,metrics)]
  return input_images, input_clf, adv_images, adv_clf

def get_cifar_mse_data(saved_info, model_names):
  data = []; means = []; stds = []
  for model_name in model_names:
    target_adv_mses = saved_info[model_name]["target_adv_mses"]
    data.append(target_adv_mses)
    means.append(np.mean(target_adv_mses))
    stds.append(np.std(target_adv_mses))
  return (data, means, stds)

def adjacent_values(vals, q1, q3):
    upper_adjacent_value = q3 + (q3 - q1) * 1.5
    upper_adjacent_value = np.clip(upper_adjacent_value, q3, vals[-1])
    lower_adjacent_value = q1 - (q3 - q1) * 1.5
    lower_adjacent_value = np.clip(lower_adjacent_value, vals[0], q1)
    return lower_adjacent_value, upper_adjacent_value

def make_violin(ax, group_data, group_means, x_pos, bar_width, color='k', plot_means=True, plot_medians=False):
  for data, means, pos in zip(group_data, group_means, x_pos):
    parts = ax.violinplot(data, [pos], widths=bar_width,
      showmeans=False, showextrema=False, showmedians=False, bw_method="silverman")#, bw_method=0.5)
    for pc in parts['bodies']:
      pc.set_facecolor(color)
      pc.set_edgecolor('k')
      pc.set_alpha(1)
    quartile1, medians, quartile3 = np.percentile(np.array(data), [25, 50, 75])#, axis=1)
    whiskers = np.array([adjacent_values(data, quartile1, quartile3)])
    whiskersMin, whiskersMax = whiskers[:, 0], whiskers[:, 1]
    if plot_medians:
      ax.scatter(pos, medians, marker='o', color='white', s=10, zorder=3, alpha=1)
    if plot_means:
      ax.scatter(pos, means, marker='o', color='white', s=10, zorder=3, alpha=1)
    ax.vlines(pos, quartile1, quartile3, color='k', linestyle='-', lw=5, alpha=1)
    ax.vlines(pos, whiskersMin, whiskersMax, color='k', linestyle='-', lw=1, alpha=1)
  return ax

def convert_cifar_label(label):
  if(label == 0 or label == "0"):
    return "airplane"
  if(label == 1 or label == "1"):
    return "automobile"
  if(label == 2 or label == "2"):
    return "bird"
  if(label == 3 or label == "3"):
    return "cat"
  if(label == 4 or label == "4"):
    return "deer"
  if(label == 5 or label == "5"):
    return "dog"
  if(label == 6 or label == "6"):
    return "frog"
  if(label == 7 or label == "7"):
    return "horse"
  if(label == 8 or label == "8"):
    return "ship"
  if(label == 9 or label == "9"):
    return "truck"
  return None

In [None]:
# adv stop confidence
stop_conf = .95

# path to projects directory
projects_path = os.path.expanduser("~")+"/Work/Projects/"

# kurakin analysis path
k_file_path = "/analysis/0.0/savefiles/class_adversary_analysis_test_kurakin_targeted.npz"
k_img_path = "/analysis/0.0/savefiles/class_adversary_images_analysis_test_kurakin_targeted.npz"

# carlini analysis path
c_file_path = "/analysis/0.0/savefiles/class_adversary_analysis_test_carlini_targeted.npz"
c_img_path = "/analysis/0.0/savefiles/class_adversary_images_analysis_test_carlini_targeted.npz"

# model names
lca_768_2layer = "slp_lca_768_latent_75_steps_mnist"
lca_768_3layer = "mlp_lca_768_latent_75_steps_mnist"
lca_1568_2layer = "slp_lca_1568_latent_75_steps_mnist"
lca_1568_3layer = "mlp_lca_1568_latent_75_steps_mnist"
mlp_768_2layer = "mlp_cosyne_mnist"
mlp_768_3layer = "mlp_3layer_cosyne_mnist"
mlp_1568_2layer = "mlp_1568_mnist"
mlp_1568_3layer = "mlp_1568_3layer_mnist"

# bar chart files/parameters
lca_1568_files = [projects_path + model_name + k_file_path for model_name in [lca_1568_2layer, lca_1568_3layer]]
lca_768_files = [projects_path + model_name + k_file_path for model_name in [lca_768_2layer, lca_768_3layer]]
mlp_1568_files = [projects_path + model_name + k_file_path for model_name in [mlp_1568_2layer, mlp_1568_3layer]]
mlp_768_files = [projects_path + model_name + k_file_path for model_name in [mlp_768_2layer, mlp_768_3layer]]

xtick_labels = ['2-layer', '3-layer']
file_lists = [mlp_768_files, mlp_1568_files, lca_768_files, lca_1568_files]
metric = "input_adv_mses"
colors = ['blue', 'red', 'green', 'orange']
names = ['w/o LCA_768', 'w/o LCA_1568', 'w/ LCA_768', 'w/ LCA_1568']

# adv example image files/parameters
mlp_metric_files = [projects_path +  file + k_file_path for file in [mlp_768_2layer, mlp_768_3layer]]
lca_metric_files = [projects_path +  file + k_file_path for file in [lca_768_2layer, lca_768_3layer]]
mlp_img_files = [projects_path +  file + k_img_path for file in [mlp_768_2layer, mlp_768_3layer]]
lca_img_files = [projects_path +  file + k_img_path for file in [lca_768_2layer, lca_768_3layer]]

mlp_orig_images, mlp_orig_labels, mlp_adv_images, mlp_target_labels = get_mnist_data(mlp_metric_files, mlp_img_files, stop_conf)
mlp_diff_images = [mlp_orig_images[model_idx] - mlp_adv_images[model_idx]
  for model_idx in range(len(mlp_img_files))]
lca_orig_images, lca_orig_labels, lca_adv_images, lca_target_labels = get_mnist_data(lca_metric_files, lca_img_files, stop_conf)
lca_diff_images = [lca_orig_images[model_idx] - lca_adv_images[model_idx]
  for model_idx in range(len(lca_img_files))]

mnist_mlp_grp = [mlp_orig_images, mlp_adv_images, mlp_diff_images]
mnist_lca_grp = [lca_orig_images, lca_adv_images, lca_diff_images]
mnist_target_labels = [mlp_target_labels, lca_target_labels] # [mlp/lca][which_model][which_image]
mnist_orig_labels = [mlp_orig_labels, lca_orig_labels] # [mlp/lca][which_model][which_image]

# model names
mnist_lca_768_2layer = "slp_lca_768_latent_75_steps_mnist"
mnist_lca_768_3layer = "mlp_lca_768_latent_75_steps_mnist"
mnist_lca_1568_2layer = "slp_lca_1568_latent_75_steps_mnist"
mnist_lca_1568_3layer = "mlp_lca_1568_latent_75_steps_mnist"
mnist_mlp_768_2layer = "mlp_cosyne_mnist"
mnist_mlp_768_3layer = "mlp_3layer_cosyne_mnist"
mnist_mlp_1568_2layer = "mlp_1568_mnist"
mnist_mlp_1568_3layer = "mlp_1568_3layer_mnist"

all_mnist_model_names = [mnist_lca_768_2layer, mnist_mlp_768_2layer, mnist_lca_768_3layer, mnist_mlp_768_3layer,
  mnist_lca_1568_2layer, mnist_mlp_1568_2layer, mnist_lca_1568_3layer, mnist_mlp_1568_3layer]
mnist_file_lists = [projects_path + model_name + k_file_path for model_name in all_mnist_model_names]

all_cifar_model_names = [
    "mlp_lca_latent_cifar10_gray_2layer",
    "mlp_cifar10_gray_2layer",
    "mlp_lca_latent_cifar10_gray_3layer",
    "mlp_cifar10_gray_3layer",
    "mlp_lca_latent_cifar10_gray_3136_2layer",
    "mlp_cifar10_gray_3136_2layer",
    "mlp_lca_latent_cifar10_gray_3136_3layer",
    "mlp_cifar10_gray_3136_3layer",
    ]

#Load data
pickle_filename = "vis/CIFAR10_adv_Sheng.pkl"
with open(pickle_filename, "rb") as f:
  saved_info = pickle.load(f)

In [None]:
mlp_models = ["mlp_cifar10_gray_2layer", "mlp_cifar10_gray_3layer"]
lca_models = ["mlp_lca_latent_cifar10_gray_2layer", "mlp_lca_latent_cifar10_gray_3layer"]

indiv_allowable_indices = []
for model in mlp_models+lca_models:
  indiv_allowable_indices.append(np.argwhere(saved_info[model]["target_conf_idx"] != 0))
allowable_indices = reduce(np.intersect1d, indiv_allowable_indices)
img_indices = np.random.choice(allowable_indices, 5, replace=False)
actual_img_indices = dict()
for model in mlp_models+lca_models:
  actual_img_indices[model] = []
  for index in img_indices:
    zero_locs = np.argwhere(saved_info[model]["target_conf_idx"][:index] == 0)
    actual_img_indices[model].append(index - zero_locs.size)

mlp_orig_images = [saved_info[model]["orig_img"][actual_img_indices[model], ...] for model in mlp_models]
mlp_adv_images = [saved_info[model]["adv_img"][actual_img_indices[model], ...] for model in mlp_models]
mlp_diff_images = [mlp_orig_images[model_idx] - mlp_adv_images[model_idx]
  for model_idx in range(len(mlp_models))]
mlp_target_labels = []
mlp_orig_labels = []
for model in mlp_models:
  target_labels = [convert_cifar_label(label) for label in saved_info[model]["target_label"][actual_img_indices[model]]]
  orig_labels = [convert_cifar_label(label) for label in saved_info[model]["orig_label"][actual_img_indices[model]]]
  mlp_target_labels.append(target_labels)
  mlp_orig_labels.append(orig_labels)

lca_orig_images = [saved_info[model]["orig_img"][actual_img_indices[model], ...] for model in lca_models]
lca_adv_images = [saved_info[model]["adv_img"][actual_img_indices[model], ...] for model in lca_models]
lca_diff_images = [lca_orig_images[model_idx] - lca_adv_images[model_idx]
  for model_idx in range(len(lca_models))]
lca_target_labels = []
lca_orig_labels = []
for model in lca_models:
  target_labels = [convert_cifar_label(label) for label in saved_info[model]["target_label"][actual_img_indices[model]]]
  lca_target_labels.append(target_labels)
  orig_labels = [convert_cifar_label(label) for label in saved_info[model]["orig_label"][actual_img_indices[model]]]
  lca_orig_labels.append(orig_labels)

cifar_mlp_grp = [mlp_orig_images, mlp_adv_images, mlp_diff_images]
cifar_lca_grp = [lca_orig_images, lca_adv_images, lca_diff_images]
cifar_target_labels = [mlp_target_labels, lca_target_labels]
cifar_orig_labels = [mlp_orig_labels, lca_orig_labels]

In [None]:
labelrotation = 50
bar_width = 0.4
inner_group_names = ["w/ LCA", "w/o LCA"]
#mnist_outer_group_names = ["2 layers 768", "3 layers 768", "2 layers 1568", "3 layers 1568"]
#cifar_outer_group_names = ["2 layers 1568", "3 layers 1568", "2 layers 3136", "3 layers 3136"]
mnist_outer_group_names = ["2L; 768N", "3L; 768N", "2L; 1568N", "3L; 1568N"]
cifar_outer_group_names = ["2L; 1568N", "3L; 1568N", "2L; 3136N", "3L; 3136N"]
mnist_img_labels = [mnist_outer_group_names[0], mnist_outer_group_names[2]]
cifar_img_labels = [cifar_outer_group_names[0], cifar_outer_group_names[2]]
COLORS = [
  [1.0, 0.0, 0.0], #"r"
  [0.0, 0.0, 1.0], #"b"
]
# bar_groups are organized from left to right, with space between inner lists
bar_groups = np.array([[0, 1], [2, 3], [4, 5], [6, 7]])
num_groups = bar_groups.shape[0]
num_per_group = bar_groups.shape[1]

mnist_data, mnist_means, mnist_stds = get_flat_mnist_mse_data(mnist_file_lists, metric, stop_conf)
cifar_data, cifar_means, cifar_stds = get_cifar_mse_data(saved_info, all_cifar_model_names)

In [None]:
all_model_names = all_mnist_model_names+all_cifar_model_names
project_dir = os.path.expanduser("~")+"/Work/Projects/"

In [None]:
def make_boxplot(ax, group_data, group_means, x_pos, bar_width, linewidth=2, color='k', plot_means=True, plot_medians=False):
  boxprops = dict(linestyle='-', linewidth=linewidth, color=color)
  whiskerprops = boxprops
  capprops = boxprops
  medianprops = dict(linestyle='--', linewidth=linewidth, color='k')
  meanprops = dict(linestyle='-', linewidth=linewidth, color='k')
  for data, means, pos in zip(group_data, group_means, x_pos):
    ax.boxplot(data, sym='', positions=[pos], widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,
      whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops, meanprops=meanprops)
  return ax

def plot_mse(ax, data, means, stds, num_groups, num_per_group, bar_groups, bar_width, COLORS,
             inner_group_names, outer_group_names, title):
  linewidth = 1
  for i_g in range(num_per_group):
    group_data = [data[i] for i in bar_groups[:, i_g]]
    group_means = [means[i] for i in bar_groups[:, i_g]]
    group_stds = [stds[i] for i in bar_groups[:, i_g]]
    x_pos = np.arange(num_groups) + i_g * bar_width
    #ax = make_violin(ax, group_data, group_means, x_pos, bar_width, COLORS[i_g])
    ax = make_boxplot(ax, group_data, group_means, x_pos, bar_width, linewidth, COLORS[i_g])
  legend_elements = [Line2D([0], [0], color=COLORS[0], lw=8),
                     Line2D([0], [0], color=COLORS[1], lw=8)]
  legend = ax.legend(legend_elements, inner_group_names, framealpha=1.0)
  legend.get_frame().set_linewidth(0.0)
  ax.set_xticks([r + (bar_width)/2 for r in range(num_groups)])
  ax.set_xticklabels(outer_group_names)
  ax.set_ylabel("Input to Adversarial MSD")
  ax.set_xlabel('Number of Layers and Neurons')
  ax.tick_params("x", labelrotation=labelrotation)
  ax.spines["top"].set_visible(False)
  ax.spines["right"].set_visible(False)
  ylim = ax.get_ylim()
  ax.set_ylim([0, ylim[1]])
  ax.set_title(title)
  return ax

def plot_adv_robustness(data_list, mean_list, std_list, num_groups, num_per_group, bar_groups, bar_width,
                        colors, inner_group_names, outer_group_names, titles, text_width=200,
                        width_ratio=1.0, dpi=100):
  mnist_data, cifar_data = data_list
  mnist_means, cifar_means = mean_list
  mnist_stds, cifar_stds = std_list
  mnist_outer_group_names, cifar_outer_group_names = outer_group_names
  mnist_title, cifar_title = titles
  num_y_plots = 2
  num_x_plots = 2
  fig = plt.figure(figsize=nc.set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi)
  gs_base = plt.GridSpec(num_y_plots, num_x_plots, figure=fig, wspace=0.3)
  ax_mnist_mse = fig.add_subplot(gs_base[:, 0])
  ax_mnist_mse = plot_mse(ax_mnist_mse, mnist_data, mnist_means, mnist_stds, num_groups, num_per_group,
    bar_groups, bar_width, colors, inner_group_names, mnist_outer_group_names, mnist_title)
  mnist_legend = ax_mnist_mse.get_legend()
  mnist_legend.set_visible(False)
  ax_cifar_mse = fig.add_subplot(gs_base[:, 1])
  ax_cifar_mse = plot_mse(ax_cifar_mse, cifar_data, cifar_means, cifar_stds, num_groups, num_per_group,
    bar_groups, bar_width, colors, inner_group_names, cifar_outer_group_names, cifar_title)
  ax_cifar_mse.set_ylabel("")
  plt.show()
  return fig

In [None]:
data_list = [mnist_data, cifar_data]
mean_list = [mnist_means, cifar_means]
std_list = [mnist_stds, cifar_stds]
outer_group_names = [mnist_outer_group_names, cifar_outer_group_names]
titles = ["MNIST", "Grayscale CIFAR"]

adv_fig = plot_adv_robustness(data_list, mean_list, std_list, num_groups, num_per_group, bar_groups, bar_width,
  COLORS, inner_group_names, outer_group_names, titles, text_width, width_ratio=1.0, dpi=dpi)

out_list = [project_dir + model_name + "/analysis/0.0/vis/adv_mse_comparison_boxplots" for model_name in all_model_names]
for out_name in out_list[0:2]: #slp_lca_768_latent_75_steps_mnist; mlp_cosyne_mnist
  for ext in [".eps"]:
    save_name = out_name+ext
    adv_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.04, dpi=dpi)

In [None]:
cifar_start_idx = 4
mnist_start_idx = 45#np.random.randint(90)
print("MNIST:", mnist_start_idx, "CIFAR:", cifar_start_idx)

In [None]:
def show_image_with_label(axis, image, clf, vmin=0, vmax=1, cmap="Greys"):
  im = axis.imshow(image, cmap=cmap, interpolation="nearest", vmin=vmin, vmax=vmax)
  for spine in axis.spines.values():
    spine.set_visible(False)
  axis.tick_params(
    axis="both",
    bottom="off",
    top="off", 
    left="off",
    right="off")
  axis.set_xticks([])
  axis.set_yticks([])
  props = dict(facecolor='white', alpha=1)
  if clf is not None:
    axis.text(.05, .95, str(clf), bbox=props,
      verticalalignment='center', color = "black")
  return im

def make_grid_subplots(fig, gs, mlp_grp, lca_grp, orig_labels, target_labels, group_names, group_name_loc,
                       orig_y_adj, start_idx=0, num_categories=3, crop_amount=0, hspace=0.5, wspace=-0.4,
                       cmap="Greys_r"):
  ax_orig_list = []
  gs_sub0_list = []
  gs_sub1_list = []
  for i in range(num_categories):
    ax_orig_list.append(fig.add_subplot(gs[i, :2]))
    gs_sub0_list.append(gridspec.GridSpecFromSubplotSpec(2, 2, gs[i, 2:4], hspace=hspace, wspace=wspace))
    gs_sub1_list.append(gridspec.GridSpecFromSubplotSpec(2, 2, gs[i, 4:], hspace=hspace, wspace=wspace))
  for category_idx in range(num_categories):
    image_idx = category_idx + start_idx
    orig_ax = ax_orig_list[category_idx]
    if category_idx == 0:
      orig_ax.set_title("Unperturbed", y=orig_y_adj)
    orig_img = crop(np.squeeze(mlp_grp[0][0][image_idx, ...]), crop_amount)
    orig_im_handle = show_image_with_label(orig_ax, orig_img, orig_labels[0][0][image_idx], cmap=cmap)
    for model_idx, gs_sub in enumerate([gs_sub0_list[category_idx], gs_sub1_list[category_idx]]):
      mlp_adv_img = crop(np.squeeze(mlp_grp[1][model_idx][image_idx, ...]), crop_amount)
      mlp_diff_img = crop(np.squeeze(mlp_grp[2][model_idx][image_idx, ...]), crop_amount)
      lca_adv_img = crop(np.squeeze(lca_grp[1][model_idx][image_idx, ...]), crop_amount)
      lca_diff_img = crop(np.squeeze(lca_grp[2][model_idx][image_idx, ...]), crop_amount)
      diff_vmin = np.min(np.concatenate((mlp_diff_img, lca_diff_img)))
      diff_vmax = np.max(np.concatenate((mlp_diff_img, lca_diff_img)))
      img_list = [[mlp_adv_img, lca_adv_img], [mlp_diff_img, lca_diff_img]]
      for i in range(2): # row [adv, pert]
        for j in range(2): # col [w/o LCA, w/ LCA]
          current_ax = fig.add_subplot(gs_sub[i, j])
          current_target_label = None
          current_image = img_list[i][j]
          if i == 0:
            vmin = 0.0
            vmax = 1.0
            if j == 0: # top left image
              if model_idx == 0:
                current_ax.set_ylabel(r"$s^{*}$")
              current_target_label = target_labels[j][model_idx][image_idx]
              if category_idx == 0: # top category only
                x_loc = group_name_loc[0]
                y_loc = group_name_loc[1]
                text_handle = current_ax.text(x_loc, y_loc, group_names[j+model_idx],
                  horizontalalignment='left', verticalalignment='bottom')
          if i == 1:
            vmin = np.round(diff_vmin, 2)
            vmax = np.round(diff_vmax, 2)
            if j == 0 and model_idx == 0:
                current_ax.set_ylabel(r"$e$")
            if j == 0 and category_idx == num_categories-1: # bottom left
              current_ax.set_xlabel("w/o\nLCA")
            elif j == 1 and category_idx == num_categories-1: # bottom right
              current_ax.set_xlabel("w/\nLCA")
          im_handle = show_image_with_label(current_ax, current_image, current_target_label, vmin=vmin, vmax=vmax, cmap=cmap)
          if j == 1:
            pf.add_colorbar_to_im(im_handle, aspect=10, ax=current_ax, ticks=[vmin, vmax])

def plot_adv_images(image_groups, labels, mnist_start_idx, cifar_start_idx, text_width=200, width_ratio=1.0, dpi=100):
  mnist_mlp_grp, mnist_lca_grp = image_groups[0]
  cifar_mlp_grp, cifar_lca_grp = image_groups[1]
  mnist_orig_labels, mnist_target_labels, mnist_img_labels = labels[0]
  cifar_orig_labels, cifar_target_labels, cifar_img_labels = labels[1]
  
  hspace_0 = -0.3
  wspace_0 = 0.3
  hspace_1 = 0.0
  wspace_1 = 3.3
  hspace_2 = -0.7
  wspace_2 = 0.2
  orig_y_adj = 1.15
  img_label_loc = [-8.0, -8.0] # [x, y]
  num_y_plots = 2
  num_x_plots = 1
  fig = plt.figure(figsize=nc.set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi)
  gs_base = plt.GridSpec(num_y_plots, num_x_plots, figure=fig, wspace=wspace_0, hspace=hspace_0)
  num_categories = 1
  
  gs_mnist = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs_base[0], hspace=hspace_1, wspace=wspace_1)
  make_grid_subplots(fig, gs_mnist, mnist_mlp_grp, mnist_lca_grp, mnist_orig_labels,
    mnist_target_labels, mnist_img_labels, img_label_loc, orig_y_adj, mnist_start_idx,
    num_categories, hspace=hspace_2, wspace=wspace_2, cmap="Greys")
  
  gs_cifar = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs_base[1], hspace=hspace_1, wspace=wspace_1)
  make_grid_subplots(fig, gs_cifar, cifar_mlp_grp, cifar_lca_grp, cifar_orig_labels,
    cifar_target_labels, cifar_img_labels, img_label_loc, orig_y_adj, cifar_start_idx,
    num_categories, crop_amount=2, hspace=hspace_2, wspace=wspace_2, cmap="Greys_r")
  
  plt.show()
  return fig

In [None]:
mnist_start_idx = 46
cifar_start_idx = 4
image_groups = [[mnist_mlp_grp, mnist_lca_grp], [cifar_mlp_grp, cifar_lca_grp]]
labels = [[mnist_orig_labels, mnist_target_labels, mnist_img_labels],
  [cifar_orig_labels, cifar_target_labels, cifar_img_labels]]
adv_img_fig = plot_adv_images(image_groups, labels, mnist_start_idx, cifar_start_idx, text_width, width_ratio=1.0, dpi=dpi)

In [None]:
out_list = [project_dir + model_name + "/analysis/0.0/vis/adv_mse_comparison_example_images_single"
  for model_name in all_model_names]
for out_name in [out_list[0]]:
  for ext in [".eps"]:#[".png", ".eps"]:
    save_name = out_name+ext
    adv_img_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.05, dpi=dpi)

# Supplementary Figures/Analysis

### Histogram the MSE datapoints

In [None]:
def mean_conf(outputs, labels, index):
  return np.mean(np.sum(outputs[index] * labels, axis=1))

def find_avg_conf_step(analysis):
  outputs = analysis["adversarial_outputs"][0]
  target_labels = analysis["target_labels"]
  indices = len(outputs)
  confs = []
  for i in range(indices):
    confs.append(mean_conf(outputs, target_labels, i))
  return np.array(confs)

def plot_average_conf_step(analysis_files, model_names):
    stop_conf=.95 
    fig, ax = plt.subplots()
    for file, name in zip(analysis_files, model_names):
        analysis = np.load(file, allow_pickle=True)["data"].item()
        ax.plot(find_avg_conf_step(analysis), label=name)
    
        print(np.max(find_avg_conf_step(analysis)))
    ax.legend(loc="lower right", framealpha=1.0)
    ax.set_xlabel("Attack Step")
    ax.set_ylabel("Mean Target-class Confidence")
    ax.set_ylim(0,1.1)

In [None]:
def get_mnist_mse_data(file_lists, metric, stop_conf):
  """
  Params
  ------
  file_lists: list
      list of list with organization
      inner list: model-type filenames (i.e. lca or mlp)
      outer lists: groups across models (i.e. layer depth)
  metric: str
      the analysis metric to average over
  stop_conf: float
      the desired classifier confidence for stopping the adversarial attack
  """
  data = []; means = []; stds = []
  for file_list in file_lists:
    data_step, means_step, stds_step = get_rects(file_list, metric, stop_conf) 
    means.append(means_step)
    data.append(data_step)
    stds.append(stds_step)
  return (data, means, stds)

def multi_model_compare(ax, data, means, stds, colors, names, xtick_labels,
                        xlabel, ylabel, ylim, width, title, fontsize):
  # orgnaize the data
  cmap_gray = cm.get_cmap("gray")
  N = len(data[0]) # number of depths
  M = len(data) # number of models being compared
  # create the bar chart
  ind = np.arange(N)  # the x locations for the depths    
  rects = []
  for i in range(M):
    # the bars
    x = ind + i * (width+.01)
    rect = ax.bar(x, means[i], color=colors[i], yerr=stds[i], width=width, alpha=1.0)
    rects.append(rect)
    # the data points
    bar_data = np.array(data[i])
    x_tiled = np.tile(x+(width/4), (bar_data.shape[-1],1)).T
    ax.scatter(x_tiled, bar_data, color='black', alpha=1.0, s=1, zorder=2)
  ax.set_xticks(ind + ((M-1)*(width+.01))/2)
  ax.set_xticklabels(xtick_labels, fontsize=fontsize)
  ax.set_ylabel(ylabel, fontsize=fontsize)
  ax.set_xlabel(xlabel, fontsize=fontsize)
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  ax.xaxis.set_ticks_position('bottom')
  ax.yaxis.set_ticks_position('left')
  ax.yaxis.grid(which="major", color=cmap_gray(.8), linestyle='--', linewidth=1)
  ax.tick_params("both", labelsize=fontsize)
  ax.set_axisbelow(True)
  ax.set_ylim([0, ylim])
  ax.set_title(title, fontsize=fontsize)
  ax.legend([r[0] for r in rects], names, fontsize=fontsize, loc='upper right', framealpha=1.0)
  return data

def adv_mse_comparison_plot(ax, file_lists, metric, stop_conf,
                            colors, names, xtick_labels, xlabel, ylabel, ylim, width, title, fontsize):
  """
  Bar chart that compares a designated metric for each model in file_lists
  
  Params
  ------
  file_lists: list
      list of list with organization
      inner list: model-type filenames (i.e. lca or mlp)
      outer lists: groups across models (i.e. layer depth)
  metric: str
      the analysis metric to average over
  stop_conf: float
      the desired classifier confidence for stopping the adversarial attack
  colors: list
      list of list of matplotlib color codes for each model
      has same nested order as file_lists
  xtick_labels: list
      list of the group labels (i.e. layer depths)
  names: list
      names of the models (i.e. lca or mlp)
  xlabel: str
      x axis label
  """
  data, means, stds = get_mnist_mse_data(file_lists, metric, stop_conf)
  return multi_model_compare(ax, data, means, stds, colors, names, xtick_labels, xlabel, ylabel, ylim, width, title, fontsize)


In [None]:
path = os.path.expanduser("~")+"/Work/Projects/"

file_path = k_file_path # kurakin
# file_path = c_file_path # carlini
lca_files = [path + model_name + file_path for model_name in [lca_768_2layer]]
mlp_files = [path + model_name + file_path for model_name in [mlp_768_2layer]]
files = mlp_files + lca_files

names = ['w/o LCA', 'w/ LCA'] 
plot_average_conf_step(files, names)

In [None]:
mlp = mnist_mlp_768_2layer#"mlp_cosyne_mnist"
lca = mnist_lca_768_2layer#"slp_lca_768_latent_75_steps_mnist"
lista = 'slp_lista_768_5_layers_mnist'
lca_file, lista_file, mlp_file = (path + model_name + k_file_path for model_name in [lca, lista, mlp])
lca_img_file, lista_img_file, mlp_img_file = (path + model_name + k_img_path for model_name in [lca, lista, mlp])

In [None]:
colors = [[color_vals['md_blue'], color_vals['md_green'], color_vals['md_red']]]
xtick_labels = ['MLP', 'w/ LISTA', 'w/ LCA']
file_lists = [[mlp_file, lista_file, lca_file]] #, lca_1568_files]

metric = "input_adv_mses"
names = [None]

stop_conf=.95

fig, ax = plt.subplots()
data_points = adv_mse_comparison_plot(ax, file_lists, metric, stop_conf, colors, names,
                                     xtick_labels, "", "Mean Squared Distance", .08, width=.5,
                                     title="Adversarial MNIST at\n95% Confidence", fontsize=fontsize)
ax.get_legend().set_visible(False)
plt.show()

out_list = [path + model_name + "/analysis/0.0/vis/mlp_lista_lca_adv_comparison"
  for model_name in [mnist_mlp_768_2layer, lista, mnist_lca_768_2layer]]
for out_name in out_list:
  for ext in [".png", ".eps"]:
    save_name = out_name+ext
    fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

In [None]:
lca_data = get_mnist_mse_data([lca_1568_files], metric, stop_conf)[0][0]
n_bins = 25

plt.figure()
hist1, bins = np.histogram(lca_data[0], n_bins)
plt.bar(bins[:n_bins],hist1/len(lca_data[0]), width = .003, alpha=.7, label="lca_2layer")

hist2, bins = np.histogram(lca_data[1], n_bins)
plt.bar(bins[:n_bins],hist2/len(lca_data[1]), width = .003, alpha=.7, label="lca_3layer")

plt.xlabel("Adversarial MSE")
plt.ylabel("Frequency of Images")
plt.legend(framealpha=1.0)
plt.show()

### Average target-class confidence per kurakin attack step

### Average adversarial MSE per kurakin attack step

In [None]:
def find_mse(analysis):
  outputs = analysis["input_adv_mses"][0]
  indices = len(outputs)
  mean_mses = []; std_mses = []
  for i in range(indices):
    mean_mses.append(np.mean(outputs[i]))
    std_mses.append(np.std(outputs[i]))
  return np.array(mean_mses), np.array(std_mses)     

def plot_average_mse_step(analysis_files, colors, axis_titles, model_names, figsize, dpi, fontsize):
    #hatches = ["x", "+"]
    #hatches = ["o", "O"]
    hatches = ["/", "\\"]
    #hatches = ["-", "|"]
    fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
    for ax_idx, (sub_analysis_files, title) in enumerate(zip(analysis_files, axis_titles)):
      for file_idx, (file, name) in enumerate(zip(sub_analysis_files, model_names)):
        analysis = np.load(file, allow_pickle=True)["data"].item()
        mean_mse, std_mse = find_mse(analysis)
        axes[ax_idx].plot(range(len(mean_mse)), mean_mse, label=name, lw=3, color=colors[file_idx][0], zorder=1)
        axes[ax_idx].fill_between(range(len(mean_mse)), mean_mse + std_mse , mean_mse - std_mse,
          edgecolor=colors[file_idx][1], alpha=1.0, zorder=0, facecolor="none", hatch=hatches[file_idx],
          rasterized=False)
        axes[ax_idx].set_title(title, fontsize=fontsize)
        axes[ax_idx].set_xlabel("Attack Step", fontsize=fontsize)
        axes[ax_idx].tick_params("both", labelsize=fontsize)
    axes[0].legend(loc="upper left", fontsize=fontsize, framealpha=1.0)
    axes[0].set_ylabel("Adversarial Mean Squared Distance", fontsize=fontsize)
    return fig, axes

In [None]:
colors = [[color_vals['md_blue'], color_vals['lt_blue']], [color_vals['md_red'], color_vals['lt_red']]]
axis_titles = ["Kurakin", "Carlini"]
model_names = ['w/o LCA', 'w/ LCA']

k_lca_files = [path + model_name + k_file_path for model_name in [mnist_lca_768_2layer]]
k_mlp_files = [path + model_name + k_file_path for model_name in [mnist_mlp_768_2layer]]
k_files = k_mlp_files + k_lca_files
c_lca_files = [path + model_name + c_file_path for model_name in [mnist_lca_768_2layer]]
c_mlp_files = [path + model_name + c_file_path for model_name in [mnist_mlp_768_2layer]]
c_files = c_mlp_files + c_lca_files

fig, ax = plot_average_mse_step([k_files, c_files], colors, axis_titles, model_names, [figsize[0], figsize[1]/2],
  dpi, fontsize)

out_list = [path + model_name + "/analysis/0.0/vis/kurakin_carlini_mse_vs_iteration"
  for model_name in [lca_768_2layer, mlp_768_2layer]]
for out_name in out_list:
  for ext in [".png", ".eps"]:
    save_name = out_name+ext
    fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

## Extra: Differentiable Loss Surface Adversarial Comparison

In [None]:
class params(object):
    def __init__(self, model_name):
        self.model_name = model_name
        self.version = '0.0'
        self.projects_dir = os.path.expanduser("~")+"/Work/Projects/"     
        self.save_info = 'test'
        self.overwrite_analysis_log = False
        self.do_neuron_visualization=False
        
def get_label_est(model_name, input_images):
    # Get params, set dirs
    analysis_params = params(model_name) # construct object

    # Load arguments
    model_name_list = os.listdir(analysis_params.projects_dir)
    analysis_params.model_dir = analysis_params.projects_dir+analysis_params.model_name

    model_log_file = (analysis_params.model_dir+"/logfiles/"+analysis_params.model_name
      +"_v"+analysis_params.version+".log")
    model_logger = Logger(model_log_file, overwrite=False)
    model_log_text = model_logger.load_file()
    model_params = model_logger.read_params(model_log_text)[-1]
    analysis_params.model_type = model_params.model_type

    # Initialize & setup analyzer
    analyzer = ap.get_analyzer(analysis_params.model_type)
    analyzer.setup(analysis_params)
    analysis_params.data_type = analyzer.model_params.data_type
    analyzer.setup_model(analyzer.model_params)
    
    # run forward pass
    with tf.Session(graph=analyzer.model.graph) as sess:
        feed_dict = analyzer.model.get_feed_dict(input_images, is_test=True)
        sess.run(analyzer.model.init_op, feed_dict)
        analyzer.model.load_full_model(sess, analyzer.analysis_params.cp_loc)
        label_est = sess.run(analyzer.model.label_est, feed_dict)
        
    return label_est

### Compare adv. MSE for more differentiable loss surface
 LCA vs LISTA
 
### See if LISTA adv. examples translate to LCA

In [None]:
lista_stop_conf = .95
# get lista adv images
input_images, input_clf, adv_images, adv_clf = get_results([lista_file], [lista_img_file], lista_stop_conf)
# pass through the models
lca_softmax_labels = get_label_est(lca, adv_images[0])
lista_softmax_labels = get_label_est(lista, adv_images[0])
mlp_softmax_labels = get_label_est(mlp, adv_images[0])

In [None]:
# organize the data
lca_confs_lista_adv = [out[cls] for out, cls in zip(lca_softmax_labels, adv_clf[0])]
lca_confs_input = [out[cls] for out, cls in zip(lca_softmax_labels, input_clf[0])]

mlp_confs_lista_adv = [out[cls] for out, cls in zip(mlp_softmax_labels, adv_clf[0])]
mlp_confs_input = [out[cls] for out, cls in zip(mlp_softmax_labels, input_clf[0])]

lista_confs_lista_adv = [out[cls] for out, cls in zip(lista_softmax_labels, adv_clf[0])]
lista_confs_input = [out[cls] for out, cls in zip(lista_softmax_labels, input_clf[0])]

filter_indices = np.where(np.array(lista_confs_lista_adv) > .8)[0]

data = [[lista_confs_input, lista_confs_lista_adv],
        [mlp_confs_input, mlp_confs_lista_adv],
        [lca_confs_input, lca_confs_lista_adv]]

data = [[np.array(confs)[filter_indices] for confs in model_confs] for model_confs in data]
means = [[np.mean(confs) for confs in model_confs] for model_confs in data]
stds = [np.array([[0,0],[np.std(confs) for confs in model_confs]]) for model_confs in data]

In [None]:
colors = [color_vals['md_green'], color_vals['md_blue'], color_vals['md_red']]
xtick_labels = ["Original Label", "Adv Target Label"]
xlabel = None#"Class Position"
ylabel = "Softmax Confidence"
names = ['w/ LISTA', 'MLP', 'w/ LCA']

fig, ax = plt.subplots(1,1)
multi_model_compare(ax, data, means, stds, colors, names, xtick_labels, xlabel, ylabel, 1, width=.25,
                   title="Transferability of LISTA\nAdversarial Images", fontsize=fontsize);
legend = ax.get_legend()
legend.set_bbox_to_anchor([1.28,0.94,0,0], transform=fig.transFigure)
plt.show()

out_list += [path + lista + "/analysis/0.0/vis/lista_adv_transferability"]
for out_name in out_list:
  for ext in [".png", ".eps"]:
    save_name = out_name+ext
    fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

### Additional Adv Image Examples

In [None]:
def make_grid_subplots_with_fontsize(fig, gs, mlp_grp, lca_grp, orig_labels, target_labels, group_names, group_name_loc,
                       orig_y_adj, start_idx=0, num_categories=3, crop_ammount=0, hspace=0.5, wspace=-0.4,
                       cmap="Greys_r", fontsize=12):
  ax_orig_list = []
  gs_sub0_list = []
  gs_sub1_list = []
  for i in range(num_categories):
    ax_orig_list.append(fig.add_subplot(gs[i, :2]))
    gs_sub0_list.append(gridspec.GridSpecFromSubplotSpec(2, 2, gs[i, 2:4], hspace=hspace, wspace=wspace))
    gs_sub1_list.append(gridspec.GridSpecFromSubplotSpec(2, 2, gs[i, 4:], hspace=hspace, wspace=wspace))
  
  for category_idx in range(num_categories):
    image_idx = category_idx + start_idx
    orig_ax = ax_orig_list[category_idx]
    if category_idx == 0:
      orig_ax.set_title("Unperturbed", y=orig_y_adj, fontsize=fontsize)
    orig_img = crop(np.squeeze(mlp_grp[0][0][image_idx, ...]), crop_ammount)
    orig_im_handle = show_image_with_label(orig_ax, orig_img, orig_labels[0][0][image_idx], cmap=cmap)
    for model_idx, gs_sub in enumerate([gs_sub0_list[category_idx], gs_sub1_list[category_idx]]):
      mlp_adv_img = crop(np.squeeze(mlp_grp[1][model_idx][image_idx, ...]), crop_ammount)
      mlp_diff_img = crop(np.squeeze(mlp_grp[2][model_idx][image_idx, ...]), crop_ammount)
      lca_adv_img = crop(np.squeeze(lca_grp[1][model_idx][image_idx, ...]), crop_ammount)
      lca_diff_img = crop(np.squeeze(lca_grp[2][model_idx][image_idx, ...]), crop_ammount)
      diff_vmin = np.min(np.concatenate((mlp_diff_img, lca_diff_img)))
      diff_vmax = np.max(np.concatenate((mlp_diff_img, lca_diff_img)))
      img_list = [[mlp_adv_img, lca_adv_img], [mlp_diff_img, lca_diff_img]]
      for i in range(2): # row [adv, pert]
        for j in range(2): # col [w/o LCA, w/ LCA]
          current_ax = fig.add_subplot(gs_sub[i, j])
          current_target_label = None
          current_image = img_list[i][j]
          if i == 0:
            vmin = 0.0
            vmax = 1.0
            if j == 0: # top left image
              if model_idx == 0:
                current_ax.set_ylabel(r"$s^{*}_{T}$", fontsize=fontsize)
              current_target_label = target_labels[j][model_idx][image_idx]
              if category_idx == 0: # top category only
                x_loc = group_name_loc[0]
                y_loc = group_name_loc[1]
                text_handle = current_ax.text(x_loc, y_loc, group_names[j+model_idx], fontsize=fontsize,
                  horizontalalignment='left', verticalalignment='bottom')
          else: # i == 1
            vmin = np.round(diff_vmin, 2)
            vmax = np.round(diff_vmax, 2)
            if j == 0 and model_idx == 0:
                current_ax.set_ylabel(r"$s-s^{*}_{T}$", fontsize=fontsize)
            if j == 0 and category_idx == num_categories-1: # bottom left
              current_ax.set_xlabel("w/o\nLCA", fontsize=fontsize)
            elif j == 1 and category_idx == num_categories-1: # bottom right
              current_ax.set_xlabel("w/\nLCA", fontsize=fontsize)
          im_handle = show_image_with_label(current_ax, current_image, current_target_label, vmin=vmin, vmax=vmax, cmap=cmap)
          if j == 1:
            pf.add_colorbar_to_im(im_handle, aspect=10, ax=current_ax, ticks=[vmin, vmax], labelsize=fontsize/2)

def plot_adv_images_with_figsize(image_groups, labels, mnist_start_idx, cifar_start_idx, figsize, fontsize, dpi=100):
  mnist_mlp_grp, mnist_lca_grp = image_groups[0]
  cifar_mlp_grp, cifar_lca_grp = image_groups[1]
  mnist_orig_labels, mnist_target_labels, mnist_img_labels = labels[0]
  cifar_orig_labels, cifar_target_labels, cifar_img_labels = labels[1]
    
  hspace = 0.3
  wspace = 1.8
  sub_hspace = 0.2
  sub_wspace = 0.2
  orig_y_adj = 1.10
  img_label_loc = [-8.0, -8.0] # [x, y]
  fig2 = plt.figure(figsize=[figsize[0]/2, figsize[1]], dpi=dpi)
  gs0 = plt.GridSpec(2, 1, figure=fig2, hspace=0.3)
  
  num_categories=3
  
  gs_mnist = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs0[0], hspace=hspace, wspace=wspace)
  make_grid_subplots_with_fontsize(fig2, gs_mnist, mnist_mlp_grp, mnist_lca_grp, mnist_orig_labels,
    mnist_target_labels, mnist_img_labels, img_label_loc, orig_y_adj, mnist_start_idx, num_categories,
    hspace=sub_hspace, wspace=sub_wspace, cmap="Greys", fontsize=fontsize)
  
  gs_cifar = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs0[1], hspace=hspace, wspace=wspace)
  make_grid_subplots_with_fontsize(fig2, gs_cifar, cifar_mlp_grp, cifar_lca_grp, cifar_orig_labels,
    cifar_target_labels, cifar_img_labels, img_label_loc, orig_y_adj, 0, num_categories,
    hspace=sub_hspace, wspace=sub_wspace, cmap="Greys_r", fontsize=fontsize)
  
  plt.show()

In [None]:
full_adv_img_fig = plot_adv_images_with_figsize(image_groups, labels, mnist_start_idx=44, cifar_start_idx=0,
  figsize=(16, 16), fontsize=20, dpi=dpi)

In [None]:
out_list = [project_dir + model_name + "/analysis/0.0/vis/adv_mse_comparison_example_images_fullpage"
  for model_name in all_model_names]
for out_name in [out_list[0]]:
  for ext in [".eps"]:#[".png", ".eps"]:
    save_name = out_name+ext
    full_adv_img_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.05, dpi=dpi)