## Imports

In [None]:
import os
os.chdir("../")
%env CUDA_VISIBLE_DEVICES=0
%matplotlib inline

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from skimage.measure import compare_psnr
import tensorflow as tf
from data.dataset import Dataset
import data.data_selector as ds
import utils.data_processing as dp
import utils.plot_functions as pf
import analysis.analysis_picker as ap

In [None]:
class params(object):
  def __init__(self):
    self.model_type = "lca_subspace"
    self.model_name = "lca_subspace_vh"
    #self.model_name = "lca_512_vh"
    #self.model_name = "lca_768_vh"
    #self.model_name = "lca_1024_vh"
    self.version = "0.0"
    self.save_info = "analysis_train_carlini_targeted"
    self.overwrite_analysis_log = False

# Computed params
analysis_params = params()
analysis_params.project_dir = (os.path.expanduser("~")+"/Work/Projects/")
analysis_params.model_dir = (analysis_params.project_dir+analysis_params.model_name)

In [None]:
analyzer = ap.get_analyzer(analysis_params.model_type)
analyzer.setup(analysis_params)
analyzer.load_analysis(save_info=analysis_params.save_info)

In [None]:
normed_image = (
  (analyzer.full_image - np.min(analyzer.full_image))
  / (np.max(analyzer.full_image) - np.min(analyzer.full_image))).astype(np.float32)

normed_recon = (
  (analyzer.full_recon - np.min(analyzer.full_recon))
  / (np.max(analyzer.full_recon) - np.min(analyzer.full_recon))).astype(np.float32)

fig, ax = plt.subplots(1, 2, figsize=(12,12))
ax[0] = pf.clear_axis(ax[0])
ax[0].imshow(np.squeeze(normed_image), cmap="Greys_r")
ax[0].set_title("Input Image", fontsize=16)
ax[1] = pf.clear_axis(ax[1])
ax[1].imshow(np.squeeze(normed_recon), cmap="Greys_r")
percent_active = "{:.2f}".format(analyzer.recon_frac_act*100)
psnr = "{:.2f}".format(compare_psnr(normed_image, normed_recon, data_range=1))
ax[1].set_title("Reconstruction\n"+percent_active+" percent active"+"\n"+"PSNR = "+psnr, fontsize=16)
plt.show()
fig.savefig(analyzer.analysis_out_dir+"/vis/"+analysis_params.model_name+"_image_recon.png", transparent=True,
  bbox_inches="tight")

In [None]:
keys=["a_fraction_active", "recon_loss", "sparse_loss", "total_loss"]
labels=["activity", "recon loss", "sparse loss", "total loss"]
stats_fig = pf.plot_stats(analyzer.run_stats, keys=keys, labels=labels, start_index=100, figsize=(10,10))
stats_fig.savefig(analyzer.analysis_out_dir+"/vis/"+analysis_params.model_name+"_train_stats.png")

In [None]:
atas_fig = pf.plot_data_tiled(analyzer.atas.T, normalize=False, title="Activity triggered averages on image data")
atas_fig.savefig(analyzer.analysis_out_dir+"/vis/"+analysis_params.model_name+"_img_atas.png")

In [None]:
num_noise_images = analyzer.num_noise_images
if hasattr(analyzer, "noise_activity"):
  noise_activity = analyzer.noise_activity
  noise_atas = analyzer.noise_atas
  noise_atcs = analyzer.noise_atcs
  noise_atas_fig = pf.plot_data_tiled(noise_atas.T, normalize=False, title="Activity triggered averages on standard normal noise data")
  noise_atas_fig.savefig(analyzer.analysis_out_dir+"/vis/"+analysis_params.model_name+"_noise_atas.png")
  neuron_idx=0
  evals, evecs = np.linalg.eigh(noise_atcs[neuron_idx,...]) 
  top_indices = np.argsort(evals)[::-1]
  fig = pf.plot_weights(evecs.T.reshape(256,16,16)[top_indices,:,:])
  fig2 = pf.plot_eigenvalues(evals[::-1], ylim=[np.min(evals), np.max(evals)])

In [None]:
weight_shape = [analyzer.bf_stats["num_outputs"], analyzer.bf_stats["patch_edge_size"], analyzer.bf_stats["patch_edge_size"]]
dict_fig = pf.plot_weights(analyzer.evals["lca/weights/w:0"].T.reshape(weight_shape), title="Weights", figsize=(24,24))
dict_fig.savefig(analyzer.analysis_out_dir+"/vis/"+analysis_params.model_name+"_dict.png", transparent=True,
  bbox_inches="tight")

In [None]:
fig = pf.plot_loc_freq_summary(analyzer.bf_stats, figsize=(12, 4), fontsize=16)
fig.savefig(analyzer.analysis_out_dir+"/vis/fig_location_frequency_centers.png")

In [None]:
def plot_inference_traces(data, activation_threshold, img_idx=None, act_indicator_threshold=None, num_plot_neurons=None):
  """
  Plot of model neurons' inputs over time
  Args:
    data: [dict] with each trace, with keys [b, u, a, ga, images]
      Dictionary is created by analyze_lca.evaluate_inference()
    activation_threshold: [float] value of the sparse multiplier, lambda
    img_idx: [int] which image in data["images"] to run analysis on
    act_indicator_threshold: [float] sets the threshold for when a neuron is marked as "recently active"
      Recently active neurons are those that became active towards the end of the inference process
      Recency is computed as any time step that is greater than num_inference_steps * act_indicator_threshold
      Recently active neurons are indicated by a dotted magenta border
      This input must be between 0.0 and 1.0
    num_plt_neurons: [int] number of neurons to plot. If None, then plot all neurons
  """
  plt.rc('text', usetex=True)
  (num_images, num_time_steps, num_neurons) = data["b"].shape
  if num_plot_neurons is None:
    sqrt_nn = int(np.sqrt(num_neurons))
  else:
    sqrt_nn = int(np.sqrt(num_plot_neurons))
  if img_idx is None:
    img_idx = np.random.choice(num_images)
  global_max_val = float(np.max(np.abs([data["b"][img_idx,...],
    data["u"][img_idx,...], data["ga"][img_idx,...], data["a"][img_idx,...],
    np.ones_like(data["b"][img_idx,...])*activation_threshold])))
  fig, sub_axes = plt.subplots(sqrt_nn+2, sqrt_nn+1, figsize=(20, 20))
  fig.subplots_adjust(hspace=0.20, wspace=0.20)
  lines = []
  for (axis_idx, axis) in enumerate(fig.axes): # one axis per neuron
    if axis_idx < num_neurons:
      t = np.arange(data["b"].shape[1])
      b = data["b"][img_idx, :, axis_idx]
      u = data["u"][img_idx, :, axis_idx]
      ga = data["ga"][img_idx, :, axis_idx]
      a = data["a"][img_idx, :, axis_idx]
      line, = axis.plot(t, b, linewidth=0.25, color="g", label="b")
      lines.append(line)
      line, = axis.plot(t, u, linewidth=0.25, color="b", label="u")
      lines.append(line)
      line, = axis.plot(t, ga, linewidth=0.25, color="r", label="Ga")
      lines.append(line)
      line, = axis.plot(t, [activation_threshold for _ in t], linewidth=0.25, color="k",
        linestyle=":", dashes=(1,1), label=r"$\lambda$")
      lines.append(line)
      line, = axis.plot(t, a, linewidth=0.25, color="darkorange", label="a")
      lines.append(line)
      line, = axis.plot(t, [0 for _ in t], linewidth=0.25, color="k", linestyle="-",
        label="zero")
      lines.append(line)
      if "fb" in data.keys():
        fb = data["fb"][img_idx,:,axis_idx]
        line, = axis.plot(t, fb, linewidth=0.25, color="darkgreen", label="fb")
        lines.append(line)
        
      max_val = np.max(np.abs([b, ga, u, a]))
      scale_ratio = max_val / global_max_val
      transFigure = fig.transFigure.inverted()
      axis_height = axis.get_window_extent().transformed(transFigure).height
      line_length = axis_height * scale_ratio
      x_offset = 0.002
      axis_origin = transFigure.transform(axis.transAxes.transform([0,0]))
      coord1 = [axis_origin[0] - x_offset, axis_origin[1]]
      coord2 = [coord1[0], coord1[1] + line_length]
      line = matplotlib.lines.Line2D((coord1[0], coord2[0]), (coord1[1],
        coord2[1]), transform=fig.transFigure, color="0.3")
      fig.lines.append(line)
      if (a[-1] > 0):
        pf.clear_axis(axis, spines="magenta")
        if act_indicator_threshold is not None:
          assert act_indicator_threshold > 0.0 and act_indicator_threshold < 1.0, (
            "act_indicator_threshold must be between 0.0 and 1.0")
          thresh_index = int(num_time_steps * act_indicator_threshold)
          if np.all([a[idx] == 0 for idx in range(0, thresh_index)]): # neuron has recently become active
             for ax_loc in ["top", "bottom", "left", "right"]:
              axis.spines[ax_loc].set_linestyle((1, (1, 3))) #length, spacing (on, off)
      else:
        pf.clear_axis(axis, spines="black")
        if act_indicator_threshold is not None:
          thresh_index = int(num_time_steps * act_indicator_threshold)
          if np.any([a[idx] > 0 for idx in range(thresh_index, num_time_steps)]): # neuron has recently become inactive
             for ax_loc in ["top", "bottom", "left", "right"]:
              axis.spines[ax_loc].set_linestyle((1, (1, 3))) #length, spacing (on, off)
    else:
      pf.clear_axis(axis)
  num_pixels = np.size(data["images"][img_idx])
  image = data["images"][img_idx,...].reshape(int(np.sqrt(num_pixels)), int(np.sqrt(num_pixels)))
  sub_axes[sqrt_nn+1, 0].imshow(image, cmap="Greys", interpolation="nearest")
  for plot_col in range(sqrt_nn):
    pf.clear_axis(sub_axes[sqrt_nn+1, plot_col])
  fig.suptitle("LCA Activity", y=0.9, fontsize=20)
  handles, labels = sub_axes[0,0].get_legend_handles_labels()
  legend = sub_axes[sqrt_nn+1, 1].legend(handles, labels, fontsize=12, ncol=3,
    borderaxespad=0., bbox_to_anchor=[0, 0], fancybox=True, loc="upper left")
  for line in legend.get_lines():
    line.set_linewidth(3)
  plt.show()
  return fig

In [None]:
act_indicator_threshold = 0.80
inf_trace_fig = plot_inference_traces(analyzer.inference_stats, analyzer.model_schedule[0]["sparse_mult"],
  act_indicator_threshold=act_indicator_threshold)
inf_trace_fig.savefig(analyzer.analysis_out_dir+"/vis/"+analysis_params.model_name+"_inference_traces_dot_thresh-"+str(act_indicator_threshold)+"_"+analysis_params.save_info+".pdf",
                     transparent=True, bbox_inches="tight", pad=0.1)

In [None]:
inf_stats_fig = pf.plot_inference_stats(analyzer.inference_stats, title="Loss During Inference")
inf_stats_fig.savefig(analyzer.analysis_out_dir+"/vis/"+analysis_params.model_name+"_inference_loss_50_"+analysis_params.save_info+".png",
                     transparent=True, bbox_inches="tight", pad=0.1)

In [None]:
ot_fig = pf.plot_contrast_orientation_tuning(analyzer.ot_grating_responses["neuron_indices"],
  analyzer.ot_grating_responses["contrasts"],
  analyzer.ot_grating_responses["orientations"],
  analyzer.ot_grating_responses["mean_responses"])
ot_fig.savefig(analyzer.analysis_out_dir+"/vis/"+analysis_params.model_name+"_orientation_tuning.pdf")

In [None]:
#cross_fig = pf.plot_masked_orientation_tuning(co_bf_indices, co_mask_orientations, co_base_mean_responses, analyzer.co_grating_responses["test_mean_responses"])
cross_fig = pf.plot_masked_orientation_tuning(co_bf_indices, co_mask_orientations, co_base_mean_responses, co_test_mean_responses)
cross_fig.savefig(analyzer.analysis_out_dir+"/vis/"+analysis_params.model_name+"_cross_orientation_tuning.pdf")

In [None]:
cross_contrast_fig = pf.plot_plaid_contrast_tuning(co_bf_indices, co_contrasts, co_contrasts, co_base_orientations,
  co_mask_orientations, co_test_mean_responses)
cross_contrast_fig.savefig(analyzer.analysis_out_dir+"/vis/"+analysis_params.model_name+"_cross_contrast_orientation_tuning.pdf")

In [None]:
#grating = lambda bf_idx,orientation,phase,contrast:dp.generate_grating(
#  *dp.get_grating_params(bf_stats=analyzer.bf_stats, bf_idx=bf_idx, orientation=orientation,
#  phase=phase, contrast=contrast, diameter=-1)).reshape(16,16)
#
#bf_idx = 29
#bf = analyzer.evals["weights/phi:0"].T[co_bf_indices[bf_idx],:].reshape(16,16)
#base_stim = grating(co_bf_indices[bf_idx], co_base_orientations[bf_idx], co_phases[0], 0.5)
#mask_stim = grating(co_bf_indices[bf_idx], orthogonal_orientations[bf_idx], co_phases[5], 0.5)
#test_stim = base_stim + mask_stim
#
#all_min = np.min(np.stack([base_stim, mask_stim, test_stim]))
#all_max = np.max(np.stack([base_stim, mask_stim, test_stim]))
#
#fig, axes = plt.subplots(4)
#axes[0] = pf.clear_axis(axes[0])
#axes[1] = pf.clear_axis(axes[1])
#axes[2] = pf.clear_axis(axes[2])
#axes[3] = pf.clear_axis(axes[3])
#axes[0].imshow(bf, cmap="Greys_r")
#axes[1].imshow(base_stim, cmap="Greys_r", vmin=all_min, vmax=all_max)
#axes[2].imshow(mask_stim, cmap="Greys_r", vmin=all_min, vmax=all_max)
#axes[3].imshow(test_stim, cmap="Greys_r", vmin=all_min, vmax=all_max)
#plt.show()
#fig.savefig("/home/dpaiton/tmp_figs/"+analysis_params.model_name+"_ex_cross_stim.png")

In [None]:
#constructed_bfs = np.zeros_like(analyzer.evals["weights/phi:0"].T)
#for bf_idx in range(constructed_bfs.shape[0]):
#  params = dp.get_grating_params(analyzer.bf_stats, bf_idx)
#  grating = dp.generate_grating(*params)
#  constructed_bfs[bf_idx,...] = grating.reshape(256)
#fig = pf.plot_data_tiled(constructed_bfs)

### compute iso_response_contrast curves

In [None]:
outputs = analyzer.iso_response_contrasts(analyzer.bf_stats, base_contrast=0.5, contrast_resolution=0.01,
  closeness=0.01, num_alt_orientations=4, orientations=np.linspace(0.0, np.pi, 16),
  phases = np.linspace(-np.pi, np.pi, 12), neuron_indices=[52,53,54], diameter=-1,
  scale=analyzer.analysis_params.input_scale)

In [None]:
outputs["iso_response_parameters"][0]