# Visualize differences between prey capture in box and in VR arena

In [1]:
import os
import sys
sys.path.insert(0, r'..\..')
import paths

import panel as pn
import holoviews as hv
from holoviews import opts, dim
hv.extension('bokeh')
from bokeh.resources import INLINE

import functions_bondjango as bd
import functions_kinematic as fk
import functions_plotting as fp
import functions_misc as fm
import functions_data_handling as fd
import pandas as pd
import numpy as np
import h5py

from scipy.stats import sem
import sklearn.decomposition as decomp
import umap
import sklearn.mixture as mix
from scipy.stats import sem


line_width = 5

In [2]:
# define the name to be used for the saved figures
save_name = 'VPrey_VRArena_box'

In [8]:
# Define a data loading function

def load_dataset(search_string, label=None, exclusion=None):
    # load the data
    # get the data paths
    try:
        data_path = snakemake.input[0]
    except NameError:
        # query the database for data to plot
        data_all = bd.query_database('analyzed_data', search_string)

        if exclusion is not None:
            for ds in data_all:
                if exclusion not in ds['analysis_path']:
                    data_path = ds['analysis_path']
                    data_date = ds['date']
                    break
        else:
            data_path = data_all[0]['analysis_path']
            data_date = data_all[0]['date']
    print(data_path)
    print(data_date)

    # assemble a label for this data set
    if label is None:
        d = fd.parse_search_string(search_string)
        label = '_'.join([d['rig'], d['lighting'], d['result'], d['notes']])
    print('data label: ' + label + '\n')

    # load the data
    return fd.aggregate_loader(data_path), label

## Encounter analysis
This section of the analysis relies on the aggregated encounters. Load this data first.

In [9]:
# create container for holding multiple data sets
data_dict = {}

# Load real prey capture in VR arena in the light
search_string = 'result:succ, lighting:normal, rig:VR, analysis_type:aggEnc'
ds, label = load_dataset(search_string, exclusion='obstacle', label="VR_light_succ")
data_dict[label] = ds

# Load real prey capture in VR arena in the dark - successes
search_string = 'result:succ, lighting:dark, rig:VR, analysis_type:aggEnc'
ds, label = load_dataset(search_string, label="VR_dark_succ")
data_dict[label] = ds

# Load real prey capture in VR arena in the dark - failures
search_string = 'result:fail, lighting:dark, rig:VR, analysis_type:aggEnc'
ds, label = load_dataset(search_string, label="VR_dark_fail")
data_dict[label] = ds

# Load real prey capture in small box
search_string = 'result:succ, lighting:normal, rig:miniscope, analysis_type:aggEnc'
ds, label = load_dataset(search_string, label="Box_light_succ")
data_dict[label] = ds

# Get rid of doubled data set
del ds

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_VR_normal_ALL_ALL_ALL_ALL_2020-06-23T00-00-00_ALL_aggEnc.hdf5
2020-12-01T10:27:14.316934Z
data label: VR_light_succ

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_VR_dark_ALL_ALL_ALL_ALL_2020-06-23T00-00-00_ALL_aggEnc.hdf5
2020-12-01T08:28:11.959957Z
data label: VR_dark_succ

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_fail_VR_dark_ALL_ALL_ALL_ALL_2020-06-23T00-00-00_ALL_aggEnc.hdf5
2020-12-01T08:27:55.869412Z
data label: VR_dark_fail

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_miniscope_normal_ALL_ALL_ALL_ALL_ALL_ALL_aggEnc.hdf5
2020-08-19T08:02:08.261545Z
data label: Box_light_succ



In [10]:
# Plot the number of encounters per trial for each condition

# allocate a list for the plots
plot_list = []
means = []
enc_sem = []

# Plots by trial type
for name in data_dict:

    data = data_dict[name]

    # load the parameter
    parameter = data[['event_id','trial_id']].copy()
    # find the number of encounters
    grouped_parameter = parameter.groupby(['trial_id']).agg(list)
    encounters = np.array([el[-1] for el in grouped_parameter['event_id']]) + 1
    means.append(encounters.mean())
    enc_sem.append((name, encounters.mean(), sem(encounters)))

    # plot the results
    enc_plot = hv.Bars((np.arange(encounters.shape[0]), encounters)).opts(title=name, xlabel='trial', ylabel='# encounters') * \
        hv.HLine(encounters.mean()).opts(color='red', line_width=1)
    plot_list.append(enc_plot)

encounters_panel = hv.Layout(plot_list).opts(shared_axes=True)
save_path = os.path.join(paths.figures_path, '_'.join([save_name, 'encounters']))
hv.save(encounters_panel, save_path, fmt='png')

# display the image
encounters_panel

In [11]:
# Separately plot mean + sem of encounters

# Plot of means
enc_means = hv.Bars((list(data_dict.keys()), means)).opts(title='Mean # Encounters', ylabel='# encounters', ylim=(0,15), xrotation=45) 
enc_means = hv.ErrorBars(enc_sem) * enc_means
save_path = os.path.join(paths.figures_path, '_'.join([save_name, 'encounter_means']))
hv.save(enc_means, save_path, fmt='png')

# dispaly the iamge
enc_means

#### PCA of Encounter types

In [5]:
# define the target parameter and PCA
target_parameter = 'cricket_0_mouse_distance'

# container for plots
plot_list = []

# container for PCA fit
pca_transforms = []

# container for target data
target = []

for name in data_dict:
    data = data_dict[name]

    # assemble the array with the parameters of choice
    target_data = data[[target_parameter] + ['event_id', 'trial_id']].groupby(['trial_id', 'event_id']).agg(list).to_numpy()

    # HACK REMOVE
    if ('VR' or 'VPrey') in name:
        target_data = np.array([el for sublist in target_data for el in sublist if len(el) == 594])
    else:
        target_data = np.array([el for sublist in target_data for el in sublist if len(el) == 74])
    target.append(target_data)

    # PCA the data before clustering
    pca = decomp.PCA()
    transformed_data = pca.fit_transform(target_data)
    pca_transforms.append(transformed_data)

    # fp.plot_2d([[pca.explained_variance_ratio_]])
    exp_var = hv.Curve(pca.explained_variance_ratio_).opts(xlabel='PCs', ylabel='explained variance', title=name)
    plot_list.append(exp_var)

hv.Layout(plot_list)

#### Guassian Mixture Model of clusters in data

In [7]:
# Cluster the data using GMMs
plot_list = []
clusters = []

for transformed_data, name in zip(pca_transforms, data_dict.keys()):
    
    # define the vector of components
    component_vector = [2, 3, 4, 5, 10, 15]
    # allocate memory for the results
    gmms = []
    # for all the component numbers
    for comp in component_vector:
        # # define the number of components
        gmm = mix.GaussianMixture(n_components=comp, covariance_type='diag', n_init=50)
        gmm.fit(transformed_data[:, :7])
        gmms.append(gmm.bic(transformed_data[:, :7]))    # Pull the first 7 PCs and get the bayesian information criterion

    # select the minimum bic number of components
    n_components = np.array(component_vector)[np.argmin(gmms)]
    # predict the cluster indexes
    gmm = mix.GaussianMixture(n_components=n_components, covariance_type='diag', n_init=50)
    cluster_idx = gmm.fit_predict(transformed_data[:, :7])

    # discard singletons
    # turn cluster_idx in a float
    cluster_idx = cluster_idx.astype(float)
    # get the IDs
    clu_unique = np.unique(cluster_idx)
    for clu in clu_unique:
        # get the number of traces in the cluster
        number_traces = sum(cluster_idx==clu)
        # if it's less than 5, eliminate the cluster
        if number_traces < 5:
            cluster_idx[cluster_idx==clu] = np.nan
    clusters.append(cluster_idx)
        
    # plot the BIC
    BIC = hv.Curve((component_vector, gmms)).opts(title=name, xlabel='cluster', ylabel='BIC')
    plot_list.append(BIC)

hv.Layout(plot_list).opts(shared_axes=False)

In [16]:
# plot the clusters
plot_list = []

for target_data, cluster_idx, name in zip(target, clusters, data_dict):

    # add the cluster indexes to the dataframe
    cluster_data = np.array([np.mean(target_data[cluster_idx == el, :], axis=0) for el in np.arange(n_components)])
    cluster_std = np.array([np.std(target_data[cluster_idx == el, :], axis=0)/np.sqrt(np.sum(cluster_idx == el))
                            for el in np.arange(n_components)])
    # plot the results
    cluster_plot = hv.Overlay([hv.Curve(el, label=str(idx), kdims=['Time (s)'], vdims=[target_parameter.replace('_', ' ')+' (px)']) for idx, el in enumerate(cluster_data)] + 
                                [hv.Spread((np.arange(el.shape[0]),el,cluster_std[idx, :])) for idx, el in enumerate(cluster_data)])

    cluster_plot.relabel('Clusters').opts({'Curve': dict(color=hv.Palette('Category20')), 
                                            'Spread': dict(color=hv.Palette('Category20'))})
    
    cluster_plot.opts(title=name)

    # For publication-ready image
    cluster_plot.opts(
        opts.Curve(
                    width=fp.pix(10.7), height=fp.pix(5), 
                    toolbar=None, hooks=[fp.margin], 
                    fontsize=fp.font_sizes['small'], 
                    line_width=12, xticks=3, yticks=3
                    ),
        opts.Overlay(legend_position='right', text_font='Arial')
        )

    # cluster_plot.opts(
    #     opts.Curve(
    #                 # width=fp.pix(10.7), height=fp.pix(5), 
    #                 # toolbar=None, hooks=[fp.margin], 
    #                 # fontsize=fp.font_sizes['small'], 
    #                 # line_width=12, 
    #                 xticks=3, yticks=3
    #                 ),
    #     opts.Overlay(legend_position='right', text_font='Arial')
    # )

    plot_list.append(cluster_plot)
    # print(cluster_plot)

cluster_trace_panel = hv.Layout(plot_list).opts(shared_axes=False)
# assemble the save path
save_path = os.path.join(paths.figures_path, '_'.join([save_name, target_parameter, 'cluster']))
hv.save(cluster_trace_panel, save_path, fmt='png')

# display the image
cluster_trace_panel

In [18]:
grouped_parameter = data.loc[:, [target_parameter] + ['event_id', 'trial_id']].groupby(['trial_id', 'event_id']).agg(list).to_numpy()

In [33]:
np.percentile(grouped_parameter, 95)

126.79739059663683

In [8]:
# define the target parameter and PCA
target_parameter = 'cricket_0_mouse_distance'

# plot the clusters as an image
plot_list = []
for name in data_dict:
    data = data_dict[name]

    # group the single traces
    grouped_parameter = data.loc[:, [target_parameter] + ['event_id', 'trial_id']].groupby(['trial_id', 'event_id']).agg(list).to_numpy()

    if 'VR' in name:
        grouped_parameter = np.array([el for sublist in grouped_parameter for el in sublist if len(el) == 594])
        grouped_parameter *= 100   # Convert to cm
    elif ('Box' in name) and ('distance' in target_parameter):       
        grouped_parameter = np.array([el for sublist in grouped_parameter for el in sublist]) 

    # plot all traces
    [sorted_traces,_,_] = fp.sort_traces(grouped_parameter)

    image = hv.Image(sorted_traces, ['Time','Trial #'], 
                    [target_parameter.replace('_', ' ')], 
                    bounds=[0, 0, target_data.shape[1], target_data.shape[0]]
                    ).opts(title=name)

    # For publication-ready image                
    # image.opts(
    #         width=fp.pix(5.8), 
    #         height=fp.pix(5.8), 
    #         toolbar=None, 
    #         hooks=[fp.margin], 
    #         fontsize=fp.font_sizes['small'], 
    #         xticks=3, yticks=3, 
    #         colorbar=True, cmap='viridis', 
    #         colorbar_opts={'major_label_text_align': 'left'}
    #         )

    image.opts(
            # width=fp.pix(5.8), 
            # height=fp.pix(5.8), 
            toolbar=None, 
            hooks=[fp.margin], 
            #fontsize=fp.font_sizes['small'], 
            #xticks=3, yticks=3, 
            colorbar=True, cmap='viridis', 
            colorbar_opts={'major_label_text_align': 'left'}
            )

    plot_list.append(image)


sorted_cluster_heatmap_panel = hv.Layout(plot_list).opts(shared_axes=False)

# assemble the save path
save_path = os.path.join(paths.figures_path,'_'.join([save_name, target_parameter]))
hv.save(sorted_cluster_heatmap_panel, save_path, fmt='png')

# display the image
sorted_cluster_heatmap_panel

#### UMAP embedding

In [39]:
# UMAP
plot_list = []

for transformed_data, cluster_idx, name in zip(pca_transforms, clusters, data_dict):
    # Pull data from storage dictionary
    data = data_dict[name]

    # embed the data via UMAP
    reducer = umap.UMAP(min_dist=0.5, n_neighbors=10)
    embedded_data = reducer.fit_transform(transformed_data)

    #--- Plot the embedding ---#

    # use the cluster indexes
    umap_data = np.concatenate((embedded_data,np.expand_dims(cluster_idx, axis=1)),axis=1)

    # # use the trial ID
    # # group the single traces
    # grouped_parameter = data.loc[:, ['event_id', 'trial_id']].groupby(
    #     ['trial_id']).agg(list)
    # temp_parameter = []
    # counter = 0
    # for idx, el in enumerate(grouped_parameter['event_id']):
    #     # get the event ids
    #     event_ids = np.unique(el)
    #     temp_parameter.append(idx*np.ones(event_ids.shape[0]))

    # grouped_parameter = np.concatenate(temp_parameter, axis=0)
    # umap_data = np.concatenate((embedded_data,np.expand_dims(grouped_parameter, axis=1)),axis=1)

    # highlight the last encounter of every group
    # allocate a list for that 
    winner_list = []
    grouped_parameter = data.loc[:, ['event_id', 'trial_id']].groupby(['trial_id']).agg(list)

    # for all the trials
    for idx, el in enumerate(grouped_parameter['event_id']):
        # get the event ids
        encounter_list = np.zeros(np.unique(el).shape[0])
        encounter_list[-1] = 1
        winner_list.append(encounter_list)

    grouped_parameter = np.concatenate(winner_list, axis=0)


    umap_plot = hv.Scatter(umap_data, vdims=['Dim 2','cluster'], kdims=['Dim 1'])
    umap_plot.opts(color='cluster', colorbar=True, cmap='Category10', size=5)

    # # For publication-ready image      
    # umap_plot.opts(color='cluster', colorbar=True, cmap='Category10', size=20)          
    # umap_plot.opts(
    #     opts.Scatter(
    #         width=fp.pix(5.7), 
    #         height=fp.pix(7.8), 
    #         toolbar=None, 
    #         hooks=[fp.margin], 
    #         fontsize=fp.font_sizes['small'], 
    #         xticks=3, 
    #         yticks=3
    #         )
    #     )

    umap_plot.opts(
        opts.Scatter(
            # width=fp.pix(5.7), 
            # height=fp.pix(7.8), 
            toolbar=None, 
            # hooks=[fp.margin], 
            fontsize=fp.font_sizes['small'], 
            xticks=3, 
            yticks=3
            )
        )

    #             opts.Overlay(legend_position='right', text_font='Arial'))

    # winner_data = embedded_data[grouped_parameter==1]

    # winner_plot = hv.Scatter(winner_data, vdims=['Dim 2'], kdims=['Dim 1'])
    # winner_plot.opts(width=fp.pix(5.7), height=fp.pix(7.8), toolbar=None, 
    #                         hooks=[fp.margin], fontsize=fp.font_sizes['small'], xticks=3, yticks=3, color='black', size=20)
    # umap_overlay = umap_plot*winner_plot

    plot_list.append(umap_plot)

umap_panel = hv.Layout(plot_list)

# assemble the save path
save_path = os.path.join(paths.figures_path,'_'.join([save_name, 'umap']))
# hv.save(umap_panel, save_path, fmt='png')

# display the image
umap_panel

## Binned time analysis
We move to using the binned time analysis to get an idea of what the overall kinematics of the scene are like

In [9]:
# create container for holding multiple data sets
data_dict = {}

# Load real prey capture in VR arena in the light
search_string = 'result:succ, lighting:normal, rig:VR, analysis_type:aggBin, notes:crickets'
ds, label = load_dataset(search_string, exclusion='obstacle', label="VR_light_succ")
data_dict[label] = ds

# Load real prey capture in VR arena in the dark - successes
search_string = 'result:succ, lighting:dark, rig:VR, analysis_type:aggBin'
ds, label = load_dataset(search_string, label="VR_dark_succ")
data_dict[label] = ds

# Load real prey capture in VR arena in the dark - failures
search_string = 'result:fail, lighting:dark, rig:VR, analysis_type:aggBin'
ds, label = load_dataset(search_string, label="VR_dark_fail")
data_dict[label] = ds

# Load real prey capture in small box
search_string = 'result:succ, lighting:normal, rig:miniscope, analysis_type:aggBin'
ds, label = load_dataset(search_string, label="Box_light_succ")
data_dict[label] = ds

# Get rid of doubled data set
del ds

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_VR_normal_ALL_crickets_1_vrcrickets_0_ALL_ALL_2020-06-23T00-00-00_ALL_aggBin.hdf5
data label: VR_light_succ

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_VR_dark_ALL_crickets_1_vrcrickets_0_ALL_ALL_2020-06-23T00-00-00_ALL_aggBin.hdf5
data label: VR_dark_succ

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_fail_VR_dark_ALL_crickets_1_vrcrickets_0_ALL_ALL_2020-06-23T00-00-00_ALL_aggBin.hdf5
data label: VR_dark_fail

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_miniscope_normal_ALL_ALL_ALL_ALL_ALL_ALL_aggBin.hdf5
data label: Box_light_succ



In [10]:
# Plot example encounter traces sorted

# define the target parameter
target_parameters = ['mouse_speed', 'cricket_0_mouse_distance', 'cricket_0_speed']


# allocate a list for the plots
plot_list = []

keys = list(data_dict.keys())
for name in data_dict:

    data = data_dict[name]
    cluster_idx = None

    # for all the parameters
    for target_param in target_parameters:

        # load the parameter
        try:
            parameter = data[[target_param, 'trial_id']].copy()
        except KeyError:
            if ('cricket' in target_param) and ('vr' not in target_param):
                hmap = hv.Empty()
                plot_list.append(hmap)
                continue

        # group the single traces
        grouped_parameter = parameter.groupby(['trial_id']).agg(list)
        grouped_parameter = np.array([el for el in grouped_parameter[target_param]])
        if np.argwhere(np.isinf(grouped_parameter)).size != 0:
            grouped_parameter[grouped_parameter == np.inf] = 0
        
        # get the clustering for first parameter, and preserve that sorting for all other target parameters tested
        if cluster_idx is None:
                [sorted_traces, cluster_idx, clusters] = fp.sort_traces(grouped_parameter, nclusters=min((10, len(grouped_parameter))))
        else:
                sorted_traces = grouped_parameter[cluster_idx, :]
        
        # plot all traces
        hmap = hv.Image(sorted_traces, ['Binned Time','Trial #'],
                        [target_param.replace('_', ' ')], 
                        bounds=[0, 0, grouped_parameter.shape[1], grouped_parameter.shape[0]],
                        group=name, 
                        label=target_param)

        # # For publication-ready image                
        # hmap.opts(
        #         width=fp.pix(5.8), 
        #         height=fp.pix(5.8), 
        #         toolbar=None, 
        #         hooks=[fp.margin], 
        #         fontsize=fp.font_sizes['small'], 
        #         xticks=3, yticks=3, 
        #         colorbar=True, cmap='viridis', 
        #         colorbar_opts={'major_label_text_align': 'left'}
        #         )

        hmap.opts(
                width=fp.pix(1.5), 
                height=fp.pix(1.5), 
                toolbar=None, 
                # hooks=[fp.margin], 
                #fontsize=fp.font_sizes['small'], 
                xticks=3, yticks=3, 
                colorbar=True, cmap='viridis', 
                colorbar_opts={'major_label_text_align': 'left'}
                )

        plot_list.append(hmap)

heatmaps = hv.Layout(plot_list).cols(len(target_parameters))
save_path = os.path.join(paths.figures_path, '_'.join([save_name, 'binned_kinematics']))
hv.save(heatmaps, save_path, fmt='png')
heatmaps

In [11]:
# Plot example encounter traces sorted by cluster

# define the target parameter
target_parameters = ['mouse_speed', 'cricket_0_speed', 'cricket_0_mouse_distance']

# allocate a list for the plots
plot_list = []

for name in data_dict:

    data = data_dict[name]

    # for all the parameters
    for target_param in target_parameters:

        # load the parameter
        try:
            parameter = data[[target_param, 'trial_id']].copy()
        except KeyError:
            # handle cases where you have VR only and no real crickets
            if ('cricket' in target_param) and ('vr' not in target_param):
                hmap = hv.Empty()
                plot_list.append(hmap)
                continue
            # handle cases when you have real only and Vr only comparisons
            elif ('vr' in target_param) and ('VPrey' not in name):
                hmap = hv.Empty()
                plot_list.append(hmap)
                continue

        # group the single traces
        grouped_parameter = parameter.groupby(['trial_id']).agg(list)
        grouped_parameter = np.array([el for el in grouped_parameter[target_param]])
        if np.argwhere(np.isinf(grouped_parameter)).size != 0:
            grouped_parameter[grouped_parameter == np.inf] = 0

        # For all of these, we have units of m/s or meters. Convert to cm/s or cm
        if 'VR' in name:
            grouped_parameter *= 100

        # get the statistics of the cluster so we set the same bins for all plots of the same variable
        # HACK: this gets rid of zero values, need to find a way to show them in the speed plots
        # sorted_traces += 1e-10
        # lower = np.log10(np.percentile(grouped_parameter[np.nonzero(grouped_parameter)], 1))
        # # upper = np.log10(np.percentile(grouped_parameter, 99))
        # # lower = np.log10(np.min(grouped_parameter[np.nonzero(grouped_parameter)]))
        # upper = np.log10(np.floor(np.max(grouped_parameter)))
        lower = -2
        upper = 2.5

        bin_edges = np.logspace(lower, upper, 50)

        # generate histogram
        freq, edges = np.histogram(grouped_parameter, bin_edges, density=False)
        freq = freq / np.sum(freq)
        # generate CDF of histogram
        cdf = np.cumsum(freq)

        # plot histogram
        hist = hv.Histogram((edges, freq), 
                group=': '.join((name, target_param)), 
                ).opts(logx=True)

        # Add cdf
        # cum_dist = hv.Curve((edges[1:], cdf)).opts(color='green', logx=True)
        # hist = hist * cum_dist

        # Make a reference line at 1 or 10 (10cm for distance, 10cm/s for velocity)
        # v_line = 1 if 'cricket' in target_param else 10
        # vline = hv.VLine(v_line).opts(color='red', line_width=1, line_dash='dashed')
        # hist = (hist * vline)

        # Addx axis labels
        if 'speed' in target_param:
            hist.opts(xlabel='cm/s')
        elif 'distance' in target_param:
            hist.opts(xlabel='cm')

        plot_list.append(hist)


param_hists = hv.Layout(plot_list).opts(shared_axes=True).cols(len(target_parameters))
save_path = os.path.join(paths.figures_path, '_'.join([save_name, 'histogram_kinematics']))
hv.save(param_hists, save_path, fmt='png')
param_hists

## Full aggregate analysis
Here we look at the full trace for each experiment. This is useful for getting the trial duration, but also a different view of the kinematics

In [46]:
# create container for holding multiple data sets
data_dict = {}

# Load real prey capture in VR arena in the light
search_string = 'result:succ, lighting:normal, rig:VR, analysis_type:aggFull, notes:crickets'
ds, label = load_dataset(search_string, exclusion='obstacle', label="VR_light_succ")
data_dict[label] = ds

# Load real prey capture in VR arena in the dark - successes
search_string = 'result:succ, lighting:dark, rig:VR, analysis_type:aggFull'
ds, label = load_dataset(search_string, label="VR_dark_succ")
data_dict[label] = ds

# Load real prey capture in VR arena in the dark - failures
search_string = 'result:fail, lighting:dark, rig:VR, analysis_type:aggFull'
ds, label = load_dataset(search_string, label="VR_dark_fail")
data_dict[label] = ds

# Load real prey capture in small box
search_string = 'result:succ, lighting:normal, rig:miniscope, analysis_type:aggFull'
ds, label = load_dataset(search_string, label="Box_light_succ")
data_dict[label] = ds

# Get rid of doubled data set
del ds

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_VR_normal_ALL_crickets_1_vrcrickets_0_ALL_ALL_2020-06-23T00-00-00_ALL_aggFull.hdf5
data label: VR_light_succ

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_VR_dark_ALL_crickets_1_vrcrickets_0_ALL_ALL_2020-06-23T00-00-00_ALL_aggFull.hdf5
data label: VR_dark_succ

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_fail_VR_dark_ALL_crickets_1_vrcrickets_0_ALL_ALL_2020-06-23T00-00-00_ALL_aggFull.hdf5
data label: VR_dark_fail

J:\Drago Guggiana Nilo\Prey_capture\AnalyzedData\preprocessing_succ_miniscope_normal_ALL_ALL_ALL_ALL_ALL_ALL_aggFull.hdf5
data label: Box_light_succ



In [53]:
np.unique(data['trial_id'], return_counts=True)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
        34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
        51, 52, 53, 54]),
 array([4534, 4575, 4242, 1593, 4026, 6914, 4459, 7041, 4878, 3474, 4257,
        6771, 4700, 1312, 2826, 1778, 2674, 1325, 1666, 2547, 3378, 4208,
        2953, 5024, 3142, 3015, 2334, 2154, 2999, 3417, 1331, 3818, 2775,
        1445, 4263, 2408, 3110, 2432, 4595, 3912, 1635, 2722, 2361, 1978,
        1412, 1667, 1592, 1919, 3644, 2861, 3578, 3287, 3450, 1468, 2687],
       dtype=int64))

In [None]:
# Get a sense of how many trials there are per data set
for name in data_dict:
    data = data_dict[name]
    

In [47]:
# Plot historgrams of trial duration

# allocate a list for the plots
plot_list = []
means = []
dur_sem = []

for name in data_dict:

    data = data_dict[name]

    times = data[['time_vector', 'trial_id']].copy()
    times = times.groupby(['trial_id']).agg(list)
    duration = np.array([trial[-1] for trial in times['time_vector']])

    means.append(duration.mean())
    dur_sem.append((name, duration.mean(), sem(duration)))

    # plot the results
    duration_histogram = hv.Bars(duration).opts(title=name, xlabel='trial', ylabel='duration')
    plot_list.append(duration_histogram)

hv.Layout(plot_list)

In [51]:
# Separately plot mean + sem of trial duration

# Plot of means
trial_means = hv.Bars((list(data_dict.keys()), means)).opts(title='Mean Trial Duration', ylabel='Duration (s)', ylim=(0,300), xrotation=45) 
trial_means = hv.ErrorBars(dur_sem) * trial_means
save_path = os.path.join(paths.figures_path, '_'.join([save_name, 'duration_means']))
hv.save(trial_means, save_path, fmt='png')

# dispaly the iamge
enc_means