# Cluster and visualize the aggregated results in high D

In [1]:
# imports
import os
import sys
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture'))

import panel as pn
import holoviews as hv
from holoviews import opts, dim
hv.extension('bokeh')
from bokeh.resources import INLINE

import paths
import functions_bondjango as bd
import pandas as pd
import numpy as np
import sklearn.mixture as mix
import sklearn.decomposition as decomp
import functions_plotting as fp
import functions_data_handling as fd
import functions_plotting as fp
import umap
# define the name to be used for the saved figures
save_name = 'clusters_mouse'
line_width = 5

In [2]:
# load the data
# get the data paths
try:
    data_path = snakemake.input[0]
except NameError:
    # define the search string
    search_string = 'result:succ,lighting:normal,rig:vr,analysis_type:aggEnc'
    # query the database for data to plot
    data_all = bd.query_database('analyzed_data', search_string)
    data_path = data_all[0]['analysis_path']
print(data_path)

# load the data
data = fd.aggregate_loader(data_path)

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_vr_normal_ALL_ALL_ALL_aggEnc.hdf5


In [61]:
# define the target parameter and PCA
# target_parameter = 'mouse_cricket_distance'
target_parameter = 'cricket_0_mouse_distance'

# assemble the array with the parameters of choice
if 'aggEnc' in search_string:
    target_data = data.loc[:, [target_parameter] + ['event_id', 'trial_id']].groupby(
        ['trial_id', 'event_id']).agg(list).to_numpy()
else:
    target_data = data.loc[:, [target_parameter] + ['trial_id']].groupby(
        ['trial_id']).agg(list).to_numpy()
# # HACK REMOVE
# target_data = np.array([el for sublist in target_data for el in sublist if len(el) == 594])
[print(el) for el in target_data]
# PCA the data before clustering
pca = decomp.PCA()
transformed_data = pca.fit_transform(target_data)
print(transformed_data.shape)
# fp.plot_2d([[pca.explained_variance_ratio_]])

hv.Curve(pca.explained_variance_ratio_)

[list([629.7877541538695, 634.9098194076807, 640.6941280040919, 645.1971341850999, 650.9918345942481, 656.1473926714959, 661.9522960939793, 667.1168323957855, 672.2852878584506, 678.1779329783134, 683.4719063915344, 688.7697662443967, 695.397086183902, 700.7028690641781, 705.3483097149204, 711.3248961369407, 716.6409878169095, 721.9011040196617, 727.7828185218095, 733.0142360181063, 738.248874889693, 744.1412126116163, 749.38215279452, 754.6258056517095, 759.8724243387998, 762.8179993524865, 763.3321143423872, 763.9104880728485, 764.4245879408099, 765.0029618953675, 765.4528032790455, 766.0311774436432, 766.6095364044373, 767.2697610242246, 768.0176435862793, 768.7655683772954, 769.6069768803234, 770.3549438509668, 771.196421949317, 771.8509147207177, 772.6924147820145, 773.5673509379923, 774.6390547933163, 775.7129838849623, 776.6695067463851, 777.6277678418533, 778.7079367108395, 779.5495927908759, 780.5131150209135, 780.8722326299512, 780.7502291663163, 780.7238343646203, 780.814826

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [52]:
# Cluster the data

# define the vector of components
component_vector = [2, 3, 4, 5, 10, 20, 30]
# allocate memory for the results
gmms = []
# for all the component numbers
for comp in component_vector:
    # # define the number of components
    # n_components = 10
    gmm = mix.GaussianMixture(n_components=comp, covariance_type='diag', n_init=50)
    gmm.fit(transformed_data[:, :7])
    gmms.append(gmm.bic(transformed_data[:, :7]))

# select the minimum bic number of components
n_components = np.array(component_vector)[np.argmin(gmms)]
# predict the cluster indexes
gmm = mix.GaussianMixture(n_components=n_components, covariance_type='diag', n_init=50)
cluster_idx = gmm.fit_predict(transformed_data[:, :7])

In [53]:
# discard singletons
# turn cluster_idx in a float
cluster_idx = cluster_idx.astype(float)
# get the IDs
clu_unique = np.unique(cluster_idx)
for clu in clu_unique:
    # get the number of traces in the cluster
    number_traces = sum(cluster_idx==clu)
    # if it's less than 5, eliminate the cluster
    if number_traces < 5:
        cluster_idx[cluster_idx==clu] = np.nan
    
# plot the BIC
hv.Curve((component_vector, gmms))

In [54]:
# plot the clusters
# add the cluster indexes to the dataframe
cluster_data = np.array([np.mean(target_data[cluster_idx == el, :], axis=0) for el in np.arange(n_components)])
cluster_std = np.array([np.std(target_data[cluster_idx == el, :], axis=0)/np.sqrt(np.sum(cluster_idx == el))
                        for el in np.arange(n_components)])
# plot the results
# fp.plot_2d([cluster_data])
cluster_plot = hv.Overlay([hv.Curve(el, label=str(idx), kdims=['Time (s)'], 
                                    vdims=[target_parameter.replace('_', ' ')+' (px)'])
                           for idx, el in enumerate(cluster_data)] + 
                            [hv.Spread((np.arange(el.shape[0]),el,cluster_std[idx, :])) 
                             for idx, el in enumerate(cluster_data)])

cluster_plot.relabel('Clusters').opts({'Curve': dict(color=hv.Palette('Category20')), 
                                         'Spread': dict(color=hv.Palette('Category20'))})

cluster_plot.opts(opts.Curve(width=fp.pix(10.7), height=fp.pix(5), toolbar=None, 
                        hooks=[fp.margin], fontsize=fp.font_sizes['small'], line_width=12, xticks=3, yticks=3),
            opts.Overlay(legend_position='right', text_font='Arial'))

# assemble the save path
save_path = os.path.join(paths.figures_path,save_name+'_cluster.png')
hv.save(cluster_plot,save_path)

cluster_plot
# print(cluster_plot)

# hv.Curve(cluster_data)

In [7]:
hv.Image(target_data[cluster_idx == 6, :], kdims=['Time','Encounters'])

In [41]:
# plot the clusters as an image

# define the target parameter
# target = 'mouse_cricket_distance'
target = 'vrcricket_0_mouse_distance'

# # load the parameter
# parameter = data[[target,'trial_id']].copy()
# group the single traces
grouped_parameter = data.loc[:, [target] + ['event_id', 'trial_id']].groupby(
    ['trial_id', 'event_id']).agg(list).to_numpy()

# print(grouped_parameter)
# grouped_parameter = np.array([el for el in grouped_parameter[target]])
grouped_parameter = np.array([el for sublist in grouped_parameter for el in sublist if len(el) == 594])


# plot all traces

[sorted_traces,_,_] = fp.sort_traces(grouped_parameter)

image = hv.Image(sorted_traces, ['Time','Trial #'], 
                 [target.replace('_', ' ')], bounds=[0, 0, grouped_parameter.shape[0], grouped_parameter.shape[1]])
image.opts(width=fp.pix(5.8), height=fp.pix(5.8), toolbar=None, 
                    hooks=[fp.margin], fontsize=fp.font_sizes['small'], 
                    xticks=3, yticks=3,
                    colorbar=True, cmap='viridis', 
           colorbar_opts={'major_label_text_align': 'left'})


# assemble the save path
save_path = os.path.join(paths.figures_path,save_name+'_'+target+'.png')
hv.save(image,save_path)


image

In [21]:
# UMAP

# embed the data via UMAP
reducer = umap.UMAP(min_dist=0.5, n_neighbors=10)
embedded_data = reducer.fit_transform(transformed_data)

In [40]:
# Plot the embedding

# use the cluster indexes
umap_data = np.concatenate((embedded_data,np.expand_dims(cluster_idx, axis=1)),axis=1)

# # use the trial ID
# # group the single traces
# grouped_parameter = data.loc[:, ['event_id', 'trial_id']].groupby(
#     ['trial_id']).agg(list)
# temp_parameter = []
# counter = 0
# for idx, el in enumerate(grouped_parameter['event_id']):
#     # get the event ids
#     event_ids = np.unique(el)
#     temp_parameter.append(idx*np.ones(event_ids.shape[0]))

# grouped_parameter = np.concatenate(temp_parameter, axis=0)
# umap_data = np.concatenate((embedded_data,np.expand_dims(grouped_parameter, axis=1)),axis=1)

# highlight the last encounter of every group
# allocate a list for that 
winner_list = []
grouped_parameter = data.loc[:, ['event_id', 'trial_id']].groupby(
    ['trial_id']).agg(list)

# for all the trials
for idx, el in enumerate(grouped_parameter['event_id']):
    # get the event ids
    encounter_list = np.zeros(np.unique(el).shape[0])
    encounter_list[-1] = 1
    winner_list.append(encounter_list)

grouped_parameter = np.concatenate(winner_list, axis=0)


umap_plot = hv.Scatter(umap_data, vdims=['Dim 2','cluster'], kdims=['Dim 1'])
umap_plot.opts(color='cluster', colorbar=True, cmap='Category10', size=20)

umap_plot.opts(opts.Scatter(width=fp.pix(5.7), height=fp.pix(7.8), toolbar=None, 
                        hooks=[fp.margin], fontsize=fp.font_sizes['small'], xticks=3, yticks=3))
#             opts.Overlay(legend_position='right', text_font='Arial'))

# winner_data = embedded_data[grouped_parameter==1]

# winner_plot = hv.Scatter(winner_data, vdims=['Dim 2'], kdims=['Dim 1'])
# winner_plot.opts(width=fp.pix(5.7), height=fp.pix(7.8), toolbar=None, 
#                         hooks=[fp.margin], fontsize=fp.font_sizes['small'], xticks=3, yticks=3, color='black', size=20)
# umap_overlay = umap_plot*winner_plot

# assemble the save path
save_path = os.path.join(paths.figures_path,save_name+'_umap.png')
hv.save(umap_plot,save_path)

umap_plot
