In [None]:
import os
import sys
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture'))
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture\mine_pub'))

import panel as pn
import holoviews as hv
from holoviews import opts, dim
hv.extension('bokeh')
from bokeh.resources import INLINE

import paths
import functions_plotting as fp
from functions_plotting import format_figure as ff
import functions_bondjango as bd
import functions_loaders as fl
import snakemake_scripts.tc_calculate as tc
import processing_parameters

import importlib
import numpy as np
import pandas as pd
import h5py
import scipy.stats as stat
import scipy.signal as signal
import sklearn.decomposition as decomp

In [None]:
# set up the figure config
importlib.reload(fp)
importlib.reload(processing_parameters)
# define the target saving path
save_path = os.path.join(paths.figures_path, 'Behavior_representation')

# define the printing mode
save_mode = True
# define the target document
target_document = 'paper'
# set up the figure theme
fp.set_theme()
# load the label dict
label_dict = processing_parameters.label_dictionary
variable_list = processing_parameters.variable_list

In [None]:
importlib.reload(processing_parameters)
importlib.reload(fl)

# get the paths from the database using search_list
all_paths, all_queries = fl.query_search_list()
# print(all_paths)

data_list = []
# load the data
for path, queries in zip(all_paths, all_queries):
    
    data, _, _  = fl.load_preprocessing(path, queries)
    data_list.append(data)

# print(all_paths)
print(f'Number of trials: {len(data_list)}')

In [None]:
# define the range of trials to use for the calculation

# list the available mouse and date pairs
unique_dates_mice = np.unique([(el.loc[0, 'mouse'], el.loc[0, 'datetime'][:10]) for el in data_list[0]], axis=0)
# print(unique_dates_mice)
# define the target day and mouse
target_mouse = 'DG_210202_a'
target_date = '2021-04-30'
# extract the trials
target_trials = [el for el in data_list[0] if (target_mouse in el.loc[0, 'mouse']) & (target_date in el.loc[0, 'datetime'])]
print(len(target_trials))

In [None]:
# Get only the behavior

variable_list = processing_parameters.variable_list
# for all the data, keep only the behavioral variables of interest
data_behavior = []

for el in target_trials:
#     for el2 in el:
    try:
        data_behavior.append(el[variable_list].reset_index(drop=True))
    except KeyError:
        continue
# concatenate the data
# data_behavior = pd.concat(data_behavior, axis=0).to_numpy()
# print(data_behavior.shape)
print(len(data_behavior))

behavior_columns = variable_list
# raise ValueError

In [None]:
# define the target trial
target_trial = 1
ticks = [(idx+0.5, label_dict[el]) for idx, el in enumerate(variable_list)]

plot = hv.Raster(data_behavior[target_trial].to_numpy().T)
plot.opts(width=1200, height=600, cmap='Viridis', tools=['hover'], yticks=ticks)
plot

In [None]:
# apply PCA to the behavior

# allocate a list for the transformation matrices
matrix_list = []
# for all the trials, get the transformation matrices
for trial in data_behavior:
    pca = decomp.PCA()
    pca.fit(trial)
    matrix_list.append(pca.components_)
    
# average the transformation matrix
average_transform = np.mean(matrix_list, axis=0)
# print(pca.get_params())
# raise ValueError

# allocate a list for the output
pca_behavior = []
# for all the trials, transform the trials
for trial in data_behavior:
    pca = decomp.PCA()
    pca.fit(trial)
    pca.components_ = average_transform
#     pca.mean_ = 0
#     pca = pca.set_params({'components_': average_transform})
    
    pca_output = pca.transform(trial)
    pca_output = pd.DataFrame(pca_output, columns=['pc'+str(el) for el in np.arange(pca_output.shape[1])])
    pca_behavior.append(pca_output.reset_index(drop=True))
# save the column names
pca_columns = pca_output.columns
    

In [None]:
plot_list = []

for el in matrix_list:
    plot = hv.Raster(el)
    plot.opts(cmap='Viridis')
    plot_list.append(plot)
    
layout = hv.Layout(plot_list).cols(4)
layout

In [None]:
# get the latents

# allocate a list for the output
vame_behavior = []
# for all the trials
for trial in target_trials:
    # get the relevant columns
    target_columns = [el for el in trial.columns if 'latent' in el]
    if len(target_columns) == 0:
        continue
    vame_behavior.append(trial[target_columns].reset_index(drop=True))
# save the columns
vame_columns = target_columns
    

In [None]:
# combine the latents, PCs and behaviors

# allocate the output
combined_data = []

for idx, trial in enumerate(data_behavior):
    
    # get the calcium information
    original_trial = target_trials[idx]
    cells = [el for el in original_trial.columns if 'cell' in el]
    calcium = original_trial[cells].reset_index(drop=True)
#     print(vame_behavior[idx])
    combined_data.append(pd.concat([trial, pca_behavior[idx], vame_behavior[idx], calcium], axis=1))
# combine the column names
overall_columns = [behavior_columns, pca_columns, vame_columns]

In [None]:
%%time
# calculate the TCs in all three decompositions
importlib.reload(tc)

# get the number of bins
bin_num = processing_parameters.bin_number
# define the pairs to quantify
variable_names = np.concatenate(overall_columns)
# variable_names = ['pc0', 'pc1']

# clip the calcium traces
clipped_data = tc.clip_calcium(combined_data)

# parse the features (bin number is for spatial bins in this one)
features, calcium = tc.parse_features(clipped_data, variable_names, bin_number=10)

# concatenate all the trials
features = pd.concat(features)
calcium = np.concatenate(calcium)

# get the number of cells
cell_num = calcium.shape[1]

# get the TCs and their responsivity
tcs_half, tcs_full, tcs_resp, tc_count, tc_bins = tc.extract_tcs_responsivity(features, calcium, variable_names,
                                                                              cell_num, percentile=95,
                                                                              bin_number=bin_num)


In [None]:
# compare the tunings

overall_fractions = []
overall_points = []
for idx in np.arange(3):
    # allocate memory for the fractions
    fraction_list = []
    variable_list = overall_columns[idx]
    
    points_dict = []
    # for all the targets
    for target_feature in variable_list:
        # load the data
    #     data = []
    #     for file in data_path:
    #         try:
    #             data.append(pd.read_hdf(file, target_feature))
    #         except KeyError:
    #             continue
#         print(target_feature)
        analysis_df = pd.DataFrame(tcs_resp[target_feature], columns=['none1', 'none2', 'Qual_index', 'Qual_test'])

    #     data = pd.concat(data, axis=0)
        # plot average histograms
    #     analysis_df = data[['Qual_test', 'day', 'animal']]

        # generate a binary vector with cells passing both criteria
    #     both = ((analysis_df.loc[:, 'Resp_test']>0) & (analysis_df.loc[:, 'Cons_test']>0)).to_numpy()
        both = (analysis_df.loc[:, 'Qual_test']>0).to_numpy()
        # insert this as a column in the dataframe
        analysis_df.insert(analysis_df.shape[1], 'Pass_fraction', both) 

    #     sums = analysis_df.groupby(['day'], as_index=False)['Pass_fraction'].sum()
#         counts = analysis_df.groupby(['day'], as_index=False)['Pass_fraction'].mean()
        counts = analysis_df.mean(axis=0)
        if target_feature in label_dict.keys():
            counts['Feature'] = label_dict[target_feature]
        else:
            counts['Feature'] = target_feature
#         counts.loc[:, 'Feature'] = label_dict[target_feature]


        fraction_list.append(counts[['Feature', 'Pass_fraction', 'Qual_index']])
        
#         points_dict[target_feature] = analysis_df['']
        points_dict.append(analysis_df['Qual_index'].to_numpy())
        
    
    overall_fractions.append(pd.concat(fraction_list, axis=1).transpose())
    overall_points.append(pd.DataFrame(np.array(points_dict).T, columns=variable_list))
    

In [None]:
print(overall_fractions[0])

In [None]:
# plot the averages
plot_list = []
# for all the methods
for idx in np.arange(3):
        
    plot = hv.Scatter(overall_fractions[idx], kdims='Feature', vdims='Pass_fraction')
    plot.opts(width=400, height=400, xrotation=45, ylim=(0, 1))
    plot_list.append(plot)
layout = hv.Layout(plot_list).cols(3).opts(shared_axes=False)
layout

In [None]:
# plot the averages
plot_list = []
# for all the methods
for idx in np.arange(3):
        
    plot = hv.Scatter(overall_fractions[idx], kdims='Feature', vdims='Qual_index')
    plot.opts(width=400, height=400, xrotation=45, ylim=(-0.1, 0.1))
    plot_list.append(plot)
layout = hv.Layout(plot_list).cols(3).opts(shared_axes=False)
layout

In [None]:
# plot the correlation structures

# plot the averages
plot_list = []
# for all the methods
for idx in np.arange(3):
    current_correlation = overall_points[idx].to_numpy()
    current_correlation[np.isnan(current_correlation)] = 0
#     if target_feature in label_dict.keys():
#         counts['Feature'] = label_dict[target_feature]
#     else:
#         counts['Feature'] = target_feature
    
    ticks = [(idx2+0.5, el) for idx2, el in enumerate(overall_columns[idx])]
#     print(ticks)
#     raise ValueError
    current_correlation, _ = stat.spearmanr(current_correlation, nan_policy='omit')
    plot = hv.Raster(current_correlation)
#     plot = hv.Scatter(overall_fractions[idx], kdims='Feature', vdims='Qual_index')
    plot.opts(width=400, height=600, xrotation=45, tools=['hover'], xticks=ticks, ylabel='', xlabel='')
    plot_list.append(plot)
layout = hv.Layout(plot_list).cols(3).opts(shared_axes=False)
layout

In [None]:
# use the encoding model in all three decompositions


In [None]:
# perform classification on all three of the decompositions