In [None]:
%load_ext autoreload
%autoreload 2

# Imports
import numpy as np
import pandas as pd
import xarray as xr
import panel as pn
import holoviews as hv
import datashader as dshade
import ipywidgets as widgets
from ipywidgets import interact

import os
import sys
import random
import importlib
import datetime
import warnings
import math
import cmath
import pycircstat as circ
warnings.filterwarnings('ignore')

from itertools import zip_longest
from scipy.stats import sem, norm, binned_statistic, percentileofscore, ttest_1samp, ttest_ind
from scipy.optimize import least_squares

from bokeh.resources import INLINE
from bokeh.io import export_svgs, export_png
from holoviews import opts, dim
from holoviews.operation import histogram
from holoviews.operation.datashader import datashade, shade
hv.extension('bokeh')
# hv.extension('matplotlib')
import matplotlib.pyplot as plt
%matplotlib widget

sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
%aimport paths
import processing_parameters
import functions_bondjango as bd
import functions_plotting as fp
import functions_data_handling as fdh
import functions_kinematic as fk
import functions_tuning as tuning
from functions_misc import list_lists_to_array
from wirefree_experiment import WirefreeExperiment, DataContainer

importlib.reload(fp)
# set up the figure theme
fp.set_theme()

plt.rcParams["font.family"] = "Arial"
plt.rcParams['font.size'] = 15
cm = 1/2.54  # centimeters in inches

In [None]:
def spike_raster(data, cells=None):
        if cells is None:
            cells = [el for el in data.columns if 'cell' in el]
    
        spikes = data.loc[:, cells]
    
        im = hv.Image((data.time_vector, np.arange(len(cells)), spikes.values.T), 
                        kdims=['Time (s)', 'Cells'], vdims=['Activity (a.u.)'])
        im.opts(width=600) #, cmap='Purples')
        return im

def trace_raster(data, cells=None, ds_factor=1):
    if cells is None:
        cells = [el for el in data.columns if 'cell' in el]

    trace = data.loc[:, cells]
    max_std = trace.std().max()

    lines = {i: hv.Curve((data.time_vector, trace.iloc[:, i].values.T + i*max_std)) for i in np.arange(len(cells))}
    lineoverlay = hv.NdOverlay(lines, kdims=['Time (s)']).opts(height=500, width=800)
    return lineoverlay

def plot_fixed_ethogram(experiment, num_cells=15, **kwargs):
    kinem = experiment.kinematics
    time = kinem.time_vector

    # Select cells
    cells = np.random.choice(experiment.cells, num_cells)
    
    fig, axes = plt.subplots(nrows=num_cells+4, ncols=1, sharex=True, **kwargs)
    kine_labels = ['trials', 'run speed\n [m/s]', 'pupil dia.\n [px]', 'pupil pos.\n [px]'] 

    # Plot trials and kinematics
    trial_tick = np.where(kinem.trial_num > 0, 1, 0)
    axes[0].plot(time, trial_tick, 'k')
    
    
    axes[1].plot(time, np.abs(kinem.wheel_speed), 'k')
    axes[1].set_ylim((0, 1))

    smoothed_diameter = tuning.normalize(np.unwrap(fk.jump_killer(kinem.pupil_diameter, 2), 10))
    axes[2].plot(time, tuning.moving_average(smoothed_diameter, 5), 'k')
    axes[2].set_ylim((0, 1))

    smoothed_x = np.unwrap(fk.jump_killer(kinem.fit_pupil_center_x, 3), 15)
    smoothed_x -= np.mean(smoothed_x)
    smoothed_y = np.unwrap(fk.jump_killer(kinem.fit_pupil_center_y, 3), 15)
    smoothed_y -= np.mean(smoothed_y)
    axes[3].plot(time, smoothed_x, 'k', time, smoothed_y, 'r')
    

    # Plot neural data
    for i, cell in enumerate(cells):
        axes[i+4].plot(time, exp.norm_dff[cell], 'k', alpha=0.7)
        axes[i+4].plot(time, exp.norm_spikes[cell], 'g', alpha=0.5)
        axes[i+4].set_ylim((0, 1))

    # Remove ticks
    axes_labels = kine_labels  + list(cells)
    for i, a in enumerate(axes):
        if i < len(kine_labels):
            a.set_ylabel(axes_labels[i].capitalize(), rotation=0, va='center', wrap=True, labelpad=25)
        else:
            a.set_ylabel(axes_labels[i].replace("_", " ").capitalize(), rotation=0, va='center', wrap=True, labelpad=25)
            
        a.spines["top"].set_visible(False)
        a.spines["right"].set_visible(False)
        if i < len(axes)-1:
            a.spines["bottom"].set_visible(False)
            a.xaxis.set_tick_params(bottom=False)
    
    axes[-1].set_xlabel ("Time (s)")
    fig.align_ylabels()

    return fig

In [None]:
importlib.reload(processing_parameters)

# get the search string
search_string = processing_parameters.search_string + r", rig:VWheelWF, analysis_type:preprocessing"

# get the paths from the database
file_path, paths_all, parsed_query, date_list, animal_list = fdh.fetch_preprocessing(search_string)

animal_idxs = [i for i,d in enumerate(animal_list) if d==parsed_query['mouse'].lower()]
good_entries = [file_path[index] for index in animal_idxs]
input_path = [paths_all[index] for index in animal_idxs]

# # assemble the output path
print(input_path)

exp = WirefreeExperiment(input_path[0], use_xarray=False)
exp.dff = tuning.calculate_dff(exp.raw_fluor)
exp.norm_spikes = tuning.normalize_responses(exp.raw_spikes)
exp.norm_fluor = tuning.normalize_responses(exp.raw_fluor)
exp.norm_dff = tuning.normalize_responses(exp.dff)

if 'head_pitch' not in exp.kinematics.columns:
    pitch = -fk.wrap_negative(exp.kinematics.mouse_xrot_m.values)
    exp.kinematics['head_pitch'] = fk.smooth_trace(pitch, range=(-180, 180), kernel_size=10, discont=2*np.pi)
    
    yaw = fk.wrap_negative(exp.kinematics.mouse_zrot_m.values)
    exp.kinematics['head_yaw'] = fk.smooth_trace(yaw, range=(-180, 180), kernel_size=10, discont=2*np.pi)
    
    roll = fk.wrap_negative(exp.kinematics.mouse_yrot_m.values)
    exp.kinematics['head_roll'] = fk.smooth_trace(roll, range=(-180, 180), kernel_size=10, discont=2*np.pi)

In [None]:
def get_trial_duration_stats(df, trial_key, time_key):
    grouped_trials = df[df[trial_key] > 0].groupby(trial_key)
    trial_durations = grouped_trials.apply(lambda x: x[time_key].to_list()[-1] - x[time_key].to_list()[0])
    print(trial_durations.min(), trial_durations.max(), trial_durations.mean())
    return np.array((trial_durations.min(), trial_durations.max(), trial_durations.mean()))

In [None]:
a = get_trial_duration_stats(exp.norm_spikes, 'trial_num', 'time_vector')

In [None]:
%matplotlib widget
fig = plot_fixed_ethogram(exp, num_cells=10, figsize=(15*cm,15*cm))
plt.tight_layout()
save_path = os.path.join(r'C:\Users\mmccann\Dropbox\bonhoeffer lab\SFN 2023\poster', f'fixed_ethogram.png')
# fp.save_figure(fig, fig_width=8, save_path=save_path, fontsize='poster', target='both')
# plt.savefig(save_path, dpi=600)

## Filter by running speed
Look at # surviving trials (how many are left?). Running modulation? Bin based on running speed. How well does speed explain the activity of each cell on a given trial? If ccs high, sad.  

Split tuning curve trials in 80% train, 20% test. Caclulate fit on train, get R2 to test. Gets an idea of response reliability/goodness, as opposed to robustness to noise with bootstrapping. need to explain clearly why these are

In [None]:
# Get an idea of the cutoff to be considered still
exp.kinematics['wheel_speed_abs'] = np.abs(exp.kinematics['wheel_speed'])
speed_cutoff = np.percentile(exp.kinematics.wheel_speed_abs, 80)
exp.kinematics['is_running'] = exp.kinematics.wheel_speed_abs >= speed_cutoff

In [None]:
fig=plt.figure()
plt.plot(exp.kinematics.wheel_speed_abs) 
plt.hlines(speed_cutoff, 0, len(exp.kinematics.wheel_speed_abs), colors='r')
plt.show()

In [None]:
still_trials = exp.kinematics.groupby('trial_num').filter(lambda x: x['wheel_speed_abs'].mean() < speed_cutoff).trial_num.unique()
still_spikes = exp.norm_spikes.loc[exp.norm_spikes.trial_num.isin(still_trials)]
exp.norm_spikes_still = still_spikes

In [None]:
%%time
datasets = ['norm_spikes_still'] # 'raw_spikes', 'norm_spikes', 
for dataset in datasets: 
    activity = getattr(exp, dataset)
    
    for tuning_type in ['direction_wrapped', 'orientation']:
        props = calculate_visual_tuning(activity, tuning_type, bootstrap_shuffles=1)
        
        tuning_label = tuning_type.split('_')[0]
        setattr(exp, f'{dataset}_{tuning_label}_props', props)

In [None]:
%matplotlib inline
interact(plot_visual_tuning_curves,
         experiment = widgets.fixed(exp),
         cell=exp.cells, 
         data_type=datasets, 
         error = ['std', 'sem'],
         polar=[True, False],
         norm=widgets.fixed(True),
         axes=widgets.fixed(None))

In [None]:
# Get number of presentation per orientation
exp.dff.groupby(['orientation']).apply(lambda x: x.trial_num.unique().size)

# Cell responses by orientation

In [None]:
ds = exp.raw_fluor
ds.dropna(inplace=True)
cells = [col for col in ds if 'cell' in col]
trials_per_ori = ds.groupby(['orientation', 'trial_num']).agg(list).groupby(['orientation'])[cells].agg(list).applymap(list_lists_to_array).reset_index()
trials_per_ori.drop(trials_per_ori[trials_per_ori.orientation == -1000].index, inplace=True)
trials_per_ori.set_index('orientation', inplace=True)
trials_averages = trials_per_ori[cells].applymap(np.nanmean, axis=0) 

plot_list = []

for i, cell in enumerate(cells):

    cell_resps = trials_per_ori[cell]
    cell_mean = trials_averages[cell]
    
    for k in range(cell_resps.shape[-1]):
        resps = cell_resps.iloc[k]
        mean = cell_mean.iloc[k]
        
        trials_list = [hv.Curve(resps[r, :100]).opts(color='blue', alpha=0.25) for r in np.arange(resps.shape[0])]
        mean_list = [hv.Curve((np.arange(0, 100), mean[:100])).opts(color='k', xticks=[(0, 0), (50, 2.5)], xlabel='', ylabel='', fontsize={'xticks': 10, 'yticks': 10,})]
        
        if i == 0:
            [mean_plot.opts(title=f"{cell_resps.index[k]:.1f}") for mean_plot in mean_list]
            
        plot_list.append(hv.Overlay(trials_list + mean_list).opts(width=120, height=150))
        # plot_list.append(hv.Overlay(mean_list).opts(width=120, height=150))

raw_fluor_layout = hv.Layout(plot_list).cols(12)
raw_fluor_layout

In [None]:
ds = exp.norm_dff
ds.dropna(inplace=True)
cells = [col for col in ds if 'cell' in col]
trials_per_ori = ds.groupby(['orientation', 'trial_num']).agg(list).groupby(['orientation'])[cells].agg(list).applymap(list_lists_to_array).reset_index()
trials_per_ori.drop(trials_per_ori[trials_per_ori.orientation == -1000].index, inplace=True)
trials_per_ori.set_index('orientation', inplace=True)
trials_averages = trials_per_ori[cells].applymap(np.nanmean, axis=0) 

plot_list = []

for i, cell in enumerate(cells):

    cell_resps = trials_per_ori[cell]
    cell_mean = trials_averages[cell]
    
    for k in range(cell_resps.shape[-1]):
        resps = cell_resps.iloc[k]
        mean = cell_mean.iloc[k]
        
        trials_list = [hv.Curve(resps[r, :100]).opts(color='blue', alpha=0.25) for r in np.arange(resps.shape[0])]
        mean_list = [hv.Curve((np.arange(0, 100), mean[:100])).opts(color='k', xticks=[(0, 0), (50, 2.5)], xlabel='', ylabel='', fontsize={'xticks': 10, 'yticks': 10,})]
        
        if i == 0:
            [mean_plot.opts(title=f"{cell_resps.index[k]:.1f}") for mean_plot in mean_list]
                
        plot_list.append(hv.Overlay(trials_list + mean_list).opts(width=120, height=150))
        # plot_list.append(hv.Overlay(mean_list).opts(width=120, height=150))

norm_dff_layout = hv.Layout(plot_list).cols(12)
norm_dff_layout

In [None]:
ds = exp.dff
ds.dropna(inplace=True)
cells = [col for col in ds if 'cell' in col]
trials_per_ori = ds.groupby(['orientation', 'trial_num']).agg(list).groupby(['orientation'])[cells].agg(list).applymap(list_lists_to_array).reset_index()
trials_per_ori.drop(trials_per_ori[trials_per_ori.orientation == -1000].index, inplace=True)
trials_per_ori.set_index('orientation', inplace=True)
trials_averages = trials_per_ori[cells].applymap(np.nanmean, axis=0)

plot_list = []

for i, cell in enumerate(cells):

    cell_resps = trials_per_ori[cell]
    cell_mean = trials_averages[cell]
    
    for k in range(cell_resps.shape[-1]):
        resps = cell_resps.iloc[k]
        mean = cell_mean.iloc[k]
        
        trials_list = [hv.Curve(resps[r, :100]).opts(color='blue', alpha=0.25) for r in np.arange(resps.shape[0])]
        mean_list = [hv.Curve((np.arange(0, 100), mean[:100])).opts(color='k', xticks=[(0, 0), (50, 2.5)], xlabel='', ylabel='', fontsize={'xticks': 10, 'yticks': 10,})]
        
        if i == 0:
            [mean_plot.opts(title=f"{cell_resps.index[k]:.1f}") for mean_plot in mean_list]
                
        plot_list.append(hv.Overlay(trials_list + mean_list).opts(width=120, height=150))
        # plot_list.append(hv.Overlay(mean_list).opts(width=120, height=150))
dff_layout = hv.Layout(plot_list).cols(12)
dff_layout

In [None]:
ds = exp.norm_spikes
ds.dropna(inplace=True)
cells = [col for col in ds if 'cell' in col]
trials_per_ori = ds.groupby(['orientation', 'trial_num']).agg(list).groupby(['orientation'])[cells].agg(list).applymap(list_lists_to_array).reset_index()
trials_per_ori.drop(trials_per_ori[trials_per_ori.orientation == -1000].index, inplace=True)
trials_per_ori.set_index('orientation', inplace=True)
trials_averages = trials_per_ori[cells].applymap(np.nanmean, axis=0)

plot_list = []

for i, cell in enumerate(cells):

    cell_resps = trials_per_ori[cell]
    cell_mean = trials_averages[cell]
    
    for k in range(cell_resps.shape[-1]):
        resps = cell_resps.iloc[k]
        mean = cell_mean.iloc[k]
        
        trials_list = [hv.Curve(resps[r, :100]).opts(color='blue', alpha=0.25) for r in np.arange(resps.shape[0])]
        mean_list = [hv.Curve((np.arange(0, 100), mean[:100])).opts(color='k', xticks=[(0, 0), (50, 2.5)], xlabel='', ylabel='', fontsize={'xticks': 10, 'yticks': 10,})]
        
        if i == 0:
            [mean_plot.opts(title=f"{cell_resps.index[k]:.1f}") for mean_plot in mean_list]
                
        plot_list.append(hv.Overlay(trials_list + mean_list).opts(width=120, height=150))
        # plot_list.append(hv.Overlay(mean_list).opts(width=120, height=150))
norm_spikes_layout = hv.Layout(plot_list).cols(12)
norm_spikes_layout

In [None]:
ds = exp.raw_spikes
ds.dropna(inplace=True)
cells = [col for col in ds if 'cell' in col]
trials_per_ori = ds.groupby(['orientation', 'trial_num']).agg(list).groupby(['orientation'])[cells].agg(list).applymap(list_lists_to_array).reset_index()
trials_per_ori.drop(trials_per_ori[trials_per_ori.orientation == -1000].index, inplace=True)
trials_per_ori.set_index('orientation', inplace=True)
trials_averages = trials_per_ori[cells].applymap(np.nanmean, axis=0)

plot_list = []

for i, cell in enumerate(cells):

    cell_resps = trials_per_ori[cell]
    cell_mean = trials_averages[cell]
    
    for k in range(cell_resps.shape[-1]):
        resps = cell_resps.iloc[k]
        mean = cell_mean.iloc[k]
        
        trials_list = [hv.Curve(resps[r, :100]).opts(color='blue', alpha=0.25) for r in np.arange(resps.shape[0])]
        mean_list = [hv.Curve((np.arange(0, 100), mean[:100])).opts(color='k', xticks=[(0, 0), (50, 2.5)], xlabel='', ylabel='', fontsize={'xticks': 10, 'yticks': 10,})]
        
        if i == 0:
            [mean_plot.opts(title=f"{cell_resps.index[k]:.1f}") for mean_plot in mean_list]
                
        plot_list.append(hv.Overlay(trials_list + mean_list).opts(width=120, height=150))
        # plot_list.append(hv.Overlay(mean_list).opts(width=120, height=150))
raw_spikes_layout = hv.Layout(plot_list).cols(12)
raw_spikes_layout

# Cell responses by visual stimulus

In [None]:
ds = exp.norm_dff
ds.dropna(inplace=True)
cells = [col for col in ds if 'cell' in col]
trials_per_ori = ds.groupby(['trial_num']).agg(list).groupby(['orientation'])[cells].agg(list).applymap(list_lists_to_array).reset_index()
trials_per_ori.drop(trials_per_ori[trials_per_ori.orientation == -1000].index, inplace=True)
trials_per_ori.set_index('orientation', inplace=True)
trials_averages = trials_per_ori[cells].applymap(np.nanmean, axis=0) 

plot_list = []

for i, cell in enumerate(cells):

    cell_resps = trials_per_ori[cell]
    cell_mean = trials_averages[cell]
    
    for k in range(cell_resps.shape[-1]):
        resps = cell_resps.iloc[k]
        mean = cell_mean.iloc[k]
        
        trials_list = [hv.Curve(resps[r, :100]).opts(color='blue', alpha=0.25) for r in np.arange(resps.shape[0])]
        mean_list = [hv.Curve((np.arange(0, 100), mean[:100])).opts(color='k', xticks=[(0, 0), (50, 2.5)], xlabel='', ylabel='', fontsize={'xticks': 10, 'yticks': 10,})]
        
        if i == 0:
            [mean_plot.opts(title=f"{cell_resps.index[k]:.1f}") for mean_plot in mean_list]
                
        plot_list.append(hv.Overlay(trials_list + mean_list).opts(width=120, height=150))
        # plot_list.append(hv.Overlay(mean_list).opts(width=120, height=150))

norm_dff_layout = hv.Layout(plot_list).cols(12)
norm_dff_layout

# Visual stimulus triggered responses

In [None]:
# Test with just trial_num agg
ds = exp.norm_dff
ds.dropna(inplace=True)

frame_rate = processing_parameters.wf_frame_rate    #fps
cells = [col for col in ds if 'cell' in col]
trial_idx_frames = ds[ds.trial_num > 0.0].groupby(['trial_num']).apply(lambda x: [x.index[0], x.index[-1]])
trial_idx_frames = np.array(trial_idx_frames.to_list())

trial_idx_frames[:, 0] -= 2 * frame_rate
trial_idx_frames[:, 1] += 2 * frame_rate
if trial_idx_frames[-1, 1] > ds.index[-1]:
    trial_idx_frames[-1, 1] = ds.index[-1]
if trial_idx_frames[0, 0] < 0:
    trial_idx_frames[0, 0] = 0
    
traces = []
for i, frame in enumerate(trial_idx_frames):
    ds_slice = ds.iloc[frame[0]:frame[1], :].copy()
    ds_slice['trial'] = i
    traces.append(ds_slice)

traces = pd.concat(traces, axis=0).reset_index(drop=True)
traces = traces.groupby('trial').agg(list)

plot_list = []

for cell in cells:
    resps = list_lists_to_array(traces[cell].tolist(), prepend=True)
    mean = np.nanmean(resps, axis=0)

    trials_list = [hv.Curve(resps[r, :]).opts(color='blue', alpha=0.15) for r in np.arange(resps.shape[0])]
    mean_list = [hv.Curve((np.arange(0, mean.shape[0]), mean)).opts(color='red', alpha=0.5, xticks=[(0, -2), (40, 0), (90, 2.5), (140, 5)], xlabel='', ylabel='', title=cell, fontsize={'xticks': 10, 'yticks': 10,})]
    vlines = [hv.VLine(40).opts(color='k', line_width=1), hv.VLine(140).opts(color='k', line_width=1)]
    plot_list.append(hv.Overlay(trials_list + mean_list + vlines).opts(width=150, height=150))

vis_trig_avg = hv.Layout(plot_list).cols(10)
vis_trig_avg

# Running onset triggered responses

In [None]:
def consecutive(data, stepsize=1):
    return np.split(data, np.where(np.diff(data) > stepsize)[0]+1)

In [None]:
ds = exp.norm_dff
ds.dropna(inplace=True)

a = exp.kinematics[exp.kinematics.is_running > 0].index.to_numpy()
# Group contiuous bouts
min_duration = processing_parameters.wf_frame_rate * 1
bouts = [bout for bout in consecutive(a, stepsize=min_duration) if len(bout) >= min_duration]
bouts = np.array([[bout[0], bout[0] + 5*processing_parameters.wf_frame_rate] for bout in bouts])

bouts[:, 0] -= 2 * processing_parameters.wf_frame_rate
# bouts[:, 1] += 2* processing_parameters.wf_frame_rate
if bouts[-1, 1] > ds.index[-1]:
    bouts[-1, 1] = ds.index[-1]
if bouts[0, 0] < 0:
    bouts[0, 0] = 0
    
traces = []
for i, frame in enumerate(bouts):
    ds_slice = ds.iloc[frame[0]:frame[1], :].copy()
    ds_slice['trial'] = i
    traces.append(ds_slice)

traces = pd.concat(traces, axis=0).reset_index(drop=True)
traces = traces.groupby('trial').agg(list)

plot_list = []

for cell in cells:
    resps = list_lists_to_array(traces[cell].tolist(), prepend=True)
    mean = np.nanmean(resps, axis=0)

    trials_list = [hv.Curve(resps[r, :]).opts(color='blue', alpha=0.15) for r in np.arange(resps.shape[0])]
    mean_list = [hv.Curve((np.arange(0, mean.shape[0]), mean)).opts(color='red', xticks=[(0, -2), (40, 0), (90, 2.5), (140, 5)], xlabel='', ylabel='', title=cell, fontsize={'xticks': 10, 'yticks': 10,})]
    vlines = [hv.VLine(40).opts(color='k', line_width=2), hv.VLine(140).opts(color='k', line_width=2)]
    plot_list.append(hv.Overlay(trials_list + mean_list + vlines).opts(width=120, height=150))

running_trig_avg = hv.Layout(plot_list).cols(10)
running_trig_avg

In [None]:
ds = exp.norm_dff
ds.dropna(inplace=True)

frame_rate = processing_parameters.wf_frame_rate    #fps
cells = [col for col in ds if 'cell' in col]
trial_idx_frames = ds[ds.direction > -1000].groupby(['trial_num']).apply(lambda x: [x.index[0], x.index[0] + 5*frame_rate])
trial_idx_frames = np.array(trial_idx_frames.to_list())

# trial_idx_frames[:, 0] -= 2* frame_rate
# trial_idx_frames[:, 1] += 2* frame_rate
if trial_idx_frames[-1, 1] > ds.index[-1]:
    trial_idx_frames[-1, 1] = ds.index[-1]
if trial_idx_frames[0, 0] < 0:
    trial_idx_frames[0, 0] = 0
    
traces = []
for i, frame in enumerate(trial_idx_frames):
    ds_slice = ds.iloc[frame[0]:frame[1], :].copy()
    ds_slice['trial_num'] = i + 1
    traces.append(ds_slice)

traces = pd.concat(traces, axis=0).reset_index(drop=True)
mean_trial_activity = traces.groupby(['direction_wrapped', 'trial_num']).agg(np.nanmean).reset_index()
mean_trial_activity = mean_trial_activity.drop(mean_trial_activity[mean_trial_activity.direction_wrapped == -1000].index).sort_values('trial_num')

In [None]:
mean_trial_activity.direction_wrapped.unique()

# Check trials

In [None]:
ds = exp.norm_dff
ds.dropna(inplace=True)
cells = [col for col in ds if 'cell' in col]

In [None]:
%matplotlib inline
ds_drop = ds[ds.trial_num > 0].copy()
trial_duration = ds_drop.groupby(['trial_num']).apply(lambda x: x['time_vector'].to_list()[-1] - x['time_vector'].to_list()[0])
start_stop_trial = ds_drop.groupby(['trial_num']).apply(lambda x: [x['time_vector'].to_list()[0],  x['time_vector'].to_list()[-1]])
starts = ds_drop.groupby(['trial_num']).apply(lambda x: x['time_vector'].to_list()[0])
stops = ds_drop.groupby(['trial_num']).apply(lambda x: x['time_vector'].to_list()[-1])
print(np.sum(trial_duration <= 5), len(trial_duration))

In [None]:
trial_duration

In [None]:
frame_rate = 20    #fps
cells = [col for col in ds if 'cell' in col]
true_trial_times = ds[ds.direction > -1000].groupby(['trial_num']).apply(lambda x: [x.time_vector.to_list()[0], x.time_vector.to_list()[5*frame_rate]])
true_trial_duration = ds[ds.direction > -1000].groupby(['trial_num']).apply(lambda x: x.time_vector.to_list()[5*frame_rate] - x.time_vector.to_list()[0])
true_trial_starts = ds[ds.direction > -1000].groupby(['trial_num']).apply(lambda x: x.time_vector.to_list()[0])
true_trial_stops = ds[ds.direction > -1000].groupby(['trial_num']).apply(lambda x: x.time_vector.to_list()[5*frame_rate])

In [None]:
np.diff(true_trial_stops)

In [None]:
np.diff(true_trial_starts)

In [None]:
true_trial_duration