# Calculate behavioral binned time averages

In [None]:
# imports
import os
import sys
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture'))

import panel as pn
import holoviews as hv
from holoviews import opts, dim
hv.extension('bokeh')
from bokeh.resources import INLINE

import paths
import functions_bondjango as bd
import pandas as pd
import numpy as np
import functions_plotting as fp
import functions_data_handling as fd
from scipy.stats import sem
import sklearn.decomposition as decomp
import umap
import sklearn.mixture as mix
from scipy.stats import sem

# define the name to be used for the saved figures
save_name = 'acrossTrials'
line_width = 5

In [None]:
# load the data
# get the data paths
try:
    data_path = snakemake.input[0]
except NameError:
    # define the search string
    search_string = 'result:succ, lighting:normal, rig:miniscope, =analysis_type:aggBin'
    # query the database for data to plot
    data_all = bd.query_database('analyzed_data', search_string)
    data_path = data_all[0]['analysis_path']
print(data_path)

# load the data
data = fd.aggregate_loader(data_path)

In [None]:
# Plot example traces

# define the target parameter
target_parameters = ['mouse_cricket_distance', 'cricket_speed','mouse_speed']
# allocate a list for the plots
plot_list = []

# for all the parameters
for target in target_parameters:
    # load the parameter
    parameter = data[[target,'trial_id']].copy()
    # group the single traces
    grouped_parameter = parameter.groupby(['trial_id']).agg(list)
    grouped_parameter = np.array([el for el in grouped_parameter[target]])

    # plot all traces
    
    [sorted_traces,_,_] = fp.sort_traces(grouped_parameter)
    
    image = hv.Image(sorted_traces, ['Binned Time','Trial #'], 
                     [target.replace('_', ' ')], bounds=[0, 0, grouped_parameter.shape[1], grouped_parameter.shape[0]])
    image.opts(width=fp.pix(5.8), height=fp.pix(5.8), toolbar=None, 
                        hooks=[fp.margin], fontsize=fp.font_sizes['small'], 
                        xticks=3, yticks=3,
                        colorbar=True, cmap='viridis', 
               colorbar_opts={'major_label_text_align': 'left'})
            

    # assemble the save path
    save_path = os.path.join(paths.figures_path,save_name+'_'+target+'.png')
    hv.save(image,save_path)
    
    
    plot_list.append(image)
    
    

#     # calculate the mean and sem
#     trace_mean = np.mean(grouped_parameter, axis=0)
#     trace_sem = sem(grouped_parameter, axis=0)
#     x_vector = np.arange(trace_mean.shape[0])
#     plot_list.append(hv.Curve((x_vector, trace_mean), 'Time', target)*\
#                      hv.Spread((x_vector, trace_mean, trace_sem), 'Time', vdims=[target, 'error']))

hv.Layout(plot_list)
    
    
    


In [None]:
# Plot the duration of the trial for each trial

# load the parameter
parameter = data[['mouse_cricket_distance','trial_id']].copy()
# group the single traces
grouped_parameter = parameter.groupby(['trial_id']).agg(list)
grouped_parameter = np.array([el for el in grouped_parameter['mouse_cricket_distance']])

# plot all traces

[_,_,cluster_idx] = fp.sort_traces(grouped_parameter)

times = data[['time_vector','trial_id']].copy()
times = times.groupby(['trial_id']).agg(list)
# allocate a list for the durations
duration_list = []

# for all the trials
for trial in times['time_vector']:
    
    duration_list.append(trial[-1])

# turn the durations into an array
duration_list = np.array(duration_list)

# allocate a list for the duration averages
print(np.max(cluster_idx)+1)
duration_averages = np.zeros((np.max(cluster_idx)+1, 2))
# for all the clusters
for clu in cluster_idx:
    duration_averages[clu, 0] = np.mean(duration_list[cluster_idx==clu])
    duration_averages[clu, 1] = sem(duration_list[cluster_idx==clu])
    
    
# plot the results
clu_vector = np.arange(np.max(cluster_idx)+1)
errorbar = hv.ErrorBars((clu_vector, duration_averages[:, 0], duration_averages[:, 1])) * \
hv.Bars((clu_vector, duration_averages[:, 0]))
errorbar
    

In [None]:
# define the target parameter and PCA
target_parameter = 'mouse_cricket_distance'

# load the parameter
parameter = data[[target_parameter,'trial_id']].copy()
# group the single traces
target_data = parameter.groupby(['trial_id']).agg(list).to_numpy()
target_data = np.array([el for sublist in target_data for el in sublist])

# PCA the data before clustering
pca = decomp.PCA()
transformed_data = pca.fit_transform(target_data)
# fp.plot_2d([[pca.explained_variance_ratio_]])

curve = hv.Curve(np.cumsum(pca.explained_variance_ratio_)/np.sum(pca.explained_variance_ratio_))
curve.opts(tools=['hover'])
# define the number of PCs to use
pc_number = 7

curve

In [None]:
# Cluster the data

# define the vector of components
# component_vector = [2, 3, 4, 5, 10, 20, 30]
component_vector = [2, 3, 4, 5, 10, 20]
# allocate memory for the results
gmms = []
# for all the component numbers
for comp in component_vector:
    # # define the number of components
    # n_components = 10
    gmm = mix.GaussianMixture(n_components=comp, covariance_type='diag', n_init=50)
    gmm.fit(transformed_data[:, :pc_number+1])
    gmms.append(gmm.bic(transformed_data[:, :pc_number+1]))

# select the minimum bic number of components
n_components = np.array(component_vector)[np.argmin(gmms)]
# predict the cluster indexes
gmm = mix.GaussianMixture(n_components=n_components, covariance_type='diag', n_init=50)
cluster_idx = gmm.fit_predict(transformed_data[:, :pc_number+1])

In [None]:
# discard singletons
# turn cluster_idx in a float
cluster_idx = cluster_idx.astype(float)
# get the IDs
clu_unique = np.unique(cluster_idx)
for clu in clu_unique:
    # get the number of traces in the cluster
    number_traces = sum(cluster_idx==clu)
    # if it's less than 5, eliminate the cluster
    if number_traces < 5:
        cluster_idx[cluster_idx==clu] = np.nan
    
# plot the BIC
hv.Curve((component_vector, gmms))

In [None]:
# UMAP

# embed the data via UMAP
reducer = umap.UMAP(min_dist=0.5, n_neighbors=10)
embedded_data = reducer.fit_transform(transformed_data[:, :pc_number+1])
# embedded_data = reducer.fit_transform(target_data)

In [None]:
# use AHC on the data

# load the parameter
parameter = data[['mouse_cricket_distance','trial_id']].copy()
# group the single traces
grouped_parameter = parameter.groupby(['trial_id']).agg(list)
grouped_parameter = np.array([el for el in grouped_parameter['mouse_cricket_distance']])

# plot all traces

[_,_,cluster_idx] = fp.sort_traces(grouped_parameter)

# [_,_,cluster_idx] = fp.sort_traces(transformed_data[:, :pc_number+1])
[_,_,cluster_idx] = fp.sort_traces(grouped_parameter)

In [None]:
# Plot the embedding
# umap_data = np.concatenate((embedded_data,np.expand_dims(cluster_idx, axis=1)),axis=1)

# umap_data = np.concatenate((embedded_data,np.expand_dims(duration_list, axis=1)),axis=1)
umap_data = np.concatenate((embedded_data,np.expand_dims(cluster_idx, axis=1),
                            np.expand_dims(duration_list, axis=1)),axis=1)



# fp.plot_scatter([[embedded_data]])
umap_plot = hv.Scatter(umap_data, vdims=['Dim 2','cluster', 'duration'], kdims=['Dim 1'])
umap_plot.opts(color='cluster', colorbar=True, cmap='Category10', size='duration')

umap_plot.opts(opts.Scatter(width=fp.pix(6), height=fp.pix(6.14), toolbar=None, 
                        hooks=[fp.margin], fontsize=fp.font_sizes['small'], xticks=3, yticks=3),
            opts.Overlay(legend_position='right', text_font='Arial'))

# duration_plot = hv.Scatter(umap_data, vdims=['Dim 2','cluster'], kdims=['Dim 1'])

# assemble the save path
save_path = os.path.join(paths.figures_path,save_name+'_umap.png')
hv.save(umap_plot,save_path)

umap_plot