# Journal of Vision Paper Figures

### Imports

In [None]:
import os
import sys
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))
if root_path not in sys.path: sys.path.append(root_path)
#%env CUDA_VISIBLE_DEVICES=0
%matplotlib inline

In [None]:
import re

from functools import reduce as reduce
import numpy as np
from skimage.io import imread
from skimage.util import crop
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.ticker as plticker
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.ticker as ticker
from matplotlib.patches import FancyArrowPatch
from matplotlib.lines import Line2D
from mpl_toolkits.mplot3d import proj3d
import matplotlib.font_manager
import matplotlib.cm as cm
import pickle
import pandas as pd
import tensorflow as tf
import proplot as plot

import response_contour_analysis.utils.histogram_analysis as ha
import response_contour_analysis.utils.dataset_generation as dg
import response_contour_analysis.utils.plotting as resp_pf

from DeepSparseCoding.tf1x.utils.logger import Logger
from DeepSparseCoding.tf1x.data.dataset import Dataset
import DeepSparseCoding.tf1x.data.data_selector as ds
import DeepSparseCoding.tf1x.analysis.analysis_picker as ap
import DeepSparseCoding.tf1x.utils.data_processing as dp
import DeepSparseCoding.tf1x.utils.plot_functions as pf
import DeepSparseCoding.tf1x.utils.jov_funcs as nc

# Parameters

In [None]:
"""
textwidth in pt: 540.60236pt
textwidth in cm: 18.9973cm
textwidth in in: 7.48178in
"""
text_width = 540.60236 #pt 416.83269 #pt = 14.65cm
text_width_cm = 18.9973 # 14.705
fontsize = 10
dpi = 300
file_extensions = ['.pdf']#, '.eps', '.png']

In [None]:
font_settings = {
        "text.usetex": True,
        "font.family": 'serif',
        "font.serif": 'Computer Modern Roman',
        "axes.labelsize": fontsize,
        "axes.titlesize": fontsize,
        "figure.titlesize": fontsize+2,
        "font.size": fontsize,
        "legend.fontsize": fontsize,
        "xtick.labelsize": fontsize-2,
        "ytick.labelsize": fontsize-2,
}
mpl.rcParams.update(font_settings)

In [None]:
class lca_512_vh_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_512_vh"
    self.display_name = "Sparse Coding 512"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class lca_768_vh_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_768_vh"
    self.display_name = "Sparse Coding"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class lca_1024_vh_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_1024_vh"
    self.display_name = "Sparse Coding 1024"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class lca_2560_vh_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_2560_vh"
    self.display_name = "Sparse Coding"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class sae_768_vh_params(object):
  def __init__(self):
    self.model_type = "sae"
    self.model_name = "sae_768_vh"
    self.display_name = "Sparse\nAutoencoder"
    self.version = "1.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class rica_768_vh_params(object):
  def __init__(self):
    self.model_type = "rica"
    self.model_name = "rica_768_vh"
    self.display_name = "Linear\nAutoencoder"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class ae_768_vh_params(object):
  def __init__(self):
    self.model_type = "ae"
    self.model_name = "ae_768_vh"
    self.display_name = "ReLU\nAutoencoder"
    self.version = "1.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class lca_768_mnist_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_768_mnist"
    self.display_name = "Sparse Coding 768"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class lca_1536_mnist_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_1536_mnist"
    self.display_name = "Sparse Coding 1536"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False

class ae_768_mnist_params(object):
  def __init__(self):
    self.model_type = "ae"
    self.model_name = "ae_768_mnist"
    self.display_name = "Leaky ReLU"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False

class sae_768_mnist_params(object):
  def __init__(self):
    self.model_type = "sae"
    self.model_name = "sae_768_mnist"
    self.display_name = "Sparse Autoencoder"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False

class rica_768_mnist_params(object):
  def __init__(self):
    self.model_type = "rica"
    self.model_name = "rica_768_mnist"
    self.display_name = "Linear Autoencoder"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class ae_deep_mnist_params(object):
  def __init__(self):
    self.model_type = "ae"
    self.model_name = "ae_deep_mnist"
    self.display_name = "Leaky ReLU"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False

In [None]:
num_levels = 10
color_vals = dict(zip(["blk", "lt_green", "md_green", "dk_green", "lt_blue", "md_blue", "dk_blue", "lt_red", "md_red", "dk_red"],
  ["#000000", "#A9DFBF", "#196F3D", "#27AE60", "#AED6F1", "#3498DB", "#21618C", "#F5B7B1", "#E74C3C", "#943126"]))

In [None]:
def load_analysis(params):
  params.model_dir = (os.path.expanduser("~")+"/Work/Projects/"+params.model_name)
  analyzer = ap.get_analyzer(params.model_type)
  analyzer.setup(params)
  analyzer.model.setup(analyzer.model_params)
  analyzer.load_analysis(save_info=params.save_info)
  analyzer.model_name = params.model_name
  return analyzer

In [None]:
def add_analyzer_keys(analyzer):
    # backwards compatibility
    for name_option in ['iso_params', 'attn_params']:
        if hasattr(analyzer, name_option):
            if 'num_comparisons' in getattr(analyzer, name_option).keys():
                getattr(analyzer, name_option)['num_comparison_vects'] = getattr(analyzer, name_option)['num_comparisons']
            if 'num_comparisons' in getattr(analyzer, name_option).keys():
                getattr(analyzer, name_option)['num_comparison_vects'] = getattr(analyzer, name_option)['num_comparisons']
    for name_option in ['comp_contour_dataset', 'iso_comp_contour_dataset', 'rand_comp_contour_dataset']:
        if hasattr(analyzer, name_option):
            if 'proj_target_neuron' in getattr(analyzer, name_option).keys(): 
                analyzer.comp_contour_dataset['proj_target_vect'] = getattr(analyzer, name_option)['proj_target_neuron']
            if 'proj_comparison_neuron' in getattr(analyzer, name_option).keys(): 
                analyzer.comp_contour_dataset['proj_comparison_vect'] = getattr(analyzer, name_option)['proj_comparison_neuron']
    return analyzer

# Iso-contour activations comparison

In [None]:
save_names = ['', '', '', 'rescaled_']
params_list = [rica_768_vh_params(), ae_768_vh_params(), sae_768_vh_params(), lca_768_vh_params()]

analyzer_list = [load_analysis(params) for params in params_list]
for analyzer, save_name in zip(analyzer_list, save_names):
    analyzer.iso_params = np.load(analyzer.analysis_out_dir+'savefiles/iso_params_'+save_name
        +analyzer.analysis_params.save_info+'.npz', allow_pickle=True)['data'].item()
    x_range = analyzer.iso_params['x_range']
    y_range = analyzer.iso_params['y_range']

    iso_vectors = np.load(analyzer.analysis_out_dir+'savefiles/iso_vectors_'+save_name
        +analyzer.analysis_params.save_info+'.npz', allow_pickle=True)['data'].item()
    analyzer.target_neuron_ids = iso_vectors['target_neuron_ids']
    analyzer.comparison_neuron_ids = iso_vectors['comparison_neuron_ids']
    analyzer.target_vectors = iso_vectors['target_vectors']
    analyzer.rand_orth_vectors = iso_vectors['rand_orth_vectors']
    analyzer.comparison_vectors = iso_vectors['comparison_vectors']
  
    analyzer.comp_activations = np.load(analyzer.analysis_out_dir+'savefiles/iso_comp_activations_'+save_name
        +analyzer.analysis_params.save_info+'.npz', allow_pickle=True)['data']
    
    analyzer.comp_contour_dataset = np.load(analyzer.analysis_out_dir+'savefiles/iso_comp_contour_dataset_'+save_name
        +analyzer.analysis_params.save_info+'.npz', allow_pickle=True)['data'].item()

    analyzer.rand_activations = np.load(analyzer.analysis_out_dir+'savefiles/iso_rand_activations_'+save_name
        +analyzer.analysis_params.save_info+'.npz', allow_pickle=True)['data']
    analyzer.rand_contour_dataset = np.load(analyzer.analysis_out_dir+'savefiles/iso_rand_contour_dataset_'+save_name
        +analyzer.analysis_params.save_info+'.npz', allow_pickle=True)['data'].item()
    
    analyzer = add_analyzer_keys(analyzer)

In [None]:
target_act = 0.5 # target activity spot between min & max value of normalized activity (btwn 0 and 1)
lca_activations = analyzer_list[-1].comp_activations
curvatures, fits, contours = ha.iso_response_curvature_poly_fits(
  lca_activations,
  target_act=target_act
)
max_comp_indices = []
max_vals = []
for target_neuron_id in range(len(curvatures)):
    max_idx = np.argmax(curvatures[target_neuron_id])
    max_comp_indices.append(max_idx)
    max_vals.append(curvatures[target_neuron_id][max_idx])
max_target_id = np.argmax(max_vals)
max_comparison_id = max_comp_indices[max_target_id]

In [None]:
all_angles = [[]]
target_min_idx = []
for target_neuron_id in range(len(analyzer_list[-1].target_vectors)):
    target_vect = dg.normalize_vector(analyzer_list[-1].target_vectors[target_neuron_id]).reshape((-1, 1))
    for comparison_vect in analyzer_list[-1].comparison_vectors[target_neuron_id]:
        comparison_vect = dg.normalize_vector(comparison_vect).reshape((-1, 1))
        angle = dg.angle_between_vectors(target_vect, comparison_vect) * (180 / np.pi)
        all_angles[-1].append(angle.item())
    target_min_idx.append(np.argmin(all_angles[-1][:analyzer_list[-1].iso_params['num_comparison_vects']]))
    all_angles.append([])
min_target = []
for target_angles, min_id in zip(all_angles, target_min_idx):
    min_target.append(target_angles[min_id])
min_target_id = np.argmin(min_target)
min_comparison_id = target_min_idx[min_target_id]

# 8(.039), 17(.028) 23(.037), 25(.036), 41(.033), 48(.035), 49(0.039)
neuron_indices = [0, 0, 0, max_target_id]#min_target_id]
orth_indices = [0, 0, 0, max_comparison_id]#min_comparison_id]
num_plots_y = 2
num_plots_x = 2
width_fraction = 1.0
show_contours = True

lca_activations = analyzer_list[-1].comp_activations[neuron_indices[-1], orth_indices[-1], ...][None, None, ...]
curvatures, fits, contours = ha.iso_response_curvature_poly_fits(
  lca_activations,
  target_act=target_act
)
curvature = [None, None, None, curvatures[0][0]]

#for analyzer in analyzer_list:
#    analyzer.comp_activations = analyzer.comp_activations - analyzer.comp_activations.min()
#    analyzer.comp_activations = analyzer.comp_activations / analyzer.comp_activations.max()

contour_fig, contour_handles = nc.plot_group_iso_contours(analyzer_list, neuron_indices, orth_indices,
  num_levels, x_range, y_range, show_contours, curvature, text_width, width_fraction, dpi)

for analyzer, neuron_index, orth_index, save_suffix in zip(analyzer_list, neuron_indices, orth_indices, save_names):
    for ext in file_extensions:
        neuron_str = str(analyzer.target_neuron_ids[neuron_index])
        orth_str = str(analyzer.comparison_neuron_ids[neuron_index][orth_index])
        save_name = analyzer.analysis_out_dir+"/vis/iso_contour_comparison_"
        if not show_contours:
            save_name += "continuous_"
        save_name += "bf0id"+neuron_str+"_bf1id"+orth_str+"_"+save_suffix+analyzer.analysis_params.save_info+ext
        contour_fig.savefig(save_name, dpi=dpi, transparent=False, bbox_inches="tight", pad_inches=0.01)

# Curvature histogram

In [None]:
width_fraction = 1.0
show_contours = True
num_levels = 10
num_x = 5
num_y = 5
iso_save_name='rescaled_closecomp_'
params = lca_2560_vh_params()

analyzer = load_analysis(params)

cont_analysis = np.load(
    analyzer.analysis_out_dir+'savefiles/group_iso_vectors_'
    +iso_save_name+analyzer.analysis_params.save_info+'.npz',
    allow_pickle=True)['data'].item()

curvatures = cont_analysis['curvatures']
target_act = cont_analysis['target_act']

contour_fig, contour_handles = nc.plot_iso_contour_set(
    cont_analysis,
    curvatures,
    num_levels,
    num_x,
    num_y,
    show_contours,
    text_width,
    1.00,
    dpi
)

for ext in file_extensions:
    save_name = analyzer.analysis_out_dir+"/vis/scaled_iso_contours_set_"
    if not show_contours:
        save_name += "continuous_"
    save_name += analyzer.analysis_params.save_info+ext
    contour_fig.savefig(save_name, dpi=dpi, transparent=False, bbox_inches="tight", pad_inches=0.01)

In [None]:
params_list = [lca_512_vh_params(), lca_1024_vh_params(), lca_2560_vh_params()]
iso_save_name = 'rescaled_randomcomp_'#"iso_curvature_xrange1.3_yrange-2.2_"
attn_save_name = 'rescaled_randomcomp_'#'1d_'

for params in params_list:
  params.model_dir = (os.path.expanduser("~")+"/Work/Projects/"+params.model_name)

analyzer_list = [ap.get_analyzer(params.model_type) for params in params_list]
for analyzer, params in zip(analyzer_list, params_list):
    analyzer.setup(params)
    analyzer.model.setup(analyzer.model_params)
    analyzer.load_analysis(save_info=params.save_info)
    analyzer.model_name = params.model_name
    
    analyzer.iso_params = np.load(analyzer.analysis_out_dir+"savefiles/iso_params_"+iso_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
    analyzer = add_analyzer_keys(analyzer)

    analyzer.iso_comp_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_activations_"+iso_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"]

    analyzer.iso_comp_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_contour_dataset_"+iso_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()

    analyzer.iso_rand_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_activations_"+iso_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"]

    analyzer.iso_rand_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_contour_dataset_"+iso_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()

    analyzer.iso_num_target_neurons = analyzer.iso_params["num_neurons"]
    analyzer.iso_num_comparison_vectors = analyzer.iso_params["num_comparison_vects"]

    analyzer.attn_params = np.load(analyzer.analysis_out_dir+"savefiles/iso_params_"+attn_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
    analyzer = add_analyzer_keys(analyzer)

    analyzer.attn_comp_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_activations_"+attn_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"]
      
    analyzer.attn_comp_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_contour_dataset_"+attn_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()

    analyzer.attn_rand_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_activations_"+attn_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"]

    analyzer.attn_rand_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_contour_dataset_"+attn_save_name+analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
    
    analyzer.attn_num_target_neurons = analyzer.attn_params["num_neurons"]
    analyzer.attn_num_comparison_vectors = analyzer.attn_params["num_comparison_vects"]
    
    analyzer = add_analyzer_keys(analyzer)
    
mesh_save_name = "iso_curvature_ryan_"
contour_activity = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_activations_"+mesh_save_name
    +analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"][0, 1, ...]
comp_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_contour_dataset_"+mesh_save_name
  +analyzer.analysis_params.save_info+".npz", allow_pickle=True)["data"].item()
x = range(comp_contour_dataset["x_pts"].size)
y = range(comp_contour_dataset["y_pts"].size)
contour_pts = (x,y)

In [None]:
num_bins = 50
target_act = 0.5
nc.compute_curvature_fits(analyzer_list, target_act)
nc.compute_curvature_hists(analyzer_list, num_bins)

In [None]:
label_list = [["2x", "4x", "10x"]]*2
color_list = [[color_vals["lt_red"], color_vals["md_red"], color_vals["dk_red"]],
  [color_vals["lt_blue"], color_vals["md_blue"], color_vals["dk_blue"]]]
curve_lims = { 
    "x":[min(comp_contour_dataset["x_pts"]), max(comp_contour_dataset["x_pts"])],
    "y":[min(comp_contour_dataset["y_pts"]), max(comp_contour_dataset["y_pts"])]
}

iso_title = "Iso-Response"
iso_hist_list = [[analyzer.iso_comp_hist for analyzer in analyzer_list],
  [analyzer.iso_rand_hist for analyzer in analyzer_list]]
plot_bin_lefts, plot_bin_rights = analyzer_list[0].iso_bin_edges[:-1], analyzer_list[0].iso_bin_edges[1:]
iso_plot_bin_centers = plot_bin_lefts + (plot_bin_rights - plot_bin_lefts)

attn_title = "Response Attenuation"
attn_hist_list = [[analyzer.attn_comp_hist for analyzer in analyzer_list],
  [analyzer.attn_rand_hist for analyzer in analyzer_list]]
plot_bin_lefts, plot_bin_rights = analyzer_list[0].attn_bin_edges[:-1], analyzer_list[0].attn_bin_edges[1:]
attn_plot_bin_centers = plot_bin_lefts + (plot_bin_rights - plot_bin_lefts)

full_hist_list = [iso_hist_list, attn_hist_list]
full_label_list = [label_list,]*2
full_color_list = [color_list,]*2
full_bin_centers = [iso_plot_bin_centers, attn_plot_bin_centers]
full_title = [iso_title, attn_title]
full_xlabel = ["Curvature (Comparison)", "Curvature (Random)"]

In [None]:
scatter = False
mesh_color = "#A6A6A6"
contour_angle = 200 #195
view_elevation = 25 #30
# The following two variables set the curvature line label locations.
# The numbering is sort of [z x y], where z is up-down, x is into the page, y is right-left
resp_att_loc = [108, 38, 0.85]#[105, 38, 0.8] 
iso_resp_loc = [-27, 20, 0.11]#[-18, 20, 0.11]
activity_loc = [-27, 150, 1.5]
contour_text_loc = [iso_resp_loc, resp_att_loc, activity_loc]

curvature_log_fig = nc.plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation, 
    contour_text_loc, full_hist_list, full_label_list, full_color_list, mesh_color, full_bin_centers,
    full_title, full_xlabel, curve_lims, scatter, log=True, text_width=text_width, width_ratio=0.75, dpi=dpi)

for analyzer in analyzer_list:
    for ext in file_extensions:
        save_name = (analyzer.analysis_out_dir+"/vis/"+iso_save_name+"curvatures_and_histograms_logy"
            +"_"+analyzer.analysis_params.save_info+ext)
        curvature_log_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.05, dpi=dpi)

In [None]:
curvature_lin_fig = nc.plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation,
    contour_text_loc, full_hist_list, full_label_list, full_color_list, mesh_color, full_bin_centers,
    full_title, full_xlabel, curve_lims, scatter, log=False, text_width=text_width, width_ratio=0.75, dpi=dpi)

for analyzer in analyzer_list:
    for ext in file_extensions:
        save_name = (analyzer.analysis_out_dir+"/vis/"+iso_save_name+"curvatures_and_histograms_liny"
            +"_"+analyzer.analysis_params.save_info+ext)
        curvature_lin_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.05, dpi=dpi)

# Orientation Selectivity

In [None]:
params_list = [rica_768_vh_params(), sae_768_vh_params(), lca_768_vh_params()]
params_list[-1].display_name = "Sparse Coding"
for params in params_list:
  params.model_dir = (os.path.expanduser("~")+"/Work/Projects/"+params.model_name)
analyzer_list = [ap.get_analyzer(params.model_type) for params in params_list]
for analyzer, params in zip(analyzer_list, params_list):
  analyzer.setup(params)
  analyzer.model.setup(analyzer.model_params)
  analyzer.load_analysis(save_info=params.save_info)
  analyzer.model_name = params.model_name

In [None]:
circ_var_list = []
for analyzer in analyzer_list:
  analyzer.bf_indices = np.random.choice(analyzer.ot_grating_responses["neuron_indices"], 12)
  num_orientation_samples = len(analyzer.ot_grating_responses['orientations'])
  corresponding_angles_deg = (180 * np.arange(num_orientation_samples) / num_orientation_samples) - 90
  corresponding_angles_rad = (np.pi * np.arange(num_orientation_samples) / num_orientation_samples) - (np.pi/2)
  analyzer.metrics_list = {"fwhm":[], "circ_var":[], "osi":[], "nonvarying_indices":[]}
  contrast_idx = -1
  for bf_idx in range(analyzer.bf_stats["num_outputs"]):
    ot_curve = nc.center_curve(analyzer.ot_grating_responses["mean_responses"][bf_idx, contrast_idx, :])
    if np.max(ot_curve) - np.min(ot_curve) == 0:
      analyzer.metrics_list["nonvarying_indices"].append(bf_idx)
      analyzer.metrics_list["circ_var"].append([None, None, 1.0])
    else:
      fwhm = nc.compute_fwhm(ot_curve, corresponding_angles_deg)
      analyzer.metrics_list["fwhm"].append(fwhm)
      circ_var = nc.compute_circ_var(ot_curve, corresponding_angles_rad)
      analyzer.metrics_list["circ_var"].append(circ_var)
      osi = nc.compute_osi(ot_curve)
      analyzer.metrics_list["osi"].append(osi)
  circ_var_list.append(np.array([val[2] for val in analyzer.metrics_list["circ_var"]]))

In [None]:
color_list = [color_vals["md_green"], color_vals["md_blue"], color_vals["md_red"]]
label_list = ["Linear Autoencoder", "Sparse Autoencoder", "Sparse Coding"]
num_bins = 30
width_ratios = [0.5, 0.25, 0.25]
height_ratios = [0.13, 0.25, 0.25, 0.25]
density = False

circ_var_fig = nc.plot_circ_variance_histogram(analyzer_list, circ_var_list, color_list, label_list, num_bins,
  density, width_ratios, height_ratios, text_width=text_width, width_ratio=0.75, dpi=dpi)

for analyzer in analyzer_list:
  for ext in file_extensions:
    save_name = (analyzer.analysis_out_dir+"/vis/circular_variance_combo"
      +"_"+analyzer.analysis_params.save_info+ext)
    circ_var_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

In [None]:
spatial_frequencies = np.stack([np.array(analyzer.bf_stats["spatial_frequencies"]) for analyzer in analyzer_list], axis=0)
circular_variances = np.stack([variance for variance in circ_var_list], axis=0)

cv_vs_sf_fig = plt.figure(figsize=nc.set_size(text_width, fraction=0.75), dpi=dpi)
ax = cv_vs_sf_fig.add_subplot()
for analyzer_idx in range(len(analyzer_list)):
  ax.scatter(spatial_frequencies[analyzer_idx, :], circular_variances[analyzer_idx, :],
    s=12, color=color_list[analyzer_idx], label=label_list[analyzer_idx])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("Spatial Frequency (Cycles/Patch)")
ax.set_ylabel("Circular Variance")
#ax.set_title("Weight spatial frequency alone does\nnot account for improved selectivity")
legend = ax.legend(loc="upper center", framealpha=1.0, ncol=1, borderaxespad=0., borderpad=0.,
  handlelength=0., labelspacing=0.1, bbox_to_anchor=(0.38, 0.98))
legend.get_frame().set_linewidth(0.0)
for text, color in zip(legend.get_texts(), color_list):
  text.set_color(color)
for item in legend.legendHandles:
  item.set_visible(False)
plt.show()

for analyzer in analyzer_list:
  for ext in file_extensions:
    save_name = (analyzer.analysis_out_dir+"/vis/spatial_freq_vs_circular_variance"
      +"_"+analyzer.analysis_params.save_info+ext)
    cv_vs_sf_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

In [None]:
params_list = [lca_512_vh_params(), lca_1024_vh_params(), lca_2560_vh_params()]
display_names = ["512 Neurons", "1024 Neurons", "2560 Neurons"]
for params, display_name in zip(params_list, display_names):
  params.display_name = display_name
  params.model_dir = (os.path.expanduser("~")+"/Work/Projects/"+params.model_name)
analyzer_list = [ap.get_analyzer(params.model_type) for params in params_list]
for analyzer, params in zip(analyzer_list, params_list):
  analyzer.setup(params)
  analyzer.model.setup(analyzer.model_params)
  analyzer.load_analysis(save_info=params.save_info)
  analyzer.model_name = params.model_name

In [None]:
circ_var_list = []
for analyzer in analyzer_list:
  analyzer.bf_indices = np.random.choice(analyzer.ot_grating_responses["neuron_indices"], 12)
  num_orientation_samples = len(analyzer.ot_grating_responses['orientations'])
  corresponding_angles_deg = (180 * np.arange(num_orientation_samples) / num_orientation_samples) - 90
  corresponding_angles_rad = (np.pi * np.arange(num_orientation_samples) / num_orientation_samples) - (np.pi/2)
  analyzer.metrics_list = {"fwhm":[], "circ_var":[], "osi":[], "skipped_indices":[]}
  contrast_idx = -1
  for bf_idx in range(analyzer.bf_stats["num_outputs"]):
    ot_curve = nc.center_curve(analyzer.ot_grating_responses["mean_responses"][bf_idx, contrast_idx, :])
    if np.max(ot_curve) - np.min(ot_curve) == 0:
      analyzer.metrics_list["skipped_indices"].append(bf_idx)
    else:
      fwhm = nc.compute_fwhm(ot_curve, corresponding_angles_deg)
      analyzer.metrics_list["fwhm"].append(fwhm)
      circ_var = nc.compute_circ_var(ot_curve, corresponding_angles_rad)
      analyzer.metrics_list["circ_var"].append(circ_var)
      osi = nc.compute_osi(ot_curve)
      analyzer.metrics_list["osi"].append(osi)
  circ_var_list.append(np.array([val[2] for val in analyzer.metrics_list["circ_var"]]))

In [None]:
color_list = [color_vals["md_green"], color_vals["md_blue"], color_vals["md_red"]]#, color_vals["blk"]]
label_list = display_names
num_bins = 30
width_ratios = [0.5, 0.25, 0.25]
height_ratios = [0.13, 0.25, 0.25, 0.25]
density = True

oc_vs_cv_fig = nc.plot_circ_variance_histogram(analyzer_list, circ_var_list, color_list, label_list, num_bins,
  density, width_ratios, height_ratios, text_width=text_width, width_ratio=0.75, dpi=dpi)

for analyzer in analyzer_list:
  for ext in file_extensions:
    save_name = (analyzer.analysis_out_dir+"/vis/overcompleteness_vs_circular_variance"
      +"_"+analyzer.analysis_params.save_info+ext)
    oc_vs_cv_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.05, dpi=dpi)

# Natural scene selectivity

In [None]:
params_list = [lca_512_vh_params(), lca_768_vh_params(), lca_2560_vh_params()]
model_names = ['lca_512_vh', 'lca_1024_vh', 'lca_2560_vh']
model_types = ['LCA', 'LCA', 'LCA']
model_labels = ['2x', '4x', '10x']
analyzer_list = []
for model_type, model_name, model_label, analysis_params in zip(model_types, model_names, model_labels, params_list):
    analysis_params.projects_dir = os.path.expanduser("~")+"/Work/Projects/"
    analysis_params.model_name = model_name
    analysis_params.version = '0.0'
    analysis_params.model_dir = analysis_params.projects_dir+analysis_params.model_name
    model_log_file = (analysis_params.model_dir+"/logfiles/"+analysis_params.model_name
      +"_v"+analysis_params.version+".log")
    analysis_params.model_type = model_type
    analyzer = ap.get_analyzer(analysis_params.model_type)
    analysis_params.save_info = "analysis_selectivity"
    analyzer.setup(analysis_params)
    analyzer.model_label = model_label
    analyzer.model_type = model_type
    analyzer.nat_selectivity = np.load(analyzer.analysis_out_dir+'savefiles/natural_image_selectivity.npz',
        allow_pickle=True)['data'].item()
    analyzer_list.append(analyzer)

In [None]:
def closest_val_in_array(num, arr):
    curr = arr[0]
    for val in arr:
        if abs(num - val) < abs(num - curr):
            curr = val
    curr_idx = np.argwhere(np.array(arr) == curr).item()
    return arr[curr_idx]

In [None]:
num_interesting_vals = [
    np.array([analyzer.nat_selectivity['num_interesting_img_nl'],
    analyzer.nat_selectivity['num_interesting_img_l']])
    for analyzer in analyzer_list]

num_interesting_medians = np.stack(
    [np.array([np.median(np.array(analyzer.nat_selectivity['num_interesting_img_nl'])),
    np.median(np.array(analyzer.nat_selectivity['num_interesting_img_l']))])
    for analyzer in analyzer_list], axis=0)

num_interesting_means = np.stack(
    [np.array([analyzer.nat_selectivity['num_interesting_img_nl_mean'],
    analyzer.nat_selectivity['num_interesting_img_l_mean']])
    for analyzer in analyzer_list], axis=0)

num_interesting_stds = np.stack(
    [np.array([analyzer.nat_selectivity['num_interesting_img_nl_std'],
    analyzer.nat_selectivity['num_interesting_img_l_std']])
    for analyzer in analyzer_list], axis=0)

array = [
    [1, 2, 3],
    [4, 5, 6],
]

scale = 1
rc_kwargs = {
    'fontsize':scale*matplotlib.rcParams['font.size'],
    'fontfamily':scale*matplotlib.rcParams['font.family'],
    'legend.fontsize': scale*matplotlib.rcParams['font.size'],
    'text.labelsize': scale*matplotlib.rcParams['font.size']
}
figsize = nc.set_size(text_width, fraction=1.00)
with plot.rc.context(**rc_kwargs):
    interesting_imgs_fig, axs = plot.subplots(array, sharey=False, sharex=False, aspect=3.0, figsize=figsize)
    for ovc_idx, overcompleteness in enumerate(num_interesting_vals):
        ax = axs[ovc_idx]
        df = pd.DataFrame(
            overcompleteness.T,
            columns=pd.Index(['Sparse Coding', 'Linear'])#, name='xlabel')
        )
        box_parts = ax.boxplot(
            df,
            notch=True,
            fill=False,
            whis=(5, 95),
            marker='*',
            markersize=1.0,
            lw=1.2
        )
        colors = ['md_red', 'md_green']
        for pc_idx, box in enumerate(box_parts['boxes']):
            box.set_color(color_vals[colors[pc_idx]])
        ax.format(
            ylocator=50,
            ylim=[0, np.max([np.max(val) for val in num_interesting_vals])],
            title=analyzer_list[ovc_idx].nat_selectivity['oc_label'],
            ylabel='Average number of\nintersting images',
            xtickminor=False,
            xgrid=False
        )

    for idx, analyzer in enumerate(analyzer_list):
        ax = axs[idx+3]
        angle_min = 0.0
        angle_max = 90.0
        nbins=20
        bins = np.linspace(angle_min, angle_max, nbins)
        lin_data = [mean for mean in analyzer.nat_selectivity['lin_means'] if mean>0]
        non_lin_data = [mean for mean in analyzer.nat_selectivity['lca_means'] if mean>0]
        hist_list = []
        color_list = [color_vals['md_green'], color_vals['md_red']]
        label_list = ['Linear Autoencoder', 'Sparse Coding']
        handles = []
        hist_max_list = []
        for angles, label, color in zip([lin_data, non_lin_data], label_list, color_list):
          # density means the y vals are probability density function at the bin, normalized such that the integral over the range is 1.
          hist, bin_edges = np.histogram(np.array(angles).flatten(), bins, density=False)
          hist_max_list.append(hist.max())
          hist_list.append(hist)
          bin_left, bin_right = bin_edges[:-1], bin_edges[1:]
          bin_centers = bin_left + (bin_right - bin_left)/2
          handles.append(ax.plot(bin_centers, hist, linestyle='-', drawstyle='steps-mid', color=color, label=label))
        oc = analyzer.nat_selectivity['oc_label']
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xticks(bin_left, minor=True)
        ax.set_xticks(bin_left[::2], minor=False)
        ax.xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.0f'))
        ax.set_xticks([angle_min, angle_max//2, angle_max])
        mid_val = max(hist_max_list)//2
        max_val = int(max(hist_max_list))
        #interval_list = list(range(0, mid_val+51, 50))
        #new_mid = closest_val_in_array(mid_val, interval_list)
        interval_list = list(range(0, max_val+51, 50))
        new_max = closest_val_in_array(max_val, interval_list)
        new_mid = new_max//2
        ax.set_ylim([0, new_max+0.1*new_max])
        ax.set_yticks([0, new_mid, new_max])
    #axs[-1].legend(handles, ncol=1, frameon=False, loc='ur', bbox_to_anchor=[1, 1.02])
    hist_ax_idx = 3
    axs[hist_ax_idx].format(ylabel='Total number of\ninteresting images')
    axs[hist_ax_idx:].format(
        suptitle='Sparse Coding Increases Neuron Selectivity for Natural Signals',
        xlabel='Mean image-to-weight angle',
        xlim=[0, 90],
        ygrid=False
    )
plot.show()

In [None]:
for analyzer in analyzer_list:
    for ext in file_extensions:
        save_name = (analyzer.analysis_out_dir+'/vis/natural_img_selectivity_box_'
            +analyzer.analysis_params.save_info+ext)
        interesting_imgs_fig.savefig(save_name, transparent=False, pad_inches=0.005, dpi=dpi)

In [None]:
nbins=20

color_list = [color_vals['md_green'], color_vals['md_red']]
label_list = ['Linear Autoencoder', 'Sparse Coding']

num_interesting_vals = [
    np.array([analyzer.nat_selectivity['num_interesting_img_nl'],
    analyzer.nat_selectivity['num_interesting_img_l']])
    for analyzer in analyzer_list]

num_interesting_medians = np.stack(
    [np.array([analyzer.nat_selectivity['num_interesting_img_nl_mean'],
    analyzer.nat_selectivity['num_interesting_img_l_mean']])
    for analyzer in analyzer_list], axis=0)

num_interesting_means = np.stack(
    [np.array([analyzer.nat_selectivity['num_interesting_img_nl_mean'],
    analyzer.nat_selectivity['num_interesting_img_l_mean']])
    for analyzer in analyzer_list], axis=0)

num_interesting_stds = np.stack(
    [np.array([analyzer.nat_selectivity['num_interesting_img_nl_std'],
    analyzer.nat_selectivity['num_interesting_img_l_std']])
    for analyzer in analyzer_list], axis=0)

df = pd.DataFrame(
    num_interesting_means,
    index=pd.Index(['',]*3, name=''),#'Overcompleteness'),
    columns=['LCA', 'Linear']
)

array = [
    [1, 1, 1],
    [2, 3, 4],
]

scale = 1
rc_kwargs = {
    'fontsize':scale*matplotlib.rcParams['font.size'],
    'fontfamily':scale*matplotlib.rcParams['font.family'],
    'legend.fontsize': scale*matplotlib.rcParams['font.size'],
    'text.labelsize': scale*matplotlib.rcParams['font.size']
}
figsize = nc.set_size(text_width, fraction=1.00)
with plot.rc.context(**rc_kwargs):
    interesting_imgs_fig, axs = plot.subplots(array, sharey=False, aspect=3.0, figsize=figsize)#, width=0.4*text_width_cm)
    ax = axs[0]
    obj = ax.bar(
        df,
        width=0.6,
        cycle=[color_vals['md_red'], color_vals['md_green']],
        edgecolor='black'
    )
    half_bar_width = np.abs(obj[1].patches[0].xy[0] - obj[0].patches[0].xy[0])/2
    lca_bar_locs = [patch.xy[0]+half_bar_width for patch in obj[0].patches]
    lin_bar_locs = [patch.xy[0]+half_bar_width for patch in obj[1].patches]
    ax.errorbar(lca_bar_locs, num_interesting_means[:,0] , yerr=num_interesting_stds[:,0], color='k', fmt='.')
    ax.errorbar(lin_bar_locs, num_interesting_means[:,1] , yerr=num_interesting_stds[:,1], color='k', fmt='.')
    ax.yaxis.set_major_locator(plt.MaxNLocator(6))

    ax.legend(obj, frameon=False, loc='ur', bbox_to_anchor=[1,1.02])
    ax.format(
        xlocator=1,
        xminorlocator=0.5,
        ytickminor=False,
        #ylim=[0, np.max(num_interesting_means)+np.max(num_interesting_stds)],
        #suptitle='Average number of intersting images'
        ylabel='Average number of\nintersting images',
        xgrid=False
    )
    hist_max_list = []
    for idx, analyzer in enumerate(analyzer_list):
        ax = axs[idx+1]
        angle_min = 0.0
        angle_max = 90.0
        bins = np.linspace(angle_min, angle_max, nbins)
        lin_data = [mean for mean in analyzer.nat_selectivity['lin_means'] if mean>0]
        non_lin_data = [mean for mean in analyzer.nat_selectivity['lca_means'] if mean>0]
        hist_list = []
        for angles, label, color in zip([lin_data, non_lin_data], label_list, color_list):
          # density means the y vals are probability density function at the bin, normalized such that the integral over the range is 1.
          hist, bin_edges = np.histogram(np.array(angles).flatten(), bins, density=False)
          hist_max_list.append(hist.max())
          hist_list.append(hist)
          bin_left, bin_right = bin_edges[:-1], bin_edges[1:]
          bin_centers = bin_left + (bin_right - bin_left)/2
          ax.plot(bin_centers, hist, linestyle='-', drawstyle='steps-mid', color=color, label=label)
        oc = analyzer.nat_selectivity['oc_label']
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xticks(bin_left, minor=True)
        ax.set_xticks(bin_left[::2], minor=False)
        ax.xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.0f'))
        ax.set_xticks([angle_min, angle_max//2, angle_max])
        ax.set_ylim([0, max(hist_max_list)+0.1*max(hist_max_list)])
        ax.set_yticks([0, max(hist_max_list)//2, int(max(hist_max_list))])
        ax.format(title=f'{oc}\n')#, ygrid=False)
        #ax.grid(b=False, which='both', axis='both')
    axs[1].format(ylabel='Total number of\ninteresting images')
    axs[1:].format(
        suptitle='Sparse Coding Increases Neuron Selectivity for Natural Signals',
        xlabel='Mean image-to-weight angle',
        xlim=[0, 90],
        ygrid=False
    )
plot.show()

In [None]:
for analyzer in analyzer_list:
    for ext in file_extensions:
        save_name = (analyzer.analysis_out_dir+'/vis/natural_img_selectivity_bar_'
            +analyzer.analysis_params.save_info+ext)
        interesting_imgs_fig.savefig(save_name, transparent=False, pad_inches=0.005, dpi=dpi)

# Confidence attacks on MLP & LCA network

In [None]:
def get_adv_indices(softmax_conf, all_kept_indices, confidence_threshold, num_images, labels):
  softmax_conf[np.arange(num_images, dtype=np.int32), labels] = 0 # zero confidence at true label
  confidence_indices = np.max(softmax_conf, axis=-1) # highest non-true label confidence
  adversarial_labels = np.argmax(softmax_conf, axis=-1) # index of highest non-true label
  all_above_thresh = np.nonzero(np.squeeze(confidence_indices>confidence_threshold))[0]
  keep_indices = np.array([], dtype=np.int32)
  for adv_index in all_above_thresh:
    if adv_index not in set(all_kept_indices):
      keep_indices = np.append(keep_indices, adv_index)
  return keep_indices, confidence_indices, adversarial_labels

def find_untargeted_conf_index(analysis): # for untargeted attacks
  labels = dp.one_hot_to_dense(analysis['input_labels'].astype(np.int32))
  store_time_step = -1*np.ones(data.shape[0], dtype=np.int32)
  store_labels = np.zeros(data.shape[0], dtype=np.int32)
  store_confidence = np.zeros(data.shape[0], dtype=np.float32)
  store_mses = np.zeros(data.shape[0], dtype=np.float32)
  all_kept_indices = []
  for adv_step in range(1, analysis['num_steps']+1): # first one is original
    keep_indices, confidence_indices, adversarial_labels = get_adv_indices(
      analysis['adversarial_outputs'][0, adv_step, ...],
      all_kept_indices,
      analysis['confidence_threshold'],
      labels.shape[0],
      labels)
    if keep_indices.size > 0:
      all_kept_indices.extend(keep_indices)
      store_time_step[keep_indices] = adv_step
      store_confidence[keep_indices] = confidence_indices[keep_indices]
      store_mses[keep_indices] = analysis['adversarial_input_adv_mses'][0, adv_step, keep_indices]
      store_labels[keep_indices] = adversarial_labels[keep_indices]
  batch_indices = np.arange(labels.shape[0], dtype=np.int32)[:,None]
  failed_indices = np.array([val for val in batch_indices if val not in all_kept_indices])
  if len(failed_indices) > 0:
    store_confidence[failed_indices] = confidence_indices[failed_indices]
    store_labels[failed_indices] = adversarial_labels[failed_indices]
    store_mses[failed_indices] = analysis['adversarial_input_adv_mses'][0, -1, failed_indices]
  output = {}
  output['adversarial_time_step'] = [store_time_step]
  output['adversarial_confidence'] = [store_confidence]
  output['failed_indices'] = [failed_indices]
  output['success_indices'] = [list(set(all_kept_indices))]
  output['adversarial_labels'] = [store_labels]
  output['mean_squared_distances'] = [store_mses]
  output['num_failed'] = [labels.shape[0] - len(set(all_kept_indices))]
  return output

def std_conf(outputs, labels, index):
  return np.std(np.sum(outputs[index] * labels, axis=1))

def get_cifar_mse_data(saved_info, model_names):
  data = []; means = []; stds = []
  for model_name in model_names:
    target_adv_mses = saved_info[model_name]["target_adv_mses"]
    data.append(target_adv_mses)
    means.append(np.mean(target_adv_mses))
    stds.append(np.std(target_adv_mses))
  return (data, means, stds)

def adjacent_values(vals, q1, q3):
    upper_adjacent_value = q3 + (q3 - q1) * 1.5
    upper_adjacent_value = np.clip(upper_adjacent_value, q3, vals[-1])
    lower_adjacent_value = q1 - (q3 - q1) * 1.5
    lower_adjacent_value = np.clip(lower_adjacent_value, vals[0], q1)
    return lower_adjacent_value, upper_adjacent_value

def make_violin(ax, group_data, group_means, x_pos, bar_width, color='k', plot_means=True, plot_medians=False):
  for data, means, pos in zip(group_data, group_means, x_pos):
    parts = ax.violinplot(data, [pos], widths=bar_width,
      showmeans=False, showextrema=False, showmedians=False, bw_method="silverman")#, bw_method=0.5)
    for pc in parts['bodies']:
      pc.set_facecolor(color)
      pc.set_edgecolor('k')
      pc.set_alpha(1)
    quartile1, medians, quartile3 = np.percentile(np.array(data), [25, 50, 75])#, axis=1)
    whiskers = np.array([adjacent_values(data, quartile1, quartile3)])
    whiskersMin, whiskersMax = whiskers[:, 0], whiskers[:, 1]
    if plot_medians:
      ax.scatter(pos, medians, marker='o', color='white', s=10, zorder=3, alpha=1)
    if plot_means:
      ax.scatter(pos, means, marker='o', color='white', s=10, zorder=3, alpha=1)
    ax.vlines(pos, quartile1, quartile3, color='k', linestyle='-', lw=5, alpha=1)
    ax.vlines(pos, whiskersMin, whiskersMax, color='k', linestyle='-', lw=1, alpha=1)
  return ax


In [None]:
def get_mnist(file):
    attack_results = np.load(file, allow_pickle=True)['data'].item()
    input_images = attack_results['input_images'].reshape(-1, 28, 28)
    input_labels = np.argmax(attack_results['input_labels'], axis=1)
    adv_images = attack_results['conf_adversarial_images'][0].reshape(-1, 28, 28)
    adv_labels = attack_results['conf_adversarial_labels'][0]
    adv_mse = attack_results['conf_mean_squared_distances'][0]
    adv_mse_means = np.mean(adv_mse)
    adv_mse_stds = np.std(adv_mse)
    return input_images, input_labels, adv_images, adv_labels, adv_mse, adv_mse_means, adv_mse_stds

def get_cifar(results_dict, key):
    attack_results = results_dict[key]
    input_images = attack_results['orig_img'].reshape(-1, 32, 32)
    input_labels = attack_results['orig_label']
    adv_images = attack_results['adv_img'].reshape(-1, 32, 32)
    adv_labels = attack_results['target_label']
    adv_mse = attack_results['target_adv_mses']
    adv_mse_means = np.mean(adv_mse)
    adv_mse_stds = np.std(adv_mse)
    return input_images, input_labels, adv_images, adv_labels, adv_mse, adv_mse_means, adv_mse_stds

def get_adv_data(results_files):
    """
    results_files nested lists with indices that specify [data_type][model_type][num_neurons][num_layers], for example: [mnist/cifar][mlp/lca][768/1568][2L/3L]
    """
    input_images = []
    input_labels = []
    adv_images = []
    adv_labels = []
    adv_mse = []
    adv_mse_means = []
    adv_mse_stds = []
    data_types = ['mnist', 'cifar']
    for data_idx in range(len(data_types)):
        input_images.append([])
        input_labels.append([])
        adv_images.append([])
        adv_labels.append([])
        adv_mse.append([])
        adv_mse_means.append([])
        adv_mse_stds.append([])
        for model_idx in range(len(results_files[0])):
            input_images[-1].append([])
            input_labels[-1].append([])
            adv_images[-1].append([])
            adv_labels[-1].append([])
            adv_mse[-1].append([])
            adv_mse_means[-1].append([])
            adv_mse_stds[-1].append([])
            for neurons_idx in range(len(results_files[0][model_idx])):
                input_images[-1][-1].append([])
                input_labels[-1][-1].append([])
                adv_images[-1][-1].append([])
                adv_labels[-1][-1].append([])
                adv_mse[-1][-1].append([])
                adv_mse_means[-1][-1].append([])
                adv_mse_stds[-1][-1].append([])
                for layers_idx in range(len(results_files[0][model_idx][neurons_idx])):
                    if data_types[data_idx] == 'mnist':
                        file = results_files[0][model_idx][neurons_idx][layers_idx]
                        outputs = get_mnist(file)
                    elif data_types[data_idx] == 'cifar':
                        outputs = get_cifar(results_files[1], cifar_keys[model_idx][neurons_idx][layers_idx])
                    input_images[data_idx][model_idx][neurons_idx].append(outputs[0])
                    input_labels[data_idx][model_idx][neurons_idx].append(outputs[1])
                    adv_images[data_idx][model_idx][neurons_idx].append(outputs[2])
                    adv_labels[data_idx][model_idx][neurons_idx].append(outputs[3])
                    adv_mse[data_idx][model_idx][neurons_idx].append(outputs[4])
                    adv_mse_means[data_idx][model_idx][neurons_idx].append(outputs[5])
                    adv_mse_stds[data_idx][model_idx][neurons_idx].append(outputs[6])
    return input_images, input_labels, adv_images, adv_labels, adv_mse, adv_mse_means, adv_mse_stds

In [None]:
def make_boxplot(ax, group_data, group_means, x_pos, bar_width, linewidth=2, color='k', plot_means=True, plot_medians=False):
  boxprops = dict(linestyle='-', linewidth=linewidth, color=color)
  whiskerprops = boxprops
  capprops = boxprops
  medianprops = dict(linestyle='--', linewidth=linewidth, color='k')
  meanprops = dict(linestyle='-', linewidth=linewidth, color='k')
  for data, means, pos in zip(group_data, group_means, x_pos): # indexing 4 neurons;layers comparisons
    ax.boxplot(data, sym='', positions=[pos], whis=(5, 95), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,
      whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops, meanprops=meanprops)
  return ax

def plot_data_mse(ax, mses, means, stds, num_groups, num_per_group, bar_groups, bar_width, COLORS,
             inner_group_names, outer_group_names, title):
  linewidth = 1
  for i_g in range(num_per_group): 
    group_data = [item for sublist in mses[::-1][i_g] for item in sublist]
    group_means = [item for sublist in means[::-1][i_g] for item in sublist]
    group_stds = [item for sublist in stds[::-1][i_g] for item in sublist]
    x_pos = np.arange(num_groups) + i_g * bar_width
    ax = make_boxplot(ax, group_data, group_means, x_pos, bar_width, linewidth, COLORS[i_g])
  legend_elements = [Line2D([0], [0], color=COLORS[0], lw=8),
                     Line2D([0], [0], color=COLORS[1], lw=8)]
  legend = ax.legend(legend_elements, inner_group_names, framealpha=1.0)
  legend.get_frame().set_linewidth(0.0)
  ax.set_xticks([r + (bar_width)/2 for r in range(num_groups)])
  ax.set_xticklabels(outer_group_names)
  ax.set_ylabel("Input to Adversarial MSD")
  ax.set_xlabel('Number of Layers and Neurons')
  ax.tick_params("x", labelrotation=labelrotation)
  ax.spines["top"].set_visible(False)
  ax.spines["right"].set_visible(False)
  ylim = ax.get_ylim()
  ax.set_ylim([0, ylim[1]])
  ax.set_title(title)
  return ax

def plot_adv_robustness(mse_list, mean_list, std_list, num_groups, num_per_group, bar_groups, bar_width,
                        colors, inner_group_names, outer_group_names, titles, text_width=200,
                        width_ratio=1.0, dpi=100):
  mnist_mse, cifar_mse = mse_list
  mnist_means, cifar_means = mean_list
  mnist_stds, cifar_stds = std_list
  mnist_outer_group_names, cifar_outer_group_names = outer_group_names
  mnist_title, cifar_title = titles
  num_y_plots = 2
  num_x_plots = 2
  fig = plt.figure(figsize=nc.set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi)
  gs_base = plt.GridSpec(num_y_plots, num_x_plots, figure=fig, wspace=0.3)
  ax_mnist_mse = fig.add_subplot(gs_base[:, 0])
  ax_mnist_mse = plot_data_mse(ax_mnist_mse, mnist_mse, mnist_means, mnist_stds, num_groups, num_per_group,
    bar_groups, bar_width, colors, inner_group_names, mnist_outer_group_names, mnist_title)
  mnist_legend = ax_mnist_mse.get_legend()
  mnist_legend.set_visible(False)
  ax_cifar_mse = fig.add_subplot(gs_base[:, 1])
  ax_cifar_mse = plot_data_mse(ax_cifar_mse, cifar_mse, cifar_means, cifar_stds, num_groups, num_per_group,
    bar_groups, bar_width, colors, inner_group_names, cifar_outer_group_names, cifar_title)
  ax_cifar_mse.set_ylabel("")
  ax_mnist_mse.grid(b=False, which='both', axis='both')
  ax_cifar_mse.grid(b=False, which='both', axis='both')
  plt.show()
  return fig

In [None]:
# adv stop confidence
#stop_conf = 0.90#.95

# path to projects directory
projects_dir = root_path+'/Projects/'
analysis_dir = '/analysis/0.0/'

# kurakin analysis path
#k_file_path = analysis_dir+'savefiles/class_adversary_analysis_test_kurakin_targeted.npz'
#k_img_path = analysis_dir+'savefiles/class_adversary_images_analysis_test_kurakin_targeted.npz'
k_file_path = analysis_dir+'savefiles/class_adversary_analysis_test_temp_kurakin_targeted.npz'
k_img_path = analysis_dir+'savefiles/class_adversary_images_analysis_test_temp_kurakin_targeted.npz'

k_file_path2 = analysis_dir+'savefiles/class_adversary_analysis_test_temp2_kurakin_targeted.npz'
k_img_path2 = analysis_dir+'savefiles/class_adversary_images_analysis_test_temp2_kurakin_targeted.npz'

# carlini analysis path
c_file_path = analysis_dir+'savefiles/class_adversary_analysis_test_carlini_targeted.npz'
c_img_path = analysis_dir+'savefiles/class_adversary_images_analysis_test_carlini_targeted.npz'

# model names - note mnist 768 2layer was retrained and so image/label indices will not match up
mnist_lca_768_2layer = 'slp_lca_768_latent_mnist'#'slp_lca_768_latent_75_steps_mnist'
mnist_lca_768_3layer = 'mlp_lca_768_latent_75_steps_mnist'
mnist_lca_1568_2layer = 'slp_lca_1568_latent_mnist'#'slp_lca_1568_latent_75_steps_mnist'
mnist_lca_1568_3layer = 'mlp_lca_1568_latent_75_steps_mnist'
mnist_mlp_768_2layer = 'mlp_768_mnist'#'mlp_cosyne_mnist'
mnist_mlp_768_3layer = 'mlp_3layer_cosyne_mnist'
mnist_mlp_1568_2layer = 'mlp_1568_mnist'
mnist_mlp_1568_3layer = 'mlp_1568_3layer_mnist'

cifar_lca_1568_2layer = 'mlp_lca_latent_cifar10_gray_2layer'
cifar_lca_1568_3layer = 'mlp_lca_latent_cifar10_gray_3layer'
cifar_lca_3136_2layer = 'mlp_lca_latent_cifar10_gray_3136_2layer'
cifar_lca_3136_3layer = 'mlp_lca_latent_cifar10_gray_3136_3layer'
cifar_mlp_1568_2layer = 'mlp_cifar10_gray_2layer'
cifar_mlp_1568_3layer = 'mlp_cifar10_gray_3layer'
cifar_mlp_3136_2layer = 'mlp_cifar10_gray_3136_2layer'
cifar_mlp_3136_3layer = 'mlp_cifar10_gray_3136_3layer'

output_dir = projects_dir+mnist_mlp_768_2layer+analysis_dir+'vis/'

#[mlp/lca][768/1568][2L/3L]
mnist_files = [
    [ # mlp
        [ # 768
            projects_dir+mnist_mlp_768_2layer+k_file_path,
            projects_dir+mnist_mlp_768_3layer+k_file_path
        ], [ # 1568
            projects_dir+mnist_mlp_1568_2layer+k_file_path,
            projects_dir+mnist_mlp_1568_3layer+k_file_path
        ]
    ], [ # lca
        [ # 768
            projects_dir+mnist_lca_768_2layer+k_file_path,
            projects_dir+mnist_lca_768_3layer+k_file_path
        ], [ # 1568
            projects_dir+mnist_lca_1568_2layer+k_file_path2,
            projects_dir+mnist_lca_1568_3layer+k_file_path
        ]
        
    ]
]

#[mlp/lca][1568/3136][2L/3L]
cifar_keys = [
    [ # mlp
        [ # 1568
            cifar_mlp_1568_2layer, cifar_mlp_1568_3layer
        ], [ # 3136
            cifar_mlp_3136_2layer, cifar_mlp_3136_3layer
        ]
    ], [ # lca
        [ # 1568
            cifar_lca_1568_2layer, cifar_lca_1568_3layer
        ], [ # 3136
            cifar_lca_3136_2layer, cifar_lca_3136_3layer
        ]
        
    ]
]

#Load data
pickle_filename = (root_path+'/DeepSparseCoding/tf1x/vis/'
    +'vis_class_adversarial_analysis.pkl')#CIFAR10_adv_Sheng.pkl'
with open(pickle_filename, 'rb') as f:
    cifar_saved_info = pickle.load(f)
    
# file_lists is indexed [mnist/cifar][mlp/lca][768/1568][2L/3L]
file_lists = [mnist_files, cifar_saved_info]

In [None]:
#[mnist/cifar][mlp/lca][768/1568][2L/3L]
orig_images, orig_labels, adv_images, target_labels, mses, mse_means, mse_stds = get_adv_data(file_lists)

In [None]:
labelrotation = 50
bar_width = 0.4
inner_group_names = ["w/ LCA", "w/o LCA"]
mnist_outer_group_names = ["2L; 768N", "3L; 768N", "2L; 1568N", "3L; 1568N"]
cifar_outer_group_names = ["2L; 1568N", "3L; 1568N", "2L; 3136N", "3L; 3136N"]
COLORS = [
  [1.0, 0.0, 0.0], #"r"
  [0.0, 0.0, 1.0] #"b"
]
# bar_groups are organized from left to right, with space between inner lists
bar_groups = np.array([[0, 1], [2, 3], [4, 5], [6, 7]])
num_groups = bar_groups.shape[0]
num_per_group = bar_groups.shape[1]

outer_group_names = [mnist_outer_group_names, cifar_outer_group_names]
titles = ["MNIST", "Grayscale CIFAR"]

adv_fig = plot_adv_robustness(mses, mse_means, mse_stds, num_groups, num_per_group, bar_groups, bar_width,
    COLORS, inner_group_names, outer_group_names, titles, text_width, width_ratio=1.0, dpi=dpi)

#for analyzer in analyzer_list:
for ext in file_extensions:
    save_name = (output_dir+'/adv_mse_comparison_boxplots'+ext)
    adv_fig.savefig(save_name, transparent=False, bbox_inches='tight', pad_inches=0.05, dpi=dpi)

In [None]:
def convert_cifar_label_set(labels):
    return [convert_cifar_label(label) for label in labels]

def convert_cifar_label(label):
    if(label == 0 or label == "0"):
        return "airplane"
    if(label == 1 or label == "1"):
        return "automobile"
    if(label == 2 or label == "2"):
        return "bird"
    if(label == 3 or label == "3"):
        return "cat"
    if(label == 4 or label == "4"):
        return "deer"
    if(label == 5 or label == "5"):
        return "dog"
    if(label == 6 or label == "6"):
        return "frog"
    if(label == 7 or label == "7"):
        return "horse"
    if(label == 8 or label == "8"):
        return "ship"
    if(label == 9 or label == "9"):
        return "truck"
    return None

def show_image_with_label(axis, image, clf, vmin=0, vmax=1, cmap="Greys"):
  im = axis.imshow(image, cmap=cmap, interpolation="nearest", vmin=vmin, vmax=vmax)
  for spine in axis.spines.values():
    spine.set_visible(False)
  axis.tick_params(
    axis="both",
    bottom="off",
    top="off", 
    left="off",
    right="off")
  axis.set_xticks([])
  axis.set_yticks([])
  props = dict(facecolor='white', alpha=1)
  if clf is not None:
    axis.text(.05, .95, str(clf), bbox=props,
      verticalalignment='center', color = "black")
  return im

def make_grid_subplots(fig, gs, mlp_grp, lca_grp,
                       labels,
                       #orig_labels, target_labels, group_names,
                       group_name_loc,
                       orig_y_adj, start_idx=0, num_categories=3, crop_amount=0, hspace=0.5, wspace=-0.4,
                       cmap="Greys_r"):
    """
    labels [mlp/lca][orig/target/plot][2L768N/2L1568N]
    """
    ax_orig_list = []
    gs_sub0_list = []
    gs_sub1_list = []
    for i in range(num_categories):
        ax_orig_list.append(fig.add_subplot(gs[i, :2]))
        gs_sub0_list.append(gridspec.GridSpecFromSubplotSpec(2, 2, gs[i, 2:4], hspace=hspace, wspace=wspace))
        gs_sub1_list.append(gridspec.GridSpecFromSubplotSpec(2, 2, gs[i, 4:], hspace=hspace, wspace=wspace))
    for category_idx in range(num_categories):
        image_idx = category_idx + start_idx
        orig_ax = ax_orig_list[category_idx]
        if category_idx == 0:
            orig_ax.set_title("Unperturbed", y=orig_y_adj)
        orig_img = crop(np.squeeze(mlp_grp[0][0][image_idx, ...]), crop_amount)
        orig_im_handle = show_image_with_label(orig_ax, orig_img, labels[0][0][0][image_idx], cmap=cmap)
        for model_idx, gs_sub in enumerate([gs_sub0_list[category_idx], gs_sub1_list[category_idx]]):
            mlp_adv_img = crop(np.squeeze(mlp_grp[1][model_idx][image_idx, ...]), crop_amount)
            mlp_diff_img = crop(np.squeeze(mlp_grp[2][model_idx][image_idx, ...]), crop_amount)
            lca_adv_img = crop(np.squeeze(lca_grp[1][model_idx][image_idx, ...]), crop_amount)
            lca_diff_img = crop(np.squeeze(lca_grp[2][model_idx][image_idx, ...]), crop_amount)
            diff_vmin = np.min(np.concatenate((mlp_diff_img, lca_diff_img)))
            diff_vmax = np.max(np.concatenate((mlp_diff_img, lca_diff_img)))
            img_list = [[mlp_adv_img, lca_adv_img], [mlp_diff_img, lca_diff_img]]
            for i in range(2): # row [adv, pert]
                for j in range(2): # col [w/o LCA, w/ LCA]
                    current_ax = fig.add_subplot(gs_sub[i, j])
                    current_target_label = None
                    current_image = img_list[i][j]
                    if i == 0:
                        vmin = 0.0
                        vmax = 1.0
                        if j == 0: # top left image
                            if model_idx == 0:
                                current_ax.set_ylabel(r"$s^{*}$")
                            current_target_label = labels[j][1][model_idx][image_idx]
                            if category_idx == 0: # top category only
                                x_loc = group_name_loc[0]
                                y_loc = group_name_loc[1]
                                text_handle = current_ax.text(x_loc, y_loc, labels[j][2][model_idx],
                                    horizontalalignment='left', verticalalignment='bottom')
                    if i == 1:
                        vmin = np.round(diff_vmin, 2)
                        vmax = np.round(diff_vmax, 2)
                        if j == 0 and model_idx == 0:
                            current_ax.set_ylabel(r"$e$")
                        if j == 0 and category_idx == num_categories-1: # bottom left
                            current_ax.set_xlabel("w/o\nLCA")
                        elif j == 1 and category_idx == num_categories-1: # bottom right
                            current_ax.set_xlabel("w/\nLCA")
                    im_handle = show_image_with_label(current_ax, current_image, current_target_label, vmin=vmin, vmax=vmax, cmap=cmap)
                    if j == 1:
                      pf.add_colorbar_to_im(im_handle, aspect=10, ax=current_ax, ticks=[vmin, vmax])

def plot_adv_images(image_groups, labels, mnist_start_idx, cifar_start_idx, text_width=200, width_ratio=1.0, dpi=100):
    mnist_mlp_grp, mnist_lca_grp = image_groups[0]
    cifar_mlp_grp, cifar_lca_grp = image_groups[1]
    hspace_0 = -0.3
    wspace_0 = 0.3
    hspace_1 = 0.0
    wspace_1 = 3.3
    hspace_2 = -0.7
    wspace_2 = 0.2
    orig_y_adj = 1.15
    img_label_loc = [-8.0, -8.0] # [x, y]
    num_y_plots = 2
    num_x_plots = 1
    fig = plt.figure(figsize=nc.set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi)
    gs_base = plt.GridSpec(num_y_plots, num_x_plots, figure=fig, wspace=wspace_0, hspace=hspace_0)
    num_categories = 1
    
    gs_mnist = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs_base[0], hspace=hspace_1, wspace=wspace_1)
    make_grid_subplots(fig, gs_mnist, mnist_mlp_grp, mnist_lca_grp,
        labels[0], img_label_loc, orig_y_adj, mnist_start_idx,
        num_categories, hspace=hspace_2, wspace=wspace_2, cmap="Greys")
    
    gs_cifar = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs_base[1], hspace=hspace_1, wspace=wspace_1)
    make_grid_subplots(fig, gs_cifar, cifar_mlp_grp, cifar_lca_grp,
        labels[1], img_label_loc, orig_y_adj, cifar_start_idx,
        num_categories, crop_amount=2, hspace=hspace_2, wspace=wspace_2, cmap="Greys_r")
    
    plt.show()
    return fig

In [None]:
mnist_start_idx = 37
cifar_start_idx = 16

mnist_img_labels = [mnist_outer_group_names[0], mnist_outer_group_names[2]] # only looking at 2L models
cifar_img_labels = [cifar_outer_group_names[0], cifar_outer_group_names[2]]

mnist = 0
cifar = 1
mlp = 0
lca = 1
small = 0
large = 1
shallow = 0
deep = 1

# from [orig/adv][mnist/cifar][mlp/lca][768/1568][2L/3L]
# to   [mnist/cifar][mlp/lca][orig/adv/diff][2L768N/2L1568N]
image_groups = [
   [ #mnist 
       [ # mlp
           [ # orig
               orig_images[mnist][mlp][small][shallow],
               orig_images[mnist][mlp][large][shallow]
           ],
           [ # adv
               adv_images[mnist][mlp][small][shallow],
               adv_images[mnist][mlp][large][shallow]
           ],
           [ # diff
               orig_images[mnist][mlp][small][shallow] - adv_images[mnist][mlp][small][shallow],
               orig_images[mnist][mlp][large][shallow] - adv_images[mnist][mlp][large][shallow]
           ]
       ],
       [ # lca
           [ # orig
               orig_images[mnist][lca][small][shallow],
               orig_images[mnist][lca][large][shallow] 
           ],
           [ # adv
               adv_images[mnist][lca][small][shallow],
               adv_images[mnist][lca][large][shallow]
           ],
           [ # diff
               orig_images[mnist][lca][small][shallow] - adv_images[mnist][lca][small][shallow],
               orig_images[mnist][lca][large][shallow] - adv_images[mnist][lca][large][shallow]
           ]
       ]
   ],
   [ #cifar
       [ # mlp
           [ # orig
               orig_images[cifar][mlp][small][shallow],
               orig_images[cifar][mlp][large][shallow]
           ],
           [ # adv
               adv_images[cifar][mlp][small][shallow],
               adv_images[cifar][mlp][large][shallow]
           ],
           [ # diff
               orig_images[cifar][mlp][small][shallow] - adv_images[cifar][mlp][small][shallow],
               orig_images[cifar][mlp][large][shallow] - adv_images[cifar][mlp][large][shallow]
           ]
       ],
       [ # lca
           [ # orig
               orig_images[cifar][lca][small][shallow],
               orig_images[cifar][lca][large][shallow]
           ],
           [ # adv
               adv_images[cifar][lca][small][shallow],
               adv_images[cifar][lca][large][shallow]
           ],
           [ # diff
               orig_images[cifar][lca][small][shallow] - adv_images[cifar][lca][small][shallow],
               orig_images[cifar][lca][large][shallow] - adv_images[cifar][lca][large][shallow]
           ]
       ]
   ]
]

label_groups = [
   [ #mnist 
       [ # mlp
           [ # orig
               orig_labels[mnist][mlp][small][shallow],
               orig_labels[mnist][mlp][large][shallow]
           ],
           [ # adv
               target_labels[mnist][mlp][small][shallow],
               target_labels[mnist][mlp][large][shallow]
           ],
           mnist_img_labels
       ],
       [ # lca
           [ # orig
               orig_labels[mnist][lca][small][shallow],
               orig_labels[mnist][lca][large][shallow] 
           ],
           [ # adv
               target_labels[mnist][lca][small][shallow],
               target_labels[mnist][lca][large][shallow]
           ],
           mnist_img_labels
       ]
   ],
   [ #cifar
       [ # mlp
           [ # orig
               convert_cifar_label_set(orig_labels[cifar][mlp][small][shallow]),
               convert_cifar_label_set(orig_labels[cifar][mlp][large][shallow])
           ],
           [ # adv
               convert_cifar_label_set(target_labels[cifar][mlp][small][shallow]),
               convert_cifar_label_set(target_labels[cifar][mlp][large][shallow])
           ],
           cifar_img_labels
       ],
       [ # lca
           [ # orig
               convert_cifar_label_set(orig_labels[cifar][lca][small][shallow]),
               convert_cifar_label_set(orig_labels[cifar][lca][large][shallow])
           ],
           [ # adv
               convert_cifar_label_set(target_labels[cifar][lca][small][shallow]),
               convert_cifar_label_set(target_labels[cifar][lca][large][shallow])
           ],
           cifar_img_labels
       ]
   ]
]

#labels = [orig_labels, target_labels, [mnist_img_labels, cifar_img_labels]]
adv_img_fig = plot_adv_images(image_groups, label_groups, mnist_start_idx, cifar_start_idx, text_width, width_ratio=1.0, dpi=dpi)

In [None]:
#for analyzer in analyzer_list:
for ext in file_extensions:
    save_name = (output_dir+'/adv_mse_comparison_example_images'+ext)
    adv_img_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.05, dpi=dpi)

# Supplementary Figures/Analysis

### Histogram the MSE datapoints

In [None]:
def plot_average_conf_step(analysis_files, model_names):
    fig, ax = plt.subplots()
    for file, name in zip(analysis_files, model_names):
        analysis = np.load(file, allow_pickle=True)["data"].item()
        mean_conf = np.mean(np.max(np.squeeze(analysis['adversarial_outputs']), axis=-1), axis=-1)[1:]
        ax.plot(mean_conf, label=name)
        max_conf = np.max(mean_conf)
        print(f'Maximum confidence: {max_conf}')
    ax.legend(loc="lower right", framealpha=1.0)
    ax.set_xlabel("Attack Step")
    ax.set_ylabel("Mean Target-class Confidence")
    ax.set_ylim(0, 1.01)
    return fig

In [None]:
#[mlp/lca][768/1568][2L/3L]
lca_files = [mnist_files[1][0][0]]
mlp_files = [mnist_files[0][0][0]]
files = mlp_files + lca_files
names = ['w/o LCA', 'w/ LCA'] 

conf_fig = plot_average_conf_step(files, names)
#for analyzer in analyzer_list:
for ext in file_extensions:
    save_name = (output_dir+'adv_mse_comparison_example_images'+ext)
    conf_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

In [None]:
def find_targeted_conf_index(analysis, stop_conf): # for targeted attacks
    outputs = np.squeeze(analysis["adversarial_outputs"])
    inv_true_labels = 1 - np.squeeze(analysis['input_labels'])[None, ...] # add time step dimension
    outputs *= inv_true_labels # zero out correct class
    confs = np.max(outputs, axis=-1)
    stop_indices = []
    for i in range(1, confs.shape[0]):
        gt_stop_conf = np.nonzero(confs[i,:] >= stop_conf)[0]
        if len(gt_stop_conf) > 0:
            stop_indices.append(gt_stop_conf)
        else:
            stop_indices.append(-1)
    return stop_indices

def get_rects(filename_list, metric, stop_conf):
    data = []; means = []; stds = [];
    for file in filename_list:
        metrics = np.load(file, allow_pickle=True)["data"].item()
        stop_indices = find_targeted_conf_index(metrics, stop_conf)
        for stop_step in stop_indices:
            MSE = metrics[metric][0, stop_step, np.arange(metrics[metric].shape[-1])]
            data.append(MSE)
            means.append(np.mean(MSE))
            stds.append(np.std(MSE)) 
    return data, means, stds

def get_mses(file_list):
    data = []; means = []; stds = [];
    for file in file_list:
        results = np.load(file, allow_pickle=True)['data'].item()
        if 'input_adv_mses' in results.keys():
            mse_results = np.squeeze(results['input_adv_mses'])
            conf_results = np.squeeze(results['adversarial_outputs'])
            num_steps, num_images = mse_results.shape
            max_confs = np.max(conf_results, axis=-1)
            stop_times = []
            for image in range(num_images):
                for time in range(num_steps):
                    if max_confs[time, image] > 90:
                        stop_times.append(time)
                        break # stop checking times for this image
            MSE = np.zeros(num_images)
            for data_idx, time in enumerate(stop_times):
                MSE[data_idx] = mse_results[time, data_idx]
        elif 'conf_mean_squared_distances' in results.keys():
            #import IPython; IPython.embed(); raise SystemExit
            stop_time = np.squeeze(results['conf_adversarial_time_step'][0])[:100]
            MSE = np.squeeze(results['conf_mean_squared_distances'])[stop_time, :100]
        else:
            print('bleh')
            import IPython; IPython.embed(); raise SystemExit
        data.append(MSE[:100])
        means.append(np.mean(MSE[:100]))
        stds.append(np.std(MSE[:100])) 
    return data, means, stds

def get_mnist_mse_data(file_lists, metric, stop_conf):
    """
    Params
    ------
    file_lists: list
        list of list with organization
        inner list: model-type filenames (i.e. lca or mlp)
        outer lists: groups across models (i.e. layer depth)
    metric: str
        the analysis metric to average over
    stop_conf: float
        the desired classifier confidence for stopping the adversarial attack
    """
    data = []; means = []; stds = []
    for file_list in file_lists:
        #data_step, means_step, stds_step = get_rects(file_list, metric, stop_conf) 
        data_step, means_step, stds_step = get_mses(file_list)
        means.append(means_step)
        data.append(data_step)
        stds.append(stds_step)
    return (data, means, stds)

def multi_model_compare(ax, data, means, stds, colors, names, xtick_labels,
                        xlabel, ylabel, ylim, width, title, fontsize):
    # orgnaize the data
    cmap_gray = cm.get_cmap("gray")
    N = len(data) # number of depths
    M = len(data[0]) # number of models being compared
    # create the bar chart
    ind = np.arange(M)  # the x locations for the depths    
    rects = []
    for i in range(N):
        # the bars
        x = ind + i * (width+.01)
        rect = ax.bar(x, means[i], color=colors[i], yerr=stds[i], width=width, alpha=1.0)
        rects.append(rect)
        # the data points
        bar_data = np.array(data[i])
        x_tiled = np.tile(x+(width/4), (bar_data.shape[-1],1)).T
        #import IPython; IPython.embed(); raise SystemExit
        ax.scatter(x_tiled, bar_data, color='black', alpha=1.0, s=1, zorder=2)
    ax.set_xticks(ind + ((M-1)*(width+.01))/2)
    ax.set_xticklabels(xtick_labels, fontsize=fontsize)
    ax.set_ylabel(ylabel, fontsize=fontsize)
    ax.set_xlabel(xlabel, fontsize=fontsize)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')
    ax.yaxis.grid(which="major", color=cmap_gray(.8), linestyle='--', linewidth=1)
    ax.tick_params("both", labelsize=fontsize)
    ax.set_axisbelow(True)
    ax.set_ylim([0, ylim])
    ax.set_title(title, fontsize=fontsize)
    ax.legend([r[0] for r in rects], names, fontsize=fontsize, loc='upper right', framealpha=1.0)
    return data

def adv_mse_comparison_plot(ax, file_lists, metric, stop_conf,
                            colors, names, xtick_labels, xlabel, ylabel, ylim, width, title, fontsize):
    """
    Bar chart that compares a designated metric for each model in file_lists
    
    Params
    ------
    file_lists: list
        list of list with organization
        inner list: model-type filenames (i.e. lca or mlp)
        outer lists: groups across models (i.e. layer depth)
    metric: str
        the analysis metric to average over
    stop_conf: float
        the desired classifier confidence for stopping the adversarial attack
    colors: list
        list of list of matplotlib color codes for each model
        has same nested order as file_lists
    xtick_labels: list
        list of the group labels (i.e. layer depths)
    names: list
        names of the models (i.e. lca or mlp)
    xlabel: str
        x axis label
    """
    data, means, stds = get_mnist_mse_data(file_lists, metric, stop_conf)
    return multi_model_compare(ax, data, means, stds, colors, names, xtick_labels, xlabel, ylabel, ylim, width, title, fontsize)


In [None]:
mlp_file = mnist_files[0][0][0]#mnist_mlp_768_2layer#"mlp_cosyne_mnist"
lca_file = mnist_files[1][0][0]#mnist_lca_768_2layer#"slp_lca_768_latent_75_steps_mnist"
lista = 'slp_lista_768_5_layers_mnist'
lista_k_file_path = '/analysis/0.0/savefiles/class_adversary_analysis_test_kurakin_targeted.npz'

lista_file = projects_dir + lista + lista_k_file_path
#lca_img_file, lista_img_file, mlp_img_file = (path + model_name + k_img_path for model_name in [lca, lista, mlp])

In [None]:
colors = [[color_vals['md_blue'], color_vals['md_green'], color_vals['md_red']]]
xtick_labels = ['MLP', 'w/ LISTA', 'w/ LCA']
file_lists = [[mlp_file, lista_file, lca_file]] #, lca_1568_files]

metric = "input_adv_mses"
names = [None]

stop_conf=.90

fig, ax = plt.subplots()
data_points = adv_mse_comparison_plot(ax, file_lists, metric, stop_conf, colors, names,
                                     xtick_labels, "", "Mean Squared Distance", .08, width=.5,
                                     title="Adversarial MNIST at 90% Confidence", fontsize=fontsize)
ax.get_legend().set_visible(False)
plt.show()

#out_list = [path + model_name + "/analysis/0.0/vis/mlp_lista_lca_adv_comparison"
#  for model_name in [mnist_mlp_768_2layer, lista, mnist_lca_768_2layer]]
#for out_name in out_list:
#  for ext in [".png", ".eps"]:
#    save_name = out_name+ext
#    fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

In [None]:
lca_data = get_mnist_mse_data([lca_768_files], metric, stop_conf)[0][0]
n_bins = 25

plt.figure()
hist1, bins = np.histogram(lca_data[0], n_bins)
plt.bar(bins[:n_bins],hist1/len(lca_data[0]), width = .003, alpha=.7, label="lca_2layer")

hist2, bins = np.histogram(lca_data[1], n_bins)
plt.bar(bins[:n_bins],hist2/len(lca_data[1]), width = .003, alpha=.7, label="lca_3layer")

plt.xlabel("Adversarial MSE")
plt.ylabel("Frequency of Images")
plt.legend(framealpha=1.0)
plt.show()

### Average target-class confidence per kurakin attack step

### Average adversarial MSE per carlini attack step

In [None]:
def plot_average_mse_step(analysis_files, recons, confs, colors, title, model_names, bar_width, hatches, figsize, dpi):
    fig = plt.figure(figsize=figsize, dpi=dpi)
    num_conditions = len(analysis_files)
    gs_top = gridspec.GridSpec(num_conditions, num_conditions)
    axes = []
    for condition, (condition_analysis_files, recon, conf) in enumerate(zip(analysis_files, recons, confs)):
        #gs0 = gridspec.GridSpec(1, 2, wspace=0.2, width_ratios = [2, 1])#, hspace=0.3)
        gs0 = gridspec.GridSpecFromSubplotSpec(1, 2, gs_top[condition, :],
            wspace=0.2, width_ratios = [2, 1])#, hspace=0.3)
        left_gs = gridspec.GridSpecFromSubplotSpec(1, 2, gs0[0], wspace=1.3)
        right_gs = gridspec.GridSpecFromSubplotSpec(1, 1, gs0[1], wspace=0.9)
        group_data = []
        group_means = []
        handles = []
        for x_ax_idx, key in enumerate(['input_adv_mses', 'adversarial_outputs']):
            axes.append(fig.add_subplot(left_gs[x_ax_idx]))
            for file_idx, (file, name) in enumerate(zip(condition_analysis_files, model_names)):
                analysis = np.load(file, allow_pickle=True)["data"].item()
                adv_conf = 100*np.max(np.squeeze(analysis['adversarial_outputs']), axis=-1)
                if x_ax_idx == 0:
                    if condition == 0:
                        axes[-1].set_ylabel('Adversarial\nConfidence')
                    axes[-1].axhline(90.0, color='black', linestyle='dashed', linewidth=1) 
                    axes[-1].set_ylim([0, 100.1])
                    mean_vals = np.mean(adv_conf, axis=-1)[1:]
                    std_vals = np.std(adv_conf, axis=-1)[1:]
                else:
                    if condition == 0:
                        axes[-1].set_ylabel('Adversarial Mean\nSquared Distance')
                    adv_mse = np.squeeze(analysis['input_adv_mses'])
                    axes[-1].yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))
                    thresh_indices = np.argwhere(np.mean(adv_conf, axis=-1)>90)
                    first_adv_cross = np.min(thresh_indices[thresh_indices>2]) # first couple are original label
                    axes[-1].axvline(first_adv_cross, color=colors[file_idx][0], linestyle='dashed', linewidth=1)
                    mean_vals = np.mean(adv_mse, axis=-1)[1:]
                    std_vals = np.std(adv_mse, axis=-1)[1:]
                    group_data.append(adv_mse[first_adv_cross, :])
                    group_means.append(mean_vals[first_adv_cross])
                    max_val = 0.03#np.max(mean_vals)+std_vals[np.argmax(mean_vals)]
                    axes[-1].set_ylim([0, max_val])
                axes[-1].plot(range(len(mean_vals)), mean_vals, label=name,
                    lw=2, color=colors[file_idx][0], zorder=1)
                axes[-1].fill_between(range(len(mean_vals)), mean_vals + std_vals , mean_vals - std_vals,
                    edgecolor=colors[file_idx][1], alpha=1.0, zorder=0, facecolor="none",
                    hatch=hatches[file_idx], rasterized=False)
                if condition == num_conditions-1:
                    axes[-1].set_xlabel('Attack Step')
                axes[-1].grid(False)
        axes.append(fig.add_subplot(right_gs[0]))
        x_pos = np.arange(2) + 2 * bar_width
        linewidth = 1
        medianprops = dict(linestyle='--', linewidth=linewidth, color='k')
        meanprops = dict(linestyle='-', linewidth=linewidth, color='k')
        float_colors = [[52/255, 152/255, 219/255], [231/255, 76/255, 60/255]] # blue, red
        axes[-1].set_title(f'c={recon}, '+r'$\kappa$'+f'={conf}')
        for data, means, pos, color, name in zip(group_data, group_means, x_pos, float_colors, model_names):
            boxprops = dict(linestyle='-', linewidth=linewidth, color=color)
            whiskerprops = boxprops
            capprops = boxprops
            handles.append(axes[-1].boxplot(data, sym='', positions=[pos],
                whis=(5, 95), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,
                whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops,
                meanprops=meanprops
            ))
            axes[-1].set_ylim([0, max_val])
            axes[-1].set_yticklabels('')
            axes[-1].get_xaxis().set_ticks([])
            axes[-1].grid(False)
            axes[-1].text(pos, 0.0025, name, horizontalalignment='center', verticalalignment='center')
    fig.subplots_adjust(top=0.8)
    fig.suptitle(title, y=0.98)
    return fig, axes

In [None]:
colors = [[color_vals['md_blue'], color_vals['lt_blue']], [color_vals['md_red'], color_vals['lt_red']]]
model_names = ['w/o LCA', 'w/ LCA']
hatches = ['///', '\\\\\\']

#carlini_title = 'Networks with an LCA layer require larger\nperturbations for equal confidence with the Carlini attack'
#carlini_title = 'Networks with an LCA layer are more robust than without'
carlini_title = ''

all_recons = []
all_confs = []
all_files = []
for recon in ['0.5', '1.0']:
    for conf in ['0.0', '10.0']:
        if conf == '10.0':
            extra_str = '_'
            temp = '1.00'
        else:
            extra_str = ''
            temp = '1.0'
        c_file_path = (f'{analysis_dir}savefiles/class_adversary_analysis_test'+
            f'{extra_str}temp{temp}_conf{conf}_recon{recon}_carlini_targeted.npz')
        c_mlp_files = [projects_dir + model_name + c_file_path for model_name in [mnist_mlp_768_2layer]]
        temp = '0.65'
        c_file_path = (f'{analysis_dir}savefiles/class_adversary_analysis_test'+
            f'{extra_str}temp{temp}_conf{conf}_recon{recon}_carlini_targeted.npz')
        c_lca_files = [projects_dir + model_name + c_file_path for model_name in [mnist_lca_768_2layer]]
        c_files = c_mlp_files + c_lca_files
        all_recons.append(recon)
        all_confs.append(conf)
        all_files.append(c_files)

figsize = nc.set_size(text_width, fraction=1.0, subplot=[2*2, 3])
fig, ax = plot_average_mse_step(all_files, all_recons, all_confs, colors, carlini_title,
    model_names, bar_width, hatches, figsize, dpi)

out_list = [projects_dir + model_name + '/analysis/0.0/vis/carlini_mse_vs_iteration_k0.0-10.0_conditions'
    for model_name in [mnist_lca_768_2layer, mnist_mlp_768_2layer]]
for out_name in out_list:
    for ext in file_extensions:
        save_name = out_name+ext
        fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

## Extra: Differentiable Loss Surface Adversarial Comparison

In [None]:
class params(object):
    def __init__(self, model_name):
        self.model_name = model_name
        self.version = '0.0'
        self.projects_dir = os.path.expanduser("~")+"/Work/Projects/"     
        self.save_info = 'test'
        self.overwrite_analysis_log = False
        self.do_neuron_visualization=False
        
def get_label_est(model_name, input_images):
    # Get params, set dirs
    analysis_params = params(model_name) # construct object

    # Load arguments
    model_name_list = os.listdir(analysis_params.projects_dir)
    analysis_params.model_dir = analysis_params.projects_dir+analysis_params.model_name

    model_log_file = (analysis_params.model_dir+"/logfiles/"+analysis_params.model_name
      +"_v"+analysis_params.version+".log")
    model_logger = Logger(model_log_file, overwrite=False)
    model_log_text = model_logger.load_file()
    model_params = model_logger.read_params(model_log_text)[-1]
    analysis_params.model_type = model_params.model_type

    # Initialize & setup analyzer
    analyzer = ap.get_analyzer(analysis_params.model_type)
    analyzer.setup(analysis_params)
    analysis_params.data_type = analyzer.model_params.data_type
    analyzer.setup_model(analyzer.model_params)
    
    # run forward pass
    with tf.compat.v1.Session(graph=analyzer.model.graph) as sess:
        feed_dict = analyzer.model.get_feed_dict(input_images, is_test=True)
        sess.run(analyzer.model.init_op, feed_dict)
        analyzer.model.load_full_model(sess, analyzer.analysis_params.cp_loc)
        label_est = sess.run(analyzer.model.label_est, feed_dict)
        
    return label_est

### Compare adv. MSE for more differentiable loss surface
 LCA vs LISTA
 
### See if LISTA adv. examples translate to LCA

In [None]:
def get_mnist_data(metric_files, img_files, stop_conf):
  imgs = [np.load(file, allow_pickle=True)["data"].item() for file in img_files]
  metrics = [np.load(file, allow_pickle=True)["data"].item() for file in metric_files]
  indices = [np.array(find_targeted_conf_index(r, stop_conf)) for r in metrics]
  input_images = [r["input_images"].reshape(-1, 28, 28) for r in metrics]
  input_clf = [np.argmax(r["input_labels"], axis=1) for r in metrics]
  adv_images = [r["adversarial_images"][0, i, np.arange(len(i)), :].reshape(-1, 28, 28) for i, r in zip(indices, imgs)]
  adv_clf = [np.argmax(r["adversarial_outputs"][0, i, np.arange(len(i))], axis=1)
    for i, r in zip(indices,metrics)]
  return input_images, input_clf, adv_images, adv_clf

In [None]:
lista_stop_conf = .95
# get lista adv images
#input_images, input_clf, adv_images, adv_clf = get_results([lista_file], [lista_img_file], lista_stop_conf)
input_images, input_clf, adv_images, adv_clf = get_mnist_data([lista_file], [lista_img_file], lista_stop_conf)
# pass through the models
lca_softmax_labels = get_label_est(lca, adv_images[0])
lista_softmax_labels = get_label_est(lista, adv_images[0])
mlp_softmax_labels = get_label_est(mlp, adv_images[0])

In [None]:
# organize the data
lca_confs_lista_adv = [out[cls] for out, cls in zip(lca_softmax_labels, adv_clf[0])]
lca_confs_input = [out[cls] for out, cls in zip(lca_softmax_labels, input_clf[0])]

mlp_confs_lista_adv = [out[cls] for out, cls in zip(mlp_softmax_labels, adv_clf[0])]
mlp_confs_input = [out[cls] for out, cls in zip(mlp_softmax_labels, input_clf[0])]

lista_confs_lista_adv = [out[cls] for out, cls in zip(lista_softmax_labels, adv_clf[0])]
lista_confs_input = [out[cls] for out, cls in zip(lista_softmax_labels, input_clf[0])]

filter_indices = np.where(np.array(lista_confs_lista_adv) > .8)[0]

data = [[lista_confs_input, lista_confs_lista_adv],
        [mlp_confs_input, mlp_confs_lista_adv],
        [lca_confs_input, lca_confs_lista_adv]]

data = [[np.array(confs)[filter_indices] for confs in model_confs] for model_confs in data]
means = [[np.mean(confs) for confs in model_confs] for model_confs in data]
stds = [np.array([[0,0],[np.std(confs) for confs in model_confs]]) for model_confs in data]

In [None]:
colors = [color_vals['md_green'], color_vals['md_blue'], color_vals['md_red']]
xtick_labels = ["Original Label", "Adv Target Label"]
xlabel = None#"Class Position"
ylabel = "Softmax Confidence"
names = ['w/ LISTA', 'MLP', 'w/ LCA']

fig, ax = plt.subplots(1,1)
multi_model_compare(ax, data, means, stds, colors, names, xtick_labels, xlabel, ylabel, 1, width=.25,
                   title="Transferability of LISTA\nAdversarial Images", fontsize=fontsize);
legend = ax.get_legend()
legend.set_bbox_to_anchor([1.28,0.94,0,0], transform=fig.transFigure)
plt.show()

out_list += [path + lista + "/analysis/0.0/vis/lista_adv_transferability"]
for out_name in out_list:
  for ext in file_extensions:
    save_name = out_name+ext
    fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

### Additional Adv Image Examples

In [None]:
def make_grid_subplots_with_fontsize(fig, gs, mlp_grp, lca_grp, orig_labels, target_labels, group_names, group_name_loc,
                       orig_y_adj, start_idx=0, num_categories=3, crop_ammount=0, hspace=0.5, wspace=-0.4,
                       cmap="Greys_r", fontsize=12):
  ax_orig_list = []
  gs_sub0_list = []
  gs_sub1_list = []
  for i in range(num_categories):
    ax_orig_list.append(fig.add_subplot(gs[i, :2]))
    gs_sub0_list.append(gridspec.GridSpecFromSubplotSpec(2, 2, gs[i, 2:4], hspace=hspace, wspace=wspace))
    gs_sub1_list.append(gridspec.GridSpecFromSubplotSpec(2, 2, gs[i, 4:], hspace=hspace, wspace=wspace))
  
  for category_idx in range(num_categories):
    image_idx = category_idx + start_idx
    orig_ax = ax_orig_list[category_idx]
    if category_idx == 0:
      orig_ax.set_title("Unperturbed", y=orig_y_adj)#, fontsize=fontsize)
    orig_img = crop(np.squeeze(mlp_grp[0][0][image_idx, ...]), crop_ammount)
    orig_im_handle = show_image_with_label(orig_ax, orig_img, orig_labels[0][0][image_idx], cmap=cmap)
    for model_idx, gs_sub in enumerate([gs_sub0_list[category_idx], gs_sub1_list[category_idx]]):
      mlp_adv_img = crop(np.squeeze(mlp_grp[1][model_idx][image_idx, ...]), crop_ammount)
      mlp_diff_img = crop(np.squeeze(mlp_grp[2][model_idx][image_idx, ...]), crop_ammount)
      lca_adv_img = crop(np.squeeze(lca_grp[1][model_idx][image_idx, ...]), crop_ammount)
      lca_diff_img = crop(np.squeeze(lca_grp[2][model_idx][image_idx, ...]), crop_ammount)
      diff_vmin = np.min(np.concatenate((mlp_diff_img, lca_diff_img)))
      diff_vmax = np.max(np.concatenate((mlp_diff_img, lca_diff_img)))
      img_list = [[mlp_adv_img, lca_adv_img], [mlp_diff_img, lca_diff_img]]
      for i in range(2): # row [adv, pert]
        for j in range(2): # col [w/o LCA, w/ LCA]
          current_ax = fig.add_subplot(gs_sub[i, j])
          current_target_label = None
          current_image = img_list[i][j]
          if i == 0:
            vmin = 0.0
            vmax = 1.0
            if j == 0: # top left image
              if model_idx == 0:
                current_ax.set_ylabel(r"$s^{*}_{T}$")#, fontsize=fontsize)
              current_target_label = target_labels[j][model_idx][image_idx]
              if category_idx == 0: # top category only
                x_loc = group_name_loc[0]
                y_loc = group_name_loc[1]
                text_handle = current_ax.text(x_loc, y_loc, group_names[j+model_idx],#, fontsize=fontsize,
                  horizontalalignment='left', verticalalignment='bottom')
          else: # i == 1
            vmin = np.round(diff_vmin, 2)
            vmax = np.round(diff_vmax, 2)
            if j == 0 and model_idx == 0:
                current_ax.set_ylabel(r"$s-s^{*}_{T}$")#, fontsize=fontsize)
            if j == 0 and category_idx == num_categories-1: # bottom left
              current_ax.set_xlabel("w/o\nLCA")#, fontsize=fontsize)
            elif j == 1 and category_idx == num_categories-1: # bottom right
              current_ax.set_xlabel("w/\nLCA")#, fontsize=fontsize)
          im_handle = show_image_with_label(current_ax, current_image, current_target_label, vmin=vmin, vmax=vmax, cmap=cmap)
          if j == 1:
            pf.add_colorbar_to_im(im_handle, aspect=10, ax=current_ax, ticks=[vmin, vmax])#, labelsize=fontsize/2)

def plot_adv_images_with_figsize(image_groups, labels, mnist_start_idx, cifar_start_idx, figsize, dpi=100):
  mnist_mlp_grp, mnist_lca_grp = image_groups[0]
  cifar_mlp_grp, cifar_lca_grp = image_groups[1]
  mnist_orig_labels, mnist_target_labels, mnist_img_labels = labels[0]
  cifar_orig_labels, cifar_target_labels, cifar_img_labels = labels[1]
    
  hspace = 0.3
  wspace = 1.8
  sub_hspace = 0.2
  sub_wspace = 0.2
  orig_y_adj = 1.10
  img_label_loc = [-8.0, -8.0] # [x, y]
  fig = plt.figure(figsize=[figsize[0]/2, figsize[1]], dpi=dpi)
  gs0 = plt.GridSpec(2, 1, figure=fig, hspace=0.3)
  
  num_categories=3
  
  gs_mnist = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs0[0], hspace=hspace, wspace=wspace)
  make_grid_subplots_with_fontsize(fig, gs_mnist, mnist_mlp_grp, mnist_lca_grp, mnist_orig_labels,
    mnist_target_labels, mnist_img_labels, img_label_loc, orig_y_adj, mnist_start_idx, num_categories,
    hspace=sub_hspace, wspace=sub_wspace, cmap="Greys")
  
  gs_cifar = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs0[1], hspace=hspace, wspace=wspace)
  make_grid_subplots_with_fontsize(fig, gs_cifar, cifar_mlp_grp, cifar_lca_grp, cifar_orig_labels,
    cifar_target_labels, cifar_img_labels, img_label_loc, orig_y_adj, 0, num_categories,
    hspace=sub_hspace, wspace=sub_wspace, cmap="Greys_r")
  
  plt.show()
  return fig          

In [None]:
figsize = nc.set_size(text_width, fraction=1.0, subplot=[16, 16])
full_adv_img_fig = plot_adv_images_with_figsize(image_groups, label_groups, mnist_start_idx=44, cifar_start_idx=0,
  figsize=figsize, dpi=dpi)