## Imports

In [None]:
import os
os.chdir("../")
%env CUDA_VISIBLE_DEVICES=1
%matplotlib inline

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from skimage.measure import compare_psnr
import tensorflow as tf
from data.dataset import Dataset
import data.data_selector as ds
import utils.data_processing as dp
import utils.plot_functions as pf
import analysis.analysis_picker as ap

In [None]:
class lca_512_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_512_vh"
    self.version = "0.0"
    self.save_info = "analysis_train_carlini_targeted"
    self.overwrite_analysis_log = False

class lca_768_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_768_vh"
    self.version = "0.0"
    self.save_info = "analysis_train_carlini_targeted"
    self.overwrite_analysis_log = False

class lca_1024_params(object):
  def __init__(self):
    self.model_type = "lca"
    self.model_name = "lca_1024_vh"
    self.version = "0.0"
    self.save_info = "analysis_train_carlini_targeted"
    self.overwrite_analysis_log = False

class ica_params(object):
  def __init__(self):
    self.model_type = "ica"
    self.model_name = "ica_vh"
    self.version = "0.0"
    self.save_info = "analysis_train_carlini_targeted"
    self.overwrite_analysis_log = False

class ae_params(object):
  def __init__(self):
    self.model_type = "ae"
    self.model_name = "ae_768_mnist"
    self.version = "0.0"
    self.save_info = "analysis_test_carlini_targeted"
    self.overwrite_analysis_log = False
    
class sae_params(object):
  def __init__(self):
    self.model_type = "sae"
    self.model_name = "sae_768_vh"
    self.version = "0.0"
    self.save_info = "analysis_train_kurakin_targeted"
    self.overwrite_analysis_log = False

params_list = [lca_768_params(), ica_params(), sae_params()]
for params in params_list:
  params.model_dir = (os.path.expanduser("~")+"/Work/Projects/"+params.model_name)

In [None]:
analyzer_list = [ap.get_analyzer(params.model_type) for params in params_list]
for analyzer, params in zip(analyzer_list, params_list):
  analyzer.setup(params)
  analyzer.model.setup(analyzer.model_params)
  analyzer.load_analysis(save_info=params.save_info)
  analyzer.model_name = params.model_name

In [None]:
orientations = np.linspace(0, np.pi, 10)

bf_stats = analyzer_list[0].bf_stats
neuron_idx = 4
phase = 0.0
contrast = 0.5
diameter = -1
grating = lambda neuron_idx,contrast,orientation,phase:dp.generate_grating(
  *dp.get_grating_params(bf_stats, neuron_idx, orientation=orientation,
  phase=phase, contrast=contrast, diameter=diameter))

stims = [grating(neuron_idx, contrast, orientation, phase) for orientation in orientations]

if not os.path.exists(analyzer_list[0].analysis_out_dir+"/vis/orientation_stims/"):
  os.makedirs(analyzer_list[0].analysis_out_dir+"/vis/orientation_stims/")

for idx, stim in enumerate(stims):
  fig, ax = plt.subplots(1)
  ax = pf.clear_axis(ax)
  ax.imshow(stim, cmap="Greys_r")
  fig.savefig(analyzer_list[0].analysis_out_dir+"/vis/orientation_stims/stim_"+str(idx).zfill(3)+".png")
  plt.close()

In [None]:
def plot_contrast_orientation_tuning(bf_indices, contrasts, orientations, activations, figsize=(32,32)):
  """
  Generate contrast orientation tuning curves. Every subplot will have curves for each contrast.
  Inputs:
    bf_indices: [list or array] of neuron indices to use
      all indices should be less than activations.shape[0]
    contrasts: [list or array] of contrasts to use
    orientations: [list or array] of orientations to use
  """
  orientations = np.asarray(orientations)*(180/np.pi) #convert to degrees for plotting
  num_bfs = np.asarray(bf_indices).size
  cmap = plt.get_cmap('Greys')
  cNorm = matplotlib.colors.Normalize(vmin=0.0, vmax=1.0)
  scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)
  fig = plt.figure(figsize=figsize)
  num_plots_y = np.int32(np.ceil(np.sqrt(num_bfs)))+1
  num_plots_x = np.int32(np.ceil(np.sqrt(num_bfs)))
  gs_widths = [1.0,]*num_plots_x
  gs_heights = [1.0,]*num_plots_y
  gs = gridspec.GridSpec(num_plots_y, num_plots_x, wspace=0.5, hspace=0.7,
    width_ratios=gs_widths, height_ratios=gs_heights)
  bf_idx = 0
  for plot_id in np.ndindex((num_plots_y, num_plots_x)):
    (y_id, x_id) = plot_id
    if y_id == 0 and x_id == 0:
      ax = fig.add_subplot(gs[plot_id])
      #ax.set_ylabel("Activation", fontsize=16)
      #ax.set_xlabel("Orientation", fontsize=16)
      ax00 = ax
    else:
      ax = fig.add_subplot(gs[plot_id])#, sharey=ax00)
    if bf_idx < num_bfs:
      for co_idx, contrast in enumerate(contrasts):
        co_idx = -1
        contrast = 1.0#contrasts[co_idx]
        activity = activations[bf_indices[bf_idx], co_idx, :]
        color_val = scalarMap.to_rgba(contrast)
        ax.plot(orientations, activity, linewidth=1, color=color_val)
        ax.scatter(orientations, activity, s=4, c=[color_val])
        ax.yaxis.set_major_formatter(FormatStrFormatter('%0.2g'))
        ax.set_yticks([0, np.max(activity)])
        ax.set_xticks([0, 90, 180])
      bf_idx += 1
    else:
      ax = pf.clear_axis(ax, spines="none")
  plt.show()
  return fig

In [None]:
for analyzer in analyzer_list:
  analyzer.bf_indices = np.random.choice(analyzer.ot_grating_responses["neuron_indices"], 12)
  ot_fig = plot_contrast_orientation_tuning(analyzer.bf_indices,
    analyzer.ot_grating_responses["contrasts"],
    analyzer.ot_grating_responses["orientations"],
    analyzer.ot_grating_responses["mean_responses"], figsize=(8,8))
  ot_fig.savefig(analyzer.analysis_out_dir+"/vis/orientation_tuning_sm.pdf")

In [None]:
bleh=pf.plot_weights(np.stack(analyzer_list[0].bf_stats["basis_functions"], axis=0)[analyzer_list[0].bf_indices, ...],
  figsize=(5,8))

In [None]:
def plot_weights(weights, title="", figsize=None, save_filename=None):
  """
    weights: [np.ndarray] of shape [num_outputs, num_input_y, num_input_x]
    The matrices are renormalized before plotting.
  """
  weights = dp.norm_weights(weights)
  vmin = np.min(weights)
  vmax = np.max(weights)
  num_plots = weights.shape[0]
  num_plots_y = int(np.floor(np.sqrt(num_plots)))
  num_plots_x = int(np.ceil(np.sqrt(num_plots)))
  fig, sub_ax = plt.subplots(num_plots_y, num_plots_x, figsize=figsize)
  filter_total = 0
  for plot_id in  np.ndindex((num_plots_y, num_plots_x)):
    if filter_total < num_plots:
      sub_ax[plot_id].imshow(np.squeeze(weights[filter_total, ...]), vmin=vmin, vmax=vmax, cmap="Greys_r")
      filter_total += 1
    pf.clear_axis(sub_ax[plot_id])
    sub_ax[plot_id].set_aspect("equal")
  fig.suptitle(title, y=0.95, x=0.5, fontsize=20)
  if save_filename is not None:
      fig.savefig(save_filename)
      plt.close(fig)
      return None
  plt.show()
  return fig

In [None]:
for analyzer in analyzer_list:
  bfs = np.stack(analyzer.bf_stats["basis_functions"], axis=0)[analyzer.bf_indices, ...]
  weights_fig = plot_weights(bfs, figsize=(7,5))
  weights_fig.savefig(analyzer.analysis_out_dir+"/vis/orientation_tuning_bfs.png")

In [None]:
def center_curve(tuning_curve):
  """
  Centers a curve about its preferred orientation
  """
  return np.roll(tuning_curve, (len(tuning_curve) // 2) - np.argmax(tuning_curve))

def compute_fwhm(centered_ot_curve, corresponding_angles_deg):
  """
  Calculates the full width at half maximum of the tuning curve

  Result is expressed in degrees to make it a little more intuitive. The curve
  is often NOT symmetric about the maximum value so we don't do any fitting and
  we return the FULL width

  Parameters
  ----------
  centered_ot_curve : ndarray
      A 1d array of floats giving the value of the ot curve, at an orientation
      relative to the *preferred orientation* which is given by the angles in
      corresponding_angles_deg. This has the maximum orientation in the
      center of the array which is nicer for visualization.
  corresponding_angles_deg : ndarray
      The orientations relative to preferred orientation that correspond to
      the values in centered_ot_curve

  Returns
  -------
  half_max_left : float
      The position of the intercept to the left of the max
  half_max_right : float
      The position of the intercept to the right of the max
  half_max_value : float
      Mainly for plotting purposes, the actual curve value that corresponds
      to the left and right points
  """
  max_idx = np.argmax(centered_ot_curve)
  min_idx = np.argmin(centered_ot_curve)
  max_val = centered_ot_curve[max_idx]
  min_val = centered_ot_curve[min_idx]
  midpoint = (max_val / 2) + (min_val / 2)
  # find the left hand point
  idx = max_idx
  while centered_ot_curve[idx] > midpoint:
    idx -= 1
    if idx == -1:
      # the width is *at least* 90 degrees
      half_max_left = -90.
      break
  if idx > -1:
    # we'll linearly interpolate between the two straddling points
    # if (x2, y2) is the coordinate of the point below the half-max and
    # (x1, y1) is the point above the half-max, then we can solve for x3, the
    # x-position of the point that corresponds to the half-max on the line
    # that connects (x1, y1) and (x2, y2)
    half_max_left = (((midpoint - centered_ot_curve[idx])
      * (corresponding_angles_deg[idx+1] - corresponding_angles_deg[idx])
      / (centered_ot_curve[idx+1] - centered_ot_curve[idx]))
      + corresponding_angles_deg[idx])
  # find the right hand point
  idx = max_idx
  while centered_ot_curve[idx] > midpoint:
    idx += 1
    if idx == len(centered_ot_curve):
      # the width is *at least* 90
      half_max_right = 90.
      break
  if idx < len(centered_ot_curve):
    # we'll linearly interpolate between the two straddling points again
    half_max_right = (((midpoint - centered_ot_curve[idx-1])
      * (corresponding_angles_deg[idx] - corresponding_angles_deg[idx-1])
      / (centered_ot_curve[idx] - centered_ot_curve[idx-1]))
      + corresponding_angles_deg[idx-1])
  return half_max_left, half_max_right, midpoint

def compute_circ_var(centered_ot_curve, corresponding_angles_rad):
  """
  From
  DL Ringach, RM Shapley, MJ Hawken (2002) - Orientation Selectivity in Macaque V1:
  Diversity and Laminar Dependence
  
  Computes the circular variance of a tuning curve and returns vals for plotting

  This is a scale-invariant measure of how 'oriented' a curve is in some
  global sense. It wraps reponses around the unit circle and then sums their
  vectors, resulting in an average vector, the magnitude of which indicates
  the strength of the tuning. Circular variance is an index of 'orientedness'
  that falls in the interval [0.0, 1.0], with 0.0 indicating a delta function
  and 1.0 indicating a completely flat tuning curve.

  Parameters
  ----------
  centered_ot_curve : ndarray
      A 1d array of floats giving the value of the ot curve, at an orientation
      relative to the *preferred orientation* which is given by the angles in
      corresponding_angles_rad. This has the maximum orientation in the
      center of the array which is nicer for visualization.
  corresponding_angles_rad : ndarray
      The orientations relative to preferred orientation that correspond to
      the values in centered_ot_curve

  Returns
  -------
  numerator_sum_components : ndarray
      The complex values the are produced from r * np.exp(j*2*theta). These
      are the elements that get summed up in the numerator
  direction_vector : complex64 or complex128
      This is the vector that points in the direction of *aggregate* tuning.
      its magnitude is upper bounded by 1.0 which is the case when only one
      orientation has a nonzero value. We can plot it to get an idea of how
      tuned a curve is
  circular_variance : float
      This is 1 minus the magnitude of the direction vector. It represents and
      index of 'global selectivity'
  """
  # in the original definition, angles are [0, 2*np.pi] so the factor of 2
  # in the exponential wraps the phase twice around the complex circle,
  # placing responses that correspond to angles pi degrees apart
  # onto the same place. We know there's a redudancy in our responses at pi
  # offsets so our responses get wrapped around the unit circle once.
  numerator_sum_components = (centered_ot_curve
    * np.exp(1j * 2 * corresponding_angles_rad))
  direction_vector = (np.sum(numerator_sum_components)
    / np.sum(centered_ot_curve))
  circular_variance = 1 - np.abs(direction_vector)
  return (numerator_sum_components, direction_vector, circular_variance)

def compute_osi(centered_ot_curve):
  """
  Compute the Orientation Selectivity Index.

  This is the most coarse but popular measure of selectivity. It really
  doesn't tell you much. It just measures the maximum response relative to
  the minimum response.

  Parameters
  ----------
  centered_ot_curve : ndarray
      A 1d array of floats giving the value of the ot curve, at an orientation
      relative to the *preferred orientation*

  Returns
  -------
  osi : float
      This is (a_max - a_orth) / (a_max + a_orth) where a_max is the maximum
      response across orientations when orientation responses are
      *averages* over phase. a_orth is the orientation which is orthogonal to
      the orientation which produces a_max.
  """
  max_val = np.max(centered_ot_curve)
  # Assume that orthogonal orientation is at either end of the curve modulo 1
  # bin (if we had like an even number of orientation values)
  orth_val = centered_ot_curve[0]
  osi = (max_val - orth_val) / (max_val + orth_val)
  return osi

In [None]:
from matplotlib.ticker import FormatStrFormatter

def plot_circular_variance(cv_data, max_bfs_per_fig=400, title="", save_filename=None):
  assert np.sqrt(max_bfs_per_fig) % 1 == 0, "Pick a square number for max_bfs_per_fig"
  orientations = (np.pi * np.arange(len(cv_data))
    / len(cv_data)) - (np.pi/2) # relative to preferred
  num_bfs = len(cv_data)
  num_bf_figs = int(np.ceil(num_bfs / max_bfs_per_fig))
  # this determines how many ot curves are aranged in a square grid within
  # any given figure
  if num_bf_figs > 1:
    bfs_per_fig = max_bfs_per_fig
  else:
    squares = [x**2 for x in range(1, int(np.sqrt(max_bfs_per_fig))+1)]
    bfs_per_fig = squares[bisect.bisect_left(squares, num_bfs)]
  plot_sidelength = int(np.sqrt(bfs_per_fig))
  bf_idx = 0
  bf_figs = []
  for in_bf_fig_idx in range(num_bf_figs):
    fig = plt.figure(figsize=(32, 32))
    plt.suptitle(title + ', fig {} of {}'.format(
      in_bf_fig_idx+1, num_bf_figs), fontsize=20)
    subplot_grid = gridspec.GridSpec(plot_sidelength, plot_sidelength,
      wspace=0.4, hspace=0.4)
    fig_bf_idx = bf_idx % bfs_per_fig
    while fig_bf_idx < bfs_per_fig and bf_idx < num_bfs:
      #if bf_idx % 100 == 0:
      #  print("plotted ", bf_idx, " of ", num_bfs, " circular variance plots")
      ## print("sum vector: ", np.real(cv_data[bf_idx][1]), np.imag(cv_data[bf_idx][1]))
      ax = plt.Subplot(fig, subplot_grid[fig_bf_idx])
      ax.plot(np.real(cv_data[bf_idx][0]), np.imag(cv_data[bf_idx][0]),
              c='g', linewidth=0.5)
      ax.scatter(np.real(cv_data[bf_idx][0]), np.imag(cv_data[bf_idx][0]),
                 c='g', s=4)
      ax.quiver(np.real(cv_data[bf_idx][1]), np.imag(cv_data[bf_idx][1]),
                angles='xy', scale_units='xy', scale=1.0, color='b',
                width=0.01)
      # ax.quiver(0.5, 0.5, color='b')
      ax.axvline(x=0.0, color='k', linestyle='--', alpha=0.6, linewidth=0.3)
      ax.axhline(y=0.0, color='k', linestyle='--', alpha=0.6, linewidth=0.3)
      ax.yaxis.set_major_formatter(FormatStrFormatter('%0.2g'))
      xaxis_size = max(np.max(np.real(cv_data[bf_idx][0])), 1.0)
      yaxis_size = max(np.max(np.imag(cv_data[bf_idx][0])), 1.0)
      ax.set_yticks([-1. * yaxis_size, yaxis_size])
      ax.set_xticks([-1. * xaxis_size, xaxis_size])
      # put the circular variance index in the upper left
      ax.text(0.02, 0.97, 'CV: {:.2f}'.format(cv_data[bf_idx][2]),
              horizontalalignment='left', verticalalignment='top',
              transform=ax.transAxes, color='b', fontsize=10)
      fig.add_subplot(ax)
      fig_bf_idx += 1
      bf_idx += 1
    if save_filename is not None:
      filename_split = os.path.split(save_filename)
      save_filename = filename_split[0]+str(in_bf_fig_idx).zfill(2)+"_"+filename_split[1]
      fig.savefig(save_filename)
      plt.close(fig)
      bf_figs.append(None)
    else:
      bf_figs.append(fig)
  if save_filename is None:
    plt.show()
  return bf_figs

In [None]:
for analyzer in analyzer_list:
  contrast_idx = -1
  num_orientation_samples = len(analyzer.ot_grating_responses['orientations'])
  corresponding_angles_deg = (180 * np.arange(num_orientation_samples) / num_orientation_samples) - 90
  corresponding_angles_rad = (np.pi * np.arange(num_orientation_samples) / num_orientation_samples) - (np.pi/2)
  
  analyzer.metrics_list = {"fwhm":[], "circ_var":[], "osi":[], "skipped_indices":[]}
  for bf_idx in range(analyzer.bf_stats["num_outputs"]):
    ot_curve = center_curve(analyzer.ot_grating_responses["mean_responses"][bf_idx, contrast_idx, :])
    if np.max(ot_curve) - np.min(ot_curve) == 0:
      analyzer.metrics_list["skipped_indices"].append(bf_idx)
    else:
      fwhm = compute_fwhm(ot_curve, corresponding_angles_deg)
      analyzer.metrics_list["fwhm"].append(fwhm)
      circ_var = compute_circ_var(ot_curve, corresponding_angles_rad)
      analyzer.metrics_list["circ_var"].append(circ_var)
      osi = compute_osi(ot_curve)
      analyzer.metrics_list["osi"].append(osi)

In [None]:
for analyzer in analyzer_list:
  circ_var_figs = plot_circular_variance(analyzer.metrics_list["circ_var"],
    max_bfs_per_fig=144, title="Circular Variance")
  for fig_idx, circ_fig in enumerate(circ_var_figs):
    circ_fig.savefig(analyzer.analysis_out_dir+"/vis/circular_variance_"+str(fig_idx)+".png")

In [None]:
def plot_circular_variance_histogram(variances_list, label_list, num_bins=50, y_max=None,
  figsize=None, save_filename=None):
  variance_min = np.min([np.min(var) for var in variances_list])#0.0
  variance_max = np.max([np.max(var) for var in variances_list])#1.0
  bins = np.linspace(variance_min, variance_max, num_bins)
  bar_width = np.diff(bins).min()
  fig, ax = plt.subplots(1, figsize=figsize)
  hist_list = []
  handles = []
  for variances, label in zip(variances_list, label_list):
    hist, bin_edges = np.histogram(variances.flatten(), bins)
    #hist = hist / np.max(hist)
    hist_list.append(hist)
    bin_left, bin_right = bin_edges[:-1], bin_edges[1:]
    bin_centers = bin_left + (bin_right - bin_left)/2
    handles.append(ax.bar(bin_centers, hist, width=bar_width, log=True, align="center", alpha=0.5, label=label))
  ax.set_xticks(bin_left, minor=True)
  ax.set_xticks(bin_left[::4], minor=False)
  ax.xaxis.set_major_formatter(FormatStrFormatter("%0.0f"))
  ax.tick_params("both", labelsize=16)
  ax.set_xlim([variance_min, variance_max])
  ax.set_xticks([variance_min, variance_max])
  ax.set_xticklabels(["More selective", "Less selective"])
  ticks = ax.xaxis.get_major_ticks()
  ticks[0].label1.set_horizontalalignment("left")
  ticks[1].label1.set_horizontalalignment("right")
  if y_max is None:
    # Round up to the nearest power of 10
    y_max = 10**(np.ceil(np.log10(np.max([np.max(hist) for hist in hist_list]))))
  ax.set_ylim([1, y_max])
  ax.set_title("Circular Variance Histogram", fontsize=18)
  ax.set_xlabel("Selectivity", fontsize=18)
  ax.set_ylabel("Log Count", fontsize=18)
  legend = ax.legend(handles, label_list, fontsize=12, #ncol=len(label_list),
    borderaxespad=0., bbox_to_anchor=[0.98, 0.98], fancybox=True, loc="upper right")
  if save_filename is not None:
    fig.savefig(save_filename)
    plt.close(fig)
    return None
  plt.show()
  return fig

In [None]:
circ_var_list = []
label_list = ["LCA", "ICA", "SAE"]
for analyzer in analyzer_list:
  circ_var_list.append(np.array([val[2] for val in analyzer.metrics_list["circ_var"]]))
circ_hist_fig = plot_circular_variance_histogram(circ_var_list, label_list, figsize=(8,8))
for analyzer in analyzer_list:
    circ_hist_fig.savefig(analyzer.analysis_out_dir+"/vis/circular_variance_histogram.png")

In [None]:
#cross_fig = pf.plot_masked_orientation_tuning(co_bf_indices, co_mask_orientations, co_base_mean_responses, analyzer.co_grating_responses["test_mean_responses"])
cross_fig = pf.plot_masked_orientation_tuning(co_bf_indices, co_mask_orientations, co_base_mean_responses, co_test_mean_responses)
cross_fig.savefig(analyzer.analysis_out_dir+"/vis/"+analysis_params.model_name+"_cross_orientation_tuning.pdf")

In [None]:
cross_contrast_fig = pf.plot_plaid_contrast_tuning(co_bf_indices, co_contrasts, co_contrasts, co_base_orientations,
  co_mask_orientations, co_test_mean_responses)
cross_contrast_fig.savefig(analyzer.analysis_out_dir+"/vis/"+analysis_params.model_name+"_cross_contrast_orientation_tuning.pdf")