# General imports/paths

## Import

In [1]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import h5py
from scipy.sparse import csr_matrix
import pickle
from shapely.geometry import Point
from shapely.geometry import LineString
import seaborn as sns

## Paths

In [48]:
data_dir = 'C:/Users/cortiz/Projects/inscopix_project_cajal/data'
list_animals = ['20211204/mouse_01_CA1','20211206/mouse_01_CA1','20211207/mouse_01_CA1']
path_registration_file = os.path.join(data_dir, 'cellRegistered_M1_CA1.mat')

## Parameters

In [3]:
n_bins = 5
marker_size = 400
color_palette = np.array(sns.color_palette('Spectral_r', n_bins))

# Loading registration

In [4]:
with h5py.File(path_registration_file, 'r') as f:
    mat_reg = np.array(f['cell_registered_struct']['cell_to_index_map'])

# Functions

In [88]:
def loadAll(animal):
    
    global meta_spikes, df_dlc, mat_spikes, df_frames, dict_meta, reference_points, path_pickle, path_table, path_mat_spikes, cur_dir
    
    cur_dir = os.path.join(data_dir, animal)
    path_table = os.path.join(cur_dir, 'behavior', 'table_cut.csv')
    path_mat_spikes = os.path.join(cur_dir, 'spikes', 'finalSpikesMat.mat')    
    path_pickle = os.path.join(cur_dir, 'pickle')
    
    with open(os.path.join(path_pickle, 'meta_spikes.pck'), 'rb') as f:
        meta_spikes = pickle.load(f)
    with open(os.path.join(path_pickle, 'df_dlc_projections.pck'), 'rb') as f:
        df_dlc = pickle.load(f)
    with open(os.path.join(path_pickle, 'mat_spikes.pck'), 'rb') as f:
        mat_spikes = pickle.load(f)
    with open(os.path.join(path_pickle, 'df_frames.pck'), 'rb') as f:
        df_frames = pickle.load(f)
    with open(os.path.join(path_pickle, 'dict_meta.pck'), 'rb') as f:
        dict_meta = pickle.load(f)
    with open(os.path.join(path_pickle, 'reference_points.pck'), 'rb') as f:
        reference_points = pickle.load(f)
        
    df_dlc_axis = df_dlc[['closest_arm', 'projection', 'environment']]
    meta_spikes = meta_spikes.join(df_dlc_axis, on = 'behavioral_frame')
    
    
def plotArm(dict_arm):    

    c = reference_points[dict_arm['mode']]
    max_val = np.max([np.max(d) if k != 'mode' else 0 for (k,d) in dict_arm.items()])
    bins_colorpalette = np.linspace(0, max_val+0.00000001, num = n_bins + 1)

    val_middle = list()
    
    for arm in ['middle', 'left', 'right']:

        if arm == 'middle':
            xe = c['x_m']
            ye = c['y_m']
        elif arm == 'left':
            xe = c['x_l']
            ye = c['y_l']
        elif arm == 'right':
            xe = c['x_r']
            ye = c['y_r']

        x_vect = np.linspace(c['x_c'], xe, num=n_bins)
        y_vect = np.linspace(c['y_c'], ye, num=n_bins)

        color_code = np.digitize(dict_arm[arm][1:], bins_colorpalette)
        val_middle.append(dict_arm[arm][0])
        plt.scatter(x_vect[1:], y_vect[1:], c = color_palette[color_code-1], s = marker_size)
    
    color_code = np.digitize(np.mean(val_middle), bins_colorpalette)
    plt.scatter(x_vect[0], y_vect[0], color = color_palette[color_code], s = marker_size)
    plt.gca().set_aspect('equal')

    plt.axis('off')
    plt.gca().invert_yaxis()
    
    
def getEventRate(env, cell = None, norm_by_time = True):
    
    if cell is None:
        cell = range(mat_spikes.shape[1])
    sub_spikes = meta_spikes.query('environment == @env')
    dict_arm = {'mode': env}

    for arm in ['middle', 'left', 'right']:

        sub_arm = sub_spikes.query('closest_arm == @arm').copy()
        sample_counts, _ = np.histogram(sub_arm['projection'], hist_edges)
        sub_arm['bin_assignment'] = np.digitize(sub_arm['projection'], hist_edges) - 1
        summed_activity_binned = np.zeros(n_bins)

        for i,r in sub_arm.iterrows():
            
            summed_activity_binned[r['bin_assignment']] += np.sum(mat_spikes[i,cell] > 0)
            
        dict_arm[arm] = summed_activity_binned
        if norm_by_time:
            dict_arm[arm] /= sample_counts 
    
    return(dict_arm)    

# Finding most active cells

# Heatmaps analysis

In [6]:
hist_edges = np.linspace(0, 1.0001, num = n_bins + 1)

## Overview plots

In [35]:
loadAll(list_animal[0])
num_events = -1*np.ones(mat_reg.shape[1])

for i,r in enumerate(mat_reg.T):
    cell_selected = mat_reg[0,i]
    if cell_selected == 0:
        continue
    else:
        cell_selected = int(cell_selected - 1)
        num_events[i] = np.sum(mat_spikes[:,cell_selected] > 0)
        
sorted_cells = np.argsort(-num_events)

In [90]:
for env in ['t_maze', 'y_maze']:

    path_plots = os.path.join(data_dir, '..', 'plots', 'cell_across_days', 'CA1', env)
    os.makedirs(path_plots, exist_ok=True)

    cursor = 0
    for cell_ID in sorted_cells[:50]:

        plt.figure(figsize = (14,8))
        plt.suptitle('Cell: ' + str(cell_ID))

        for i,animal in enumerate(list_animals):
            plt.subplot(2, len(list_animal), i+1)
            cell_daily_mat = int(mat_reg[i,cell_ID] - 1)
            loadAll(animal)
            plotArm(getEventRate(env, cell_daily_mat))
            plt.title('Session ' + str(i+1))
            a,b = plt.xlim()
            plt.xlim(a - (b-a)*0.15, b + (b-a)*0.15)        
            a,b = plt.ylim()
            plt.ylim(a - (b-a)*0.15, b + (b-a)*0.15)

            plt.subplot(2, len(list_animal), i+1+len(list_animal))
            sub_meta = meta_spikes.query('environment == @env')
            for sess in pd.unique(sub_meta['session']):
                xoff = 0
                yoff = 0
                if 'offset' in dict_meta.keys():
                     if sess in dict_meta['offset'].keys():
                            xoff = dict_meta['offset'][sess]['x']
                            yoff = dict_meta['offset'][sess]['y']
                cur_meta = sub_meta.query('session == @sess')
                plt.plot(cur_meta['x']+xoff, cur_meta['y']+yoff, c = 'darkgray')   
                
                spikes = np.where(mat_spikes[cur_meta.index,cell_daily_mat] > 0)[0]
                plt.scatter(cur_meta.iloc[spikes]['x']+xoff, cur_meta.iloc[spikes]['y']+yoff, zorder = 10, c = 'r')

            plt.axis('off')
            plt.gca().invert_yaxis()
            plt.gca().set_aspect('equal')
        plt.savefig(os.path.join(path_plots, 'cell_' + str(cursor).zfill(4) + '.png'))
        cursor += 1
        plt.close()