In [1]:
import os
import sys

import pickle
import numpy as np
import scipy
from scipy import optimize
from skimage import measure
import numpy.polynomial.polynomial as poly

import matplotlib as mpl
import matplotlib.pyplot as plt
import proplot as plot

current_path = !pwd
parent_path = os.path.dirname(current_path[0])
if parent_path not in sys.path: sys.path.append(parent_path)
root_path = "/".join(parent_path.split('/')[:-1])
if root_path not in sys.path: sys.path.append(root_path)
santi_path = root_path+"/santi_iso_response"
santi_etc_path = os.path.join(santi_path, 'etc')

import utils.model_handling as model_funcs
import utils.dataset_generation as iso_data
import utils.histogram_analysis as hist_funcs
import utils.plotting as plot_funcs

from santi_iso_response.iso_response import utils as santi_utils

ModuleNotFoundError: No module named 'santi_iso_response'

In [None]:
dpi = 300
file_extensions = ['.pdf']#, '.eps', '.png']

def get_rc_args(fontsize):
    font_settings = {
            "text.usetex": True,
            "font.family": 'serif',
            "font.serif": 'Computer Modern Roman',
            "axes.labelsize": fontsize,
            "axes.titlesize": fontsize,
            "figure.titlesize": fontsize,
            "font.size": fontsize,
            "legend.fontsize": fontsize,
            "xtick.labelsize": fontsize,
            "ytick.labelsize": fontsize,
    }
    mpl.rcParams.update(font_settings)
    rc_kwargs = {
        'fontsize':mpl.rcParams['font.size'],
        'fontfamily':mpl.rcParams['font.family'],
        'legend.fontsize':mpl.rcParams['font.size'],
        'text.labelsize':mpl.rcParams['font.size']
    }
    return rc_kwargs
fontsize = 10
rc_kwargs = get_rc_args(fontsize)

In [None]:
results_path = root_path+'/santi_analysis/'
save_prefix = 'santi'
params_file = results_path+save_prefix+'_params.npz'
mei_curvature_file = results_path+save_prefix+'_meis.npz'
mes_curvature_file = results_path+save_prefix+'_stim.npz'
rand_curvature_file = results_path+save_prefix+'_rand.npz'

In [None]:
analysis_params = np.load(params_file, allow_pickle=True)['data'].item()
label_list = ["Neuron "+str(neuron_id) for neuron_id in analysis_params["target_neuron_ids"]]
print('\n'.join([f'{key}\t\t{val}' for key, val in analysis_params.items()]))
analysis_params['iso_window_bounds'] = ((-1, 1), (-1, 1))

In [None]:
mei_curvatures = np.load(mei_curvature_file, allow_pickle=True)['data'].item()
print('\n'.join([key for key in mei_curvatures.keys()]))

In [None]:
curvature_type = 'iso_curvatures'
#curvature_type = 'attn_curvatures'

curvatures = mei_curvatures[curvature_type]

In [None]:
num_bins = 50
num_neurons = len(curvatures)

curvature_means = [np.mean(curvature) for curvature in curvatures]
curvature_medians = [np.median(curvature) for curvature in curvatures]
median_argsort = np.argsort(curvature_medians)

if(np.sqrt(num_neurons) == np.round(np.sqrt(num_neurons), 0)):
    num_neurons_edge = int(np.sqrt(num_neurons))
    nrows = num_neurons_edge
    ncols = num_neurons_edge
else:
    num_neurons_edge = int(np.sqrt(num_neurons))
    nrows = num_neurons // num_neurons_edge + 1
    ncols = num_neurons_edge

all_curvatures = [val for vals in curvatures for val in  vals]
lower_percentile = 1
upper_percentile = 98
lower_curvature = np.percentile(all_curvatures, lower_percentile)
upper_curvature = np.percentile(all_curvatures, upper_percentile)
max_curvature = np.max([np.abs(lower_curvature), np.abs(upper_curvature)])
print('max curvature = ', np.max(all_curvatures))
print('min curvature = ',np.min(all_curvatures))
print(f'{lower_percentile} percentile = {lower_curvature}')
print(f'{upper_percentile} percentile = {upper_curvature}')
plot_min = -max_curvature
plot_max = max_curvature

bins = hist_funcs.get_bins(num_bins, plot_min, plot_max)

axwidth = 0.5
fig, axes = plot.subplots(
    nrows=nrows,
    ncols=ncols,
    sharex=True,
    sharey=True,
    axwidth=axwidth)

simple_cell_ids = [1, 32, 34, 47, 53]
complex_cell_ids = [2, 4, 19, 24, 29, 33, 38]
plot_idx = 0
for row_idx in range(nrows):
    for col_idx in range(ncols):
        if plot_idx < num_neurons:
            unsort_idx = median_argsort[plot_idx]
            unit_id = analysis_params['target_neuron_ids'][unsort_idx]
            if unit_id in simple_cell_ids:
                ax = plot_funcs.clear_axis(axes[row_idx, col_idx], spines='r')
            elif unit_id in complex_cell_ids:
                ax = plot_funcs.clear_axis(axes[row_idx, col_idx], spines='b')
            else:
                ax = plot_funcs.clear_axis(axes[row_idx, col_idx], spines='none')
            hist, bin_edges = hist_funcs.get_relative_hist(curvatures[unsort_idx], bins)
            bin_lefts, bin_rights = bins[:-1], bins[1:]
            bin_centers = bin_lefts + (bin_rights - bin_lefts)
            ax.plot(bin_centers, hist, linestyle='-', drawstyle='steps-mid',
                color='k', linewidth=0.5, label=label_list[unsort_idx], zorder=1)
            ax.fill_between(bin_centers, 0, hist, color='k', zorder=2)
            ax.plot([0, 0], [0, np.max(hist)], color='r', linestyle='--', linewidth=0.5, zorder=3)
            ax.format(title=f'U:{unit_id} C:{np.round(curvature_medians[unsort_idx], 2):0.2f}')
            if plot_idx >= 1:
                prev_unsort_idx = median_argsort[plot_idx-1]
                prev_curv_med = np.round(curvature_medians[prev_unsort_idx], 2)
                curr_curv_med = np.round(curvature_medians[unsort_idx], 2)
                x_offset = -0.08
                y_offset = -0.02
                if prev_curv_med < 0.00 and curr_curv_med >= 0.00:
                    pos = ax.get_position().get_points() # get the axis position in the form [[x0, y0], [x1, y1]]
                    xpos = [pos[1][0]+x_offset,]*2
                    ypos = [pos[0][1]+y_offset, pos[1][1]+y_offset]
                    ax.plot(xpos, ypos, color='g', lw=2, transform=fig.transFigure, clip_on=False)
                if prev_curv_med <= 0.0 and curr_curv_med > 0.0:
                    pos = ax.get_position().get_points() # get the axis position in the form [[x0, y0], [x1, y1]]
                    xpos = [pos[1][0]+x_offset,]*2
                    ypos = [pos[0][1]+y_offset, pos[1][1]+y_offset]
                    ax.plot(xpos, ypos, color='g', lw=2, transform=fig.transFigure, clip_on=False)
        else:
            ax = plot_funcs.clear_axis(axes[row_idx, col_idx])
        plot_idx += 1
axes.format(
    xlim=[plot_min, plot_max],
    suptitle='Neuron iso-response curvature histograms, sorted by median value'
    #suptitle='Neuron response attenuation curvature histograms, sorted by median value'
)
plot.show()

#for extension in file_extensions:
#    save_name = analysis_params['output_directory']+'iso_curvature_histograms'+extension
#    #save_name = analysis_params['output_directory']+'attn_curvature_histograms'+extension
#    fig.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0.05, dpi=dpi)

In [None]:
#Some simple cells: 1, 32, 34, 47, 53. Some complex cells: 2, 4, 19, 24, 29, 33, 38
simple_cell_ids = [1, 32, 34, 47, 53]
complex_cell_ids = [2, 4, 19, 24, 29, 33, 38]
simple_medians = []
complex_medians = []
for list_index, neuron_id in enumerate(analysis_params['target_neuron_ids']):
    if neuron_id in simple_cell_ids:
        simple_medians.append(np.round(curvature_medians[list_index], 3))
    if neuron_id in complex_cell_ids:
        complex_medians.append(np.round(curvature_medians[list_index], 3))
print(f'simple median curvatures = {simple_medians}\ncomplex median curvatures = {complex_medians}')

In [None]:
def argmed(array):
    array = np.array(array)
    med_idx = np.abs(array - np.median(array)).argmin()
    return med_idx

def add_arrow(ax, vect, xrange, yx_offset=[1,1], linestyle='-', label='', text_color='k'):
    arrow_width = 0.0
    arrow_linewidth = 1
    arrow_headsize = 0.15
    arrow_head_length = 0.15
    arrow_head_width = 0.15
    target_vector_x = vect[0].item()
    target_vector_y = vect[1].item()
    ax.arrow(0, 0, target_vector_x, target_vector_y,
        width=arrow_width, head_width=arrow_head_width, head_length=arrow_head_length,
        fc='k', ec='k', linestyle=linestyle, linewidth=arrow_linewidth)
    tenth_range_shift = xrange/10 # For shifting labels
    text_handle = ax.text(
        target_vector_x+(tenth_range_shift*yx_offset[1]),
        target_vector_y+(tenth_range_shift*yx_offset[0]),
        label,
        weight='bold',
        color=text_color,
        horizontalalignment='center',
        verticalalignment='center'
    )

def plot_fits(ax, activity, contours, fits, yx_pts, yx_range, proj_vects=None, num_levels=10, title=''):
    vmin = np.min(activity)
    vmax = np.max(activity)
    cmap = plt.get_cmap('cividis')
    cNorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    scalarMap = mpl.cm.ScalarMappable(norm=cNorm, cmap=cmap)
    # Plot contours
    x_mesh, y_mesh = np.meshgrid(*yx_pts[::-1])
    levels = np.linspace(vmin, vmax, num_levels)
    contsf = ax.contourf(x_mesh, y_mesh, activity,
        levels=levels, vmin=vmin, vmax=vmax, alpha=1.0, antialiased=True, cmap=cmap)
    if proj_vects is not None:
        # Add arrows
        proj_target = proj_vects[0]
        xrange = max(yx_range[1]) - min(yx_range[1])
        add_arrow(ax, proj_target, xrange, linestyle='-')
        proj_comparison = proj_vects[1]
        add_arrow(ax, proj_comparison, xrange, linestyle='--')
        proj_orth = proj_vects[2]
        add_arrow(ax, proj_orth, xrange, linestyle='-')
    # Add axis grid
    ax.set_aspect('equal')
    ax.plot(yx_range[1], [0,0], color='k', linewidth=0.5)
    ax.plot([0,0], yx_range[0], color='k', linewidth=0.5)
    ax.scatter(contours[0], contours[1], s=4, color='r')
    ax.scatter(fits[0], fits[1], s=3, marker='*', color='k')
    ax.format(
        ylim=[np.min(yx_pts[0]), np.max(yx_pts[0])],
        xlim=[np.min(yx_pts[1]), np.max(yx_pts[1])],
        title=title
    )
    return contsf

In [None]:
activations = mei_curvatures['response_images'].copy()
num_target, num_planes, num_y, num_x = activations.shape 
y_pts, x_pts = (
    mei_curvatures['contour_dataset']['y_pts'].copy(),
    mei_curvatures['contour_dataset']['x_pts'].copy()
)
y_range = max(y_pts) - min(y_pts)
x_range = max(x_pts) - min(x_pts)
bounds = analysis_params['iso_window_bounds']
y_bounds, x_bounds = bounds
y_bound_range = max(y_bounds) - min(y_bounds)
x_bound_range = max(x_bounds) - min(x_bounds)
y_trim = 0.5 * y_bound_range / y_range
x_trim = 0.5 * x_bound_range / x_range
start_y = int(np.floor(y_trim*num_y))
end_y = int(np.ceil(3*y_trim*num_y))
start_x = int(np.floor(x_trim*num_x))
end_x = int(np.ceil(3*x_trim*num_x))
new_num_y = end_y - start_y
new_num_x = end_x - start_x
y_pts_trim = y_pts[start_y:end_y]
x_pts_trim = x_pts[start_x:end_x]
new_y_range = max(y_pts_trim) - min(y_pts_trim)
new_x_range = max(x_pts_trim) - min(x_pts_trim)
y_scale_factor =  new_y_range / new_num_y
x_scale_factor =  new_x_range / new_num_x

iso_curvatures, iso_fits, iso_contours = hist_funcs.iso_response_curvature_poly_fits(
    activations[:, :, start_y:end_y, start_x:end_x],
    analysis_params['target_activity'],
    [y_scale_factor, x_scale_factor]
)

In [None]:
num_targets = activations.shape[0]
num_neurons_per_batch = 20
num_batches = num_targets//num_neurons_per_batch
start_id = 0
for batch_id in range(num_batches):
    with plot.rc.context(**rc_kwargs):
        fig, axes = plot.subplots(
            nrows=num_neurons_per_batch,
            ncols=3,
            sharex=True,
            sharey=True,
            wspace=0.05,
            hspace=0.05,
            axwidth=0.8
        )

    for ax_id, target_neuron_id in enumerate(range(start_id, start_id+num_neurons_per_batch)):
        min_idx = np.argmin(curvatures[target_neuron_id])
        med_idx = argmed(curvatures[target_neuron_id])
        max_idx = np.argmax(curvatures[target_neuron_id])
        index_list = [min_idx, med_idx, max_idx]
        curvatures_list = [
            curvatures[target_neuron_id][min_idx],
            curvatures[target_neuron_id][med_idx],
            curvatures[target_neuron_id][max_idx]
        ] 
        activations_list = []
        for comp_neuron_id in index_list:
            activations_list.append(
                np.squeeze(mei_curvatures['response_images'][target_neuron_id, comp_neuron_id, ...]).copy()
            )

        for ax, activity, comp_neuron_id in zip(axes[ax_id, :], activations_list, index_list):
            ax = plot_funcs.clear_axis(ax)
            yx_pts = (
                mei_curvatures['contour_dataset']['y_pts'],
                mei_curvatures['contour_dataset']['x_pts']
            )
            yx_range = (
                analysis_params['y_range'],
                analysis_params['x_range']
            )
            proj_vects = (
                mei_curvatures['contour_dataset']['proj_target_vect'][target_neuron_id][comp_neuron_id],
                mei_curvatures['contour_dataset']['proj_comparison_vect'][target_neuron_id][comp_neuron_id],
                mei_curvatures['contour_dataset']['proj_orth_vect'][target_neuron_id][comp_neuron_id],
            )
            contsf = plot_fits(
                ax,
                activity,
                iso_contours[target_neuron_id][comp_neuron_id],
                iso_fits[target_neuron_id][comp_neuron_id],
                yx_pts,
                yx_range,
                proj_vects,
                num_levels=10,
                title=''#f'C={np.round(curvatures[target_neuron_id][comp_neuron_id], 2)}'
            )
        unit_label = analysis_params['target_neuron_ids'][target_neuron_id]
        text_handle = axes[ax_id, 0].text(
            -2.5,
            0,
            f'unit {unit_label}',
            #weight='bold',
            color='k',
            rotation=90,
            horizontalalignment='center',
            verticalalignment='center'
        )
    cb = axes[0, -1].colorbar(contsf, loc='r', ticks=0.25, label='')#'activation')
    for extension in file_extensions:
        save_name = analysis_params['output_directory']+f'iso_curvature_planes_all_b{batch_id}'+extension
        fig.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0.05, dpi=dpi)
    start_id += num_neurons_per_batch

In [None]:
_, _, plot_iso_contours = hist_funcs.iso_response_curvature_poly_fits(
    activations[:, :, start_y:end_y, start_x:end_x],
    0.9,
    [y_scale_factor, x_scale_factor]
)

iso_contour_data = []
iso_scatter_points = []
att_contour_data = []
att_x_indices = []
for target_neuron_id in range(num_targets):
    min_idx = np.argmin(curvatures[target_neuron_id])
    med_idx = argmed(curvatures[target_neuron_id])
    max_idx = np.argmax(curvatures[target_neuron_id])
    index_list = [min_idx, med_idx, max_idx]
    sub_iso_data_list = []
    sub_iso_scatter_list = []
    sub_att_data_list = []
    sub_att_x_index_list = []
    for comp_neuron_id in index_list:
        activity = mei_curvatures['response_images'][target_neuron_id][comp_neuron_id].copy()
        num_y, num_x = activity.shape 
        #  iso-response images
        num_iso_contour_pts = len(plot_iso_contours[target_neuron_id][comp_neuron_id][0])
        scatter_points = np.zeros((num_iso_contour_pts, 2), dtype=np.int32)
        x_pts, y_pts = plot_iso_contours[target_neuron_id][comp_neuron_id] 
        for point_idx, (x_contour_loc, y_contour_loc) in enumerate(zip(x_pts, y_pts)):
            y_loc = int(np.abs(mei_curvatures['contour_dataset']['y_pts'] - y_contour_loc).argmin())
            x_loc = int(np.abs(mei_curvatures['contour_dataset']['x_pts'] - x_contour_loc).argmin())
            scatter_points[point_idx, 0] = y_loc
            scatter_points[point_idx, 1] = x_loc
        dataset = iso_data.inject_data(
            mei_curvatures['contour_dataset']['proj_matrix'][target_neuron_id][comp_neuron_id],
            mei_curvatures['contour_dataset']['proj_datapoints'],
            analysis_params['image_scale'])
        num_datapoints, num_y_pixels, num_x_pixels, num_channels = dataset.shape
        dataset = np.squeeze(dataset.reshape(num_y, num_x, num_y_pixels, num_x_pixels, 1))
        indiv_iso_contour_data = np.zeros((num_iso_contour_pts, num_y_pixels, num_x_pixels))
        for contour_data_index in range(num_iso_contour_pts):
            y_loc = scatter_points[contour_data_index, 0]
            x_loc = scatter_points[contour_data_index, 1]
            indiv_iso_contour_data[contour_data_index, ...] = dataset[y_loc, x_loc, ...]
        sub_iso_data_list.append(indiv_iso_contour_data)
        sub_iso_scatter_list.append(scatter_points)
        #  response attenuation images
        x_target_index = 3*num_x//4+1 # Approx MEI location, assuming window is -2, 2 and mei is at 1
        indiv_att_contour_data = dataset[:, x_target_index, :, :]
        sub_att_data_list.append(indiv_att_contour_data)
        sub_att_x_index_list.append(x_target_index)
    iso_contour_data.append(sub_iso_data_list)
    iso_scatter_points.append(sub_iso_scatter_list)
    att_contour_data.append(sub_att_data_list)
    att_x_indices.append(sub_att_x_index_list)

target_neuron_id = 0
num_images = 11
fig, axes = plot.subplots(ncols=num_images, nrows=3, axwidth=0.5)
for row_idx in range(3): # min, med, max
    num_contour_pts = iso_contour_data[target_neuron_id][row_idx].shape[0]
    image_indices = np.round(np.linspace(0, num_contour_pts - 1, num_images)).astype(int)
    for col_idx, img_idx in enumerate(image_indices):
        ax = plot_funcs.clear_axis(axes[row_idx, col_idx])
        ax.imshow(iso_contour_data[target_neuron_id][row_idx][img_idx, ...], cmap='greys_r', vmin=-3, vmax=3)
plot.show()

target_neuron_id = 0
num_images = 11
fig, axes = plot.subplots(ncols=num_images, nrows=1, axwidth=0.5)
num_contour_pts = att_contour_data[target_neuron_id][row_idx].shape[0]
image_indices = np.round(np.linspace(0, num_contour_pts - 1, num_images)).astype(int)
for col_idx, img_idx in enumerate(image_indices):
    ax = plot_funcs.clear_axis(axes[col_idx])
    ax.imshow(att_contour_data[target_neuron_id][2][img_idx, ...], cmap='greys_r', vmin=-3, vmax=3)
plot.show()

In [None]:
target_neuron_id = 0; comp_neuron_id = 0
scatter_points = iso_scatter_points[target_neuron_id][comp_neuron_id]
x_target = att_x_indices[target_neuron_id][comp_neuron_id]
fig, ax = plot.subplots()
ax.imshow(mei_curvatures['response_images'][target_neuron_id, comp_neuron_id, ...])
ax.scatter(scatter_points[:,1], scatter_points[:,0], s=5, c='b')
ax.scatter([x_target,]*29, np.linspace(0, 29, 29), s=5, c='g')
ax.format(xlim=[0, 29], ylim=[0, 29], gridminor=True)
plot.show()

In [None]:
tuning_predictions = np.load(santi_etc_path+'/tuning_predictions.npy')
num_shifts, num_orients = 16, 16
num_neurons = tuning_predictions.shape[1]
max_predictions = []
for unit in range(num_neurons):
    unit_predictions = tuning_predictions[:, unit]
    divider = num_orients * num_shifts
    max_freq = np.argmax(unit_predictions) // divider #best frequency
    freq_start = max_freq * divider
    freq_end = (max_freq + 1) * divider
    max_pred = unit_predictions[freq_start:freq_end]
    max_pred = max_pred.reshape([num_orients, num_shifts])
    max_predictions.append(max_pred)
orients = np.linspace(0, np.pi, num_orients + 1)[:-1]/np.pi*180
shifts  = np.linspace(0, 2 * np.pi, num_shifts + 1)[:-1]/np.pi*180

In [None]:
unit = 28
fig, axes = plot.subplots(nrows=1, ncols=2, sharex=False)

axes[0].plot(orients, np.max(max_predictions[unit], axis=0))
axes[0].set_ylim([0, np.max(max_predictions[unit])+1])
axes[0].set_xticks([0,90,180])
axes[0].format(
    title = f'Orientation selectivity for unit {unit}',
    xlabel='Orientation (deg)'
)

axes[1].plot(shifts, np.max(max_predictions[unit], axis=1))
axes[1].set_ylim([0, np.max(max_predictions[unit])+1])
axes[1].set_xticks([0, 180, 360])
axes[1].format(
    title = f'Phase selectivity for unit {unit}',
    xlabel='Phase (deg)'
)

axes.format(ylabel='unit activation')

plot.show()

In [None]:
num_vis_neurons = 9

num_target_neurons = len(analysis_params['target_neuron_ids'])

#neuron_indices = median_argsort[np.round(np.linspace(0, median_argsort.size - 1, num_vis_neurons)).astype(int)]

#rand_indices = np.random.choice(np.arange(num_target_neurons), num_vis_neurons, replace=False).astype(int)
#neuron_indices = median_argsort[rand_indices]

#target_units = [86, 136, 4, 138, 34, 100, 33, 32, 89]
#target_units = [0, 4, 16, 25, 51, 59, 91, 102, 153, 103]
#target_units = [1, 28, 134, 74, 107, 115, 125, 87]
target_units = [1, 107, 102, 115, 125, 134, 153, 24]
num_vis_neurons = len(target_units)
target_neuron_ids = np.array(analysis_params['target_neuron_ids'])
unsorted_neuron_indices = [np.argwhere(target_neuron_ids == unit).item() for unit in target_units]
neuron_indices = []
unit_indices = []
for index in median_argsort:
    if index in unsorted_neuron_indices:
        neuron_indices.append(index)
        unit_indices.append(target_neuron_ids[index])

num_examples_per_trajectory = 3
num_columns = 2*num_examples_per_trajectory + 6
yx_pts = (
    mei_curvatures['contour_dataset']['y_pts'],
    mei_curvatures['contour_dataset']['x_pts']
)
yx_range = (
    analysis_params['y_range'],
    analysis_params['x_range']
)
all_curvatures = []
for index in neuron_indices:
    for curvature in curvatures[index]:
        all_curvatures.append(curvature)
lower_percentile = 1
upper_percentile = 98
lower_curvature = np.percentile(all_curvatures, lower_percentile)
upper_curvature = np.percentile(all_curvatures, upper_percentile)
max_curvature = np.max([np.abs(lower_curvature), np.abs(upper_curvature)])
plot_min = -max_curvature
plot_max = max_curvature
bins = hist_funcs.get_bins(num_bins, plot_min, plot_max)

sm_gap = 0.16
bg_gap = 3*sm_gap#0.3
wspaces = [sm_gap,]*(num_examples_per_trajectory-1)+[bg_gap]+[sm_gap,]*(num_examples_per_trajectory-1)+[bg_gap]+[sm_gap,]*3+[bg_gap, sm_gap]
fontsize = 16
rc_kwargs = get_rc_args(fontsize)
with plot.rc.context(**rc_kwargs):
    fig, axes = plot.subplots(
        nrows=num_vis_neurons,
        ncols=num_columns,
        wspace=wspaces,
        hspace=sm_gap,
        axwidth=0.9,
        sharex=False,
        sharey=False
    )

for row_idx in range(num_vis_neurons):
    target_neuron_id = neuron_indices[row_idx]
    unit_id = unit_indices[row_idx]
    col_idx = 0
    #axes[row_idx, col_idx].format(title=f'target neuron {target_neuron_id} unit {unit_id}')
    # min curvature iso-trajectory images 
    min_contour_images = iso_contour_data[target_neuron_id][0] # 0 index is min
    num_contour_pts = min_contour_images.shape[0]
    min_image_indices = np.round(np.linspace(0, num_contour_pts - 1, num_examples_per_trajectory)).astype(int)
    for min_image_index in min_image_indices:
        ax = plot_funcs.clear_axis(axes[row_idx, col_idx])
        plot_image = min_contour_images[min_image_index, ...]
        #plot_image = plot_image - plot_image.min()
        #plot_image = plot_image / plot_image.max()
        ax.imshow(plot_image, cmap='greys_r', vmin=-3, vmax=3)
        num_y, num_x = plot_image.shape
        center_y = num_y//2-1
        center_x = num_x//2-1
        ax.plot([0, num_x], [center_y, center_y], color='r', linestyle='--', linewidth=0.3)
        ax.plot([center_x, center_x], [0, num_y], color='r', linestyle='--', linewidth=0.3)
        col_idx += 1
    
    # max curvature att-trajectory images
    max_contour_images = att_contour_data[target_neuron_id][2] # 2 index is max
    num_contour_pts = max_contour_images.shape[0]
    max_image_indices = np.round(np.linspace(num_contour_pts//2, num_contour_pts - 1, num_examples_per_trajectory)).astype(int)
    for max_image_index in max_image_indices:
        ax = plot_funcs.clear_axis(axes[row_idx, col_idx])
        plot_image = max_contour_images[max_image_index, ...]
        #plot_image = plot_image - plot_image.min()
        #plot_image = plot_image / plot_image.max()
        ax.imshow(plot_image, cmap='greys_r', vmin=-3, vmax=3)
        ax.plot([0, num_x], [center_y, center_y], color='r', linestyle='--', linewidth=0.3)
        ax.plot([center_x, center_x], [0, num_y], color='r', linestyle='--', linewidth=0.3)
        col_idx += 1
        
    # curvatures
    min_idx = np.argmin(curvatures[target_neuron_id])
    med_idx = argmed(curvatures[target_neuron_id])
    max_idx = np.argmax(curvatures[target_neuron_id])
    index_list = [min_idx, med_idx, max_idx]
    for comp_neuron_id in index_list:
        ax = plot_funcs.clear_axis(axes[row_idx, col_idx])
        proj_vects = (
            mei_curvatures['contour_dataset']['proj_target_vect'][target_neuron_id][comp_neuron_id],
            mei_curvatures['contour_dataset']['proj_comparison_vect'][target_neuron_id][comp_neuron_id],
            mei_curvatures['contour_dataset']['proj_orth_vect'][target_neuron_id][comp_neuron_id],
        )
        activity = mei_curvatures['response_images'][target_neuron_id, comp_neuron_id, ...]
        contsf = plot_fits(
            ax,
            activity,
            iso_contours[target_neuron_id][comp_neuron_id],
            iso_fits[target_neuron_id][comp_neuron_id],
            yx_pts,
            yx_range,
            proj_vects,
            num_levels=10,
            title=''
        )
        col_idx += 1
    
    # histogram
    ax = plot_funcs.clear_axis(axes[row_idx, col_idx], spines='k')
    hist, bin_edges = hist_funcs.get_relative_hist(curvatures[target_neuron_id], bins)
    bin_lefts, bin_rights = bins[:-1], bins[1:]
    bin_centers = bin_lefts + (bin_rights - bin_lefts)
    ax.plot(bin_centers, hist, linestyle='-', drawstyle='steps-mid',
        color='k', linewidth=0.5, zorder=1)
    ax.fill_between(bin_centers, 0, hist, color='k', zorder=2)
    ax.plot([0, 0], [0, np.max(hist)], color='r', linestyle='--', linewidth=2, zorder=3)
    ax.format(xlim=[plot_min, plot_max])
    col_idx += 1
    
    # tuning
    # orientation
    ax = plot_funcs.clear_axis(axes[row_idx, col_idx], spines='k')
    ax.plot(orients, np.max(max_predictions[unit_id], axis=0), linewidth=2, color='k')
    ax.format(
        ylim = [0, np.max(max_predictions[unit_id])+1],
        xlim = [0, 180]
    )
    col_idx += 1

    # phase
    ax = plot_funcs.clear_axis(axes[row_idx, col_idx], spines='k')
    ax.plot(shifts, np.max(max_predictions[unit_id], axis=1), linewidth=2, color='k')
    ax.format(
        ylim = [0, np.max(max_predictions[unit_id])+1],
        xlim = [0, 360]
    )
    
axes[-1, -3].get_xaxis().set_visible(True) # last curvature plot
axes[-1, -3].tick_params(axis='both', bottom=True)
ticks = [np.round(plot_min, 1), 0, np.round(plot_max, 1)]
axes[-1, -3].format(xticks=ticks, xticklabels=[f'{tick}' for tick in ticks])

axes[-1, -2].get_xaxis().set_visible(True) # last orientation plot
axes[-1, -2].tick_params(axis='both', bottom=True, rotation=-45)
ticks = [0, 90, 180]
axes[-1, -2].format(xticks=ticks, xticklabels=[f'{tick}' for tick in ticks])
    
axes[-1, -1].get_xaxis().set_visible(True) # last phase plot
axes[-1, -1].tick_params(axis='both', bottom=True, rotation=-45)
ticks = [0, 180, 360]
axes[-1, -1].format(xticks=ticks, xticklabels=[f'{tick}' for tick in ticks])

plot.show()
for extension in file_extensions:
    save_name = analysis_params['output_directory']+f'curvatures_and_trajectory_images'+extension
    fig.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0.05, dpi=dpi)

In [None]:
with open(os.path.join(santi_etc_path, 'meis.pkl'), 'rb') as g:
    meis = pickle.load(g)
images, activations, performance = meis['images'], meis['activations'], meis['performance']
performance = [performance[idx] for idx in analysis_params['target_neuron_ids']]
curvature_medians = [np.median(curvature) for curvature in curvatures]
#for neuron_id in analysis_params['target_neuron_id']:
#    median_curvature = curvatures_medians[neuron_id]
#    FEV = performance[neuron_id]
fig, ax = plot.subplots()
ax.scatter(curvature_medians, performance, s=3, c='k')
ax.format(xlim=[plot_min, plot_max], xlabel='Median Curvature', ylabel='FEV')
plot.show()