# Debug VR Trial Alignment

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import os
import sys

import importlib
import holoviews as hv
hv.extension('bokeh')

import pandas as pd
import matplotlib.pyplot as plt


sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
import processing_parameters
import functions_bondjango as bd
import functions_misc as fmisc
import functions_matching as fm
import functions_data_handling as fdh
import functions_tuning as tuning
import functions_plotting as fp
import functions_kinematic as fk
import functions_loaders as fl
from wirefree_experiment import WirefreeExperiment, DataContainer
from functions_wirefree_trigger_fix import get_trial_duration_stats, drop_partial_or_long_trials

# fig_path = paths.wf_figures_path
fig_path = r"C:\Users\mmccann\Dropbox\bonhoeffer lab\thesis\figures"

## Test on a single experiment

In [None]:
def plot_dff_spikes_trials(exp, save_dir, save=True, plot_spikes=True, plot_trials=True, plot_running=False, **kwargs):

    basename_modifier = kwargs.pop('basename_modifier', '')
    fig_width = kwargs.pop('fig_width', 7)
    dpi = kwargs.pop('dpi', 800)
    fontsize = kwargs.pop('fontsize', 'poster')
    
    dff_plots = []
    for i, cell in enumerate(exp.cells_to_match):
        basename = f"dff_{i}{basename_modifier}"

        # Plot the dff
        out_fig = hv.Curve(exp.norm_dff[['time_vector', cell]]).opts(color='black', height=75, width=1000, **kwargs)

        if plot_spikes:
            spikes_plot = hv.Curve(exp.norm_spikes[['time_vector', cell]]).opts(color='green', alpha=0.5)
            out_fig = hv.Overlay([spikes_plot, out_fig])
            basename += '_spikes'

        if plot_trials:
            trials_on = exp.norm_dff['trial_num'] > 0
            time = exp.norm_dff['time_vector']
            trials_plot = hv.Area((time, trials_on)).opts(color='gray', alpha=0.25)
            out_fig = hv.Overlay([trials_plot, out_fig]).opts(hv.opts.Area(yaxis=None, xaxis=None, xlabel=None, ylabel=None, show_legend=False))
            basename += '_trials'

        if plot_running:
            try:
                running_plot = hv.Curve(exp.norm_dff[['time_vector', 'running_speed']]).opts(color='red', alpha=0.5)
            except KeyError:
                running_plot = hv.Curve(exp.norm_dff[['time_vector', 'wheel_speed_abs']]).opts(color='red', alpha=0.5)

            out_fig = hv.Overlay([out_fig, running_plot])
            basename += '_running'

        # Final options for the figure
        out_fig = out_fig.opts(hv.opts.Curve(yaxis=None, xaxis=None, xlabel=None, ylabel=None, show_legend=False, width=1000, height=75))
        
        # Save the figure
        save_path = os.path.join(save_dir, f'{basename}.png')

        if save:
            out_fig = fp.save_figure(out_fig, save_path=save_path, fig_width=fig_width, dpi=dpi, fontsize=fontsize, target='save', display_factor=0.2)
        else:
            out_fig = fp.save_figure(out_fig, save_path=save_path, fig_width=fig_width, dpi=dpi, fontsize=fontsize, target='screen', display_factor=0.2)

        dff_plots.append(out_fig)

    return dff_plots

In [None]:
importlib.reload(processing_parameters)

# get the search string
search_string = 'mouse:MM_221109_a, slug:01_10_2023, rig:VWheelWF'
parsed_search = fdh.parse_search_string(search_string)

# get the raw experiment
exp_query = bd.query_database('vr_experiment', search_string)
exp_query.sort(key=lambda x: x['rig'])

# Get the preprocessing file
preproc_query = bd.query_database('analyzed_data', search_string + r', analysis_type:preprocessing')
preproc_query = [q for q in preproc_query if parsed_search['mouse'].lower() in q['slug']]
preproc_query.sort(key=lambda x: x['rig'])
preproc_paths = np.sort(np.array([el['analysis_path'] for el in preproc_query if (el['analysis_type'] == 'preprocessing') and
                        (parsed_search['mouse'].lower() in el['slug'])]))

# Get the calcium file names
calcium_paths = np.array([p.replace('preproc', 'calciumraw') for p in preproc_paths])

# Parse the rigs
rigs = np.array([os.path.basename(file).split('_')[6] for file in calcium_paths])
print(calcium_paths)
print(preproc_paths)
print(rigs)

In [None]:
exp_free = WirefreeExperiment(exp_info=exp_query[0], preproc_info=preproc_query[0])
exp_free._load_preprocessing()

exp_free.dff = tuning.calculate_dff(exp_free.raw_fluor.copy(), baseline_type='quantile', quantile=0.25)
exp_free.norm_dff = tuning.normalize_responses(exp_free.dff.copy())
exp_free.norm_spikes = tuning.normalize_responses(exp_free.raw_spikes.copy())
exp_free.norm_dff = drop_partial_or_long_trials(exp_free.norm_dff)
exp_free.norm_spikes = drop_partial_or_long_trials(exp_free.norm_spikes)
# exp_free.cells_to_match = ['cell_0128', 'cell_0089', 'cell_0094', 'cell_0030', 'cell_0012', 'cell_0064', 'cell_0132', 'cell_0020']
#['cell_0119', 'cell_0069', 'cell_0076', 'cell_0023', 'cell_0011', 'cell_0054', 'cell_0125', 'cell_0020']

In [None]:
trial_durations, iti_durations, _, _ = get_trial_duration_stats(exp_free.norm_dff, trial_key='trial_num', time_key='time_vector')

In [None]:
trial_plot = hv.Scatter(trial_durations).opts(width=1000, height=500)
iti_plot = hv.Scatter(iti_durations).opts(width=1000, height=500)
hv.Overlay([trial_plot, iti_plot])

In [None]:
cells = [col for col in exp_free.norm_dff.columns if 'cell' in col]
cells_to_plot = np.random.choice(cells, min(15, len(cells)), replace=False)
trials_on = exp_free.norm_dff['trial_num'] > 0
time = exp_free.norm_dff['time_vector']
trials_plot = hv.Area((time, trials_on)).opts(color='gray', alpha=0.25)

plot_list = []
for cell in cells_to_plot:
    dff_plot = hv.Curve(exp_free.norm_dff[['time_vector', cell]]).opts(color='black', height=100, width=1000, yaxis=None, xaxis=None, xlabel=None, ylabel=None, show_legend=False)
    dff_plot = hv.Overlay([trials_plot, dff_plot])
    plot_list.append(dff_plot)

hv.Layout(plot_list).cols(1)

## Test on all experiments

In [None]:
frame_rate = processing_parameters.wf_frame_rate

all_paths, all_queries = fl.query_search_list()
mice = ['_'.join(os.path.basename(path).split('_')[7:10]) for path in all_paths[0]]
print(all_paths)

# load the data
data_list = []
path_list = []
for path, queries in zip(all_paths, all_queries):
    
        data, _, metadata  = fl.load_preprocessing(path, queries, latents_flag=False)
        data_list.append(data)
        path_list.append(path)

data_list = [ds for el in data_list for ds in el]

In [None]:
# Get the trial durations
all_trial_durations = []
all_iti_durations = []
for data in data_list:
    trial_durations, iti_durations, _, _ = get_trial_duration_stats(data, trial_key='trial_num', time_key='time_vector', display_info=False)
    all_trial_durations.append(trial_durations.to_list())
    all_iti_durations.append(iti_durations.tolist())

all_trial_durations = fmisc.list_lists_to_array(all_trial_durations)
all_iti_durations = fmisc.list_lists_to_array(all_iti_durations)

dur_mean = np.nanmean(all_trial_durations, axis=0)
dur_std = np.nanstd(all_trial_durations, axis=0)

iti_mean = np.nanmean(all_iti_durations, axis=0)
iti_std = np.nanstd(all_iti_durations, axis=0)

plt.errorbar(np.arange(len(dur_mean)), dur_mean, yerr=dur_std)
plt.errorbar(np.arange(len(iti_mean)), iti_mean, yerr=iti_std)
plt.legend(['Trial Durations', 'ITI Durations'])

In [None]:
# Get the trial durations but drop the short or long trials
curated_trial_durations = []
curated_iti_durations = []
for data in data_list:
    data = drop_partial_or_long_trials(data, min_trial_length=4.5, max_trial_length=5.5)
    trial_durations, iti_durations, _, _ = get_trial_duration_stats(data, trial_key='trial_num', time_key='time_vector', display_info=False)
    curated_trial_durations.append(trial_durations.to_list())
    curated_iti_durations.append(iti_durations.tolist())

curated_trial_durations = fmisc.list_lists_to_array(curated_trial_durations)
curated_iti_durations = fmisc.list_lists_to_array(curated_iti_durations)

dur_mean = np.nanmean(curated_trial_durations, axis=0)
dur_std = np.nanstd(curated_trial_durations, axis=0)

iti_mean = np.nanmean(curated_iti_durations, axis=0)
iti_std = np.nanstd(curated_iti_durations, axis=0)

fig = plt.figure()
plt.errorbar(np.arange(len(dur_mean)), dur_mean, yerr=dur_std)
plt.errorbar(np.arange(len(iti_mean)), iti_mean, yerr=iti_std)
plt.legend(['Trial Durations', 'ITI Durations'])

In [None]:
# Kick out experiments with messed up trials
all_trial_durations = []
all_iti_durations = []
num_trials = []
bad_paths = []
bad_idxs = []
for i, data in enumerate(data_list):
    trial_durations, iti_durations, trial_stats, _ = get_trial_duration_stats(data, trial_key='trial_num', time_key='time_vector', display_info=False)

    num_trials.append(len(trial_durations))
    # Do not include experiment with trials that are too short or too long
    if (trial_stats[0] < 2.5) or (trial_stats[1] > 5.5):

        long_trials = np.argwhere(trial_durations.to_numpy() > 5.5)
        short_trials = np.argwhere(trial_durations.to_numpy() < 2.5)

        # Check if it's only the last trial that's short, because that's okay
        if (len(short_trials) == 1) and (len(long_trials) == 0):
            short_trial = short_trials[-1]
            if short_trial == len(trial_durations) - 1:
                all_trial_durations.append(trial_durations.to_list())
                all_iti_durations.append(iti_durations.tolist())
            else:
                bad_paths.append([i])
                bad_idxs.append(i)
                continue
        else:
            bad_paths.append([i])
            bad_idxs.append(i)
            continue    
    else:
        all_trial_durations.append(trial_durations.to_list())
        all_iti_durations.append(iti_durations.tolist())

all_trial_durations = fmisc.list_lists_to_array(all_trial_durations)
all_iti_durations = fmisc.list_lists_to_array(all_iti_durations)

dur_mean = np.nanmean(all_trial_durations, axis=0)
dur_std = np.nanstd(all_trial_durations, axis=0)

trials = np.arange(len(dur_mean)) + 1
trial_array = np.ones_like(all_trial_durations) * trials

iti_mean = np.nanmean(all_iti_durations, axis=0)
iti_std = np.nanstd(all_iti_durations, axis=0)


In [None]:
fig = plt.figure()
plt.errorbar(trials, dur_mean, yerr=dur_std)
plt.errorbar(trials, iti_mean, yerr=iti_std)
plt.scatter(trial_array, all_trial_durations, color='black', s=1, alpha=0.5)
plt.scatter(trial_array, all_iti_durations, color='red', s=1, alpha=0.5)
plt.scatter(num_trials, np.zeros_like(num_trials), color='green', s=1, alpha=0.5)
plt.legend(['Trials', 'ITIs', 'Exp. End', 'Trial Durations', 'ITI Durations'])

plt.show()

In [None]:
path_names = [ds for el in all_paths for ds in el]
path_names = [os.path.basename(path) for path in path_names]
bad_paths = np.array(path_names)[bad_paths]
print(bad_paths)

In [None]:
# plot the bad ones
# Kick out experiments with messed up trials
all_trial_durations = []
all_iti_durations = []
num_trials = []

bad_paths = []
for i in bad_idxs:
    data = data_list[i]
    trial_durations, iti_durations, trial_stats, _ = get_trial_duration_stats(data, trial_key='trial_num', time_key='time_vector', display_info=False)
    num_trials.append(len(trial_durations))
    all_trial_durations.append(trial_durations.to_list())
    all_iti_durations.append(iti_durations.tolist())

all_trial_durations = fmisc.list_lists_to_array(all_trial_durations)
all_iti_durations = fmisc.list_lists_to_array(all_iti_durations)
trials = np.arange(len(dur_mean)) + 1
trial_array = np.ones_like(all_trial_durations) * trials

dur_mean = np.nanmean(all_trial_durations, axis=0)
dur_std = np.nanstd(all_trial_durations, axis=0)

iti_mean = np.nanmean(all_iti_durations, axis=0)
iti_std = np.nanstd(all_iti_durations, axis=0)

fig = plt.figure()
plt.errorbar(trials, dur_mean, yerr=dur_std)
plt.errorbar(trials, iti_mean, yerr=iti_std)
plt.scatter(trial_array, all_trial_durations, color='black', s=1, alpha=0.5)
plt.scatter(trial_array, all_iti_durations, color='red', s=1, alpha=0.5)
plt.scatter(num_trials, np.zeros_like(num_trials), color='green', s=1, alpha=0.5)
plt.legend(['Trials', 'ITIs', 'Exp. End', 'Trial Durations', 'ITI Durations'])

plt.show()