In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/mine_pub'))

import panel as pn
import holoviews as hv
from holoviews import opts, dim
hv.extension('bokeh')
from bokeh.resources import INLINE

import paths
import processing_parameters
import functions_loaders as fl
import functions_plotting as fp
import functions_kinematic as fk
from functions_tuning import normalize, calculate_dff

import numpy as np
import pandas as pd
import importlib
import h5py
import sklearn.preprocessing as preproc
import joblib as jb
from umap.umap_ import UMAP
from rastermap import Rastermap

from mine import Mine, MineData

In [None]:
def calculate_extra_angles(ds):
    
    # Apply wrapping for directions to get range [0, 360]
    ds['direction_wrapped'] = ds['direction'].copy()
    mask = ds['direction_wrapped'] > -1000
    ds.loc[mask, 'direction_wrapped'] = ds.loc[mask, 'direction_wrapped'].apply(fk.wrap)

    # Now find the direction relative to the ground plane
    try:
        ds['direction_rel_ground'] = ds['direction_wrapped'].copy()
        ds.loc[mask, 'direction_rel_ground'] = ds.loc[mask, 'direction_rel_ground'] + ds.loc[mask, 'head_roll']
    except KeyError:
        ds['direction_rel_ground'] = ds['direction_wrapped'].copy()

    # Calculate orientation explicitly
    if 'orientation' not in ds.columns:
        ds['orientation'] = ds['direction_wrapped'].copy()
        ds['orientation_rel_ground'] = ds['direction_rel_ground'].copy()
        mask = ds['orientation'] > -1000
        ds.loc[mask, 'orientation'] = ds.loc[mask, 'orientation'].apply(fk.wrap, bound=180)
        ds.loc[mask, 'orientation_rel_ground'] = ds.loc[mask, 'orientation_rel_ground'].apply(fk.wrap, bound=180)
    
    return ds

In [None]:
### set up the figure config
importlib.reload(fp)
importlib.reload(processing_parameters)
# define the target saving path
save_path = os.path.join(paths.figures_path, 'MINE_vis')

# define the printing mode
save_mode = True
# define the target document
target_document = 'paper'
# define ca activity type
ca_type = 'fluor'    #'spikes' or 'fluor'
# define if dropping ITI
drop_ITI = False
# set up the figure theme
fp.set_theme()
# load the label dict
# label_dict = processing_parameters.label_dictionary
# variable_list = processing_parameters.variable_list
variable_list = processing_parameters.variable_list_free + processing_parameters.variable_list_visual
label_dict = processing_parameters.wf_label_dictionary

In [None]:
# load the data
importlib.reload(processing_parameters)
importlib.reload(fl)

# get the paths from the database using search_list
all_paths, all_queries = fl.query_search_list()
mice = ['_'.join(os.path.basename(path).split('_')[7:10]) for path in all_paths[0]]
print(all_paths)

data_list = []
# load the data
for path, queries in zip(all_paths, all_queries):
    
    data, _, _  = fl.load_preprocessing(path, queries, latents_flag=False)
    data_list.append(data)

for i, (ds, mouse) in enumerate(zip(data_list[0], mice)):
    ds.loc[:, 'mouse'] = mouse
    
    # drop activity not of correct type
    cols_to_drop = [el for el in ds.columns if ('cell' in el) and (ca_type not in el)]
    ds.drop(cols_to_drop, axis='columns', inplace=True)

    # If using fluorescence data, calulate dF/F
    if ca_type == 'fluor':
        ds = calculate_dff(ds, baseline_type='iti', inplace=True)

    # Do a quick calculation of orientation & dir/ori relative to ground
    if ('direction_wrapped' in variable_list) and ('direction_wrapped' not in ds.columns):
        ds = calculate_extra_angles(ds)

    # Drop the ITI
    if drop_ITI:
        ds.drop(ds[ds['trial_num'] == 0].index, inplace=True)
        ds.reset_index(drop=True, inplace=True)

    data_list[0][i] = ds

# print(f'Number of trials: {len(data_list)}')

In [None]:
frame_rate = processing_parameters.wf_frame_rate
# define the MINE parameters
MINE_params = {
    'train_fraction': 2.0/3,
    'model_history': frame_rate*10,
    'corr_cut': 0.2,
    'compute_taylor': True,
    # 'complexity': True,
    'return_jacobians': True,
    'taylor_look_ahead': frame_rate*5,
    'taylor_pred_every': frame_rate*5,
}

In [None]:
# %%time
# test MINE

# # define the range of augmentations
# augmentation_list = [-1, 0, 0.05]
# # define the repeat number
# repeat_number = 5

# # allocate memory for the instances
# results_array = np.zeros([len(augmentation_list), repeat_number])

# # for all augmentation factors
# for aug_idx, augmentation_factor in enumerate(augmentation_list):
#     # for all reps
#     for rep_idx, _ in enumerate(np.arange(1, repeat_number+1)):
        # define the augmentation factor
    #     augmentation_factor = 0.01
# get the unique dates and mice
unique_dates_mice = np.unique([(el.loc[0, 'mouse'], el.loc[0, 'datetime'][:10]) for el in data_list[0]], axis=0)

# allocate a list to store the psid objects
mine_list = []
mouse_date_list = []
calcium_list = []

predictor_columns = variable_list

# define the columns to exclude
exclude_columns = ['mouse', 'datetime', 'motifs'] #+ ['latent_'+el for el in np.arange(15)]
# for all the pairs
for idx, pair in enumerate(unique_dates_mice):

    print(f'Current mouse: {pair[0]}, current date: {pair[1]}, current index: {idx}')
    # get the relevant trials
    target_trials = [el for el in data_list[0] if (pair[0] in el.loc[0, 'mouse']) & (pair[1] in el.loc[0, 'datetime'])]

    # concatenate
    all_trials = pd.concat(target_trials, axis=0)
    all_trials.dropna(inplace=True)

    # check that all predictors are present
    if not set(variable_list).issubset(all_trials.columns):
        print(f'Mouse {pair[0]} on day {pair[1]} does not have all predictors')
        continue
#     print(all_trials.shape[1])
#     continue

    # get the calcium and predictors
    # cell_columns = [el for el in all_trials.columns if 'cell' in el]
    # cell_columns = [el for el in all_trials.columns if ('cell' in el) and ('spikes' in el)]
    cell_columns = [el for el in all_trials.columns if ('cell' in el)]
    
#     predictor_columns = [el for el in all_trials.columns if ('cell' not in el) & (el not in exclude_columns)]
#     predictor_columns = variable_list
    calcium = all_trials[cell_columns].fillna(0).to_numpy()
    predictors = all_trials[predictor_columns].fillna(0).to_numpy()
    
    if calcium.shape[0] < 600:
        print(f'Mouse {pair[0]} and date {pair[1]} have less than 600 timepoints so skip')
        continue
    print(f'Timepoints: {calcium.shape[0]}, Cells: {calcium.shape[1]}, Predictors: {predictors.shape[1]}')
    
    # create the cell ids
#     cell_ids = pd.DataFrame(np.arange(calcium.shape[1]), columns=['id'])
#     cell_ids['mouse'] = pair[0]
#     cell_ids['day'] = pair[1]
    
    cell_ids = np.array([[el, pair[0], pair[1]] for el in np.arange(calcium.shape[1])])
    
#     print(cell_ids.shape, cell_ids)
#     raise ValueError
#     continue
    # remove nans
#     predictors[np.isnan(predictors)] = 0

#     plot1 = hv.Raster(predictors.T).opts(width=1200)
#     plot2 = hv.Raster(calcium.T).opts(width=1200)
#     break

#     # z score them
#     calcium = preproc.StandardScaler().fit_transform(calcium)
#     predictors = preproc.StandardScaler().fit_transform(predictors)

    # split into train and test for scaling and augmentation
    train_frames = int(MINE_params['tt_split']*calcium.shape[0])
    calcium_train = calcium[:train_frames, :]
    calcium_test = calcium[train_frames:, :]
    calcium_scaler = preproc.StandardScaler().fit(calcium_train)
    calcium_train = calcium_scaler.transform(calcium_train)
    calcium_test = calcium_scaler.transform(calcium_test)
    # duplicate calcium for the augmented predictors
#     if augmentation_factor == -1:
#         MINE_params['tt_split'] = 2/3
    calcium = np.concatenate([calcium_train, calcium_test], axis=0)
#     else:
#         MINE_params['tt_split'] = 4/5
#         calcium = np.concatenate([calcium_train, calcium_train, calcium_test], axis=0)

    predictors_train = predictors[:train_frames, :]
    predictors_test = predictors[train_frames:, :]
    predictors_scaler = preproc.StandardScaler().fit(predictors_train)
    predictors_train = predictors_scaler.transform(predictors_train)
    predictors_test = predictors_scaler.transform(predictors_test)
#     if augmentation_factor == -1:
    predictors = np.concatenate([predictors_train, predictors_test], axis=0)
#     else:
#         predictors_aug = predictors_train.copy() + np.random.randn(*predictors_train.shape)*augmentation_factor
#         # add the augmented data
#         predictors = np.concatenate([predictors_train, predictors_aug, predictors_test], axis=0)
#     print(predictors_train[:5, 0])
#     print(predictors_aug[:5, 0])
#     raise ValueError
    # create the MINE element
#     miner = Mine(tt_split, window, , True, False, True, 25, 5)
    miner = Mine(*MINE_params.values())
    # run Mine
    mine_data = miner.analyze_data(predictors.T, calcium.T)
    
    # create the path
    save_path = os.path.join(r'Z:\test_mine_wf\fluor', f'{pair[0]}_{pair[1]}_mine.hdf5')
    with h5py.File(save_path, 'w') as f:
        mine_data.save_to_hdf5(f, overwrite=True)
        f.create_dataset('cell_ids', data=cell_ids.astype('S'))
    mine_list.append(mine_data)
    mouse_date_list.append(pair)

#             results_array[aug_idx, rep_idx] = mine_data.correlations_test[0]
# (plot1+plot2).opts(shared_axes=False).cols(1)

In [None]:

# # get the unique dates and mice
# unique_dates_mice = np.unique([(el.loc[0, 'mouse'], el.loc[0, 'datetime'][:10]) for el in data_list[0]], axis=0)

# # split the data into day/mouse packages
# trial_list = []
# # for all the pairs
# for pair in unique_dates_mice:
#     trials = [el for el in data_list[0] if (pair[0] in el.loc[0, 'mouse']) & (pair[1] in el.loc[0, 'datetime'])]
#     trial_list.append([trials, pair[0], pair[1]])
# # allocate a list to store the psid objects
# # mine_list = []
# # mouse_date_list = []

# predictor_columns = variable_list

# # define the columns to exclude
# exclude_columns = ['mouse', 'datetime', 'motifs']
# with jb.Parallel(n_jobs=-1) as parallel:
    
#     nn_outputs = parallel(jb.delayed(run_mine)(el[0], el[1], el[2]) for el in trial_list)
#         for idx, el in enumerate(nn_outputs):
#             performance_list.append([l1, l2, batch, el[0][1], idx, False])
    

In [None]:
# load the mine data
# define the base directory
base_path = r'Z:\test_mine_wf\fluor'
# define the list of fields
field_list = ['correlations_trained', 'correlations_test', 'taylor_scores', 'taylor_true_change', 'taylor_full_prediction', 'taylor_by_predictor', 'model_lin_approx_scores', 'me_scores', 'jacobians'] # 'nl_probs',

# get the files in the directory
files = os.listdir(base_path)
# allocate memory for the mine data
mine_list = []
df_list = []
mouse_date_list = []
# for all the files
for file in files:
    
    # get the mouse and day
    name = file.split('_')
    mouse = '_'.join(name[0:3])
    day = name[3]
    filepath = os.path.join(base_path, file)
    # load the file
    with h5py.File(filepath, mode='r') as f:    
        current_mine = [np.array(f[el]) for el in field_list]

    current_mine.append(None)
    # create the mine object
    current_mine = MineData(*current_mine)
    
    # store on a list
    mine_list.append(current_mine)
    mouse_date_list.append((mouse, day))

# load a single file to get the predictor columns variable (crappy crappy hacky code)
predictor_columns = variable_list
# # get the paths from the database using search_list
# all_paths, all_queries = fl.query_search_list()
# # print(all_paths)

# # data_list = []
# # load the data
# for path, queries in zip(all_paths[:1], all_queries[:1]):
    
#     data, _, _  = fl.load_preprocessing(path[:1], queries[:1])
# #     data_list.append(data)
    
# # print(data)
# predictor_columns = data[0].columns
# predictor_columns = [el for el in predictor_columns if (el not in ['motifs', 'mouse', 'datetime']) & ('cell_' not in el)]

In [None]:
current_mine.__dict__.keys()

In [None]:
# print([el.taylor_scores.shape for el in mine_list])
# print(files)

In [None]:
# compile the info for the cells that were fit,

# get the number of predictors from MINE
predictor_number = len(predictor_columns)
scores_number = ((predictor_number**2) - predictor_number)/2 +predictor_number

# allocate for the current trial
mine_cells = []

# for all date/mice
for current_mine, (mouse, date) in zip(mine_list, mouse_date_list):
    
    # for all the cells
    for cell, _ in enumerate(current_mine.correlations_trained):
#         print(current_mine.taylor_scores[cell, :, 0].shape)

        taylor_scores = current_mine.taylor_scores[cell, :, 0].flatten()
        if taylor_scores.shape[0] < scores_number:
            taylor_scores = np.concatenate([taylor_scores, np.zeros((int(scores_number - taylor_scores.shape[0])))*np.nan], axis=0)
#         raise ValueError
        
        temp_list = [current_mine.correlations_test[cell], *taylor_scores, current_mine.model_lin_approx_scores[cell], current_mine.mean_exp_scores[cell], mouse, date]
#         print(len(temp_list))
        mine_cells.append(np.array(temp_list))

# generate the column names
interaction_names = []
idx = 1
for name1 in predictor_columns:
    for name2 in predictor_columns[idx:]:
#         if name1 == name2:
#             continue
        interaction_names.append(f'{name1}_{name2}')
    idx += 1
taylor_columns = predictor_columns + interaction_names
columns = ['correlation_test', *taylor_columns, 'model_lin_approx_scores', 'mean_exp_score', 'mouse', 'date']
# turn the output into a dataframe
mine_cells = pd.DataFrame(mine_cells, columns=columns)
        

In [None]:
# get the distribution of tunings
# print(mine_cells.columns)
# allocate the output
mine_stats = []
model_lin_approx_scores_distribution = []
mean_exp_score_distribution = []

for (mouse, date), data in mine_cells.groupby(['mouse', 'date']):
    total_cells = data.shape[0]
    # get the fit cells
    selection_vector = ~np.isnan(data['correlation_test'].to_numpy().astype(float)) & (data['correlation_test'].to_numpy().astype(float) > MINE_params['threshold'])
    fit_cells = data.iloc[selection_vector, :]
    count_cells = fit_cells.shape[0]
    fraction_fit = count_cells/total_cells
    
    mine_stats.append([mouse, date, count_cells, fraction_fit])
    
    model_lin_approx_scores_distribution.extend(fit_cells['model_lin_approx_scores'].to_numpy().astype(float))
    mean_exp_score_distribution.extend(fit_cells['mean_exp_score'].to_numpy().astype(float))
    
mine_stats = pd.DataFrame(mine_stats, columns=['mouse', 'date', 'count', 'fraction'])

# Number of fitted cells 

In [None]:
# plot the numbers of fit cells
counts = mine_stats['count'].to_numpy().astype(int)
location, freq  = np.unique(counts, return_counts=True)
ticks = [(int(el), el) for el in np.arange(0, np.max(location), 1)]
plot = hv.Scatter((location, freq))
plot.opts(width=600, xrotation=45, xlabel='Cells fit', ylabel='Sessions', xticks=ticks, size=5)

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'Cells_fit')) + '.png')
# save the figure
fig = fp.save_figure(plot, save_name, fig_width=15, dpi=1200, fontsize=target_document, target='screen')

# Fraction of fitted cells 

In [None]:
# plot the fractions

freq, location  = np.histogram(mine_stats['fraction'].to_numpy().astype(float)*100, bins=20)
ticks = [(int(el), int(el)) for el in np.arange(0, np.max(location), 1)]
plot = hv.Scatter((location, freq))
plot.opts(xrotation=45, xlabel='Percentage of cells fit', ylabel='Sessions', width=600, xticks=ticks, size=5)

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'Cells_perc')) + '.png')
# save the figure
fig = fp.save_figure(plot, save_name, fig_width=15, dpi=1200, fontsize=target_document, target='screen')

# Nonlinear prob vs second order 

In [None]:
# plot complexity

plot = hv.Scatter((model_lin_approx_scores_distribution, mean_exp_score_distribution))
plot.opts(xrotation=45, ylabel='R2 2nd order model fit', xlabel='R2 1st order model fit', width=400, xlim=(0,1), ylim=(0,1))

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'Nonlinear_scatter')) + '.png')
# save the figure
fig = fp.save_figure(plot, save_name, fig_width=15, dpi=1200, fontsize='poster', target='screen')


# Linear Probability

In [None]:
# histogram of nonlinear probability

# plot = hv.Scatter((nl_probability_distribution, mean_exp_score_distribution))

freq, location  = np.histogram(model_lin_approx_scores_distribution)
plot = hv.Bars((location, freq))
plot.opts(xrotation=45, ylabel='Cells', xlabel='Linear probability', width=400, xformatter='%.2f')

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'Linear_prob')) + '.png')
# save the figure
fig = fp.save_figure(plot, save_name, fig_width=10, dpi=1200, fontsize='paper', target='screen')

# Nonlinear probability 

In [None]:
# histogram of nonlinear probability

# plot = hv.Scatter((nl_probability_distribution, mean_exp_score_distribution))

freq, location  = np.histogram(mean_exp_score_distribution)
plot = hv.Bars((location, freq))
plot.opts(xrotation=45, ylabel='Cells', xlabel='Nonlinear probability', width=400, xformatter='%.2f')

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'Nonlinear_prob')) + '.png')
# save the figure
fig = fp.save_figure(plot, save_name, fig_width=10, dpi=1200, fontsize='paper', target='screen')

# Percentage of nonlinear weights

In [None]:
# count the number of significant interaction weights across the dataset

interaction_weights = mine_cells.loc[:, interaction_names].copy().astype(float).dropna()

interaction_weights[interaction_weights<0.05] = 0

# print((interaction_weights>0).sum(axis=0))

interaction_matrix = np.zeros((len(variable_list), len(variable_list)))
for idx0, feature0 in enumerate(variable_list):
    for idx1, feature1 in enumerate(variable_list):
        interaction_feature = feature0+'_'+feature1
        if interaction_feature not in interaction_names:
            continue
        interaction_matrix[idx0, idx1] = 100*(interaction_weights[interaction_feature]>0).sum(axis=0)/(interaction_weights.shape[0])
    
interaction_matrix[interaction_matrix==0] = np.nan
ticks = [(idx+0.5, label_dict[el]) for idx, el in enumerate(variable_list)]
plot = hv.Raster(interaction_matrix)
plot.opts(width=900, height=600, tools=['hover'], cmap='Viridis', xticks=ticks, xrotation=45, yticks=ticks, xlabel='', ylabel='', colorbar=True)
plot

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'Interaction_fraction')) + '.png')
# save the figure
fig = fp.save_figure(plot, save_name, fig_width=10, dpi=1200, fontsize='paper', target='screen')

In [None]:
print(((interaction_weights>0.05).sum(axis=1)==1).sum(axis=0))

In [None]:
# Scale the tunings

# eps = 1
# min_samples = 3
perc = 99

tunings = []
# get the tunings
for (mouse, date), data in mine_cells.groupby(['mouse', 'date']):
    current_tunings = data[predictor_columns].copy()
    selection_vector = ~np.isnan(data['correlation_test'].to_numpy().astype(float)) & (data['correlation_test'].to_numpy().astype(float) > MINE_params['threshold'])
    current_tunings = current_tunings.iloc[selection_vector, :].to_numpy().astype(float)
    if selection_vector.sum() == 0:
        continue
#     current_tunings = np.abs(current_tunings)
#     current_tunings[current_tunings>np.percentile(current_tunings, perc)] = np.percentile(current_tunings, perc)
#     current_tunings[current_tunings<np.percentile(current_tunings, 100-perc)] = np.percentile(current_tunings, 100-perc)
    current_tunings[current_tunings<0.05] = 0
    
    tunings.extend(current_tunings)
tunings = np.array(tunings)
tunings[np.isnan(tunings)] = 0
raw_tunings = tunings.copy()
tunings = preproc.StandardScaler().fit_transform(tunings)
# clusters = cluster.DBSCAN(eps=eps, min_samples=min_samples).fit_predict(tunings)
# clusters = cluster.AgglomerativeClustering(distance_threshold=5, n_clusters=None).fit_predict(tunings)
# print(clusters)

# Tunings 

In [None]:
print(raw_tunings.shape)

In [None]:
# plot the tunings from MINE

ticks = [(idx+0.5, label_dict[el]) for idx, el in enumerate(predictor_columns)]
plot_matrix = raw_tunings.copy()
plot_matrix[plot_matrix<0.05] = 0
# model = Rastermap(n_components=2, n_clusters =30, n_PCs=200, init='pca')
model = Rastermap(n_clusters =2, n_PCs=200)
model.fit(plot_matrix)
plot_matrix = plot_matrix[model.isort, :]
plot = hv.Raster(plot_matrix)
plot.opts(width=1000, height=600, cmap='RdBu_r', tools=['hover'], clim=(-1, 1), xticks=ticks, xrotation=45, xlabel='', ylabel='Cells', colorbar=True)

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'MINE_tunings')) + '.png')
# save the figure
fig = fp.save_figure(plot, save_name, fig_width=10, dpi=1200, fontsize='paper', target='screen')

# UMAP 

In [None]:
# perform umap on the fit cell tuning
reducer1 = UMAP(min_dist=0.1, n_neighbors=20)
embedded_data1 = reducer1.fit_transform(tunings)

In [None]:
target_field = 'head_yaw'
perc = 99

label_idx = [idx for idx, el in enumerate(predictor_columns) if target_field == el]
raw_labels = tunings[:, label_idx]

# raw_labels = np.abs(raw_labels)

raw_labels[raw_labels>np.percentile(raw_labels, perc)] = np.percentile(raw_labels, perc)
raw_labels[raw_labels<np.percentile(raw_labels, 100-perc)] = np.percentile(raw_labels, 100-perc)

plot_data = np.concatenate([embedded_data1, raw_labels.reshape((-1, 1))], axis=1)

In [None]:
# umap plot of the fit cell tunings
umap_plot = hv.Scatter(plot_data, vdims=['Dim 2', 'Parameter'], kdims=['Dim 1'])
# umap_plot = hv.HexTiles(umap_data, kdims=['Dim 1', 'Dim 2'])
umap_plot.opts(colorbar=True, color='Parameter', cmap='Spectral_r', tools=['hover'], alpha=1)
umap_plot.opts(width=1200, height=1000, size=5)

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'MINE_UMAP')) + '.png')
# save the figure
fig = fp.save_figure(umap_plot, save_name, fig_width=15, dpi=1200, fontsize=target_document, target='screen')

In [None]:
# plot the importances per neuron

# print(len(mine_data.taylor_by_predictor))
# print(mine_data.taylor_full_prediction[0].shape)
# print(mine_data.taylor_true_change.shape)
# print(mine_data.__dict__.keys())

mine_data = mine_list[0]

print(f'Number of cells: {calcium.shape[1]}, Number of timepoints: {calcium.shape[0]}, Number of predictors: {predictors.shape[1]}')
for el in mine_data.__dict__.keys():
    try:
        print(el, getattr(mine_data, el).shape)
    except AttributeError:
        print(el)

# taylor scores is predictors + n choose 2 predictors, i.e. the interaction terms
# first dimension of taylor_true_change, taylor_full_prediction and taylor_by_predictor is the fitted cells
# the second dimension in the Taylor metrics comes from (total time - past - future_prediction)/frame_step , see TaylorDecomp

# second dimension in jacobians is predictors x timepoints set for calculation

        
# hv.Raster(np.squeeze(mine_data.taylor_by_predictor[0][1, :, :])).opts(width=800, height=800)
# hv.Curve(mine_data.taylor_full_prediction[0]).opts(width=1200)
# hv.Curve(np.squeeze(mine_data.taylor_by_predictor[0][:, 20, 0])).opts(width=1200)
ticks = [(idx+0.5, el) for idx, el in enumerate(predictor_columns)]

plot = hv.Raster(np.squeeze(mine_data.taylor_scores[:, :len(predictor_columns), 0]))
# plot = hv.Raster(np.squeeze(mine_data.taylor_scores[:, :, 0]))
plot.opts(width=1200, height=800, xticks=ticks, xrotation=90, ylabel='Cells', xlabel='', tools=['hover'])


# Single cell interactions 

In [None]:
# plot the interaction predictors for a given cell

predictor_idx = [idx for idx, el in enumerate(variable_list)]
# define the target cell
mine_data = mine_list[-1]
target_cell = 0
print(mine_data.taylor_scores.shape)
# get the taylor coefficients
current_taylor = mine_data.taylor_scores[target_cell, len(predictor_columns):, 0]
# allocate the plotting matrix
plot_matrix = np.zeros((len(predictor_columns), len(predictor_columns)))
# fill the triangular matrix
plot_matrix[np.tri(len(predictor_columns), k=-1, dtype=bool).T] = current_taylor
diagonal = np.diag(np.ones((len(predictor_columns)))).astype(bool)
plot_matrix[diagonal] = mine_data.taylor_scores[target_cell, :len(predictor_columns), 0]
plot_matrix[plot_matrix==0] = np.nan

plot_matrix = plot_matrix[predictor_idx, :]
plot_matrix = plot_matrix[:, predictor_idx]
yticks = [(idx+0.5, label_dict[el]) for idx, el in enumerate(variable_list)]
clim = (-np.nanmax(plot_matrix), np.nanmax(plot_matrix))
plot = hv.Raster(plot_matrix)
plot.opts(width=900, height=600, cmap='Spectral_r', yticks=yticks, ylabel='', xticks=yticks, xlabel='', xrotation=45, tools=['hover'], colorbar=True)
plot.opts(clim=clim)
plot

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'Jacobian_example')) + '.png')
# save the figure
fig = fp.save_figure(plot, save_name, fig_width=10, dpi=1200, fontsize='paper', target='screen')


In [None]:
# plot a particular cell

# define the target cell
target_cell = 0
# define the shift factor
shift_factor = 0.1
number_window = MINE_params['window']

# get the cells data
current_cell = mine_data.jacobians[target_cell, :]
current_cell = current_cell.reshape([1, number_window, predictor_number], order='F').transpose([2, 1, 0]).reshape([predictor_number, -1], order='F')
# normalize per predictor
# current_cell = ((current_cell - current_cell.min(axis=0))/(current_cell.max(axis=0) - current_cell.min(axis=0)))
# allocate the plot list
plot_list = []
tick_list = []

predictor_idx = [idx for idx, el in enumerate(predictor_columns) if el in variable_list]
counter = 0

xticks = [(int(number_window)-int(el), str(el)) for el in np.arange(0, number_window, 10, dtype=int)]
print(xticks)
# for all the predictors
for idx, predictor in enumerate(current_cell):
    if idx not in predictor_idx:
        continue
    plot = hv.Curve(predictor + counter*shift_factor)
    plot.opts(width=500, color='red')
    plot_list.append(plot)
    tick_list.append(((counter)*shift_factor, label_dict[variable_list[counter]]))
    counter += 1

overlay = hv.Overlay(plot_list)
overlay.opts(height=800, yticks=tick_list, xrotation=45, xlabel='Time', ylabel='', xticks=xticks)
overlay

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'Trace_example')) + '.png')
# save the figure
fig = fp.save_figure(overlay, save_name, fig_width=10, dpi=1200, fontsize=target_document, target='screen')

# Single Cell fits

In [None]:
ca_traces = []
mine_pred = []

for current_mine, (mouse, date) in zip(mine_list, mouse_date_list):
    target_trials = [el for el in data_list[0] if (mouse in el.loc[0, 'mouse']) & (date in el.loc[0, 'datetime'])]
    all_trials = pd.concat(target_trials, axis=0)
    all_trials.dropna(inplace=True)

    selection_vector = ~np.isnan(np.array(current_mine.correlations_test).astype(float)) & \
                        (np.array(current_mine.correlations_test).astype(float) > MINE_params['threshold']) & \
                        (np.array(current_mine.model_lin_approx_scores).astype(float) >= -1) & \
                        (np.array(current_mine.mean_exp_scores).astype(float) >= -1)
    
    fit_cells = np.argwhere(selection_vector).flatten()
    fit_cols = [el for el in all_trials.columns for cell in fit_cells if ('cell_{:04d}'.format(cell) in el)]
    
    for idx, (cell, col) in enumerate(zip(fit_cells, fit_cols)):
        ca_traces.append(all_trials[col].to_numpy().flatten())
        mine_pred.append(current_mine.taylor_full_prediction[idx, :])

In [None]:
# define the target cell
target_cell = 0

# get the cells data
current_cell_pred = mine_pred[target_cell]
current_cell_calcium = ca_traces[target_cell][MINE_params['window']::MINE_params['interval']]

current_cell_pred_norm = normalize(current_cell_pred)
current_cell_pred_norm -= np.mean(current_cell_pred_norm)
current_cell_calcium_norm = normalize(current_cell_calcium)

pred = hv.Curve((np.arange(len(current_cell_pred_norm)), current_cell_pred_norm)).opts(alpha=0.75)
calcium = hv.Curve((np.arange(len(current_cell_calcium_norm)), current_cell_calcium_norm)).opts(alpha=0.75)
overlay = pred * calcium
overlay.opts(width=600, height=400)

In [None]:
a = mine_cells.loc[~np.isnan(np.array(mine_cells.correlation_test).astype(float)) & \
                    (np.array(mine_cells.correlation_test).astype(float) > MINE_params['threshold']) & \
                    (np.array(mine_cells.model_lin_approx_scores).astype(float) >= -1) & \
                    (np.array(mine_cells.mean_exp_score).astype(float) >= -1), :].reset_index()
a = a.sort_values('correlation_test', ascending=False)[['correlation_test', 'model_lin_approx_scores', 'mean_exp_score']]
a