# Visualize comparisons between real and virtual prey capture attempts
This notebook only deals with purely virtual prey capture - there are no mixed VR+Real cricket experiments here. 

In [1]:
import os
import sys
sys.path.insert(0, r'..\..')
import paths

import panel as pn
import holoviews as hv
from holoviews import opts, dim
hv.extension('bokeh')
from bokeh.resources import INLINE

import functions_bondjango as bd
import functions_kinematic as fk
import functions_plotting as fp
import functions_misc as fm
import functions_data_handling as fd
import pandas as pd
import numpy as np
import h5py

from scipy.stats import sem
import sklearn.decomposition as decomp
import umap
import sklearn.mixture as mix
from scipy.stats import sem


line_width = 5

In [2]:
# Define a data loading function

def load_dataset(search_string, label=None, exclusion=None):
    # load the data
    # get the data paths
    try:
        data_path = snakemake.input[0]
    except NameError:
        # query the database for data to plot
        data_all = bd.query_database('analyzed_data', search_string)

        if exclusion is not None:
            for ds in data_all:
                if exclusion not in ds['analysis_path']:
                    data_path = ds['analysis_path']
                    data_date = ds['date']
                    break
        else:
            data_path = data_all[0]['analysis_path']
            data_date = data_all[0]['date']
    print(data_path)
    print(data_date)

    # assemble a label for this data set
    if label is None:
        d = fd.parse_search_string(search_string)
        label = '_'.join([d['rig'], d['lighting'], d['result'], d['notes']])
    print('data label: ' + label + '\n')

    # load the data
    return fd.aggregate_loader(data_path), label

In [3]:
# define the name to be used for the saved figures
save_name = 'VPrey_crickets_0_vrcrickets_1'

## Encounter analysis
This analysis gets us an idea of the number of encounters, as well as encounter types

### Load the encounter data

In [13]:
### Load new data
# create container for holding multiple data sets
data_dict = {}

# Load real prey capture in the light - this is a baseline comparison
search_string = 'result:succ, lighting:normal, rig:VR, analysis_type:aggEnc'
ds, label = load_dataset(search_string, exclusion='obstacle', label="VR_light_succ")
data_dict[label] = ds

# Load VR prey capture with VR blackCr
search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggEnc, notes:blackCr_rewarded_crickets_0_vrcrickets_1'
ds, label = load_dataset(search_string, label="VPrey_blackCr_lightBG")
data_dict[label] = ds

# Load VR prey capture with VR whiteCr_blackBG
search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggEnc, notes:whiteCr_blackBG_rewarded_crickets_0_vrcrickets_1'
ds, label = load_dataset(search_string, label="VPrey_whiteCr_blackBG")
data_dict[label] = ds

# Load VR  prey capture with VR blackCr_grayBG
search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggEnc, notes:blackCr_grayBG_rewarded_crickets_0_vrcrickets_1'
ds, label = load_dataset(search_string, label="VPrey_blackCr_grayBG")
data_dict[label] = ds

# Load VR prey capture with VR whiteCr_grayBG
search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggEnc, notes:whiteCr_grayBG_rewarded_crickets_0_vrcrickets_1'
ds, label = load_dataset(search_string, label="VPrey_whiteCr_grayBG")
data_dict[label] = ds

# Get rid of doubled data set
del ds

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_VR_normal_ALL_ALL_ALL_ALL_2020-06-23T00-00-00_ALL_aggEnc.hdf5
2020-12-01T10:27:14.316934Z
data label: VR_light_succ

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_test_VPrey_normal_ALL_blackCr_rewarded_crickets_0_vrcrickets_1_ALL_ALL_2020-06-07T00-00-00_ALL_aggEnc.hdf5
2020-11-30T15:09:40.993967Z
data label: VPrey_blackCr_lightBG

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_test_VPrey_normal_ALL_whiteCr_blackBG_rewarded_crickets_0_vrcrickets_1_ALL_ALL_2020-06-07T00-00-00_ALL_aggEnc.hdf5
2020-11-30T15:10:17.262546Z
data label: VPrey_whiteCr_blackBG

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_test_VPrey_normal_ALL_blackCr_grayBG_rewarded_crickets_0_vrcrickets_1_ALL_ALL_2020-06-07T00-00-00_ALL_aggEnc.hdf5
2020-11-30T15:10:06.514888Z
data label: VPrey_blackCr_grayBG

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_test_VPrey_normal_ALL_whiteCr_grayBG_re

### Encounter averages

In [44]:
plot_dict = {}

# encounter variables for real cricket
encounter_angle_variables = ['mouse_heading', 'cricket_0_heading', 'cricket_0_delta_heading']
encounter_nonangle_variables = ['cricket_0_mouse_distance', 'mouse_speed', 'mouse_acceleration', 'cricket_0_speed', 'cricket_0_acceleration']

# encounter variables for VR cricket(s)
encounter_angle_variables_VR = ['mouse_heading', 'vrcricket_0_heading', 'vrcricket_0_delta_heading']
encounter_nonangle_variables_VR = ['vrcricket_0_mouse_distance', 'mouse_speed', 'mouse_acceleration', 'vrcricket_0_speed', 'vrcricket_0_acceleration']

for name in data_dict:
    print(name)
    data = data_dict[name]
    plot_container = {}

    enc_ang_var = encounter_angle_variables
    enc_nang_var = encounter_nonangle_variables

    try:
        angled_params = data[enc_ang_var]
        angled_params['event_index'] = data.index

        nonangled_params = data[enc_nang_var]
        nonangled_params['event_index'] = data.index
    except KeyError:
        if ('VPrey' in name):
            enc_ang_var = encounter_angle_variables_VR
            enc_nang_var = encounter_nonangle_variables_VR

            angled_params = data[enc_ang_var]
            angled_params['event_index'] = data.index

            nonangled_params = data[enc_nang_var]
            nonangled_params['event_index'] = data.index


    angled_average = fk.wrap(angled_params.groupby('event_index').agg(lambda x: 180 + fk.circmean_deg(x)))
    angled_std = pd.DataFrame(fk.unwrap(angled_params.groupby('event_index').agg(lambda x: fk.circstd_deg(x)/np.sqrt(x.shape[0]))), columns=enc_ang_var)

    nonangled_average = nonangled_params.groupby('event_index').mean()
    nonangled_std = pd.DataFrame(nonangled_params.groupby('event_index').sem(), columns=enc_nang_var)

    encounter_average = pd.concat((angled_average, nonangled_average), axis=1)
    encounter_sem = pd.concat((angled_std, nonangled_std), axis=1)

    # plot the results
    # define the variables to plot from
    encounter_variables = enc_ang_var + enc_nang_var
    # get the trials
    trial_list = data['trial_id'].unique()
    # get the time vector
    time_vector = data.loc[(data['event_id'] == 0) & (data['trial_id'] == trial_list[0]), 'time_vector'].to_numpy()

    # for each of the variables
    for var_count, variable in enumerate(encounter_variables):
        x = np.arange(encounter_average[variable].size)
        y = encounter_average[variable].to_numpy()
        yerr = encounter_sem[variable].to_numpy()

        plot_container[variable] = hv.Spread((list(x), list(y), list(yerr)), label=name).opts(title=variable) * hv.Curve((list(x), list(y))).opts(color='black')


    plot_dict[('variables', name)] = hv.GridSpace(plot_container, kdims=['variable'])

encounters = hv.GridSpace(plot_dict, kdims=['variables', 'dataset']).opts(plot_size=300)
full_panel = pn.panel(encounters, center=True, widget_location='top')
full_panel


VR_light_succ
VPrey_blackCr_lightBG
VPrey_whiteCr_blackBG
VPrey_blackCr_grayBG
VPrey_whiteCr_grayBG


### Number of encounters per trial

In [6]:
# Plot the number of encounters per trial for each condition

# allocate a list for the plots
plot_list = []
means = []
enc_sem = []

# Plots by trial type
for name in data_dict:
    data = data_dict[name]

    # load the parameter
    parameter = data[['event_id','trial_id']].copy()
    # find the number of encounters
    grouped_parameter = parameter.groupby(['trial_id']).agg(list)
    encounters = np.array([el[-1] for el in grouped_parameter['event_id']]) + 1
    means.append(encounters.mean())
    enc_sem.append((name, encounters.mean(), sem(encounters)))

    # plot the results
    enc_plot = hv.Bars((np.arange(encounters.shape[0]), encounters)).opts(title=name, xlabel='trial', ylabel='# encounters') * \
        hv.HLine(encounters.mean()).opts(color='red', line_width=1)
    plot_list.append(enc_plot)

encounters_panel = hv.Layout(plot_list).opts(shared_axes=True)
save_path = os.path.join(paths.figures_path, '_'.join([save_name, 'encounters']))
# hv.save(encounters_panel, save_path, fmt='png')

# display the image
encounters_panel

In [14]:
# Separately plot mean + sem of encounters

# Plot of means
enc_means = hv.Bars((list(data_dict.keys()), means)).opts(title='Mean # Encounters', ylabel='# encounters', ylim=(0,8), xrotation=45) 
enc_means = hv.ErrorBars(enc_sem) * enc_means
save_path = os.path.join(paths.figures_path, '_'.join([save_name, 'encounter_means']))
hv.save(enc_means, save_path, fmt='png')

# dispaly the iamge
enc_means

### PCA of encounter types

In [37]:
# define the target parameter and PCA
target_parameter = 'cricket_0_mouse_distance'

# container for plots
plot_list = []

# container for PCA fit
pca_transforms = []

# container for target data
target = []

for name in data_dict:
    data = data_dict[name]
    
    # assemble the array with the parameters of choice
    try:
        target_data = data[[target_parameter] + ['event_id', 'trial_id']].groupby(['trial_id', 'event_id']).agg(list).to_numpy()
    except KeyError:
        target_parameter = 'vr' + target_parameter
        target_data = data[[target_parameter] + ['event_id', 'trial_id']].groupby(['trial_id', 'event_id']).agg(list).to_numpy()

    # HACK Find a way to fix
    target_data = np.array([el for sublist in target_data for el in sublist if len(el) == 594])
    target.append(target_data)

    # Remove nans or infs
    np.nan_to_num(target_data,  0)
    if np.argwhere(np.isinf(target_data)).size != 0:
            target_data[target_data == np.inf] = 0

    # PCA the data before clustering
    pca = decomp.PCA()
    transformed_data = pca.fit_transform(target_data)
    pca_transforms.append(transformed_data)

    # fp.plot_2d([[pca.explained_variance_ratio_]])
    exp_var = hv.Curve(pca.explained_variance_ratio_).opts(xlabel='PCs', ylabel='explained variance', title=name)
    plot_list.append(exp_var)

hv.Layout(plot_list).cols(len(data_dict.keys()))

### Guassian mixture model of clusters

In [38]:
# Cluster the data using GMMs
plot_list = []
clusters = []

for transformed_data, name in zip(pca_transforms, data_dict):
    
    # define the vector of components
    component_vector = [2, 3, 4, 5, 10, 15]
    # allocate memory for the results
    gmms = []
    # for all the component numbers
    for comp in component_vector:
        # # define the number of components
        gmm = mix.GaussianMixture(n_components=comp, covariance_type='diag', n_init=50)
        gmm.fit(transformed_data[:, :7])                 # Pull the first 7 PCs and fit a GMM
        gmms.append(gmm.bic(transformed_data[:, :7]))    # Get the bayesian information criterion

    # select the minimum bic number of components
    n_components = np.array(component_vector)[np.argmin(gmms)]
    # predict the cluster indexes
    gmm = mix.GaussianMixture(n_components=n_components, covariance_type='diag', n_init=50)
    cluster_idx = gmm.fit_predict(transformed_data[:, :7])

    # discard singletons
    # turn cluster_idx in a float
    cluster_idx = cluster_idx.astype(float)
    # get the IDs
    clu_unique = np.unique(cluster_idx)
    for clu in clu_unique:
        # get the number of traces in the cluster
        number_traces = sum(cluster_idx==clu)
        # if it's less than 5, eliminate the cluster
        if number_traces < 5:
            cluster_idx[cluster_idx==clu] = np.nan
    clusters.append(cluster_idx)
        
    # plot the BIC
    BIC = hv.Curve((component_vector, gmms)).opts(title=name, xlabel='cluster', ylabel='BIC')
    plot_list.append(BIC)

hv.Layout(plot_list).cols(len(data_dict.keys()))


In [39]:
# plot the clusters
plot_list = []

for target_data, cluster_idx, name in zip(target, clusters, data_dict):

    # add the cluster indexes to the dataframe
    cluster_data = np.array([np.mean(target_data[cluster_idx == el, :], axis=0) for el in np.arange(n_components)])
    cluster_std = np.array([np.std(target_data[cluster_idx == el, :], axis=0) / np.sqrt(np.sum(cluster_idx == el))
                            for el in np.arange(n_components)])
    # plot the results
    cluster_plot = hv.Overlay(
        [hv.Curve(el, label=str(idx), kdims=['Time (s)'], vdims=[target_parameter.replace('_', ' ')+' (px)']) for idx, el in enumerate (cluster_data)] + 
        [hv.Spread((np.arange(el.shape[0]),el,cluster_std[idx, :])) for idx, el in enumerate(cluster_data)]
        )

    cluster_plot.relabel('Clusters').opts({'Curve': dict(color=hv.Palette('Category20')), 
                                           'Spread': dict(color=hv.Palette('Category20'))})
    
    cluster_plot.opts(title=name)

    # For publication-ready image
    cluster_plot.opts(
        opts.Curve(
                    width=fp.pix(10.7), height=fp.pix(5), 
                    toolbar=None, hooks=[fp.margin], 
                    fontsize=fp.font_sizes['small'], 
                    line_width=12, xticks=3, yticks=3
                    ),
        opts.Overlay(legend_position='right', text_font='Arial')
        )

    # cluster_plot.opts(
    #     opts.Curve(
    #                 # width=fp.pix(10.7), height=fp.pix(5), 
    #                 # toolbar=None, hooks=[fp.margin], 
    #                 # fontsize=fp.font_sizes['small'], 
    #                 # line_width=12, 
    #                 xticks=3, yticks=3
    #                 ),
    #     opts.Overlay(legend_position='right', text_font='Arial')
    # )

    plot_list.append(cluster_plot)
    # print(cluster_plot)

cluster_trace_panel = hv.Layout(plot_list).opts(shared_axes=False).cols(1)
# assemble the save path
save_path = os.path.join(paths.figures_path, '_'.join([save_name, target_parameter, 'cluster']))
hv.save(cluster_trace_panel, save_path, fmt='png')

# display the image
cluster_trace_panel

In [15]:
# plot the clusters as an image
plot_list = []
target_parameter = 'cricket_0_mouse_distance'
# target_parameter = 'mouse_speed'

keys = list(data_dict.keys())

for name in data_dict:
    data = data_dict[name]

    # group the single traces
    try:
        target_data = data[[target_parameter] + ['event_id', 'trial_id']].groupby(['trial_id', 'event_id']).agg(list).to_numpy()
    except KeyError:
        target_parameter = 'vr' + target_parameter
        target_data = data[[target_parameter] + ['event_id', 'trial_id']].groupby(['trial_id', 'event_id']).agg(list).to_numpy()

    # print(grouped_parameter)
    target_data = np.array([el for sublist in target_data for el in sublist if len(el) == 594])

    # Remove nans or infs
    # print(np.where(np.isnan(grouped_parameter))[0].shape)
    np.nan_to_num(target_data,  0.0)
    if np.argwhere(np.isinf(target_data)).size != 0:
        # print(np.argwhere(np.isinf(grouped_parameter)).size)
        target_data[target_data == np.inf] = 0

    # Trim to the 1st and 99th percentiles to get rid of massive outliers
    lower = np.percentile(target_data, 1)
    upper = np.percentile(target_data, 99)
    target_data[target_data > upper] = upper

    # normalize
    target_data /= np.max(target_data)
        
    # plot all traces
    [sorted_traces,_,_] = fp.sort_traces(target_data)

    image = hv.Image(sorted_traces, ['Time','Encounter'], 
                    [target_parameter.replace('_', ' ')], 
                    bounds=[0, 0, target_data.shape[1], target_data.shape[0]]
                    ).opts(title=name)
                    
    # For publication-ready image                
    # image.opts(
    #         width=fp.pix(5.8), 
    #         height=fp.pix(5.8), 
    #         toolbar=None, 
    #         hooks=[fp.margin], 
    #         fontsize=fp.font_sizes['small'], 
    #         xticks=3, yticks=3, 
    #         colorbar=True, cmap='viridis', 
    #         colorbar_opts={'major_label_text_align': 'left'}
    #         )

    image.opts(
            # width=fp.pix(5.8), 
            # height=fp.pix(5.8), 
            toolbar=None, 
            hooks=[fp.margin], 
            #fontsize=fp.font_sizes['small'], 
            #xticks=3, yticks=3, 
            colorbar=True, cmap='viridis', 
            colorbar_opts={'major_label_text_align': 'left'}
            )

    vline = hv.VLine(300).opts(color='red', line_width=1, line_dash='dashed')
    image = (image * vline)

    plot_list.append(image)
    

sorted_cluster_heatmap_panel = hv.Layout(plot_list).opts(shared_axes=False).cols(len(data_dict.keys()))

# assemble the save path
save_path = os.path.join(paths.figures_path,'_'.join([save_name, target_parameter, 'heatmap']))
hv.save(sorted_cluster_heatmap_panel, save_path, fmt='png')

# display the image
sorted_cluster_heatmap_panel

### UMAP Embedding

In [44]:
# UMAP
plot_list = []

for transformed_data, cluster_idx, name in zip(pca_transforms, clusters, data_dict):

    # embed the data via UMAP
    reducer = umap.UMAP(min_dist=0.25, n_neighbors=15)
    embedded_data = reducer.fit_transform(transformed_data)

    #--- Plot the embedding ---#

    # use the cluster indexes
    umap_data = np.concatenate((embedded_data, np.expand_dims(cluster_idx, axis=1)), axis=1)

    # # use the trial ID
    # # group the single traces
    # grouped_parameter = data.loc[:, ['event_id', 'trial_id']].groupby(['trial_id']).agg(list)
    # temp_parameter = []
    # counter = 0
    # for idx, el in enumerate(grouped_parameter['event_id']):
    #     # get the event ids
    #     event_ids = np.unique(el)
    #     temp_parameter.append(idx * np.ones(event_ids.shape[0]))

    # grouped_parameter = np.concatenate(temp_parameter, axis=0)
    # umap_data = np.concatenate((embedded_data,np.expand_dims(grouped_parameter, axis=1)),axis=1)

    # highlight the last encounter of every group
    # allocate a list for that 
    winner_list = []
    grouped_parameter = data.loc[:, ['event_id', 'trial_id']].groupby(['trial_id']).agg(list)

    # for all the trials
    for idx, el in enumerate(grouped_parameter['event_id']):
        # get the event ids
        encounter_list = np.zeros(np.unique(el).shape[0])
        encounter_list[-1] = 1
        winner_list.append(encounter_list)

    grouped_parameter = np.concatenate(winner_list, axis=0)


    umap_plot = hv.Scatter(umap_data, vdims=['Dim 2','cluster'], kdims=['Dim 1'])
    # umap_plot.opts(color='cluster', colorbar=True, cmap='Category10', size=5)

    # # For publication-ready image      
    umap_plot.opts(color='cluster', colorbar=True, cmap='Category10', size=20, title=name)          
    umap_plot.opts(
        opts.Scatter(
            width=fp.pix(5.7), 
            height=fp.pix(7.8), 
            toolbar=None, 
            hooks=[fp.margin], 
            fontsize=fp.font_sizes['small'], 
            xticks=3, 
            yticks=3
            )
        )

    # umap_plot.opts(
    #     opts.Scatter(
    #         # width=fp.pix(5.7), 
    #         # height=fp.pix(7.8), 
    #         toolbar=None, 
    #         # hooks=[fp.margin], 
    #         fontsize=fp.font_sizes['small'], 
    #         xticks=3, 
    #         yticks=3
    #         )
    #     )

    #             opts.Overlay(legend_position='right', text_font='Arial'))

    # winner_data = embedded_data[grouped_parameter==1]

    # winner_plot = hv.Scatter(winner_data, vdims=['Dim 2'], kdims=['Dim 1'])
    # winner_plot.opts(width=fp.pix(5.7), height=fp.pix(7.8), toolbar=None, 
    #                         hooks=[fp.margin], fontsize=fp.font_sizes['small'], xticks=3, yticks=3, color='black', size=20)
    # umap_overlay = umap_plot*winner_plot

    plot_list.append(umap_plot)

umap_panel = hv.Layout(plot_list).opts(shared_axes=True).cols(len(data_dict))

# assemble the save path
save_path = os.path.join(paths.figures_path,'_'.join([save_name, 'umap']))
hv.save(umap_panel, save_path, fmt='png')

# display the image
umap_panel

## Binned time analysis

In [4]:
### Load new data
# create container for holding multiple data sets
data_dict = {}

# Load real prey capture in the light - this is a baseline comparison
search_string = 'result:succ, lighting:normal, rig:VR, analysis_type:aggBin'
ds, label = load_dataset(search_string, exclusion='obstacle', label="VR_light_succ")
data_dict[label] = ds

# Load VR prey capture with VR blackCr
search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggBin, notes:blackCr_rewarded_crickets_0_vrcrickets_1'
ds, label = load_dataset(search_string, label="VPrey_blackCr_lightBG")
data_dict[label] = ds

# # Load VR prey capture with VR whiteCr_blackBG
# search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggBin, notes:whiteCr_blackBG_rewarded_crickets_0_vrcrickets_1'
# ds, label = load_dataset(search_string, label="VPrey_whiteCr_blackBG")
# data_dict[label] = ds

# # Load VR  prey capture with VR blackCr_grayBG
# search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggBin, notes:blackCr_grayBG_rewarded_crickets_0_vrcrickets_1'
# ds, label = load_dataset(search_string, label="VPrey_blackCr_grayBG")
# data_dict[label] = ds

# # Load VR prey capture with VR whiteCr_grayBG
# search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggBin, notes:whiteCr_grayBG_rewarded_crickets_0_vrcrickets_1'
# ds, label = load_dataset(search_string, label="VPrey_whiteCr_grayBG")
# data_dict[label] = ds

# Get rid of doubled data set
del ds

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_VR_normal_ALL_ALL_ALL_ALL_2020-06-23T00-00-00_ALL_aggBin.hdf5
2020-12-01T10:27:00.303800Z
data label: VR_light_succ

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_test_VPrey_normal_ALL_blackCr_rewarded_crickets_0_vrcrickets_1_ALL_ALL_ALL_ALL_aggBin.hdf5
2020-12-09T13:16:24.655912Z
data label: VPrey_blackCr_lightBG



### Heatmaps of encounter

In [17]:
# Plot example encounter traces sorted

# define the target parameter
# target_parameters = ['mouse_speed', 'vrcricket_0_mouse_distance', 'vrcricket_0_speed', 'cricket_0_mouse_distance', 'cricket_0_speed']
target_parameters = ['mouse_speed', 'cricket_0_mouse_distance', 'cricket_0_speed']


# allocate a list for the plots
plot_list = []

keys = list(data_dict.keys())
for name in data_dict:

    data = data_dict[name]
    cluster_idx = None

    # for all the parameters
    for target_param in target_parameters:

        # load the parameter
        try:
            parameter = data[[target_param, 'trial_id']].copy()
        except KeyError:
            target_param = 'vr' + target_param
            parameter = data[[target_param, 'trial_id']].copy()

        # group the single traces
        grouped_parameter = parameter.groupby(['trial_id']).agg(list)
        grouped_parameter = np.array([el for el in grouped_parameter[target_param]])

        # Remove nans or infs
        np.nan_to_num(grouped_parameter, 0)
        if np.argwhere(np.isinf(grouped_parameter)).size != 0:
            grouped_parameter[grouped_parameter == np.inf] = 0

    # Trim to the 1st and 99th percentiles to get rid of massive outliers
        lower = np.percentile(grouped_parameter, 1)
        upper = np.percentile(grouped_parameter, 99)
        grouped_parameter[grouped_parameter > upper] = upper

        # normalize to 1
        # grouped_parameter /= np.max(grouped_parameter)
        
        # get the clustering for first parameter, and preserve that sorting for all other target parameters tested
        if cluster_idx is None:
                [sorted_traces, cluster_idx, clusters] = fp.sort_traces(grouped_parameter, nclusters=min((10, len(grouped_parameter))))
        else:
                sorted_traces = grouped_parameter[cluster_idx, :]
        
        # plot all traces
        hmap = hv.Image(sorted_traces, ['Binned Time','Trial #'],
                        [target_param.replace('_', ' ')], 
                        bounds=[0, 0, grouped_parameter.shape[1], grouped_parameter.shape[0]],
                        group=name, 
                        label=target_param)

        # # For publication-ready image                
        # hmap.opts(
        #         width=fp.pix(5.8), 
        #         height=fp.pix(5.8), 
        #         toolbar=None, 
        #         hooks=[fp.margin], 
        #         fontsize=fp.font_sizes['small'], 
        #         xticks=3, yticks=3, 
        #         colorbar=True, cmap='viridis', 
        #         colorbar_opts={'major_label_text_align': 'left'}
        #         )

        hmap.opts(
                width=fp.pix(1.5), 
                height=fp.pix(1.5), 
                toolbar=None, 
                # hooks=[fp.margin], 
                #fontsize=fp.font_sizes['small'], 
                xticks=3, yticks=3, 
                colorbar=True, cmap='viridis', 
                colorbar_opts={'major_label_text_align': 'left'}
                )

        plot_list.append(hmap)

heatmaps = hv.Layout(plot_list).cols(len(target_parameters)).opts(shared_axes=False)
save_path = os.path.join(paths.figures_path, '_'.join([save_name, 'binned_kinematics']))
hv.save(heatmaps, save_path, fmt='png')
heatmaps

In [5]:
from scipy.stats import lognorm

In [30]:
# Plot example encounter traces sorted by cluster

# define the target parameter
target_parameters = ['mouse_speed', 'cricket_0_speed', 'cricket_0_mouse_distance']
# target_parameters = ['vrcricket_0_speed']

# allocate a list for the plots
plot_list = []

keys = list(data_dict.keys())

for name in data_dict:

    data = data_dict[name]

    # for all the parameters
    for target_param in target_parameters:

       # load the parameter
        try:
            parameter = data[[target_param, 'trial_id']].copy()
        except KeyError:
            target_param = 'vr' + target_param
            parameter = data[[target_param, 'trial_id']].copy()

        # group the single traces
        grouped_parameter = parameter.groupby(['trial_id']).agg(list)
        grouped_parameter = np.array([el for el in grouped_parameter[target_param]])

        # remove nans or infs
        np.nan_to_num(grouped_parameter, 0)
        if np.argwhere(np.isinf(grouped_parameter)).size != 0:
            grouped_parameter[grouped_parameter == np.inf] = 0
            # idxs = np.argwhere(np.isinf(grouped_parameter))
            # for idx in idxs:
            #     grouped_parameter[idx[0], idx[1]] = 0

        # For all of these, we have units of m/s or meters. Convert to cm/s or cm
        grouped_parameter *= 100

        # get the statistics of the cluster so we set the same bins for all plots of the same variable
        # HACK: this gets rid of zero values, need to find a way to show them in the speed plots
        # sorted_traces += 1e-10
        # lower = np.log10(np.percentile(grouped_parameter[np.nonzero(grouped_parameter)], 1))
        # # upper = np.log10(np.percentile(grouped_parameter, 99))
        # # lower = np.log10(np.min(grouped_parameter[np.nonzero(grouped_parameter)]))
        # upper = np.log10(np.floor(np.max(grouped_parameter)))
        lower = -2
        upper = 2.5

        bin_edges = np.logspace(lower, upper, 50)
        

        # Go through each cluster and plot on a histogram
        # 23.09.20202 - DO not need clustering here, makes display bad
        # plot all traces
        # overlay_list = []
        # for clu in np.unique(clusters):
        #     idxs = cluster_idx[clusters == clu]
        #     cluster_traces = sorted_traces[idxs, :]
            
        #     freq, edges = np.histogram(cluster_traces, bin_edges)
            
        #     hist = hv.Histogram((edges, freq), 
        #         group=': '.join((name, target_param)), 
        #         label=str(clu+1)
        #         )
        #     hist.opts(logx=True, alpha=0.3)

        #     # If we are the last cluster, add a reference line
        #     if clu == max(clusters):
        #         vline = hv.VLine(1e-2).opts(color='red', line_width=1, line_dash='dashed')
        #         hist = (hist * vline)
            
        #     overlay_list.append(hist)

        # # Create an overlay of all the cluster histograms
        # full_overlay = hv.Overlay(overlay_list)
        # full_overlay.opts(legend_position='right', width=fp.pix(1.5), height=fp.pix(1.5))

        # plot_list.append(full_overlay)

        # generate histogram
        freq, edges = np.histogram(grouped_parameter, bin_edges, density=False)
        freq = freq / np.sum(freq)
        # generate CDF of histogram
        cdf = np.cumsum(freq)

        # plot histogram
        hist = hv.Histogram((edges, freq), 
                group=': '.join((name, target_param)), 
                ).opts(logx=True)

        # Add cdf
        # cum_dist = hv.Curve((edges[1:], cdf)).opts(color='green', logx=True)
        # hist = hist * cum_dist

        # Make a reference line at 1 or 10 (10cm for distance, 10cm/s for velocity)
        # v_line = 1 if 'cricket' in target_param else 10
        # vline = hv.VLine(v_line).opts(color='red', line_width=1, line_dash='dashed')
        # hist = (hist * vline)

        # Addx axis labels
        if 'speed' in target_param:
            hist.opts(xlabel='cm/s')
        elif 'distance' in target_param:
            hist.opts(xlabel='cm')


        # These parameters depend on the specific type of input
        if 'black' in name:
            hist.opts(fill_color='gray')
        elif 'white' in name:
            hist.opts(fill_color='white',
                       line_color='black')

        # if target_param == 'vrcricket_0_speed':
        #     hist.opts(ylim=(0, 0.51), 
        #           yticks=[0.1, 0.2, 0.3, 0.4, 0.5],
        #           )  
        # else:
        hist.opts(ylim=(0, 0.21), 
                yticks=[0, 0.05, 0.1, 0.15, 0.2],
                )    

        # For publication-ready image      
        hist.opts(
            opts.Histogram(
                width=fp.pix(7.8), 
                height=fp.pix(7.8), 
                toolbar=None, 
                hooks=[fp.margin], 
                fontsize=fp.font_sizes['small'], 
                xticks=[10e-2, 10e-1, 10e0, 10e1, 10e2], 
                # yticks=[10e-1, 10e0, 10e1, 10e2],
                padding=0.01
                )
            )

        plot_list.append(hist)



param_hists = hv.Layout(plot_list).opts(shared_axes=False, toolbar=None).cols(len(target_parameters))
save_path = os.path.join(paths.figures_path, '_'.join([save_name, 'histogram_kinematics']))
hv.save(param_hists, save_path, fmt='png')
param_hists

### Full aggregate analysis

In [16]:
### Load new data
# create container for holding multiple data sets
data_dict = {}

# Load real prey capture in the light - this is a baseline comparison
search_string = 'result:succ, lighting:normal, rig:VR, analysis_type:aggFull'
ds, label = load_dataset(search_string, exclusion='obstacle', label="VR_light_succ")
data_dict[label] = ds

# Load VR prey capture with VR blackCr
search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggFull, notes:blackCr_rewarded_crickets_0_vrcrickets_1'
ds, label = load_dataset(search_string, label="VPrey_blackCr_lightBG")
data_dict[label] = ds

# Load VR prey capture with VR blackCr
search_string = 'result:succ, lighting:normal, rig:VPrey, analysis_type:aggFull, notes:blackCr_crickets_1'
ds, label = load_dataset(search_string, label="VPrey_blackCr_lightBG_multi")
data_dict[label] = ds

# Load VR prey capture with VR whiteCr_blackBG
search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggFull, notes:whiteCr_blackBG_rewarded_crickets_0_vrcrickets_1'
ds, label = load_dataset(search_string, label="VPrey_whiteCr_blackBG")
data_dict[label] = ds

# Load VR prey capture with VR whiteCr_blackBG
search_string = 'result:succ, lighting:normal, rig:VPrey, analysis_type:aggFull, notes:whiteCr_blackBG_crickets_1'
ds, label = load_dataset(search_string, label="VPrey_whiteCr_blackBG_multi")
data_dict[label] = ds

# Load VR  prey capture with VR blackCr_grayBG
search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggFull, notes:blackCr_grayBG_rewarded_crickets_0_vrcrickets_1'
ds, label = load_dataset(search_string, label="VPrey_blackCr_grayBG")
data_dict[label] = ds

# Load VR prey capture with VR whiteCr_grayBG
search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggFull, notes:whiteCr_grayBG_rewarded_crickets_0_vrcrickets_1'
ds, label = load_dataset(search_string, label="VPrey_whiteCr_grayBG")
data_dict[label] = ds

# Load VR prey capture with VR whiteCr_grayBG
search_string = 'result:test, lighting:normal, rig:VPrey, analysis_type:aggFull, gt_date:2020-06-23T00-00-00, lt_date:2020-07-06T00-00-00, notes:rewarded_crickets_0_vrcrickets_1'
ds, label = load_dataset(search_string, label="VPrey_original")
data_dict[label] = ds

# Load VR prey capture with VR whiteCr_grayBG
search_string = 'result:succ, lighting:normal, rig:VPrey, analysis_type:aggFull, gt_date:2020-06-23T00-00-00, lt_date:2020-07-06T00-00-00,notes:crickets_1_vrcrickets_1'
ds, label = load_dataset(search_string, label="VPrey_original_multi")
data_dict[label] = ds

# Get rid of doubled data set
del ds

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_VR_normal_ALL_ALL_ALL_ALL_2020-06-23T00-00-00_ALL_aggFull.hdf5
2020-12-01T10:27:07.480542Z
data label: VR_light_succ

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_test_VPrey_normal_ALL_blackCr_rewarded_crickets_0_vrcrickets_1_ALL_ALL_ALL_ALL_aggFull.hdf5
2020-12-09T13:19:25.121906Z
data label: VPrey_blackCr_lightBG

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_VPrey_normal_ALL_blackCr_crickets_1_ALL_ALL_ALL_ALL_aggFull.hdf5
2020-12-09T13:22:13.595409Z
data label: VPrey_blackCr_lightBG_multi

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_test_VPrey_normal_ALL_whiteCr_blackBG_rewarded_crickets_0_vrcrickets_1_ALL_ALL_ALL_ALL_aggFull.hdf5
2020-12-09T13:20:04.106311Z
data label: VPrey_whiteCr_blackBG

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_VPrey_normal_ALL_whiteCr_blackBG_crickets_1_ALL_ALL_ALL_ALL_aggFull.hdf5
2020-12-09T13:22:28.034211Z

In [17]:
for name in data_dict:

    data = data_dict[name]
    print(name)
    print(data[['trial_id']].max())

VR_light_succ
trial_id    54
dtype: int32
VPrey_blackCr_lightBG
trial_id    105
dtype: int32
VPrey_blackCr_lightBG_multi
trial_id    77
dtype: int32
VPrey_whiteCr_blackBG
trial_id    100
dtype: int32
VPrey_whiteCr_blackBG_multi
trial_id    58
dtype: int32
VPrey_blackCr_grayBG
trial_id    28
dtype: int32
VPrey_whiteCr_grayBG
trial_id    31
dtype: int32
VPrey_original
trial_id    31
dtype: int32
VPrey_original_multi
trial_id    27
dtype: int32


In [24]:
# Plot historgrams of trial duration

# allocate a list for the plots
plot_list = []
means = []
dur_sem = []

for name in data_dict:

    data = data_dict[name]

    times = data[['time_vector', 'trial_id']].copy()
    times = times.groupby(['trial_id']).agg(list)
    duration = np.array([trial[-1] for trial in times['time_vector']])

    means.append(duration.mean())
    dur_sem.append((name, duration.mean(), sem(duration)))

    # plot the results
    duration_histogram = hv.Bars(duration).opts(title=name, xlabel='trial', ylabel='duration')
    plot_list.append(duration_histogram)

hv.Layout(plot_list)

In [25]:
# Separately plot mean + sem of trial duration

# Plot of means
enc_means = hv.Bars((list(data_dict.keys()), means)).opts(title='Mean Trial Duration', ylabel='Duration (s)', ylim=(0,100), xrotation=45) 
enc_means = hv.ErrorBars(dur_sem) * enc_means
save_path = os.path.join(paths.figures_path, '_'.join([save_name, 'duration_means']))
hv.save(enc_means, save_path, fmt='png')

# dispaly the iamge
enc_means