In [None]:
import os
import sys
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture'))
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture\mine_pub'))

import panel as pn
import holoviews as hv
from holoviews import opts, dim
hv.extension('bokeh')
from bokeh.resources import INLINE

import paths
import importlib
import functions_plotting as fp
from functions_plotting import format_figure as ff
import functions_bondjango as bd
import functions_loaders as fl
import processing_parameters
import numpy as np
import pandas as pd
import h5py
import scipy.stats as stat
import scipy.signal as signal
import datetime
import umap
import sklearn.decomposition as decomp
import sklearn.preprocessing as preproc
import sklearn.cross_decomposition as xdecomp
import sklearn.cluster as cluster

from mine_pub.mine import Mine

import PSID
from PSID.evaluation import evalPrediction

In [None]:
# set up the figure config
importlib.reload(fp)
importlib.reload(processing_parameters)
# define the target saving path
save_path = os.path.join(paths.figures_path, 'Low_dim')

# define the printing mode
save_mode = True
# define the target document
target_document = 'paper'
# set up the figure theme
fp.set_theme()
# load the label dict
label_dict = processing_parameters.label_dictionary
variable_list = processing_parameters.variable_list

In [None]:
importlib.reload(processing_parameters)
importlib.reload(fl)

# get the paths from the database using search_list
all_paths, all_queries = fl.query_search_list()
# print(all_paths)

data_list = []
# load the data
for path, queries in zip(all_paths, all_queries):
    
    data, _, _  = fl.load_preprocessing(path, queries)
    data_list.append(data)

# print(all_paths)
print(f'Number of trials: {len(data_list)}')


In [None]:
def cvpca_variance(train_trajectories, test_trajectories):
    """Calculate the covariance between the test and train sets"""
    # get the relevant dimensions
    time_points = test_trajectories.shape[0]
    components = test_trajectories.shape[1]
    
    # allocate memory for the covariances
    covariances = np.zeros((components, 1))
    # for all the components
    for component in np.arange(components):
        # get the components
        current_test = test_trajectories[component, :]
        current_train = train_trajectories[component, :]
        # get the means
        mean_test = np.mean(current_test)
        mean_train = np.mean(current_train)
        current_component = (1/(time_points - 1)) * np.sum((current_test - mean_test)*(current_train - mean_train))
        
        covariances[component] = current_component
    return covariances 

In [None]:
# Run cvPCA on the data

# define a day and mouse
# target_day = '2021-04-02'
# target_mouse = 'DG_210202_a'

# target_day = '2020-08-19'
# target_mouse = 'DG_200701_a'
# ['2020-08-06' 'DG_200617_b']
# target_day = '2020-08-06'
# target_mouse = 'DG_200617_b'
# 2021-05-03' 'DG_210202_a
target_day = '2021-05-03' 
target_mouse = 'DG_210202_a'

# get the relevant trials
target_trials = [el for el in data_list[0] if (target_day in el.loc[0, 'datetime']) & (target_mouse in el.loc[0, 'mouse'])]

# test and train with one trial at a time
# zscore the whole data

# get the training data
train_trials = [el for idx1, el in enumerate(target_trials) if idx1%2 == 0]
train_trials = pd.concat(train_trials, axis=0)
# get the calcium
cells = [el for el in train_trials.columns if 'cell' in el]
not_cells = [el for el in train_trials.columns if 'cell' not in el]
train_calcium = train_trials.loc[:, cells].to_numpy()
# generate the scaler for zscoring
scaler = preproc.StandardScaler().fit(train_calcium)
# scaler = preproc.MinMaxScaler().fit(train_calcium)
# zscore the train data
# train_calcium = scaler.transform(train_calcium)
# create the fitter
pca = decomp.PCA(whiten=True)
# fit the train data
train_components = pca.fit_transform(train_calcium)

# preprocess the test data
test_trials = [el for idx1, el in enumerate(target_trials) if idx1%2 != 0]
test_trials = pd.concat(test_trials, axis=0)
test_calcium = test_trials.loc[:, cells].to_numpy()
test_calcium = scaler.transform(test_calcium)

# transform the test data
test_components = pca.transform(test_calcium)

# calculate explained variance to select components
covariances = cvpca_variance(train_components, test_components)

print(covariances.shape)
# determine the explained variance
# store


    
    

In [None]:
# plot the sorted explained covariances
hv.extension('bokeh')
x = np.arange(covariances.shape[0])
hv.Scatter((x, np.sort(covariances[:, 0])))

In [None]:
# calculate the sorted cvPCs

# get the sorted indices
idx_sort = np.argsort(covariances, axis=0).flatten()

# get the full data components
full_data = pd.concat(target_trials, axis=0)
full_calcium = full_data.loc[:, cells].to_numpy()
# full_calcium = scaler.transform(full_calcium)
full_components = pca.transform(full_calcium)

# sort the components
sorted_comps = full_components[:, np.flip(idx_sort)]


In [None]:
# plot the calcium activity
hv.extension('bokeh')
plot_calcium = full_calcium.copy()
normalized_calcium = (plot_calcium-plot_calcium.min(axis=0))/(plot_calcium.max(axis=0)-plot_calcium.min(axis=0))

raster = hv.Raster((normalized_calcium).T)
raster.opts(width=1000, height=800, cmap='Viridis', tools=['hover'])

raster

In [None]:
# plot the sorted cvPCs
hv.extension('bokeh')

normalized_comps = (sorted_comps-sorted_comps.min(axis=0))/(sorted_comps.max(axis=0)-sorted_comps.min(axis=0))
normalized_comps = normalized_comps[:, :3]
raster = hv.Raster((normalized_comps).T)
raster.opts(width=1000, height=800, cmap='Viridis')

raster

In [None]:
# do line plots
hv.extension('bokeh')
x = np.arange(normalized_comps.shape[0])
plot_list = []
for el in np.arange(normalized_comps.shape[1]):
    plot = hv.Curve((x, normalized_comps[:, el]))
    plot.opts(width=1000)
    plot_list.append(plot)
    
hv.Overlay(plot_list)
    

In [None]:
# plot the actual trajectories
hv.extension('plotly')

target_parameter = 'cricket_0_mouse_distance'

# define the target dimensions
tar_dim = np.array([0, 1, 2]) + 0

# allocate memory for the output list
plot_list = []

trial = normalized_comps.copy()

parameter = full_data[target_parameter].to_numpy()
parameter[np.isnan(parameter)] = 0
parameter = (parameter-parameter.min())/(parameter.max()-parameter.min())
#         parameter = np.log(parameter)

plot = hv.Scatter3D((trial[:, tar_dim[0]], trial[:, tar_dim[1]], trial[:, tar_dim[2]]))
# plot.opts(size=2, height=800, width=800, color=parameter, colorbar=True)
plot.opts(height=800, width=800, color=parameter, colorbar=True)
#     plot_list[idx].opts(hv.opts.Scatter3D(cmap='viridis'))
plot

# ov = hv.Overlay(plot_list)
# ov


In [None]:
%%time
# run PSID on the data

# define the test/train percentage
test_perc = 0.3

# get the unique dates and mice
unique_dates_mice = np.unique([(el.loc[0, 'datetime'][:10], el.loc[0, 'mouse']) for el in data_list[0]], axis=0)

# allocate a list to store the psid objects
psid_list = []

# for all the pairs
for pair in unique_dates_mice:
    # get the relevant trials
    # target_trials = [el for el in data_list[0] if (target_day in el.loc[0, 'datetime']) & (target_mouse in el.loc[0, 'mouse'])]
    target_trials = [el for el in data_list[0] if (pair[0] in el.loc[0, 'datetime']) & (pair[1] in el.loc[0, 'mouse'])]
#     target_trials = data_list[0]

    # define the target behaviors
    target_behavior = variable_list

    # allocate memory for the training and test sets
    ca_train = []
    ca_test = []
    beh_train = []
    beh_test = []
    # for all the trials
    for trial in target_trials:

        # get the available columns
        labels = list(trial.columns)
        cells = [el for el in labels if 'cell' in el]
        # get the cell data
        calcium_data = np.array(trial[cells].copy())
        # get rid of the super small values
        calcium_data[np.isnan(calcium_data)] = 0

        try:
            # get the parameter
            beh_data = trial[target_behavior].to_numpy()

            # smooth the parameter
            beh_data = signal.medfilt(beh_data, (21, 1))
        except KeyError:
            continue

        # skip if empty
        if (calcium_data.shape[0] == 0) | (calcium_data.shape[1] < 3):
            continue

    #     downsamp = 1
    #     # bin the data
    #     if downsamp > 1:
    #         beh_data = ss.decimate(beh_data, downsamp, axis=0)
    #         calcium_data = ss.decimate(calcium_data, downsamp, axis=0)

        # get the threshold index
        threshold_idx = int(calcium_data.shape[0]*(test_perc))
        # split the data
        ca_trial_train = calcium_data[threshold_idx:, :] 
        ca_trial_test = calcium_data[:threshold_idx, :] 
        beh_trial_train = beh_data[threshold_idx:, :]
        beh_trial_test = beh_data[:threshold_idx, :] 

        # store the data
        ca_train.append(ca_trial_train)
        ca_test.append(ca_trial_test)
        beh_train.append(beh_trial_train)
        beh_test.append(beh_trial_test)    

    # skip if empty arrays
    if len(ca_train) == 0:
        continue
    # scale the data
    # ca_scaler = preprocessing.StandardScaler().fit(np.concatenate(ca_train))
    # beh_scaler = preprocessing.StandardScaler().fit(np.concatenate(beh_train))

    # ca_train = [ca_scaler.transform(el) for el in ca_train]
    # ca_test = [ca_scaler.transform(el) for el in ca_test]
    # beh_train = [beh_scaler.transform(el) for el in beh_train]
    # beh_test = [beh_scaler.transform(el) for el in beh_test]

    # scale each trial separately
    ca_scaler_list = [preproc.StandardScaler().fit(el) for el in ca_train]
    beh_scaler_list = [preproc.StandardScaler().fit(el) for el in beh_train]

    ca_train = [ca_scaler_list[idx].transform(el) for idx, el in enumerate(ca_train)]
    ca_test = [ca_scaler_list[idx].transform(el) for idx, el in enumerate(ca_test)]
    beh_train = [beh_scaler_list[idx].transform(el) for idx, el in enumerate(beh_train)]
    beh_test = [beh_scaler_list[idx].transform(el) for idx, el in enumerate(beh_test)]


    # train the PSID model
    idSys = PSID.PSID(ca_train, beh_train, nx=20, n1=10, i=20)
    # idSys = PSID.PSID(ca_train, beh_train, nx=1, n1=1, i=20) # for cricket distance
    # idSys = PSID.PSID(ca_train, beh_train, nx=20, n1=10, i=35)
    
    # store the element
    psid_list.append([pair, idSys, ca_scaler_list])

    # allocate memory for the predictions
    beh_pred = []
    ca_pred = []
    latent_pred = []
    # predict each trial
    for trial in ca_test:
        beh_p, ca_p, latent_p = idSys.predict(trial)
        beh_pred.append(beh_p)
        ca_pred.append(ca_p)
        latent_pred.append(latent_p)

    combo_beh_test = np.vstack(beh_test)
    combo_beh_pred = np.vstack(beh_pred)

    combo_ca_test = np.vstack(ca_test)
    combo_ca_pred = np.vstack(ca_pred)

    R2TrialBased_beh = evalPrediction(combo_beh_test, combo_beh_pred, 'CC')
    R2TrialBased_ca = evalPrediction(combo_ca_test, combo_ca_pred, 'CC')

    print('Number of cells that have a larger than 0 CC:', np.sum(R2TrialBased_ca != 0))
    print('Mean Ca CC:', np.nanmean(R2TrialBased_ca))
    print('CC of behavior:', R2TrialBased_beh)

In [None]:
print(psid_list[0][1].Cz.shape)
plot_list = []
# for all the models
for model in psid_list[:10]:
    # get the idsys element
    idsys = model[1]
    
    # get the target matrix
    target_matrix = idsys.Cz
    
    plot = hv.Raster(target_matrix)
    
    plot_list.append(plot)
layout = hv.Layout(plot_list)
layout

In [None]:
%%time
# run the model on all experiments

# allocate memory for the predictions
final_beh = []
final_ca = []
final_latent = []
final_pairs = []
final_scaled_beh = []
# for all the pairs
for pair in unique_dates_mice:
    # get the trials
    target_trials = [el for el in data_list[0] if (pair[0] in el.loc[0, 'datetime']) & (pair[1] in el.loc[0, 'mouse'])]
    tag_vector = [False if (el[0][0] == pair[0]) & (el[0][1] == pair[1]) else True for el in psid_list]
    # see if the pair was calculated
    if all(tag_vector):
        continue
        
    # get the index of the corresponding psid element
    idx = np.argwhere(~np.array(tag_vector))[0][0]
    
    # get the corresponding psid element
    idSys = psid_list[idx][1]
    scalers = psid_list[idx][2]
    # predict each trial
    for trial_idx, trial in enumerate(target_trials):

        # get the available columns
        labels = list(trial.columns)
        cells = [el for el in labels if 'cell' in el]
        
        # get the cell data
        calcium_data = np.array(trial[cells].copy())
        # skip if empty
        if (calcium_data.shape[0] == 0) | (calcium_data.shape[1] < 3):
            continue
        
        # scale the data
        # get rid of the super small values
        calcium_data[np.isnan(calcium_data)] = 0
        calcium_data = scalers[trial_idx].transform(calcium_data)
        
        # predict and store
        beh_p, ca_p, latent_p = idSys.predict(calcium_data)
        final_beh.append(beh_p)
        final_ca.append(ca_p)
        final_latent.append(latent_p)
        final_pairs.append(pair)
        # get the behavior
        final_scaled_beh.append(trial[variable_list].to_numpy())

In [None]:
# plot individual PSID latents

hv.extension('bokeh')
plot_data = np.vstack(final_latent).copy()
x = np.arange(plot_data.shape[0])
plot_list = []
for el in np.arange(plot_data.shape[1]):
    plot = hv.Curve((x, plot_data[:, el]))
    plot.opts(width=1000)
    plot_list.append(plot)
    
hv.Overlay(plot_list)

In [None]:
def hook_fun(plot, element):
    b = plot.state
#     print(element.colorscale)
#     print(plot.colorscale)
    
#     print(element.__dict__)
#     print(b['layout'])
    b['layout']['colorway']='Viridis'

In [None]:
# plot the PSID dynamics reconstruction
# hv.extension('plotly')
    
# # opts.defaults(opts.Scatter3D(cmap='viridis'))
# # define the target parameter
# target_parameter = 'cricket_0_delta_heading'
# # print(normal_data[0].columns[:40])

# # define the target dimensions
# tar_dim = np.array([0, 1, 2]) + 0

# # allocate memory for the output list
# plot_list = []

# target_trial = [0, 8]
# # for idx, trial in enumerate(final_latent[target_trial[0]:target_trial[1]+1]):
# for idx, trial in enumerate(final_latent):
    
#     # get a parameter
#     try:
#         parameter = target_trials[idx][target_parameter]
#         parameter = (parameter-parameter.min())/(parameter.max()-parameter.min())
# #         parameter = np.log(parameter)
    
#         plot_list.append(hv.Scatter3D((trial[:, tar_dim[0]], trial[:, tar_dim[1]], trial[:, tar_dim[2]])))
#         plot_list[idx].opts(size=2, height=800, width=800, color=parameter, colorbar=True)
# #     plot_list[idx].opts(hv.opts.Scatter3D(cmap='viridis'))
#     except KeyError:
# #             parameter = np.zeros(trial.shape[0])
#         continue

# ov = hv.Overlay(plot_list)
# ov
# hv.help(hv.Scatter3D)

In [None]:
# calculate a correlation matrix between latents and behavioral variables

correlation_list = []
pvalue_list = []
for idx, trial in enumerate(final_latent):
#     behavior = trial[variable_list].to_numpy().copy()
    behavior = final_scaled_beh[idx]
    latents = trial
    correlation_matrix, pvalue_matrix = stat.spearmanr(behavior, latents, nan_policy='omit')
    correlation_list.append(correlation_matrix)
    pvalue_list.append(pvalue_matrix)

correlation_matrix = np.nanmean(np.array(correlation_list), axis=0)
pvalue_matrix = np.nanmean(np.array(pvalue_list), axis=0)

In [None]:
# plot the correlation
hv.extension('bokeh')

yticks = [(idx+0.5, label_dict[el]) for idx, el in enumerate(variable_list)]
# ticks += [(idx+0.5 + len(variable_list), ' '.join(('latent', str(el)))) for idx, el in enumerate(np.arange(latents.shape[1]))]
xticks = [(idx+0.5, ' '.join(('latent', str(el)))) for idx, el in enumerate(np.arange(latents.shape[1]))]


raster = hv.Raster(correlation_matrix[:len(variable_list), len(variable_list):])
# raster = hv.Raster(pvalue_matrix)
raster.opts(width=900, height=800, colorbar=True, cmap='RdBu', tools=['hover'], xticks=xticks, yticks=yticks, xrotation=45, clim=(-1, 1))

In [None]:
# plot an array of correlation matrices

# allocate the plot list
plot_list = []
# get the correlations and turn into an array
correlation_array = np.array(correlation_list).copy()
# for all the plots
for pair in unique_dates_mice:
    
    # find the indexes corresponding to this pair
    pair_idx = [idx for idx, el in enumerate(final_pairs) if (el[0] == pair[0]) & (el[1] == pair[1])]
    # get the current matrix
    current_matrix = np.nanmean(correlation_array[pair_idx], axis=0)[:len(variable_list), len(variable_list):]
    # create the plot
    raster = hv.Raster(current_matrix)
    title = ' '.join((pair[0], pair[1]))
    raster.opts(width=400, height=400, tools=['hover'], cmap='RdBu', clim=(-1, 1), xticks=xticks, yticks=yticks, xrotation=45, xlabel='', ylabel='', title=title)
    
    #store
    plot_list.append(raster)

overall = hv.Layout(plot_list).cols(4)
overall

In [None]:
# plot the marginalized correlations for each variable

# allocate the plot list
plot_list = []
correlation_array = np.array(correlation_list).copy()

yticks = [(idx, label_dict[el]) for idx, el in enumerate(variable_list)]
xticks = [(idx+0.5, ' '.join(('latent', str(el)))) for idx, el in enumerate(np.arange(latents.shape[1]))]

# allocate a df to generate a box plot
latents_boxplot = []
# for all the plots
for pair in unique_dates_mice:
    
    # find the indexes corresponding to this pair
    pair_idx = [idx for idx, el in enumerate(final_pairs) if (el[0] == pair[0]) & (el[1] == pair[1])]
    # get the current matrix
#     current_matrix = np.nanmax(np.abs(np.nanmean(correlation_array[pair_idx], axis=0)[:len(variable_list), len(variable_list):]), axis=1)
    # for the single trial matrices
    for trial_idx in pair_idx:
        current_matrix = np.nanmax(np.abs(correlation_array[trial_idx][:len(variable_list), len(variable_list):]), axis=1)
        # add as a dataframe to the list
        latents_boxplot.append(pd.DataFrame(np.vstack([variable_list, current_matrix]).T, columns=['Feature', 'Average_corr']))
#     # create the plot
#     raster = hv.Scatter(current_matrix)
#     title = ' '.join((pair[0], pair[1]))
#     raster.opts(width=400, height=400, tools=['hover'], xticks=yticks, yticks=None, xrotation=45, xlabel='', ylabel='', title=title)
    
#     #store
#     plot_list.append(raster)

# overall = hv.Layout(plot_list).cols(4)
# overall

In [None]:
%%time
# generate shuffles for the average correlation calculation

# define the number of shuffles
number_shuffles = 100
# allocate memory for the shuffles
shuffle_list = []

# for all the shuffles
for shuff in np.arange(number_shuffles):
    # allocate a list for these shuffles
    current_shuffle = []
    # for all the plots
    for pair in unique_dates_mice:

        # find the indexes corresponding to this pair
        pair_idx = [idx for idx, el in enumerate(final_pairs) if (el[0] == pair[0]) & (el[1] == pair[1])]
#         # get the current matrix
#         current_matrix = np.nanmean(correlation_array[pair_idx], axis=0)[:len(variable_list), len(variable_list):]
        # for the single trial matrices
        for trial_idx in pair_idx:
            current_matrix = correlation_array[trial_idx][:len(variable_list), len(variable_list):]
            # shuffle the rows of the matrix
            row_idx = np.random.choice(np.arange(len(variable_list)), size=len(variable_list), replace=False)
            current_matrix = np.nanmax(np.abs(current_matrix[row_idx, :]), axis=1)
            # create a dataframe
            current_shuffle.append(pd.DataFrame(np.vstack([variable_list, current_matrix]).T, columns=['Feature', 'Average_corr']))
    # store
    shuffle_list.append(pd.concat(current_shuffle, axis=0).reset_index(drop=True))
# combine into a final dataframe
shuffle_list = pd.concat(shuffle_list, axis=0).reset_index(drop=True)

    
    

In [None]:
# get a boxplot of the relationship between latents and variables

latents_df = pd.concat(latents_boxplot, axis=0).reset_index(drop=True)
latents_df['Average_corr'] = latents_df['Average_corr'].astype(float)
shuffle_list['Average_corr'] = shuffle_list['Average_corr'].astype(float)

whisker0 = hv.BoxWhisker(latents_df, ['Feature'], ['Average_corr'])
whisker0.opts(width=800, height=800, xrotation=45, ylabel='Average Correlation', xlabel='', box_line_width=1, whisker_line_width=1, tools=['hover'])

whisker1 = hv.BoxWhisker(shuffle_list, ['Feature'], ['Average_corr'])
whisker = whisker0*whisker1
whisker

In [None]:
# align cross day latents using CCA

# define the target mouse
target_mouse = 'DG_210202_a'


pair_idx = [idx for idx, el in enumerate(final_pairs) if el[1] == target_mouse]

target_latents = [final_latent[el] for el in pair_idx]
target_days = [final_pairs[el][0] for el in pair_idx]


    
