# Neural Computation Paper Figures

### Imports

In [None]:
import os
os.chdir("../")
%env CUDA_VISIBLE_DEVICES=0
%matplotlib inline

In [None]:
import re
import numpy as np
from skimage import measure
from skimage.io import imread
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import FormatStrFormatter
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
import matplotlib.font_manager
import tensorflow as tf
from data.dataset import Dataset
import data.data_selector as ds
import analysis.analysis_picker as ap
import utils.data_processing as dp
import utils.plot_functions as pf
import utils.neural_comp_funcs as nc

### Parameters

In [None]:
figsize = (16, 16)
fontsize = 20
dpi = 200

In [None]:
class lca_512_vh_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_512_vh"
    self.display_name = "Sparse Coding 512"
    self.version = "0.0"
    self.save_info = "analysis_train_carlini_targeted"
    self.overwrite_analysis_log = False

class lca_768_vh_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_768_vh"
    self.display_name = "Sparse Coding 768"
    self.version = "0.0"
    self.save_info = "analysis_train_carlini_targeted"
    self.overwrite_analysis_log = False

class lca_1024_vh_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_1024_vh"
    self.display_name = "Sparse Coding 1024"
    self.version = "0.0"
    self.save_info = "analysis_train_carlini_targeted"
    self.overwrite_analysis_log = False

class lca_2560_vh_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_2560_vh"
    self.display_name = "Sparse Coding"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class sae_768_vh_params(object):
  def __init__(self):
    self.model_type = "sae"
    self.model_name = "sae_768_vh"
    self.display_name = "Sparse Autoencoder"
    self.version = "1.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class rica_768_vh_params(object):
  def __init__(self):
    self.model_type = "rica"
    self.model_name = "rica_768_vh"
    self.display_name = "Linear Autoencoder"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class ae_768_vh_params(object):
  def __init__(self):
    self.model_type = "ae"
    self.model_name = "ae_768_vh"
    self.display_name = "ReLU"
    self.version = "1.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class lca_768_mnist_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_768_mnist"
    self.display_name = "Sparse Coding 768"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class lca_1536_mnist_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_1536_mnist"
    self.display_name = "Sparse Coding 1536"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False

class ae_768_mnist_params(object):
  def __init__(self):
    self.model_type = "ae"
    self.model_name = "ae_768_mnist"
    self.display_name = "Leaky ReLU"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False

class sae_768_mnist_params(object):
  def __init__(self):
    self.model_type = "sae"
    self.model_name = "sae_768_mnist"
    self.display_name = "Sparse Autoencoder"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False

class rica_768_mnist_params(object):
  def __init__(self):
    self.model_type = "rica"
    self.model_name = "rica_768_mnist"
    self.display_name = "Linear Autoencoder"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

class ae_deep_mnist_params(object):
  def __init__(self):
    self.model_type = "ae"
    self.model_name = "ae_deep_mnist"
    self.display_name = "Leaky ReLU"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False

In [None]:
num_levels = 10
color_vals = dict(zip(["lt_green", "md_green", "dk_green", "lt_blue", "md_blue", "dk_blue", "lt_red", "md_red", "dk_red"],
  ["#A9DFBF", "#196F3D", "#27AE60", "#AED6F1", "#3498DB", "#21618C", "#F5B7B1", "#E74C3C", "#943126"]))

### Iso-contour activations comparison

In [None]:
params_list = [rica_768_vh_params(), ae_768_vh_params(), sae_768_vh_params(), lca_2560_vh_params()]
for params in params_list:
  params.model_dir = (os.path.expanduser("~")+"/Work/Projects/"+params.model_name)
analyzer_list = [ap.get_analyzer(params.model_type) for params in params_list]
for analyzer, params in zip(analyzer_list, params_list):
  analyzer.setup(params)
  analyzer.model.setup(analyzer.model_params)
  analyzer.load_analysis(save_info=params.save_info)
  analyzer.model_name = params.model_name

In [None]:
save_name = ""
for analyzer in analyzer_list:
  run_params = np.load(analyzer.analysis_out_dir+"savefiles/iso_params_"+save_name
    +analyzer.analysis_params.save_info+".npz")["data"].item()
  min_angle = run_params["min_angle"]
  max_angle = run_params["max_angle"]
  num_neurons = run_params["num_neurons"]
  use_bf_stats = run_params["use_bf_stats"]
  num_comparison_vectors = run_params["num_comparison_vects"]
  x_range = run_params["x_range"]
  y_range = run_params["y_range"]
  num_images = run_params["num_images"]

  iso_vectors = np.load(analyzer.analysis_out_dir+"savefiles/iso_vectors_"+save_name
    +analyzer.analysis_params.save_info+".npz")["data"].item()
  analyzer.target_neuron_ids = iso_vectors["target_neuron_ids"]
  analyzer.comparison_neuron_ids = iso_vectors["comparison_neuron_ids"]
  analyzer.target_vectors = iso_vectors["target_vectors"]
  analyzer.rand_orth_vectors = iso_vectors["rand_orth_vectors"]
  analyzer.comparison_vectors = iso_vectors["comparison_vectors"]

  analyzer.comp_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_activations_"+save_name
    +analyzer.analysis_params.save_info+".npz")["data"]
  analyzer.comp_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_contour_dataset_"+save_name
    +analyzer.analysis_params.save_info+".npz")["data"].item()
  analyzer.rand_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_activations_"+save_name
    +analyzer.analysis_params.save_info+".npz")["data"]
  analyzer.rand_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_contour_dataset_"+save_name
    +analyzer.analysis_params.save_info+".npz")["data"].item()

In [None]:
neuron_indices = [0, 0, 0, 0]
orth_indices = [0, 0, 0, 1]
num_plots_y = 2
num_plots_x = 2
show_contours = True

fig, contour_handles = nc.plot_goup_iso_contours(analyzer_list, neuron_indices, orth_indices,
  num_levels, x_range, y_range, show_contours, figsize, dpi, fontsize)
for analyzer, neuron_index, orth_index in zip(analyzer_list, neuron_indices, orth_indices):
  for ext in [".png", ".eps"]:
    neuron_str = str(analyzer.target_neuron_ids[neuron_index])
    orth_str = str(analyzer.comparison_neuron_ids[neuron_index][orth_index])
    save_name = analyzer.analysis_out_dir+"/vis/iso_contour_comparison_"
    if not show_contours:
      save_name += "continuous_"
    save_name += "bf0id"+neuron_str+"_bf1id"+orth_str+"_"+analyzer.analysis_params.save_info+ext
    fig.savefig(save_name, dpi=dpi, transparent=True, bbox_inches="tight", pad_inches=0.01)

### Curvature histogram

In [None]:
params_list = [lca_512_vh_params(), lca_768_vh_params(), lca_2560_vh_params()]
iso_save_name = "iso_curvature_xrange1.3_yrange-2.2_"
attn_save_name = "1d_"

for params in params_list:
  params.model_dir = (os.path.expanduser("~")+"/Work/Projects/"+params.model_name)

analyzer_list = [ap.get_analyzer(params.model_type) for params in params_list]
for analyzer, params in zip(analyzer_list, params_list):
  analyzer.setup(params)
  analyzer.model.setup(analyzer.model_params)
  analyzer.load_analysis(save_info=params.save_info)
  analyzer.model_name = params.model_name

  iso_run_params = np.load(analyzer.analysis_out_dir+"savefiles/iso_params_"+iso_save_name+analyzer.analysis_params.save_info+".npz")["data"].item()
  analyzer.iso_comp_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_activations_"+iso_save_name+analyzer.analysis_params.save_info+".npz")["data"]
  analyzer.iso_comp_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_contour_dataset_"+iso_save_name+analyzer.analysis_params.save_info+".npz")["data"].item()
  analyzer.iso_rand_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_activations_"+iso_save_name+analyzer.analysis_params.save_info+".npz")["data"]
  analyzer.iso_rand_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_contour_dataset_"+iso_save_name+analyzer.analysis_params.save_info+".npz")["data"].item()

  analyzer.iso_num_target_neurons = iso_run_params["num_neurons"]
  analyzer.iso_num_comparison_vectors = iso_run_params["num_comparison_vects"]
  
  attn_run_params = np.load(analyzer.analysis_out_dir+"savefiles/iso_params_"+attn_save_name+analyzer.analysis_params.save_info+".npz")["data"].item()
  analyzer.attn_comp_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_activations_"+attn_save_name+analyzer.analysis_params.save_info+".npz")["data"]
  analyzer.attn_comp_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_contour_dataset_"+attn_save_name+analyzer.analysis_params.save_info+".npz")["data"].item()
  analyzer.attn_rand_activations = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_activations_"+attn_save_name+analyzer.analysis_params.save_info+".npz")["data"]
  analyzer.attn_rand_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_rand_contour_dataset_"+attn_save_name+analyzer.analysis_params.save_info+".npz")["data"].item()
  
  analyzer.attn_num_target_neurons = attn_run_params["num_neurons"]
  analyzer.attn_num_comparison_vectors = attn_run_params["num_comparison_vects"]
  
mesh_save_name = "iso_curvature_ryan_"
contour_activity = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_activations_"+mesh_save_name
  +analyzer.analysis_params.save_info+".npz")["data"][0, 1, ...]
comp_contour_dataset = np.load(analyzer.analysis_out_dir+"savefiles/iso_comp_contour_dataset_"+mesh_save_name
  +analyzer.analysis_params.save_info+".npz")["data"].item()
x = range(comp_contour_dataset["x_pts"].size)
y = range(comp_contour_dataset["y_pts"].size)
contour_pts = (x,y)

In [None]:
target_act = 0.3 # target activity spot between min & max value of normalized activity (btwn 0 and 1)
for analyzer in analyzer_list:
  analyzer.iso_comp_curvatures = []
  analyzer.iso_rand_curvatures = []
  activations_and_curvatures = ((analyzer.iso_comp_activations, analyzer.iso_comp_curvatures),
    (analyzer.iso_rand_activations, analyzer.iso_rand_curvatures))
  for activations, curvatures in activations_and_curvatures:
    (num_neurons, num_planes, num_points_y, num_points_x) = activations.shape
    for neuron_id in range(num_neurons):
      sub_curvatures = []
      for plane_id in range(num_planes):
        activity = activations[neuron_id, plane_id, ...]
        ## mirror top half of activations to bottom half to only measure curvature in the upper quadrant
        num_y, num_x = activity.shape 
        activity[:int(num_y/2), :] = activity[int(num_y/2):, :][::-1,:]
        ## compute curvature
        contours = measure.find_contours(activity, target_act)[0]
        x_vals = contours[:,1]
        y_vals = contours[:,0]
        coeffs = np.polynomial.polynomial.polyfit(y_vals, x_vals, deg=2)
        sub_curvatures.append(coeffs[-1])
      curvatures.append(sub_curvatures)

  comp_x_pts = analyzer.attn_comp_contour_dataset["x_pts"]
  rand_x_pts = analyzer.attn_rand_contour_dataset["x_pts"]
  assert(np.all(comp_x_pts == rand_x_pts)) # This makes sure we don't need to recompute proj_datapoints for each case
  num_x_imgs = len(comp_x_pts)
  x_target = comp_x_pts[num_x_imgs-1] # find a location to take a slice
  proj_datapoints = analyzer.attn_comp_contour_dataset["proj_datapoints"]
  slice_indices = np.where(proj_datapoints[:, 0] == x_target)[0]
  analyzer.sliced_datapoints = proj_datapoints[slice_indices, :][:, :] # slice grid

  analyzer.attn_comp_curvatures = []
  analyzer.attn_comp_fits = []
  analyzer.attn_comp_sliced_activity = []
  analyzer.attn_rand_curvatures = []
  analyzer.attn_rand_fits = []
  analyzer.attn_rand_sliced_activity = []
  for neuron_index in range(analyzer.attn_num_target_neurons):
    sub_comp_curvatures = []
    sub_comp_fits = []
    sub_comp_sliced_activity = []
    sub_comp_delta_activity = []
    sub_rand_curvatures = []
    sub_rand_fits = []
    sub_rand_sliced_activity = []
    for orth_index in range(analyzer.attn_num_comparison_vectors):
      comp_activity = analyzer.attn_comp_activations[neuron_index, orth_index, ...].reshape([-1])
      sub_comp_sliced_activity.append(comp_activity[slice_indices][:])
      coeff = np.polynomial.polynomial.polyfit(analyzer.sliced_datapoints[:, 1],
        sub_comp_sliced_activity[-1], deg=2) # [c0, c1, c2], where p = c0 + c1x + c2x^2
      sub_comp_curvatures.append(coeff[2])
      sub_comp_fits.append(np.polynomial.polynomial.polyval(analyzer.sliced_datapoints[:, 1], coeff))
      
    num_rand_vectors = np.minimum(analyzer.bf_stats["num_inputs"]-1, analyzer.attn_num_comparison_vectors)
    for orth_index in range(num_rand_vectors):
      rand_activity = analyzer.attn_rand_activations[neuron_index, orth_index, ...].reshape([-1])
      sub_rand_sliced_activity.append(rand_activity[slice_indices][:])
      coeff = np.polynomial.polynomial.polyfit(analyzer.sliced_datapoints[:, 1],
        sub_rand_sliced_activity[-1], deg=2)
      sub_rand_curvatures.append(coeff[2])
      sub_rand_fits.append(np.polynomial.polynomial.polyval(analyzer.sliced_datapoints[:, 1], coeff))

    analyzer.attn_comp_curvatures.append(sub_comp_curvatures)
    analyzer.attn_comp_fits.append(sub_comp_fits)
    analyzer.attn_comp_sliced_activity.append(sub_comp_sliced_activity)
    analyzer.attn_rand_curvatures.append(sub_rand_curvatures)
    analyzer.attn_rand_fits.append(sub_rand_fits)
    analyzer.attn_rand_sliced_activity.append(sub_rand_sliced_activity)

In [None]:
num_bins = 50
def get_bins(all_curvatures, num_bins=50):
  max_curvature = np.amax(all_curvatures)
  min_curvature = np.amin(all_curvatures)
  bin_width = (max_curvature - min_curvature) / (num_bins-1) # subtract 1 to leave room for the zero bin
  bin_centers = [0.0]
  while min(bin_centers) > min_curvature:
    bin_centers.append(bin_centers[-1]-bin_width)
  bin_centers = bin_centers[::-1]
  while max(bin_centers) < max_curvature:
    bin_centers.append(bin_centers[-1]+bin_width)
  bin_lefts = bin_centers - (bin_width / 2)
  bin_rights = bin_centers + (bin_width / 2)
  bins = np.append(bin_lefts, bin_rights[-1])
  return bins

iso_all_curvatures = []
for analyzer in analyzer_list:
  for neuron_index in range(num_neurons):
    iso_all_curvatures += analyzer.iso_comp_curvatures[neuron_index]
    iso_all_curvatures += analyzer.iso_rand_curvatures[neuron_index]
iso_bins = get_bins(iso_all_curvatures, num_bins)

attn_all_curvatures = []
for analyzer in analyzer_list:
  for neuron_index in range(analyzer.attn_num_target_neurons):
    attn_all_curvatures += analyzer.attn_comp_curvatures[neuron_index]
    attn_all_curvatures += analyzer.attn_rand_curvatures[neuron_index]
attn_bins = get_bins(attn_all_curvatures, num_bins)
  
for analyzer in analyzer_list:
  flat_comp_curvatures = [item for sub_list in analyzer.iso_comp_curvatures for item in sub_list]
  comp_hist, analyzer.iso_bin_edges = np.histogram(flat_comp_curvatures, iso_bins, density=False)
  analyzer.iso_comp_hist = comp_hist / np.sum(comp_hist)
  flat_rand_curvatures = [item for sub_list in analyzer.iso_rand_curvatures for item in sub_list]
  rand_hist, _ = np.histogram(flat_rand_curvatures, iso_bins, density=False)
  analyzer.iso_rand_hist = rand_hist / np.sum(rand_hist)

  flat_comp_curvatures = [item for sub_list in analyzer.attn_comp_curvatures for item in sub_list]
  comp_hist, analyzer.attn_bin_edges = np.histogram(flat_comp_curvatures, attn_bins, density=False)
  analyzer.attn_comp_hist = comp_hist / np.sum(comp_hist)
  flat_rand_curvatures = [item for sub_list in analyzer.attn_rand_curvatures for item in sub_list]
  rand_hist, _ = np.histogram(flat_rand_curvatures, attn_bins, density=False)
  analyzer.attn_rand_hist = rand_hist / np.sum(rand_hist)

In [None]:
iso_hist_list = [[analyzer.iso_comp_hist for analyzer in analyzer_list],
  [analyzer.iso_rand_hist for analyzer in analyzer_list]]
label_list = [["Comparison vectors, 2x", "Comparison vectors, 4x", "Comparison vectors, 10x"],
  ["Random vectors, 2x", "Random vectors, 4x", "Random vectors, 10x"]]
color_list = [[color_vals["lt_red"], color_vals["md_red"], color_vals["dk_red"]],
  [color_vals["lt_blue"], color_vals["md_blue"], color_vals["dk_blue"]]]

plot_bin_lefts, plot_bin_rights = analyzer_list[0].iso_bin_edges[:-1], analyzer_list[0].iso_bin_edges[1:]
iso_plot_bin_centers = plot_bin_lefts + (plot_bin_rights - plot_bin_lefts)

label_loc = [0.5, 0.3]
iso_title = "Iso-Response"
iso_xlabel = "Curvature of Iso-Response Contours"

attn_hist_list = [[analyzer.attn_comp_hist for analyzer in analyzer_list],
  [analyzer.attn_rand_hist for analyzer in analyzer_list]]

plot_bin_lefts, plot_bin_rights = analyzer_list[0].attn_bin_edges[:-1], analyzer_list[0].attn_bin_edges[1:]
attn_plot_bin_centers = plot_bin_lefts + (plot_bin_rights - plot_bin_lefts)

label_loc = [0.08, 0.30]
attn_title = "Response Attenuation"
attn_xlabel = "Curvature of Response Attenuation"

In [None]:
contour_angle = 210

full_hist_list = [iso_hist_list, attn_hist_list]
full_label_list = [label_list,]*2
full_color_list = [color_list,]*2
full_bin_centers = [iso_plot_bin_centers, attn_plot_bin_centers]
full_title = [iso_title, attn_title]
full_xlabel = [iso_xlabel, attn_xlabel]

iso_resp_loc = [0, 180, 0.42]
resp_att_loc = [100, 240, 0.38]
contour_text_loc = [iso_resp_loc, resp_att_loc]

curvature_fig = nc.plot_curvature_histograms(contour_activity, contour_pts, contour_angle, contour_text_loc,
  full_hist_list, full_label_list, full_color_list, full_bin_centers, full_title, full_xlabel,
  figsize=(2*figsize[0], figsize[1]), dpi=dpi, fontsize=fontsize)

for analyzer in analyzer_list:
  for ext in [".png"]:#, ".eps"]:
    save_name = (analyzer.analysis_out_dir+"/vis/curvatures_and_histograms"
      +"_"+analyzer.analysis_params.save_info+ext)
    curvature_fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

### Orientation Selectivity

In [None]:
params_list = [rica_768_vh_params(), sae_768_vh_params(), lca_768_vh_params()]
params_list[-1].display_name = "Sparse Coding"
for params in params_list:
  params.model_dir = (os.path.expanduser("~")+"/Work/Projects/"+params.model_name)
analyzer_list = [ap.get_analyzer(params.model_type) for params in params_list]
for analyzer, params in zip(analyzer_list, params_list):
  analyzer.setup(params)
  analyzer.model.setup(analyzer.model_params)
  analyzer.load_analysis(save_info=params.save_info)
  analyzer.model_name = params.model_name

In [None]:
def plot_contrast_orientation_tuning(bf_indices, contrasts, orientations, activations, figsize=(32,32)):
  """
  Generate contrast orientation tuning curves. Every subplot will have curves for each contrast.
  Inputs:
    bf_indices: [list or array] of neuron indices to use
      all indices should be less than activations.shape[0]
    contrasts: [list or array] of contrasts to use
    orientations: [list or array] of orientations to use
  """
  orientations = np.asarray(orientations)*(180/np.pi) #convert to degrees for plotting
  num_bfs = np.asarray(bf_indices).size
  cmap = plt.get_cmap('Greys')
  cNorm = matplotlib.colors.Normalize(vmin=0.0, vmax=1.0)
  scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)
  fig = plt.figure(figsize=figsize)
  num_plots_y = np.int32(np.ceil(np.sqrt(num_bfs)))+1
  num_plots_x = np.int32(np.ceil(np.sqrt(num_bfs)))
  gs_widths = [1.0,]*num_plots_x
  gs_heights = [1.0,]*num_plots_y
  gs = gridspec.GridSpec(num_plots_y, num_plots_x, wspace=0.5, hspace=0.7,
    width_ratios=gs_widths, height_ratios=gs_heights)
  bf_idx = 0
  for plot_id in np.ndindex((num_plots_y, num_plots_x)):
    (y_id, x_id) = plot_id
    if y_id == 0 and x_id == 0:
      ax = fig.add_subplot(gs[plot_id])
      #ax.set_ylabel("Activation", fontsize=16)
      #ax.set_xlabel("Orientation", fontsize=16)
      ax00 = ax
    else:
      ax = fig.add_subplot(gs[plot_id])#, sharey=ax00)
    if bf_idx < num_bfs:
      for co_idx, contrast in enumerate(contrasts):
        co_idx = -1
        contrast = 1.0#contrasts[co_idx]
        activity = activations[bf_indices[bf_idx], co_idx, :]
        color_val = scalarMap.to_rgba(contrast)
        ax.plot(orientations, activity, linewidth=1, color=color_val)
        ax.scatter(orientations, activity, s=4, c=[color_val])
        ax.yaxis.set_major_formatter(FormatStrFormatter('%0.2g'))
        ax.set_yticks([0, np.max(activity)])
        ax.set_xticks([0, 90, 180])
      bf_idx += 1
    else:
      ax = pf.clear_axis(ax, spines="none")
  plt.show()
  return fig

In [None]:
def plot_weights(weights, title="", figsize=None, save_filename=None):
  """
    weights: [np.ndarray] of shape [num_outputs, num_input_y, num_input_x]
    The matrices are renormalized before plotting.
  """
  weights = dp.norm_weights(weights)
  vmin = np.min(weights)
  vmax = np.max(weights)
  num_plots = weights.shape[0]
  num_plots_y = int(np.floor(np.sqrt(num_plots)))
  num_plots_x = int(np.ceil(np.sqrt(num_plots)))
  fig, sub_ax = plt.subplots(num_plots_y, num_plots_x, figsize=figsize)
  filter_total = 0
  for plot_id in  np.ndindex((num_plots_y, num_plots_x)):
    if filter_total < num_plots:
      sub_ax[plot_id].imshow(np.squeeze(weights[filter_total, ...]), vmin=vmin, vmax=vmax, cmap="Greys_r")
      filter_total += 1
    pf.clear_axis(sub_ax[plot_id])
    sub_ax[plot_id].set_aspect("equal")
  fig.suptitle(title, y=0.95, x=0.5, fontsize=20)
  if save_filename is not None:
      fig.savefig(save_filename)
      plt.close(fig)
      return None
  plt.show()
  return fig

In [None]:
def center_curve(tuning_curve):
  """
  Centers a curve about its preferred orientation
  """
  return np.roll(tuning_curve, (len(tuning_curve) // 2) - np.argmax(tuning_curve))

In [None]:
def compute_fwhm(centered_ot_curve, corresponding_angles_deg):
  """
  Calculates the full width at half maximum of the tuning curve

  Result is expressed in degrees to make it a little more intuitive. The curve
  is often NOT symmetric about the maximum value so we don't do any fitting and
  we return the FULL width

  Parameters
  ----------
  centered_ot_curve : ndarray
      A 1d array of floats giving the value of the ot curve, at an orientation
      relative to the *preferred orientation* which is given by the angles in
      corresponding_angles_deg. This has the maximum orientation in the
      center of the array which is nicer for visualization.
  corresponding_angles_deg : ndarray
      The orientations relative to preferred orientation that correspond to
      the values in centered_ot_curve

  Returns
  -------
  half_max_left : float
      The position of the intercept to the left of the max
  half_max_right : float
      The position of the intercept to the right of the max
  half_max_value : float
      Mainly for plotting purposes, the actual curve value that corresponds
      to the left and right points
  """
  max_idx = np.argmax(centered_ot_curve)
  min_idx = np.argmin(centered_ot_curve)
  max_val = centered_ot_curve[max_idx]
  min_val = centered_ot_curve[min_idx]
  midpoint = (max_val / 2) + (min_val / 2)
  # find the left hand point
  idx = max_idx
  while centered_ot_curve[idx] > midpoint:
    idx -= 1
    if idx == -1:
      # the width is *at least* 90 degrees
      half_max_left = -90.
      break
  if idx > -1:
    # we'll linearly interpolate between the two straddling points
    # if (x2, y2) is the coordinate of the point below the half-max and
    # (x1, y1) is the point above the half-max, then we can solve for x3, the
    # x-position of the point that corresponds to the half-max on the line
    # that connects (x1, y1) and (x2, y2)
    half_max_left = (((midpoint - centered_ot_curve[idx])
      * (corresponding_angles_deg[idx+1] - corresponding_angles_deg[idx])
      / (centered_ot_curve[idx+1] - centered_ot_curve[idx]))
      + corresponding_angles_deg[idx])
  # find the right hand point
  idx = max_idx
  while centered_ot_curve[idx] > midpoint:
    idx += 1
    if idx == len(centered_ot_curve):
      # the width is *at least* 90
      half_max_right = 90.
      break
  if idx < len(centered_ot_curve):
    # we'll linearly interpolate between the two straddling points again
    half_max_right = (((midpoint - centered_ot_curve[idx-1])
      * (corresponding_angles_deg[idx] - corresponding_angles_deg[idx-1])
      / (centered_ot_curve[idx] - centered_ot_curve[idx-1]))
      + corresponding_angles_deg[idx-1])
  return half_max_left, half_max_right, midpoint

In [None]:
def compute_circ_var(centered_ot_curve, corresponding_angles_rad):
  """
  From
  DL Ringach, RM Shapley, MJ Hawken (2002) - Orientation Selectivity in Macaque V1:
  Diversity and Laminar Dependence
  
  Computes the circular variance of a tuning curve and returns vals for plotting

  This is a scale-invariant measure of how 'oriented' a curve is in some
  global sense. It wraps reponses around the unit circle and then sums their
  vectors, resulting in an average vector, the magnitude of which indicates
  the strength of the tuning. Circular variance is an index of 'orientedness'
  that falls in the interval [0.0, 1.0], with 0.0 indicating a delta function
  and 1.0 indicating a completely flat tuning curve.

  Parameters
  ----------
  centered_ot_curve : ndarray
      A 1d array of floats giving the value of the ot curve, at an orientation
      relative to the *preferred orientation* which is given by the angles in
      corresponding_angles_rad. This has the maximum orientation in the
      center of the array which is nicer for visualization.
  corresponding_angles_rad : ndarray
      The orientations relative to preferred orientation that correspond to
      the values in centered_ot_curve

  Returns
  -------
  numerator_sum_components : ndarray
      The complex values the are produced from r * np.exp(j*2*theta). These
      are the elements that get summed up in the numerator
  direction_vector : complex64 or complex128
      This is the vector that points in the direction of *aggregate* tuning.
      its magnitude is upper bounded by 1.0 which is the case when only one
      orientation has a nonzero value. We can plot it to get an idea of how
      tuned a curve is
  circular_variance : float
      This is 1 minus the magnitude of the direction vector. It represents and
      index of 'global selectivity'
  """
  # in the original definition, angles are [0, 2*np.pi] so the factor of 2
  # in the exponential wraps the phase twice around the complex circle,
  # placing responses that correspond to angles pi degrees apart
  # onto the same place. We know there's a redudancy in our responses at pi
  # offsets so our responses get wrapped around the unit circle once.
  numerator_sum_components = (centered_ot_curve
    * np.exp(1j * 2 * corresponding_angles_rad))
  direction_vector = (np.sum(numerator_sum_components)
    / np.sum(centered_ot_curve))
  circular_variance = 1 - np.abs(direction_vector)
  return (numerator_sum_components, direction_vector, circular_variance)

In [None]:
def compute_osi(centered_ot_curve):
  """
  Compute the Orientation Selectivity Index.

  This is the most coarse but popular measure of selectivity. It really
  doesn't tell you much. It just measures the maximum response relative to
  the minimum response.

  Parameters
  ----------
  centered_ot_curve : ndarray
      A 1d array of floats giving the value of the ot curve, at an orientation
      relative to the *preferred orientation*

  Returns
  -------
  osi : float
      This is (a_max - a_orth) / (a_max + a_orth) where a_max is the maximum
      response across orientations when orientation responses are
      *averages* over phase. a_orth is the orientation which is orthogonal to
      the orientation which produces a_max.
  """
  max_val = np.max(centered_ot_curve)
  # Assume that orthogonal orientation is at either end of the curve modulo 1
  # bin (if we had like an even number of orientation values)
  orth_val = centered_ot_curve[0]
  osi = (max_val - orth_val) / (max_val + orth_val)
  return osi

In [None]:
def plot_circular_variance(cv_data, max_bfs_per_fig=400, title="", save_filename=None):
  assert np.sqrt(max_bfs_per_fig) % 1 == 0, "Pick a square number for max_bfs_per_fig"
  orientations = (np.pi * np.arange(len(cv_data))
    / len(cv_data)) - (np.pi/2) # relative to preferred
  num_bfs = len(cv_data)
  num_bf_figs = int(np.ceil(num_bfs / max_bfs_per_fig))
  # this determines how many ot curves are aranged in a square grid within
  # any given figure
  if num_bf_figs > 1:
    bfs_per_fig = max_bfs_per_fig
  else:
    squares = [x**2 for x in range(1, int(np.sqrt(max_bfs_per_fig))+1)]
    bfs_per_fig = squares[bisect.bisect_left(squares, num_bfs)]
  plot_sidelength = int(np.sqrt(bfs_per_fig))
  bf_idx = 0
  bf_figs = []
  for in_bf_fig_idx in range(num_bf_figs):
    fig = plt.figure(figsize=(32, 32))
    plt.suptitle(title + ', fig {} of {}'.format(
      in_bf_fig_idx+1, num_bf_figs), fontsize=20)
    subplot_grid = gridspec.GridSpec(plot_sidelength, plot_sidelength,
      wspace=0.4, hspace=0.4)
    fig_bf_idx = bf_idx % bfs_per_fig
    while fig_bf_idx < bfs_per_fig and bf_idx < num_bfs:
      #if bf_idx % 100 == 0:
      #  print("plotted ", bf_idx, " of ", num_bfs, " circular variance plots")
      ## print("sum vector: ", np.real(cv_data[bf_idx][1]), np.imag(cv_data[bf_idx][1]))
      ax = plt.Subplot(fig, subplot_grid[fig_bf_idx])
      ax.plot(np.real(cv_data[bf_idx][0]), np.imag(cv_data[bf_idx][0]),
              c='g', linewidth=0.5)
      ax.scatter(np.real(cv_data[bf_idx][0]), np.imag(cv_data[bf_idx][0]),
                 c='g', s=4)
      ax.quiver(np.real(cv_data[bf_idx][1]), np.imag(cv_data[bf_idx][1]),
                angles='xy', scale_units='xy', scale=1.0, color='b',
                width=0.01)
      # ax.quiver(0.5, 0.5, color='b')
      ax.axvline(x=0.0, color='k', linestyle='--', alpha=0.6, linewidth=0.3)
      ax.axhline(y=0.0, color='k', linestyle='--', alpha=0.6, linewidth=0.3)
      ax.yaxis.set_major_formatter(FormatStrFormatter('%0.2g'))
      xaxis_size = max(np.max(np.real(cv_data[bf_idx][0])), 1.0)
      yaxis_size = max(np.max(np.imag(cv_data[bf_idx][0])), 1.0)
      ax.set_yticks([-1. * yaxis_size, yaxis_size])
      ax.set_xticks([-1. * xaxis_size, xaxis_size])
      # put the circular variance index in the upper left
      ax.text(0.02, 0.97, 'CV: {:.2f}'.format(cv_data[bf_idx][2]),
              horizontalalignment='left', verticalalignment='top',
              transform=ax.transAxes, color='b', fontsize=10)
      fig.add_subplot(ax)
      fig_bf_idx += 1
      bf_idx += 1
    if save_filename is not None:
      filename_split = os.path.split(save_filename)
      save_filename = filename_split[0]+str(in_bf_fig_idx).zfill(2)+"_"+filename_split[1]
      fig.savefig(save_filename)
      plt.close(fig)
      bf_figs.append(None)
    else:
      bf_figs.append(fig)
  if save_filename is None:
    plt.show()
  return bf_figs

In [None]:
def plot_circular_variance_histogram(variances_list, label_list, num_bins=50, y_max=None,
  figsize=None, save_filename=None):
  variance_min = np.min([np.min(var) for var in variances_list])#0.0
  variance_max = np.max([np.max(var) for var in variances_list])#1.0
  bins = np.linspace(variance_min, variance_max, num_bins)
  bar_width = np.diff(bins).min()
  fig, ax = plt.subplots(1, figsize=figsize)
  hist_list = []
  handles = []
  for variances, label in zip(variances_list, label_list):
    hist, bin_edges = np.histogram(variances.flatten(), bins)
    #hist = hist / np.max(hist)
    hist_list.append(hist)
    bin_left, bin_right = bin_edges[:-1], bin_edges[1:]
    bin_centers = bin_left + (bin_right - bin_left)/2
    handles.append(ax.bar(bin_centers, hist, width=bar_width, log=True, align="center", alpha=0.5, label=label))
  ax.set_xticks(bin_left, minor=True)
  ax.set_xticks(bin_left[::4], minor=False)
  ax.xaxis.set_major_formatter(FormatStrFormatter("%0.0f"))
  ax.tick_params("both", labelsize=16)
  ax.set_xlim([variance_min, variance_max])
  ax.set_xticks([variance_min, variance_max])
  ax.set_xticklabels(["More selective", "Less selective"])
  ticks = ax.xaxis.get_major_ticks()
  ticks[0].label1.set_horizontalalignment("left")
  ticks[1].label1.set_horizontalalignment("right")
  if y_max is None:
    # Round up to the nearest power of 10
    y_max = 10**(np.ceil(np.log10(np.max([np.max(hist) for hist in hist_list]))))
  ax.set_ylim([1, y_max])
  ax.set_title("Circular Variance Histogram", fontsize=18)
  ax.set_xlabel("Selectivity", fontsize=18)
  ax.set_ylabel("Log Count", fontsize=18)
  legend = ax.legend(handles, label_list, fontsize=12, #ncol=len(label_list),
    borderaxespad=0., bbox_to_anchor=[0.98, 0.98], fancybox=True, loc="upper right")
  if save_filename is not None:
    fig.savefig(save_filename)
    plt.close(fig)
    return None
  plt.show()
  return fig

In [None]:
circ_var_list = []
for analyzer in analyzer_list:
  analyzer.bf_indices = np.random.choice(analyzer.ot_grating_responses["neuron_indices"], 12)
  num_orientation_samples = len(analyzer.ot_grating_responses['orientations'])
  corresponding_angles_deg = (180 * np.arange(num_orientation_samples) / num_orientation_samples) - 90
  corresponding_angles_rad = (np.pi * np.arange(num_orientation_samples) / num_orientation_samples) - (np.pi/2)
  analyzer.metrics_list = {"fwhm":[], "circ_var":[], "osi":[], "skipped_indices":[]}
  contrast_idx = -1
  for bf_idx in range(analyzer.bf_stats["num_outputs"]):
    ot_curve = center_curve(analyzer.ot_grating_responses["mean_responses"][bf_idx, contrast_idx, :])
    if np.max(ot_curve) - np.min(ot_curve) == 0:
      analyzer.metrics_list["skipped_indices"].append(bf_idx)
    else:
      fwhm = compute_fwhm(ot_curve, corresponding_angles_deg)
      analyzer.metrics_list["fwhm"].append(fwhm)
      circ_var = compute_circ_var(ot_curve, corresponding_angles_rad)
      analyzer.metrics_list["circ_var"].append(circ_var)
      osi = compute_osi(ot_curve)
      analyzer.metrics_list["osi"].append(osi)
  circ_var_list.append(np.array([val[2] for val in analyzer.metrics_list["circ_var"]]))

In [None]:
color_list = [color_vals["md_green"], color_vals["md_blue"], color_vals["md_red"]]
label_list = ["Linear Autoencoder", "Sparse Autoencoder", "Sparse Coding"]
num_bins = 30
width_ratios = [0.5, 0.25, 0.25]
fig = plt.figure(figsize=figsize, dpi=dpi)
gs0 = gridspec.GridSpec(1, 3, width_ratios=width_ratios)
axes = []

height_ratios = [0.13, 0.25, 0.25, 0.25]
gs_hist = gridspec.GridSpecFromSubplotSpec(4, 1, gs0[0], height_ratios=height_ratios)
axes.append(fig.add_subplot(gs_hist[1:3, 0]))
variance_min = 0.0
variance_max = 1.0
bins = np.linspace(variance_min, variance_max, num_bins)
bar_width = np.diff(bins).min()
hist_list = []
for variances, label, color in zip(circ_var_list, label_list, color_list):
  hist, bin_edges = np.histogram(variances.flatten(), bins)
  hist_list.append(hist)
  bin_left, bin_right = bin_edges[:-1], bin_edges[1:]
  bin_centers = bin_left + (bin_right - bin_left)/2
  axes[-1].plot(bin_centers, hist, linestyle="-", drawstyle="steps-mid", color=color, label=label)
  #axes[-1].bar(bin_centers, hist, width=bar_width, log=False,
  #  align="center", alpha=0.5, label=label)
axes[-1].set_xticks(bin_left, minor=True)
axes[-1].set_xticks(bin_left[::4], minor=False)
axes[-1].xaxis.set_major_formatter(FormatStrFormatter("%0.0f"))
axes[-1].tick_params("both", labelsize=fontsize)
axes[-1].set_xlim([variance_min, variance_max])
axes[-1].set_xticks([variance_min, variance_max])
axes[-1].set_xticklabels(["More\nselective", "Less\nselective"])
ticks = axes[-1].xaxis.get_major_ticks()
ticks[0].label1.set_horizontalalignment("left")
ticks[1].label1.set_horizontalalignment("right")
y_max = np.max([np.max(hist) for hist in hist_list])
axes[-1].set_ylim([0, y_max+1])
axes[-1].set_title("Circular Variance", fontsize=fontsize)
axes[-1].set_ylabel("Count", fontsize=fontsize)
handles, labels = axes[-1].get_legend_handles_labels()
#legend = axes[-1].legend(handles, label_list, fontsize=12,
#  borderaxespad=0., bbox_to_anchor=[0.98, 0.98], loc="upper right")
legend = axes[-1].legend(handles=handles, labels=labels, fontsize=fontsize,
  borderaxespad=0., framealpha=0.0, loc="upper right")
legend.get_frame().set_linewidth(0.0)
for text, color in zip(legend.get_texts(), color_list):
  text.set_color(color)
for item in legend.legendHandles:
  item.set_visible(False)

gs_weights = gridspec.GridSpecFromSubplotSpec(len(analyzer_list), 1, gs0[1], hspace=-0.6)
for gs_idx, analyzer in enumerate(analyzer_list):
  weights = np.stack(analyzer.bf_stats["basis_functions"], axis=0)[analyzer.bf_indices, ...]
  weights = dp.norm_weights(weights)
  vmin = np.min(weights)
  vmax = np.max(weights)
  num_plots = weights.shape[0]
  num_plots_y = int(np.ceil(np.sqrt(num_plots)))
  num_plots_x = int(np.ceil(np.sqrt(num_plots)))
  gs_weights_inner = gridspec.GridSpecFromSubplotSpec(num_plots_y, num_plots_x, gs_weights[gs_idx],
    hspace=-0.85)
  bf_idx = 0
  for plot_id in  np.ndindex((num_plots_y, num_plots_x)):
    if bf_idx < num_plots:
      axes.append(fig.add_subplot(gs_weights_inner[plot_id]))
      axes[-1].imshow(np.squeeze(weights[bf_idx, ...]), vmin=vmin, vmax=vmax, cmap="Greys_r")
      bf_idx += 1
    pf.clear_axis(axes[-1])

gs_tuning = gridspec.GridSpecFromSubplotSpec(len(analyzer_list), 1, gs0[2], hspace=-0.6)
for analyzer_idx, analyzer in enumerate(analyzer_list):
  contrasts = analyzer.ot_grating_responses["contrasts"]
  orientations = analyzer.ot_grating_responses["orientations"]
  activations = analyzer.ot_grating_responses["mean_responses"]
  activations = activations / np.max(activations[analyzer.bf_indices, -1, ...])
  orientations = np.asarray(orientations)*(180/np.pi) #convert to degrees for plotting
  orientations = orientations / np.max(orientations)
  num_plots = len(analyzer.bf_indices)
  num_plots_y = int(np.ceil(np.sqrt(num_plots)))
  num_plots_x = int(np.ceil(np.sqrt(num_plots)))
  gs_tuning_inner = gridspec.GridSpecFromSubplotSpec(num_plots_y, num_plots_x, gs_tuning[analyzer_idx],
      hspace=-0.85)
  bf_idx = 0
  for plot_id in np.ndindex((num_plots_y, num_plots_x)):
    if bf_idx < num_plots:
      if bf_idx == 0:
        axes.append(fig.add_subplot(gs_tuning_inner[plot_id]))
        ax_orig_id = len(axes)-1
      else:
        axes.append(fig.add_subplot(gs_tuning_inner[plot_id], sharey=axes[ax_orig_id], sharex=axes[ax_orig_id]))
      contrast_idx = -1
      activity = activations[analyzer.bf_indices[bf_idx], contrast_idx, :]
      #activity = activity / np.max(activity)
      axes[-1].plot(orientations, activity, linewidth=0.5, color='k')
      axes[-1].scatter(orientations, activity, s=0.1, c='k')
      axes[-1].set_aspect('equal', adjustable='box')
      axes[-1].set_yticks([])
      axes[-1].set_xticks([])
      bf_idx += 1
    (y_id, x_id) = plot_id
    if y_id == 0 and x_id == 0:
      plt.text(x=0.1, y=1.4, s=analyzer.analysis_params.display_name, horizontalalignment='center',
        verticalalignment='center', transform=axes[-1].transAxes, fontsize=fontsize)

#gs_circvar = gridspec.GridSpecFromSubplotSpec(len(analyzer_list), 1, gs0[3])#, hspace=-0.5)
#for analyzer_index, analyzer in enumerate(analyzer_list):
#  cv_data = [val for index, val in enumerate(analyzer.metrics_list["circ_var"]) if index in bf_indices]
#  orientations = (np.pi * np.arange(len(cv_data)) / len(cv_data)) - (np.pi/2) # relative to preferred
#  num_bfs = len(cv_data)
#  num_plots_y = np.int32(np.ceil(np.sqrt(num_bfs)))+1
#  num_plots_x = np.int32(np.ceil(np.sqrt(num_bfs)))
#  gs_circvar_inner = gridspec.GridSpecFromSubplotSpec(num_plots_y, num_plots_x, gs_circvar[analyzer_index],
#    wspace=0.4, hspace=0.4)
#  bf_idx = 0
#  for plot_id in np.ndindex((num_plots_y, num_plots_x)):
#    (y_id, x_id) = plot_id
#    if y_id == 0 and x_id == 0:
#      axes.append(fig.add_subplot(gs_circvar_inner[plot_id]))
#      ax00 = axes[-1]
#    else:
#      axes.append(fig.add_subplot(gs_circvar_inner[plot_id]))
#    if bf_idx < num_bfs:
#      axes[-1].plot(np.real(cv_data[bf_idx][0]), np.imag(cv_data[bf_idx][0]), c='g', linewidth=0.5)
#      #axes[-1].scatter(np.real(cv_data[bf_idx][0]), np.imag(cv_data[bf_idx][0]), c='g', s=4)
#      #axes[-1].quiver(np.real(cv_data[bf_idx][1]), np.imag(cv_data[bf_idx][1]),
#      #          angles='xy', scale_units='xy', scale=1.0, color='b', width=0.01)
#      #axes[-1].quiver(0.5, 0.5, color='b')
#      axes[-1].yaxis.set_major_formatter(FormatStrFormatter('%0.2g'))
#      xaxis_size = max(np.max(np.real(cv_data[bf_idx][0])), 1.0)
#      yaxis_size = max(np.max(np.imag(cv_data[bf_idx][0])), 1.0)
#      axes[-1].set_yticks([])#[-1. * yaxis_size, yaxis_size])
#      axes[-1].set_xticks([])#[-1. * xaxis_size, xaxis_size])
#      # put the circular variance index in the upper left
#      #axes[-1].text(0.02, 0.97, '{:.2f}'.format(cv_data[bf_idx][2]),
#      #        horizontalalignment='left', verticalalignment='top',
#      #        transform=axes[-1].transAxes, color='b', fontsize=10)
#      bf_idx += 1
#    else:
#      pf.clear_axis(axes[-1])

plt.show()

for analyzer in analyzer_list:
  for ext in [".png", ".eps"]:
    save_name = (analyzer.analysis_out_dir+"/vis/circular_variance_combo"
      +"_"+analyzer.analysis_params.save_info+ext)
    fig.savefig(save_name, transparent=False, bbox_inches="tight", pad_inches=0.01, dpi=dpi)

In [None]:
min_rads = []
max_rads = []
for analyzer in analyzer_list:
  analyzer.bf_spatial_freq_rads = [np.sqrt(x**2+y**2) for (y,x) in analyzer.bf_stats["fourier_centers"]]
  min_rads.append(np.min(analyzer.bf_spatial_freq_rads))
  max_rads.append(np.max(analyzer.bf_spatial_freq_rads))
  
num_bins = 10
min_rad = np.min(min_rads)
max_rad = np.max(max_rads)
bins = np.linspace(min_rad, max_rad, num_bins)
fig, ax = plt.subplots(1, figsize=figsize, dpi=dpi)
hist_max = []
for analyzer in analyzer_list:
  hist, bin_edges = np.histogram(analyzer.bf_spatial_freq_rads, bins, density=True)
  bin_left, bin_right = bin_edges[:-1], bin_edges[1:]
  bin_centers = bin_left + (bin_right - bin_left)/2
  
  label = analyzer.model.params.model_type.upper()# + " " + str(analyzer.model.get_num_latent())# + " van Hateren"
  #label = re.sub("_", " ", analyzer.model_name)
  ax.plot(bin_centers, hist, alpha=1.0, linestyle="-", drawstyle="steps-mid", label=label)
  hist_max.append(np.max(hist))
  
ax.set_xticks(bin_left, minor=True)
ax.set_xticks(bin_left[::4], minor=False)
ax.xaxis.set_major_formatter(FormatStrFormatter("%0.0f"))
ax.tick_params("both", labelsize=16)
ax.set_xlim([min_rad, max_rad])
ax.set_xticks([0, int(np.floor(max_rad/4)), int(2*np.floor(max_rad/4)),
  int(3*np.floor(max_rad/4)), max_rad])
ax.set_ylim([0, 0.6])
ax.set_xlabel("Spatial Frequency", fontsize=fontsize)
ax.set_ylabel("Density", fontsize=fontsize)
ax.set_title("Neuron Weight Spatial Frequency Histogram", fontsize=fontsize)
handles, labels = ax.get_legend_handles_labels()
legend = ax.legend(handles, labels, fontsize=fontsize, ncol=3,
  borderaxespad=0., bbox_to_anchor=[0.01, 0.99], fancybox=True, loc="upper left")
for line in legend.get_lines():
  line.set_linewidth(3)
plt.show()    