In [None]:
import os
import sys
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture'))
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture\mine_pub'))

import panel as pn
import holoviews as hv
from holoviews import opts, dim
hv.extension('bokeh')
from bokeh.resources import INLINE

import paths
import functions_plotting as fp
from functions_plotting import format_figure as ff
import functions_bondjango as bd
import functions_loaders as fl
import processing_parameters

import importlib
import numpy as np
import pandas as pd
import h5py
import scipy.stats as stat
import scipy.signal as signal
import sklearn.decomposition as decomp

In [None]:
# set up the figure config
importlib.reload(fp)
importlib.reload(processing_parameters)
# define the target saving path
save_path = os.path.join(paths.figures_path, 'Behavior_correlation')

# define the printing mode
save_mode = True
# define the target document
target_document = 'paper'
# set up the figure theme
fp.set_theme()
# load the label dict
label_dict = processing_parameters.label_dictionary
variable_list = processing_parameters.variable_list

In [None]:
importlib.reload(processing_parameters)
importlib.reload(fl)

# get the paths from the database using search_list
all_paths, all_queries = fl.query_search_list()
# print(all_paths)

data_list = []
# load the data
for path, queries in zip(all_paths, all_queries):
    
    data, _, _  = fl.load_preprocessing(path, queries)
    data_list.append(data)

# print(all_paths)
print(f'Number of trials: {len(data_list)}')

In [None]:
# Get only the behavior

variable_list = processing_parameters.variable_list
# for all the data, keep only the behavioral variables of interest
data_behavior = []

for el in data_list:
    for el2 in el:
        try:
            data_behavior.append(el2[variable_list])
        except KeyError:
            continue
# concatenate the data
# data_behavior = pd.concat(data_behavior, axis=0).to_numpy()
# print(data_behavior.shape)
print(len(data_behavior))
# raise ValueError

In [None]:
# define the target trial
target_trial = 100
ticks = [(idx+0.5, label_dict[el]) for idx, el in enumerate(variable_list)]

plot = hv.Raster(data_behavior[target_trial].to_numpy().T)
plot.opts(width=1200, height=600, cmap='Viridis', tools=['hover'], yticks=ticks)
plot

In [None]:
%%time
# Calculate an average correlation matrix

# allocate a list for the correlation
correlation_list = []
for el in data_behavior:
    
    correlation_matrix, pvalue_matrix = stat.spearmanr(el.to_numpy(), nan_policy='omit')
    correlation_list.append(correlation_matrix)
correlation_matrix = np.nanmean(correlation_list, axis=0)
print(correlation_matrix.shape)

# Behavioral correlation plot 

In [None]:
# Plot the matrix
importlib.reload(fp)
importlib.reload(processing_parameters)
variable_list = processing_parameters.variable_list

ticks = [(idx+0.5, label_dict[el]) for idx, el in enumerate(variable_list)]
# ticks = [(idx+0.5, idx) for idx, el in enumerate(variable_list)]
plot_matrix = correlation_matrix.copy()
plot_matrix = np.tril(plot_matrix, k=0)
plot_matrix[plot_matrix==0] = np.nan
# hv.Raster(correlation_matrix)
raster = hv.Raster(plot_matrix)
# format the plot
raster.opts(width=1200, height=800, yticks=ticks, xticks=ticks, colorbar=True, cmap='RdBu_r', clim=(-1, 1), xrotation=45, ylabel='', xlabel='', tools=['hover'])

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'overall_correlation')) + '.png')
# save the figure
fig = fp.save_figure(raster, save_name, fig_width=8, dpi=1200, fontsize='small', target='save')


In [None]:
def calculate_average_correlation(temp_behavior):
    # perform PCA
    pca = decomp.PCA()
    pca_output = pca.fit_transform(temp_behavior)
    exp_variance = pca.explained_variance_ratio_
    components = pca.components_

    # calculate correlation matrix
    corr_matrix = stat.spearmanr(pca_output.T)[0]
    # Calculate average of the absolute triangular matrix
    average_triangle = np.abs(np.triu(corr_matrix, k=-1))
    average_triangle = np.mean(average_triangle[average_triangle != 0])
    return average_triangle, exp_variance, components

In [None]:
%%time
# run PCA on the data and then correlate

rng = np.random.default_rng()

# allocate the output lists
pca_corr = []
shuffle_corr = []
pca_var = []
shuffle_var = []
pca_comp = []
# define the number of shuffles
shuffle_number = 100
# Run through all the trials from a day
for trial in data_list[0]:
    
    # get the calcium
    not_cells = variable_list
    try:
        temp_behavior = trial.loc[:, not_cells].to_numpy().T
    except KeyError:
        continue
    average_triangle, exp_variance, components = calculate_average_correlation(temp_behavior)
    pca_var.append(exp_variance)
    pca_comp.append(components)

    # store
    pca_corr.append(average_triangle)
    # for all the shuffles
    for shuff in np.arange(shuffle_number):
        # Calculate a shuffle version
        temp_shuffle = temp_behavior.copy()
        temp_shuffle = rng.permuted(temp_shuffle, axis=1)
    #     idx = np.random.permutation(np.arange())
    #     print(temp_behavior.mean(axis=0), temp_shuffle.mean(axis=0))
        shuffle_triangle, exp_variance, _ = calculate_average_correlation(temp_shuffle)
        shuffle_var.append(exp_variance)
        shuffle_corr.append(shuffle_triangle)
    
print(len(pca_corr))

# PCA correlation vs shuffle 

In [None]:
# plot the correlation values for the PCA

# assemble a data frame and plot
corr_df = pd.DataFrame(pca_corr, columns=['value'])
corr_df['group'] = 'Real'

shuff_df = pd.DataFrame(shuffle_corr, columns=['value'])
shuff_df['group'] = 'Shuffle'

overall_df = pd.concat([corr_df, shuff_df], axis=0)

boxplot = hv.BoxWhisker(overall_df, kdims=['group'], vdims=['value'])
boxplot.opts(ylabel='Average correlation', xlabel='', box_line_width=1, whisker_line_width=1, outlier_line_width=1)
# boxplot

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'correlation_shuffle_box')) + '.png')
# save the figure
fig = fp.save_figure(boxplot, save_name, fig_width=7, dpi=1200, fontsize=target_document, target='screen')


In [None]:
# plot the explained variances

pca_var_array = np.array(pca_var)
shuffle_var_array = np.array(shuffle_var)
# print(pca_var_array.shape)

pca_var_mean = pca_var_array.mean(axis=0)
pca_var_error = pca_var_array.std(axis=0)
shuffle_var_mean = shuffle_var_array.mean(axis=0)
shuffle_var_error = shuffle_var_array.std(axis=0)
x = np.arange(pca_var_array.shape[1])

plot1 = hv.Curve((x, pca_var_mean))
plot2 = hv.Spread((x, pca_var_mean, pca_var_error))
plot1.opts(width=600, height=600, xlabel='PCs', ylabel='Explained var')
plot3 = plot1*plot2

plot4 = hv.Curve((x, shuffle_var_mean))
plot5 = hv.Spread((x, shuffle_var_mean, shuffle_var_error))
plot6 = plot4*plot5

plot7 = plot3*plot6

plot7

In [None]:
# plot weight matrices

plot_number = 10
