In [None]:
import os
import sys

import pickle
import numpy as np
import scipy
from scipy import optimize
from skimage import measure
import numpy.polynomial.polynomial as poly

import matplotlib as mpl
import matplotlib.pyplot as plt
import proplot as plot

current_path = !pwd
parent_path = os.path.dirname(current_path[0])
if parent_path not in sys.path: sys.path.append(parent_path)
root_path = "/".join(parent_path.split('/')[:-1])
if root_path not in sys.path: sys.path.append(root_path)
santi_path = root_path+"/santi_iso_response"
santi_etc_path = os.path.join(santi_path, 'etc')

import utils.model_handling as model_funcs
import utils.dataset_generation as iso_data
import utils.histogram_analysis as hist_funcs
import utils.plotting as plot_funcs

from santi_iso_response.iso_response import utils as santi_utils

In [None]:
width_fraction = 1.0
text_width = 540.60236 #pt
fontsize = 10
dpi = 300
file_extensions = ['.pdf']#, '.eps', '.png']
font_settings = {
        "text.usetex": True,
        "font.family": 'serif',
        "font.serif": 'Computer Modern Roman',
        "axes.labelsize": fontsize,
        "axes.titlesize": fontsize,
        "figure.titlesize": fontsize+2,
        "font.size": fontsize,
        "legend.fontsize": fontsize,
        "xtick.labelsize": fontsize-2,
        "ytick.labelsize": fontsize-2,
}
mpl.rcParams.update(font_settings)
rc_kwargs = {
    'fontsize':mpl.rcParams['font.size'],
    'fontfamily':mpl.rcParams['font.family'],
    'legend.fontsize':mpl.rcParams['font.size'],
    'text.labelsize':mpl.rcParams['font.size']
}

In [None]:
results_path = root_path+'/iso_analysis/'
save_prefix = 'santi'
params_file = results_path+save_prefix+'_meis_params.npz'
mei_curvature_file = results_path+save_prefix+'_meis.npz'
mes_curvature_file = results_path+save_prefix+'_stim.npz'
rand_curvature_file = results_path+save_prefix+'_rand.npz'

In [None]:
analysis_params = np.load(params_file, allow_pickle=True)['data'].item()
label_list = ["Neuron "+str(neuron_id) for neuron_id in analysis_params["target_neuron_ids"]]
print('\n'.join([f'{key}\t\t{val}' for key, val in analysis_params.items()]))

In [None]:
mei_curvatures = np.load(mei_curvature_file, allow_pickle=True)['data'].item()
print('\n'.join([key for key in mei_curvatures.keys()]))

In [None]:
num_bins = 50

curvature_type = 'iso_curvatures'
#curvature_type = 'attn_curvatures'

num_neurons = len(mei_curvatures[curvature_type])

curvature_means = [np.mean(curvature) for curvature in mei_curvatures[curvature_type]]
curvature_medians = [np.median(curvature) for curvature in mei_curvatures[curvature_type]]
median_argsort = np.argsort(curvature_medians)

if(np.sqrt(num_neurons) == np.round(np.sqrt(num_neurons), 0)):
    num_neurons_edge = int(np.sqrt(num_neurons))
    nrows = num_neurons_edge
    ncols = num_neurons_edge
else:
    num_neurons_edge = int(np.sqrt(num_neurons))
    nrows = num_neurons // num_neurons_edge + 1
    ncols = num_neurons_edge

test_bin_min = -0.13
test_bin_max = 0.13
#test_bin_min = -0.25
#test_bin_max = 0.25

bins = hist_funcs.get_bins(num_bins, test_bin_min, test_bin_max)

axwidth = 0.5
fig, axes = plot.subplots(
    nrows=nrows,
    ncols=ncols,
    sharex=True,
    sharey=True,
    axwidth=axwidth)

simple_cell_ids = [1, 32, 34, 47, 53]
complex_cell_ids = [2, 4, 19, 24, 29, 33, 38]
plot_idx = 0
for row_idx in range(nrows):
    for col_idx in range(ncols):
        if plot_idx < num_neurons:
            unsort_idx = median_argsort[plot_idx]
            neuron_id = analysis_params['target_neuron_ids'][unsort_idx]
            if neuron_id in simple_cell_ids:
                ax = plot_funcs.clear_axis(axes[row_idx, col_idx], spines='r')
            elif neuron_id in complex_cell_ids:
                ax = plot_funcs.clear_axis(axes[row_idx, col_idx], spines='b')
            else:
                ax = plot_funcs.clear_axis(axes[row_idx, col_idx], spines='none')
            hist, bin_edges = hist_funcs.get_relative_hist(mei_curvatures[curvature_type][unsort_idx], bins)
            bin_lefts, bin_rights = bins[:-1], bins[1:]
            bin_centers = bin_lefts + (bin_rights - bin_lefts)
            ax.plot(bin_centers, hist, linestyle='-', drawstyle='steps-mid',
                color='k', linewidth=0.5, label=label_list[unsort_idx], zorder=1)
            ax.fill_between(bin_centers, 0, hist, color='k', zorder=2)
            ax.plot([0, 0], [0, np.max(hist)], color='r', linestyle='--', linewidth=0.5, zorder=3)
            ax.format(title=f'n{neuron_id}_c{np.round(curvature_medians[unsort_idx], 2):0.2f}')
            if plot_idx >= 1:
                prev_unsort_idx = median_argsort[plot_idx-1]
                prev_curv_med = np.round(curvature_medians[prev_unsort_idx], 2)
                curr_curv_med = np.round(curvature_medians[unsort_idx], 2)
                x_offset = -0.08
                y_offset = -0.02
                if prev_curv_med < 0.00 and curr_curv_med >= 0.00:
                    pos = ax.get_position().get_points() # get the axis position in the form [[x0, y0], [x1, y1]]
                    xpos = [pos[1][0]+x_offset,]*2
                    ypos = [pos[0][1]+y_offset, pos[1][1]+y_offset]
                    ax.plot(xpos, ypos, color='g', lw=2, transform=fig.transFigure, clip_on=False)
                if prev_curv_med <= 0.0 and curr_curv_med > 0.0:
                    pos = ax.get_position().get_points() # get the axis position in the form [[x0, y0], [x1, y1]]
                    xpos = [pos[1][0]+x_offset,]*2
                    ypos = [pos[0][1]+y_offset, pos[1][1]+y_offset]
                    ax.plot(xpos, ypos, color='g', lw=2, transform=fig.transFigure, clip_on=False)
        else:
            ax = plot_funcs.clear_axis(axes[row_idx, col_idx])
        plot_idx += 1
axes.format(
    xlim=[test_bin_min, test_bin_max],
    suptitle='Neuron iso-response curvature histograms, sorted by median value'
    #suptitle='Neuron response attenuation curvature histograms, sorted by median value'
)
plot.show()

for extension in file_extensions:
    save_name = analysis_params['output_directory']+'iso_curvature_histograms'+extension
    #save_name = analysis_params['output_directory']+'attn_curvature_histograms'+extension
    fig.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0.05, dpi=dpi)

In [None]:
#Some simple cells: 1, 32, 34, 47, 53. Some complex cells: 2, 4, 19, 24, 29, 33, 38
simple_cell_ids = [1, 32, 34, 47, 53]
complex_cell_ids = [2, 4, 19, 24, 29, 33, 38]
simple_medians = []
complex_medians = []
for list_index, neuron_id in enumerate(analysis_params['target_neuron_ids']):
    if neuron_id in simple_cell_ids:
        simple_medians.append(np.round(curvature_medians[list_index], 3))
    if neuron_id in complex_cell_ids:
        complex_medians.append(np.round(curvature_medians[list_index], 3))
print(f'simple median curvatures = {simple_medians}\ncomplex median curvatures = {complex_medians}')

In [None]:
target_neuron_id = 24#32#33#49#105

with open(os.path.join(santi_etc_path, 'meis.pkl'), 'rb') as g:
    mei = pickle.load(g)
neuron_mei = [img for img in mei['images']]
model = santi_utils.load_model()

iso_vectors = iso_data.compute_comp_vectors(
    neuron_mei,
    [target_neuron_id],
    analysis_params['min_angle'],
    len(neuron_mei),
    comp_method='closest'
)
contour_dataset, _ = iso_data.get_contour_dataset(
    iso_vectors[1],
    iso_vectors[2],
    analysis_params['x_range'],
    analysis_params['y_range'],
    analysis_params['num_images'],
    analysis_params['image_scale'],
    return_datapoints=False
)
iso_curvatures_list = []
attn_curvatures_list = []
for proj_matrix in contour_dataset['proj_matrix'][0]:
    datapoints = iso_data.inject_data(
        proj_matrix,
        contour_dataset['proj_datapoints'],
        analysis_params['image_scale']
    )
    activations = model_funcs.get_normalized_activations(
        model,
        [target_neuron_id],
        [[datapoints]],
        santi_utils.get_activations_cell
    )
    iso_curvatures, attn_curvatures = hist_funcs.compute_curvature_poly_fits(
        activations,
        contour_dataset,
        analysis_params['target_activity'],
        measure_upper_right=True
    )
    iso_curvatures_list.append(iso_curvatures[0][0])
    attn_curvatures_list.append(attn_curvatures[0][0])

In [None]:
def argmed(array):
    array = np.array(array)
    med_idx = np.abs(array - np.median(array)).argmin()
    return med_idx
min_idx = np.argmin(iso_curvatures_list)
med_idx = argmed(iso_curvatures_list)
max_idx = np.argmax(iso_curvatures_list)
index_list = [min_idx, med_idx, max_idx]
curvatures_list = [iso_curvatures_list[min_idx], iso_curvatures_list[med_idx], iso_curvatures_list[max_idx]] 
activations_list = []
for comp_idx in index_list:
    datapoints = iso_data.inject_data(
        contour_dataset['proj_matrix'][0][comp_idx],
        contour_dataset['proj_datapoints'],
        analysis_params['image_scale']
    )
    activations = model_funcs.get_normalized_activations(
        model,
        [target_neuron_id],
        [[datapoints]],
        santi_utils.get_activations_cell
    )
    activations_list.append(np.squeeze(activations))

In [None]:
def add_arrow(ax, vect, xrange, x_y_offset, label, text_color='k'):
    arrow_width = 0.0
    arrow_linewidth = 1
    arrow_headsize = 0.15
    arrow_head_length = 0.15
    arrow_head_width = 0.15
    target_vector_x = vect[0].item()
    target_vector_y = vect[1].item()
    ax.arrow(0, 0, target_vector_x, target_vector_y,
        width=arrow_width, head_width=arrow_head_width, head_length=arrow_head_length,
        fc='k', ec='k', linestyle='-', linewidth=arrow_linewidth)
    tenth_range_shift = xrange/10 # For shifting labels
    text_handle = ax.text(
        target_vector_x+(tenth_range_shift*x_y_offset[0]),
        target_vector_y+(tenth_range_shift*x_y_offset[1]),
        label,
        weight='bold',
        color=text_color,
        horizontalalignment='center',
        verticalalignment='center'
    )

In [None]:
num_levels = 10
figsize = plot_funcs.set_size(text_width, fraction=width_fraction, subplot=[1, 2])
with plot.rc.context(**rc_kwargs):
    fig, axes = plot.subplots(nrows=1, ncols=len(activations_list), figsize=figsize)

for ax, activations, curvature, comp_idx in zip(axes, activations_list, curvatures_list, index_list):
    ax = plot_funcs.clear_axis(ax)
    vmin = np.min(activations)
    vmax = np.max(activations)
    cmap = plt.get_cmap('cividis')
    cNorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    scalarMap = mpl.cm.ScalarMappable(norm=cNorm, cmap=cmap)
    # Plot contours
    x_mesh, y_mesh = np.meshgrid(
        contour_dataset['x_pts'],
        contour_dataset['y_pts']
    )
    levels = np.linspace(vmin, vmax, num_levels)
    contsf = ax.contourf(x_mesh, y_mesh, activations,
        levels=levels, vmin=vmin, vmax=vmax, alpha=1.0, antialiased=True, cmap=cmap)
    # Add arrows
    proj_target = contour_dataset['proj_target_vect'][0][comp_idx]
    xrange = max(analysis_params['x_range']) - min(analysis_params['x_range'])
    tar_text_x_offset = 0.6 / width_fraction
    tar_text_y_offset = -1.2 / width_fraction
    label = r'$\Phi_{k}$'
    add_arrow(ax, proj_target, xrange, [tar_text_x_offset, tar_text_y_offset], label)
    proj_comparison = contour_dataset['proj_comparison_vect'][0][comp_idx]
    comp_text_x_offset = 0.9 / width_fraction
    comp_text_y_offset = 0.3 / width_fraction
    label = r'$\Phi_{j}$'
    add_arrow(ax, proj_comparison, xrange, [comp_text_x_offset, comp_text_y_offset], label)
    proj_orth = contour_dataset['proj_orth_vect'][0][comp_idx]
    orth_text_x_offset = -0.56 / width_fraction
    orth_text_y_offset = 0.3 / width_fraction
    label = r'$\nu$'
    add_arrow(ax, proj_orth, xrange, [orth_text_x_offset, orth_text_y_offset], label)
    # Add axis grid
    ax.set_aspect('equal')
    ax.plot(analysis_params['x_range'], [0,0], color='k', linewidth=1/2)
    ax.plot([0,0], analysis_params['y_range'], color='k', linewidth=1/2)
    ax.format(title=f'curvature = {np.round(curvature, 2)}')
cb = axes[-1].colorbar(contsf, loc='r', label='activation', ticks=0.25)

for extension in file_extensions:
    save_name = analysis_params['output_directory']+f'iso_curvature_planes_n{target_neuron_id}'+extension
    fig.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0.05, dpi=dpi)

In [None]:
target_act = 0.55
activations = activations_list[1]

num_y, num_x = activations.shape
x_range = max(analysis_params['x_range']) - min(analysis_params['x_range'])
y_range = max(analysis_params['y_range']) - min(analysis_params['y_range'])
x_scale_factor =  x_range / num_x
y_scale_factor =  y_range / num_y

x_pts = contour_dataset['x_pts']
y_pts = contour_dataset['y_pts']

contours = measure.find_contours(activations, target_act)[0]
x_vals = contours[:,1] * x_scale_factor - (num_x*x_scale_factor/2)
y_vals = contours[:,0] * y_scale_factor - (num_y*y_scale_factor/2)

coeffs = np.polynomial.polynomial.polyfit(y_vals, x_vals, deg=2)
poly_curvature = coeffs[-1]

#fig, ax = plot.subplots()
#vmin = np.min(activations)
#vmax = np.max(activations)
#cmap = plt.get_cmap('cividis')
#cNorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
#scalarMap = mpl.cm.ScalarMappable(norm=cNorm, cmap=cmap)
#x_mesh, y_mesh = np.meshgrid(x_pts, y_pts)
#contsf = ax.contourf(x_mesh, y_mesh, activations,
#    levels=levels, vmin=vmin, vmax=vmax, alpha=1.0, antialiased=True, cmap=cmap)
#ax.plot(x_vals, y_vals, linewidth=2, color='r')
#ffit = poly.Polynomial(coeffs)
#plt.plot(x_vals, ffit(x_vals), linewidth=2, linestyle='--', color='g')
#ax.format(
#    #ylim = [np.min(y_pts), np.max(y_pts)],
#    #xlim = [np.min(x_pts), np.max(x_pts)],
#    title = f'poly fit curvature = {poly_curvature:.3e}'
#)
#plot.show()

In [None]:
class ComputeCurvature:
    def __init__(self):
        """ Initialize some variables """
        self.xc = 0  # X-coordinate of circle center
        self.yc = 0  # Y-coordinate of circle center
        self.r = 0   # Radius of the circle
        self.xx = np.array([])  # Data points
        self.yy = np.array([])  # Data points

    def calc_r(self, xc, yc):
        """ calculate the distance of each 2D points from the center (xc, yc) """
        return np.sqrt((self.xx-xc)**2 + (self.yy-yc)**2)

    def f(self, c):
        """ calculate the algebraic distance between the data points and the mean circle centered at c=(xc, yc) """
        ri = self.calc_r(*c)
        return ri - ri.mean()

    def df(self, c):
        """ Jacobian of f_2b
        The axis corresponding to derivatives must be coherent with the col_deriv option of leastsq"""
        xc, yc = c
        df_dc = np.empty((len(c), self.xx.size))

        ri = self.calc_r(xc, yc)
        df_dc[0] = (xc - self.xx)/ri                   # dR/dxc
        df_dc[1] = (yc - self.yy)/ri                   # dR/dyc
        df_dc = df_dc - df_dc.mean(axis=1)[:, np.newaxis]
        return df_dc

    def fit(self, xx, yy):
        self.xx = xx
        self.yy = yy
        center_estimate = np.r_[np.mean(xx), np.mean(yy)]
        center = optimize.leastsq(self.f, center_estimate, Dfun=self.df, col_deriv=True)[0]

        self.xc, self.yc = center
        ri = self.calc_r(*center)
        self.r = ri.mean()

        return 1 / self.r  # Return the curvature


# Apply code for an example curve
x = x_vals
y = y_vals
comp_curv = ComputeCurvature()
curvature = comp_curv.fit(x, y)

# Generate points
theta_fit = np.linspace(-np.pi, np.pi, 180)
x_fit = comp_curv.xc + comp_curv.r*np.cos(theta_fit)
y_fit = comp_curv.yc + comp_curv.r*np.sin(theta_fit)

# Plot the result
#plt.plot(x_fit, y_fit, 'k--', label='fit', lw=2)
#plt.plot(x, y, 'ro', label='data', ms=5, mec='b', mew=1)
#plt.xlabel('x')
#plt.ylabel('y')
#plt.title(f'curvature = {curvature:.3e} / poly = {poly_curvature:.3e}')
#plt.show()

In [None]:

fig, ax = plot.subplots()
vmin = np.min(activations)
vmax = np.max(activations)
cmap = plt.get_cmap('cividis')
cNorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
scalarMap = mpl.cm.ScalarMappable(norm=cNorm, cmap=cmap)
x_mesh, y_mesh = np.meshgrid(x_pts, y_pts)
contsf = ax.contourf(x_mesh, y_mesh, activations,
    levels=levels, vmin=vmin, vmax=vmax, alpha=1.0, antialiased=True, cmap=cmap)
ax.plot(x_vals, y_vals, linewidth=2, color='r')
ffit = poly.Polynomial(coeffs)
ax.plot(x_vals, ffit(x_vals), linewidth=2, linestyle='--', color='g')
ax.plot(x_fit, y_fit, 'k--', label='fit', lw=2)
ax.format(
    ylim = analysis_params['y_range'],
    xlim = analysis_params['x_range'],
    title = f'circ = {curvature:.3e}\npoly = {poly_curvature:.3e}')
plot.show()