# Preamble

## Import Packages

In [None]:
import os
import datetime
from tqdm.notebook import tqdm
import numpy as np
import scipy as sp
import quantities as pq
import elephant
import pandas as pd
import statsmodels.api as sm
import neurotic
from neurotic.datasets.data import _detect_spikes, _find_bursts
from modules.utils import BehaviorsDataFrame, DownsampleNeoSignal
from modules import r_stats
from modules.plot_utils import add_scalebar, solve_figure_horizontal_dimensions, solve_figure_vertical_dimensions, plot_signals_with_scalebars

pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import matplotlib.lines as mlines
from matplotlib.markers import CARETLEFT, CARETRIGHT, CARETUP, CARETDOWN, CARETUPBASE
import seaborn as sns

In [None]:
import warnings

# don't warn about invalid comparisons to NaN
np.seterr(invalid='ignore')

# np.nanmax raises a warning if all values are NaN and returns NaN, which is the behavior we want
warnings.filterwarnings('ignore', message='All-NaN slice encountered')

# elephant.statistics.instantaneous_rate always complains about negative values
warnings.filterwarnings('ignore', message='Instantaneous firing rate approximation contains '
                                          'negative values, possibly caused due to machine '
                                          'precision errors')

# with matplotlib>=3.1 and seaborn<=0.9.0, deprecation warnings are raised
# whenever tick marks are placed on the right axis but not the left
from matplotlib.cbook import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)

# don't complain about opening too many figures
plt.rcParams.update({'figure.max_open_warning': 0})

In [None]:
# make figures interactive and inline
%matplotlib notebook

## Plot Settings

In [None]:
export_dir = 'motor-pattern-exemplars'
if not os.path.exists(export_dir):
    os.mkdir(export_dir)

In [None]:
# general plot settings
sns.set(
    style = 'ticks',
    font_scale = 1,
    font = ['Palatino Linotype', 'serif'],
)

In [None]:
unit_colors = {
    'B38':       '#EFBF46', # yellow
    'I2':        '#DC5151', # red
    'B8a/b':     '#DA8BC3', # pink
    'B6/B9':     '#64B5CD', # light blue
    'B3/B6/B9':  '#5A9BC5', # medium blue
    'B3':        '#4F80BD', # dark blue
    'B4/B5':     '#00A86B', # jade green
    'Force':     '#1A1A1A', # very dark gray
}
force_colors = {
    'shoulder':     unit_colors['B38'],
    'dip':          unit_colors['I2'],
    'initial rise': unit_colors['B8a/b'],
    'rise':         unit_colors['B8a/b'],
    'plateau':      unit_colors['B3/B6/B9'],
    'drop':         '#808080', # medium gray
}

In [None]:
# display the selected unit colors
sns.palplot(unit_colors.values())
with plt.rc_context({'font.weight': 'bold'}):
    for i, (unit, color) in enumerate(unit_colors.items()):
        plt.gca().annotate(f'{unit}\n{color[1:]}', xy=(i, 0), ha='center', va='center', color='w')
plt.axis('off')
plt.subplots_adjust(0, 0, 1, 1)

In [None]:
# print hex codes for selected unit colors
for unit, color in unit_colors.items():
    print(f'{unit.ljust(10)} {mcolors.to_hex(color).upper()}')

In [None]:
# see simulated colorblindness for selected unit colors
print(
    'https://davidmathlogic.com/colorblind/#' +
    '-'.join([mcolors.to_hex(c).upper().replace('#', '%23') for c in unit_colors.values()]))

## Data Parameters

In [None]:
metadata_file = '../../data/metadata.yml'

In [None]:
units = [
    'B38',
    'I2',
    'B8a/b',
    'B6/B9',
    'B3',
    'B4/B5',
]

In [None]:
channel_units = [
    'uV', # I2
    'uV', # RN
    'uV', # BN2
    'uV', # BN3
    'mN', # Force
]

In [None]:
channel_names_by_animal = {
#     'JG07': ['I2-L', 'RN-L', 'BN2-L', 'BN3-L',    'Force'],
#     'JG08': ['I2',   'RN',   'BN2',   'BN3',      'Force'],
#     'JG11': ['I2',   'RN',   'BN2',   'BN3-PROX', 'Force'],
    'JG12': ['I2',   'RN',   'BN2',   'BN3-DIST', 'Force'],
#     'JG14': ['I2',   'RN',   'BN2',   'BN3-PROX', 'Force'],
}

In [None]:
epoch_types_by_food = {
    'Bite':      ['Bite (regular 5-cm nori strip)'],
    'Swallow':   ['Swallow (tape nori)'],
    'Rejection': ['Rejection (tubing)'],
}

In [None]:
feeding_bouts = {
    # (animal, food, bout_index): (data_set_name, time_window)

    ('JG12', 'Bite',      0): ('IN VIVO / JG12 / 2019-05-10 / 001', [2170.5, 2174.2]), # 1 bite
    ('JG12', 'Swallow',   0): ('IN VIVO / JG12 / 2019-05-10 / 002', [2977.3, 2984.5]), # 1 swallow
    ('JG12', 'Rejection', 0): ('IN VIVO / JG12 / 2019-05-10 / 002', [1429.3, 1441.3]), # 1 rejection
}

## Helper Functions

In [None]:
def query_union(queries):
    return '|'.join([f'({q})' for q in queries if q is not None])

def label2query(label):
    animal, food = label[:4], label[5:]
    query = f'(Animal == "{animal}") & (Food == "{food}")'
    return query

def contains(series, string):
#     return np.array([string in x for x in series])
    return series.map(lambda x: string in x)

In [None]:
def get_sig(blk, channel):
    sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == channel), None)
    if sig is None:
        raise Exception(f'Channel "{channel}" could not be found')
    else:
        return sig

In [None]:
def get_sig_index(blk, channel):
    index = next((i for i, sig in enumerate(blk.segments[0].analogsignals) if sig.name == channel), None)
    if index is None:
        raise Exception(f'Channel "{channel}" could not be found')
    else:
        return index

In [None]:
def apply_filters(blk, metadata):
    # nearly identical to neurotic's implementation except
    # time_slice ensures proxies are loaded
    
    for sig_filter in metadata['filters']:
        index = get_sig_index(blk, sig_filter['channel'])
        high = sig_filter.get('highpass', None)
        low  = sig_filter.get('lowpass',  None)
        if high:
            high *= pq.Hz
        if low:
            low  *= pq.Hz
        blk.segments[0].analogsignals[index] = elephant.signal_processing.butter(  # may raise a FutureWarning
            signal = blk.segments[0].analogsignals[index].time_slice(None, None),
            highpass_freq = high,
            lowpass_freq  = low,
        )
    
    return blk

In [None]:
def is_good_burst(burst):
    time, duration, n_spikes = burst
    return duration >= 0.5*pq.s and n_spikes > 2

In [None]:
# must first convert args from quantities to simple ndarrays
# (use .rescale('s').magnitude)
def normalize_time(fixed_times, t, extrapolate=True):
    
    if not isinstance(t, np.ndarray):
        if type(t) is list:
            t = np.array(t)
        else:
            t = np.array([t])
    
    assert not isinstance(fixed_times, pq.Quantity), f'fixed_times should not be a quantity (use .magnitude)'
    assert not isinstance(t, pq.Quantity), f't should not be a quantity (use .magnitude)'
    assert np.all(np.diff(fixed_times[~np.isnan(fixed_times)])>=0), f'fixed_times must be sorted: {fixed_times}'
    
    # create a copy in which NaNs are removed
    # - this is done because searchsorted does not work well with NaNs
    # - infinities are retained here
    fixed_times_without_nans = fixed_times[np.where(~np.isnan(fixed_times))[0]]
    
    # find the indexes of the fixed values that are just after each value in t
    indexes = np.searchsorted(fixed_times_without_nans, t)
    
    # adjust indexes to account for the NaNs that were removed
    nan_indexes = np.where(np.isnan(fixed_times))[0]
    for nan_index in nan_indexes:
        indexes[indexes >= nan_index] += 1
    
    # increment/decrement any index equal to 0/N, where N=len(fixed_times)
    # - this is needed for values in t that are less/greater than the min/max fixed
    #   time, and for NaNs in t which get assigned an index of N by searchsorted
    # - for values in t less/greater than the min/max fixed time, this
    #   increment/decrement in index will prepare that value to be normalized
    #   using extrapolation based on the first/last interval in fixed_times
    indexes = np.clip(indexes, 1, len(fixed_times)-1)
    
    # compute the normalized values of t using linear interpolation between the
    # bordering fixed times
    # - normalization of values in t that are less than the min fixed time is
    #   accomplished by extrapolation, as the fraction (after-t)/(after-before)
    #   will be greater than 1 since the t value is in fact earlier than the
    #   "before" fixed time
    # - normalization of values in t that are greater than the max fixed time is
    #   accomplished by extrapolation, as the fraction (after-t)/(after-before)
    #   will be less than 0 since the t value is in fact later than the "after"
    #   fixed time
    before = fixed_times[indexes-1]
    after  = fixed_times[indexes]
    t_normalized = indexes - (after-t)/(after-before)
    
    if not extrapolate:
        # extrapolation is already done, so here we undo it by setting to NaN
        # the normalized value for any t that is less/greater than the min/max
        # fixed time
        with np.errstate(invalid='ignore'): # don't warn about invalid comparisons to NaN
            t_normalized[t < np.nanmin(fixed_times)] = np.nan
            t_normalized[t > np.nanmax(fixed_times)] = np.nan
    
    return t_normalized

In [None]:
# must first convert args from quantities to simple ndarrays
# (use .rescale('s').magnitude)
def unnormalize_time(fixed_times, t_normalized, extrapolate=True):
    
    if not isinstance(t_normalized, np.ndarray):
        if type(t_normalized) is list:
            t_normalized = np.array(t_normalized)
        else:
            t_normalized = np.array([t_normalized])
    
    assert not isinstance(fixed_times, pq.Quantity), f'fixed_times should not be a quantity (use .magnitude)'
    assert not isinstance(t_normalized, pq.Quantity), f't_normalized should not be a quantity (use .magnitude)'
    assert np.all(np.diff(fixed_times[~np.isnan(fixed_times)])>=0), f'fixed_times must be sorted: {fixed_times}'
    
    # get the index of the fixed time that comes after each t
    indexes = np.ceil(t_normalized)
    
    # clip the "after" indexes so that the first or last interval
    # in fixed_times will be used for extrapolation
    indexes = np.clip(indexes, 1, len(fixed_times)-1)
    
    # get the fixed times before and after each t
    before = np.array([fixed_times[int(i)-1] if not np.isnan(i) else np.nan for i in indexes])
    after  = np.array([fixed_times[int(i)]   if not np.isnan(i) else np.nan for i in indexes])
    
    # compute the real values of t
    t = after + (t_normalized-indexes)*(after-before)
    
    if not extrapolate:
        # extrapolation is already done, so here we undo it by setting to NaN
        # the value for any t that is less/greater than the min/max fixed time
        with np.errstate(invalid='ignore'): # don't warn about invalid comparisons to NaN
            t[t < np.nanmin(fixed_times)] = np.nan
            t[t > np.nanmax(fixed_times)] = np.nan
    
    return t

In [None]:
# these arrays are used as interp_times for resample_sig_in_normalized_time
# depending on the segmentation scheme
# - linspace ensures samples will be taken at regular intervals in normalized time
interp_resolution = 1000
video_seg_interp_times = np.linspace(0, 9, interp_resolution)
force_seg_interp_times = np.linspace(0, 9, interp_resolution)

def resample_sig_in_normalized_time(fixed_times, sig, interp_times):
    
    # get normalized times and signal values
    # - normalize_time will put into times_normalized a np.nan wherever a time was not normalizable,
    #   i.e., wherever the time occurred adjacent to a missing fixed time (np.nan in fixed_times)
    times = sig.times.rescale('s').magnitude
    times_normalized = normalize_time(fixed_times.magnitude, times)
    y = sig.magnitude.flatten()

    # drop times that could not be normalized due to missing fixed times
    # - interp1d will erroneously interpolate across the gaps created by this deletion,
    #   but we will then replace those interpolated values with np.nan
    where_normalizable = np.where(~np.isnan(times_normalized))[0]
    times_normalized = times_normalized[where_normalizable]
    y = y[where_normalizable]

    # resample evenly in normalized time
    # - with bounds_error=False and fill_value=np.nan, interp1d will put np.nan in places outside
    #   the min and max of times_normalized
    # - interp1d will interpolate across the regions we deleted, but we want np.nan there,
    #   so we will insert them manually in the next step
    interp_func = sp.interpolate.interp1d(times_normalized, y, kind='linear', bounds_error=False, fill_value=np.nan)
    y_interp = interp_func(interp_times)

    # replace erroneously interpolated values with np.nan
    # - now the points in y_interp which would correspond to times that could not be normalized
    #   have been set to np.nan
    where_not_normalizable = np.where(np.isnan(unnormalize_time(fixed_times.magnitude, interp_times)))[0]
    y_interp[where_not_normalizable] = np.nan

    return y_interp

In [None]:
def print_column_analysis(column):
    unit = column.split('(')[-1].split(')')[0]
    mean = df_all[column].mean()
    std = df_all[column].std()
    cv = std/abs(mean)
    print(f'Overall mean +/- std: {mean:.2f} +/- {std:.2f} {unit}')
    print(f'Overall coefficient of variation CV: {cv:.2f}')

In [None]:
def differences_test(x, y, x_label='GROUP 1', y_label='GROUP 2', measure_label='MEASURE', units='UNITS', alpha=0.05):

    # descriptive statistics
    x_mean = np.mean(x)
    y_mean = np.mean(y)
    x_std = np.std(x, ddof=1)
    y_std = np.std(y, ddof=1)
    x_n = len(x)
    y_n = len(y)
    assert x_n == y_n, 'expected x and y to be paired but they have unequal length'
    print(f'{x_label}: (M = {x_mean:g}, SD = {x_std:g}, N = {x_n:g})')
    print(f'{y_label}: (M = {y_mean:g}, SD = {y_std:g}, N = {y_n:g})')
    print()

    # Shapiro-Wilk test for normality of differences
    # - equivalent R test: shapiro.test(x-y)
    shapiro_result = r_stats.shapiro_test(x.values-y.values)
    shapiro_W, shapiro_p = shapiro_result['W'], shapiro_result['p']
    shapiro_signif = '*' if shapiro_p < alpha else '(n.s.)'
    print(f'H0: Differences have normal distribution, W = {shapiro_W:g},\tp = {shapiro_p:g} {shapiro_signif}')

    if shapiro_p >= 0.05:
        print('- Because the differences can be assumed to be normal, a paired t-test will be used')

        # paired one-tailed T-test for an increase in means
        # - equivalent R test : t.test(x, y, paired=TRUE, alternative="greater")
        ttest_result = r_stats.t_test(x.values, y.values, paired=True, alternative='greater')
        ttest_t, ttest_p = ttest_result['t'], ttest_result['p']
        ttest_signif = '*' if ttest_p < alpha and x_mean > y_mean else '(n.s.)'
        print(f'H0: Difference in means is not positive,  t = {ttest_t:g},\tp = {ttest_p:g} {ttest_signif}')

        # Cohen's d for effect size
        # - equivalent R function: library(effsize); cohen.d(x, y)
        # - although the R function offers a paired=TRUE case, I'm not using it here
        #   because I don't understand the paper cited in the documentation. It appears
        #   to be a variation on a different effect size measure, so it's unclear that
        #   it should be referred to as "Cohen's d". When paired=TRUE, the result is
        #   generally a little smaller.
        cohen_d = r_stats.cohen_d(x.values, y.values)['estimate']
        print()
        print(f'Effect size: Cohen\'s d = {cohen_d:g}')
        
        if ttest_p < alpha and x_mean > y_mean:
            print()
            print(f'"A paired-samples one-tailed t-test indicated that {measure_label} for the ' \
                  f'[{x_n}] animals were significantly [greater/longer] for ' \
                  f'{x_label} (M = {x_mean:.2f} {units}, SD = {x_std:.2f} {units}) than for ' \
                  f'{y_label} (M = {y_mean:.2f} {units}, SD = {y_std:.2f} {units}), ' \
                  f't = {ttest_t:.3f}, df = {x_n-1}, p = {ttest_p:.3f}, Cohen\'s d = {cohen_d:.2f}."')
        else:
            print()
            print(f'"A paired-samples one-tailed t-test indicated no significant increase in {measure_label} in the [{x_n}] animals between ' \
                  f'{x_label} (M = {x_mean:.2f} {units}, SD = {x_std:.2f} {units}) and ' \
                  f'{y_label} (M = {y_mean:.2f} {units}, SD = {y_std:.2f} {units}), ' \
                  f't = {ttest_t:.3f}, df = {x_n-1}, p = {ttest_p:.3f}, Cohen\'s d = {cohen_d:.2f}."')
        
        return ttest_signif

    else:
        print('- Because the differences cannot be assumed to be normal, a Wilcoxon signed rank test will be used')

        # paired Wilcoxon signed rank test for increase in locations (medians?)
        # - equivalent R test: wilcox.test(x, y, paired=TRUE, alternative="greater")
        wilcoxon_result = r_stats.wilcox_test(x.values, y.values, paired=True, alternative='greater')
        wilcoxon_W, wilcoxon_p = wilcoxon_result['W'], wilcoxon_result['p']
        wilcoxon_signif = '*' if wilcoxon_p < alpha else '(n.s.)'
        print(f'H0: Difference in medians is not positive, W = {wilcoxon_W:g},\tp = {wilcoxon_p:g} {wilcoxon_signif}')
        
        return wilcoxon_signif

---

# Crunch the Numbers

## Download the Data

In [None]:
# download each file that is not already stored locally
metadata = neurotic.MetadataSelector(file=metadata_file)
for (animal, food, bout_index), (data_set_name, time_window) in feeding_bouts.items():
    metadata.select(data_set_name)
    metadata.download_all_data_files()

## Import and Process the Data

In [None]:
start = datetime.datetime.now()

# use Neo RawIO lazy loading to load much faster and using less memory
# - with lazy=True, filtering parameters specified in metadata are ignored
#     - note: filters are replaced below and applied manually anyway
# - with lazy=True, loading via time_slice requires neo>=0.8.0
lazy = True

# load the metadata containing file paths
metadata = neurotic.MetadataSelector(file=metadata_file)

# filter epochs for each bout and perform calculations
df_list = []
last_data_set_name = None
pbar = tqdm(total=len(feeding_bouts), unit='feeding bout')
for (animal, food, bout_index), (data_set_name, time_window) in feeding_bouts.items():

    channel_names = channel_names_by_animal[animal]
    epoch_types = epoch_types_by_food[food]

    ###
    ### LOAD DATASET
    ###

    metadata.select(data_set_name)

    if data_set_name is last_data_set_name:
        # skip reloading the data if it's already in memory
        pass
    else:
        blk = neurotic.load_dataset(metadata, lazy=lazy)
        if lazy:
            # manually perform filters
            blk = apply_filters(blk, metadata)

    last_data_set_name = data_set_name



    ###
    ### LOCATE BEHAVIOR EPOCHS AND SUBEPOCHS
    ###

    # construct a query for locating behaviors
    behavior_query = f'(Type in {epoch_types}) & ({time_window[0]} <= Start) & (End <= {time_window[1]})'

    # construct queries for locating epochs associated with each behavior
    # - each query should match at most one epoch
    # - dictionary keys are used as prefixes for the names of new columns
    subepoch_queries = {}

    subepoch_queries['B38 activity']            = (f'(Type == "B38 activity") & ' \
                                                   f'(@behavior_start-3 <= End) & (End <= @behavior_start+4)',
                                                   'last') # use last if there are multiple matches
                                                  # must end within a few seconds of behavior start (3 before or 4 after)

    subepoch_queries['I2 protraction activity'] = f'(Type == "I2 protraction activity") & ' \
                                                  f'(@behavior_start-1 <= Start) & (End <= @behavior_end)'
                                                  # must start no earlier than 1 second before behavior and end within it

    subepoch_queries['B8 activity']             = f'(Type == "B8 activity") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)'
                                                  # must be fully contained within behavior

    subepoch_queries['B3/6/9/10 activity']      = f'(Type == "B3/6/9/10 activity") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)'
                                                  # must be fully contained within behavior

    subepoch_queries['B4/B5 activity']          = (f'(Type == "B4/B5 activity") & ' \
                                                   f'(@behavior_start <= Start) & (Start <= @behavior_end)',
                                                   'first') # use first if there are multiple matches
                                                  # must start within behavior

    subepoch_queries['Force shoulder end']      = f'(Type == "Force shoulder end") & ' \
                                                  f'(@behavior_start-3 <= End) & (End <= @behavior_start+3)'
                                                  # must end within 3 seconds of behavior start

    subepoch_queries['Force rise start']        = f'(Type == "Force rise start") & ' \
                                                  f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                  # must start within behavior

    subepoch_queries['Force plateau start']     = f'(Type == "Force plateau start") & ' \
                                                  f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                  # must start within behavior

    subepoch_queries['Force plateau end']       = f'(Type == "Force plateau end") & ' \
                                                  f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)'
                                                  # must start within 2 seconds of behavior end

    subepoch_queries['Force drop end']          = f'(Type == "Force drop end") & ' \
                                                  f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)'
                                                  # must start within 2 seconds of behavior end

    subepoch_queries['Inward movement']         = f'(Type == "Inward movement") & ' \
                                                  f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                  # must start within behavior

    subepoch_queries['Outward movement']        = f'(Type == "Outward movement") & ' \
                                                  f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                  # must start within behavior

    # construct a table in which each row is a behavior and subepoch data
    # is added as columns, e.g. df['B38 activity start (s)']
    df = BehaviorsDataFrame(blk.segments[0].epochs, behavior_query, subepoch_queries)

    # renumber behaviors assuming all behaviors are from a single contiguous sequence
    df = df.sort_values('Start (s)').reset_index(drop=True).rename_axis('Behavior_index')



    ###
    ### START CALCULATIONS
    ###

    # some columns must have type 'object', which
    # can be accomplished by initializing with None or np.nan
    df['Video segmentation times (s)'] = None
    df['Force segmentation times (s)'] = None
    df['Force, video segmented interpolation (mN)'] = None
    df['Force, force segmented interpolation (mN)'] = None
    for unit in units:
        df[f'{unit} spike train'] = None
        df[f'{unit} firing rate (Hz)'] = None
        df[f'{unit} firing rate, video segmented interpolation (Hz)'] = None
        df[f'{unit} firing rate, force segmented interpolation (Hz)'] = None
        df[f'{unit} all bursts'] = None

        # while we're at it, initialize some other things that might otherwise never be given values
        df[f'{unit} burst start (s)'] = np.nan
        df[f'{unit} burst end (s)'] = np.nan
        df[f'{unit} burst duration (s)'] = 0
        df[f'{unit} burst spike count'] = 0
        df[f'{unit} burst mean frequency (Hz)'] = 0
        df[f'{unit} burst start (video seg normalized)'] = np.nan
        df[f'{unit} burst end (video seg normalized)'] = np.nan
        df[f'{unit} burst start (force seg normalized)'] = np.nan
        df[f'{unit} burst end (force seg normalized)'] = np.nan
    df['Inward movement start (video seg normalized)'] = np.nan
    df['Inward movement end (video seg normalized)'] = np.nan
    df['Inward movement start (force seg normalized)'] = np.nan
    df['Inward movement end (force seg normalized)'] = np.nan



    # get smoothed force for entire time window
    channel = 'Force'
    sig = get_sig(blk, channel)
    if lazy:
        sig = sig.time_slice(None, None)
    sig = elephant.signal_processing.butter(sig, lowpass_freq = 5*pq.Hz)
    sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
    sig = sig.rescale('mN')
    force_smoothed_sig = sig



    df['End to next start (s)'] = np.nan
    df['Start to next start (s)'] = np.nan
    previous_i = None

    # iterate over all swallows
    for j, i in enumerate(df.index):

        behavior_start = df.loc[i, 'Start (s)']*pq.s
        behavior_end = df.loc[i, 'End (s)']*pq.s
        inward_movement_start = df.loc[i, 'Inward movement start (s)']*pq.s
        inward_movement_end = df.loc[i, 'Inward movement end (s)']*pq.s

        # calculate interbehavior intervals assuming all behaviors are from a single contiguous sequence
        if previous_i is not None:
            df.loc[previous_i, 'End to next start (s)']   = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'End (s)']
            df.loc[previous_i, 'Start to next start (s)'] = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'Start (s)']
            df.loc[previous_i, 'Inward movement start to next inward movement start (s)'] = \
                df.loc[i, 'Inward movement start (s)'] - df.loc[previous_i, 'Inward movement start (s)']
            df.loc[previous_i, 'Inward movement end to next inward movement start (s)'] = \
                df.loc[i, 'Inward movement start (s)'] - df.loc[previous_i, 'Inward movement end (s)']
        previous_i = i



        ###
        ### VIDEO SEGMENTATION
        ###

        # inward movement start and end are required
        video_is_segmented = np.all(np.isfinite(np.array([
            inward_movement_start, inward_movement_end])))

        if video_is_segmented:
            inward_movement_duration = df.loc[i, 'Inward movement duration (s)']*pq.s

            # get the list of fixed times for normalization
            video_segmentation_times = df.at[i, 'Video segmentation times (s)'] = np.array([
                inward_movement_start-inward_movement_duration*4,
                inward_movement_start-inward_movement_duration*3,
                inward_movement_start-inward_movement_duration*2,
                inward_movement_start-inward_movement_duration*1,
                inward_movement_start,
                inward_movement_end,
                inward_movement_end+inward_movement_duration*1,
                inward_movement_end+inward_movement_duration*2,
                inward_movement_end+inward_movement_duration*3,
                inward_movement_end+inward_movement_duration*4,
            ])*pq.s # 'at', not 'loc', is important for inserting list into cell

        else: # video is not segmented
            video_segmentation_times = df.at[i, 'Video segmentation times (s)'] = np.array([
                np.nan,
                np.nan,
                np.nan,
                np.nan,
                np.nan,
                np.nan,
                np.nan,
                np.nan,
                np.nan,
                np.nan,
            ])*pq.s # 'at', not 'loc', is important for inserting list into cell



        ###
        ### FORCE SEGMENTATION
        ###

        force_shoulder_end  = df.loc[i, 'Force shoulder end start (s)']*pq.s  # start of "Force shoulder end" epoch
        force_rise_start    = df.loc[i, 'Force rise start start (s)']*pq.s    # start of "Force rise start" epoch
        force_plateau_start = df.loc[i, 'Force plateau start start (s)']*pq.s # start of "Force plateau start" epoch
        force_plateau_end   = df.loc[i, 'Force plateau end start (s)']*pq.s   # start of "Force plateau end" epoch
        force_drop_end      = df.loc[i, 'Force drop end start (s)']*pq.s      # start of "Force drop end" epoch

        # force rise start, plateau start and end, and drop end are required
        force_is_segmented = np.all(np.isfinite(np.array([
            force_rise_start, force_plateau_start, force_plateau_end, force_drop_end])))

        if force_is_segmented:
            # get some times for the previous and next swallow
            epochs_force_shoulder_end  = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force shoulder end'), None)
            epochs_force_rise_start    = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force rise start'), None)
            epochs_force_plateau_start = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force plateau start'), None)
            epochs_force_plateau_end   = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force plateau end'), None)
            epochs_force_drop_end      = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force drop end'), None)
            assert epochs_force_shoulder_end  is not None, 'failed to find "Force shoulder end" epochs'
            assert epochs_force_rise_start    is not None, 'failed to find "Force rise start" epochs'
            assert epochs_force_plateau_start is not None, 'failed to find "Force plateau start" epochs'
            assert epochs_force_plateau_end   is not None, 'failed to find "Force plateau end" epochs'
            assert epochs_force_drop_end      is not None, 'failed to find "Force drop end" epochs'

            try:
                prev_force_plateau_start = df.loc[i, 'Previous force plateau start (s)'] = epochs_force_plateau_start.time_slice(None, force_rise_start)[-1]
                assert force_rise_start-prev_force_plateau_start < 16*pq.s, f'for swallow {i}, previous force plateau start is too far away'
            except IndexError:
                prev_force_plateau_start = df.loc[i, 'Previous force plateau start (s)'] = np.nan

            try:
                prev_force_plateau_end = df.loc[i, 'Previous force plateau end (s)'] = epochs_force_plateau_end.time_slice(None, force_rise_start)[-1]
                assert force_rise_start-prev_force_plateau_end < 12*pq.s, f'for swallow {i}, previous force plateau end is too far away'
            except IndexError:
                prev_force_plateau_end = df.loc[i, 'Previous force plateau end (s)'] = np.nan

            try:
                prev_force_drop_end = df.loc[i, 'Previous force drop end (s)'] = epochs_force_drop_end.time_slice(None, force_rise_start)[-1]
                assert force_rise_start-prev_force_drop_end < 12*pq.s, f'for swallow {i}, previous force drop end is too far away'
            except IndexError:
                prev_force_drop_end = df.loc[i, 'Previous force drop end (s)'] = np.nan

            try:
                next_force_rise_start = df.loc[i, 'Next force rise start (s)'] = epochs_force_rise_start.time_slice(force_drop_end, None)[0]
                assert next_force_rise_start-force_drop_end < 12*pq.s, f'for swallow {i}, next force rise start is too far away'
            except IndexError:
                next_force_rise_start = df.loc[i, 'Next force rise start (s)'] = np.nan

            try:
                next_force_shoulder_end = df.loc[i, 'Next force shoulder end (s)'] = epochs_force_shoulder_end.time_slice(force_drop_end, None)[0]
                if next_force_shoulder_end > next_force_rise_start:
                    # next swallow did not have a shoulder and we instead grabbed a later shoulder
                    next_force_shoulder_end = df.loc[i, 'Next force shoulder end (s)'] = np.nan
            except IndexError:
                next_force_shoulder_end = df.loc[i, 'Next force shoulder end (s)'] = np.nan

            # get the list of fixed times for normalization
            force_segmentation_times = df.at[i, 'Force segmentation times (s)'] = np.array([
                prev_force_plateau_start,
                prev_force_plateau_end,
                prev_force_drop_end,
                force_shoulder_end,
                force_rise_start,
                force_plateau_start,
                force_plateau_end,
                force_drop_end,
                next_force_shoulder_end,
                next_force_rise_start,
            ])*pq.s # 'at', not 'loc', is important for inserting list into cell

        else: # force is not segmented
            force_segmentation_times = df.at[i, 'Force segmentation times (s)'] = np.array([
                np.nan,
                np.nan,
                np.nan,
                np.nan,
                np.nan,
                np.nan,
                np.nan,
                np.nan,
                np.nan,
                np.nan,
            ])*pq.s # 'at', not 'loc', is important for inserting list into cell



        ###
        ### BEHAVIORAL MARKERS
        ###

        if video_is_segmented:
            df.loc[i, 'Inward movement start (video seg normalized)'] = \
                normalize_time(video_segmentation_times.magnitude, float(inward_movement_start))
            df.loc[i, 'Inward movement end (video seg normalized)'] = \
                normalize_time(video_segmentation_times.magnitude, float(inward_movement_end))

        if force_is_segmented:
            df.loc[i, 'Inward movement start (force seg normalized)'] = \
                normalize_time(force_segmentation_times.magnitude, float(inward_movement_start))
            df.loc[i, 'Inward movement end (force seg normalized)'] = \
                normalize_time(force_segmentation_times.magnitude, float(inward_movement_end))



        ###
        ### FORCE QUANTIFICATION
        ###

        if force_is_segmented:

            # get smoothed force for whole behavior for remaining force calculations
            sig = force_smoothed_sig
            if np.isfinite(force_shoulder_end):
                sig = sig.time_slice(force_shoulder_end - 0.01*pq.s, force_drop_end + 0.01*pq.s)
            else:
                sig = sig.time_slice(force_rise_start - 1*pq.s, force_drop_end + 0.01*pq.s)
            sig = sig.rescale('mN')

            # find force peak, baseline, and the increase
            force_min_time = df.loc[i, 'Force minimum time (s)'] = elephant.spike_train_generation.peak_detection(sig, 999*pq.mN, sign='below')[0]
            force_min = df.loc[i, 'Force minimum (mN)'] = sig[sig.time_index(force_min_time)][0]
            force_peak_time = df.loc[i, 'Force peak time (s)'] = elephant.spike_train_generation.peak_detection(sig, 0*pq.mN)[0]
            force_peak = df.loc[i, 'Force peak (mN)'] = sig[sig.time_index(force_peak_time)][0]
            force_baseline = df.loc[i, 'Force baseline (mN)'] = sig[sig.time_index(force_rise_start)][0]
            force_increase = df.loc[i, 'Force increase (mN)'] = force_peak-force_baseline

            # find force plateau, drop, and shoulder values
            force_plateau_start_value = df.loc[i, 'Force plateau start value (mN)'] = sig[sig.time_index(force_plateau_start)][0]
            force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)'] = sig[sig.time_index(force_plateau_end)][0]
            force_drop_end_value = df.loc[i, 'Force drop end value (mN)'] = sig[sig.time_index(force_drop_end)][0]
            if np.isfinite(force_shoulder_end):
                force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)'] = sig[sig.time_index(force_shoulder_end)][0]
            else:
                force_shoulder_end_value = np.nan

            # find force rise and plateau durations
            force_rise_duration = df.loc[i, 'Force rise duration (s)'] = force_plateau_start - force_rise_start
            force_plateau_duration = df.loc[i, 'Force plateau duration (s)'] = force_plateau_end - force_plateau_start
            force_rise_plateau_duration = df.loc[i, 'Force rise and plateau duration (s)'] = force_plateau_end - force_rise_start

            # find average slope during rising phase
            force_rise_increase = df.loc[i, 'Force rise increase (mN)'] = force_plateau_start_value - force_baseline
            force_slope = df.loc[i, 'Force slope (mN/s)'] = (force_rise_increase/force_rise_duration).rescale('mN/s')



        ###
        ### FORCE NORMALIZATION
        ###

        if force_is_segmented:
            t_start = force_segmentation_times[0]-0.001*pq.s
            t_stop = force_segmentation_times[-1]+0.001*pq.s

            channel = 'Force'
            sig = get_sig(blk, channel)
            sig = sig.time_slice(t_start, t_stop)
            sig = sig.rescale(channel_units[channel_names.index(channel)])

            force_force_seg_interp = df.at[i, 'Force, force segmented interpolation (mN)'] = \
                resample_sig_in_normalized_time(force_segmentation_times, sig, force_seg_interp_times) # 'at', not 'loc', is important for inserting list into cell



        ###
        ### FIND SPIKE TRAINS
        ###

        if lazy:
            if metadata['amplitude_discriminators'] is not None:
                for discriminator in metadata['amplitude_discriminators']:
                    sig = get_sig(blk, discriminator['channel'])
                    if sig is not None:
                        sig = sig.time_slice(max(sig.t_start, behavior_start - 5*pq.s), min(sig.t_stop, behavior_end + 5*pq.s))
                        st = _detect_spikes(sig, discriminator, blk.segments[0].epochs)
                        st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                        st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                        st = st.time_slice(st_epoch_start, st_epoch_end)
                        df.at[i, f'{st.name} spike train'] = st # 'at', not 'loc', is important for inserting list into cell
        else:
            for spiketrain in blk.segments[0].spiketrains:
                discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == spiketrain.name), None)
                if discriminator is None:
                    raise Exception(f'For data set "{data_set_name}", discriminator "{spiketrain.name}" could not be found')
                st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                if np.isfinite(st_epoch_start) and np.isfinite(st_epoch_end):
                    st = spiketrain.time_slice(st_epoch_start, st_epoch_end)
                else:
                    # this unit's discriminator epoch was not located for this swallow
                    st = None
                df.at[i, f'{spiketrain.name} spike train'] = st # 'at', not 'loc', is important for inserting list into cell



        ###
        ### QUANTIFY SPIKE TRAINS AND BURSTS
        ###

        for k, unit in enumerate(units):
            st = df.loc[i, f'{unit} spike train']
            if st is not None:

                # get the neural channel
                channel = st.annotations['channels'][0]
                sig = get_sig(blk, channel)

                # create a continuous smoothed firing rate representation
                # by convolving the spike train with a kernel
                # - choice of t_start and t_stop here ensures firing rates are
                #   recorded as zero far from the burst and can be resampled later
                if video_is_segmented and force_is_segmented:
                    t_start = min(video_segmentation_times[0], force_segmentation_times[0])-0.001*pq.s
                    t_stop = max(video_segmentation_times[-1], force_segmentation_times[-1])+0.001*pq.s
                elif video_is_segmented:
                    t_start = video_segmentation_times[0]-0.001*pq.s
                    t_stop = video_segmentation_times[-1]+0.001*pq.s
                elif force_is_segmented:
                    t_start = force_segmentation_times[0]-0.001*pq.s
                    t_stop = force_segmentation_times[-1]+0.001*pq.s
                else:
                    # no segmentation available
                    t_start = behavior_start-5*pq.s
                    t_stop = behavior_end+5*pq.s
                smoothing_kernel = elephant.kernels.GaussianKernel(0.2*pq.s) # 200 ms standard deviation
                firing_rate = df.at[i, f'{unit} firing rate (Hz)'] = elephant.statistics.instantaneous_rate(
                    spiketrain=st,
                    t_start=t_start,
                    t_stop=t_stop,
                    sampling_period=sig.sampling_period,
                    kernel=smoothing_kernel,
                ) # 'at', not 'loc', is important for inserting list into cell

                # normalization
                if video_is_segmented:
                    firing_rate_video_seg_interp = df.at[i, f'{unit} firing rate, video segmented interpolation (Hz)'] = \
                        resample_sig_in_normalized_time(video_segmentation_times, firing_rate, video_seg_interp_times) # 'at', not 'loc', is important for inserting list into cell
                if force_is_segmented:
                    firing_rate_force_seg_interp = df.at[i, f'{unit} firing rate, force segmented interpolation (Hz)'] = \
                        resample_sig_in_normalized_time(force_segmentation_times, firing_rate, force_seg_interp_times) # 'at', not 'loc', is important for inserting list into cell

                if st.size > 0:

                    # find every sequence of spikes that qualifies as a burst
                    burst_thresholds = {d['spiketrain']: d['thresholds']*pq.Hz for d in metadata['burst_detectors']}
                    bursts = df.at[i, f'{unit} all bursts'] = _find_bursts(st, burst_thresholds[unit][0], burst_thresholds[unit][1]) # 'at', not 'loc', is important for inserting list into cell

                    first_burst_start = np.nan
                    last_burst_end = np.nan
                    if len(bursts) > 0:

                        for burst in zip(bursts.times, bursts.durations, bursts.array_annotations['spikes']):
                            if is_good_burst(burst):
                                time, duration, n_spikes = burst
                                first_burst_start = time
                                break # quit after finding first good burst

                        for burst in zip(reversed(bursts.times), reversed(bursts.durations), reversed(bursts.array_annotations['spikes'])):
                            if is_good_burst(burst):
                                time, duration, n_spikes = burst
                                last_burst_end = time + duration
                                break # quit after finding first (actually, last) good burst

                    # merge the first and last good bursts and anything in between
                    if np.isfinite(first_burst_start) and np.isfinite(last_burst_end):
                        st_burst = st.time_slice(first_burst_start, last_burst_end)
                        burst_start = df.loc[i, f'{unit} burst start (s)'] = first_burst_start
                        burst_end = df.loc[i, f'{unit} burst end (s)'] = last_burst_end
                        burst_duration = df.loc[i, f'{unit} burst duration (s)'] = (burst_end-burst_start)
                        burst_spike_count = df.loc[i, f'{unit} burst spike count'] = st_burst.size
                        if burst_spike_count > 1:
                            burst_mean_freq = df.loc[i, f'{unit} burst mean frequency (Hz)'] = ((burst_spike_count-1)/burst_duration).rescale('Hz')
                        if video_is_segmented:
                            df.loc[i, f'{unit} burst start (video seg normalized)'] = normalize_time(video_segmentation_times.magnitude, float(first_burst_start))
                            df.loc[i, f'{unit} burst end (video seg normalized)']   = normalize_time(video_segmentation_times.magnitude, float(last_burst_end))
                        if force_is_segmented:
                            df.loc[i, f'{unit} burst start (force seg normalized)'] = normalize_time(force_segmentation_times.magnitude, float(first_burst_start))
                            df.loc[i, f'{unit} burst end (force seg normalized)']   = normalize_time(force_segmentation_times.magnitude, float(last_burst_end))

        # B3/B6/B9
        df['B3/B6/B9 burst start (s)'] = np.nan
        df['B3/B6/B9 burst end (s)'] = np.nan
        df['B3/B6/B9 burst duration (s)'] = 0
        df['B3/B6/B9 burst spike count'] = 0
        df['B3/B6/B9 burst mean frequency (Hz)'] = 0
        b3b6b9_burst_start = df['B3/B6/B9 burst start (s)'] = df[['B6/B9 burst start (s)', 'B3 burst start (s)']].min(axis=1)
        b3b6b9_burst_end = df['B3/B6/B9 burst end (s)']   = df[['B6/B9 burst end (s)',   'B3 burst end (s)']]  .max(axis=1)
        b3b6b9_burst_duration = df['B3/B6/B9 burst duration (s)'] = b3b6b9_burst_end - b3b6b9_burst_start
        b3b6b9_burst_spike_count = df['B3/B6/B9 burst spike count'] = df['B6/B9 burst spike count'] + df['B3 burst spike count']
        b3b6b9_burst_mean_freq = df['B3/B6/B9 burst mean frequency (Hz)'] = (b3b6b9_burst_spike_count-1)/b3b6b9_burst_duration


    ###
    ### FINISH
    ###

    # index the table on 4 variables so that this dataframe can later be merged with others
    df['Animal'] = animal
    df['Food'] = food
    df['Bout_index'] = bout_index
    df = df.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])

    df_list += [df]

    pbar.update()

pbar.close()

df_all = pd.concat(df_list, sort=False).sort_index()

end = datetime.datetime.now()
print('elapsed time:', end-start)

## 🤪 Sanity Checks

In [None]:
skip_sanity_checks = True

In [None]:
if not skip_sanity_checks:
    start = datetime.datetime.now()

    # use Neo RawIO lazy loading to load much faster and using less memory
    # - with lazy=True, filtering parameters specified in metadata are ignored
    #     - note: filters are replaced below and applied manually anyway
    # - with lazy=True, loading via time_slice requires neo>=0.8.0
    lazy = True

    # load the metadata containing file paths
    metadata = neurotic.MetadataSelector(file=metadata_file)

    last_data_set_name = None
    pbar = tqdm(total=len(feeding_bouts), unit='figure')
    for (animal, food, bout_index), (data_set_name, time_window) in feeding_bouts.items():

        channel_names = channel_names_by_animal[animal]
        epoch_types = epoch_types_by_food[food]

        df = df_all.loc[animal, food, bout_index]



        ###
        ### LOAD DATASET
        ###

        metadata.select(data_set_name)

        if data_set_name is last_data_set_name:
            # skip reloading the data if it's already in memory
            pass
        else:
            blk = neurotic.load_dataset(metadata, lazy=lazy)
            if lazy:
                # manually perform filters
                blk = apply_filters(blk, metadata)

        last_data_set_name = data_set_name



        ###
        ### START FIGURE
        ###

#         figsize = (9.5, 10) # dimensions for notebook
#         figsize = (11, 8.5) # dimensions for printing
        figsize = (16, 9) # dimensions for filling wide screens
        fig, axes = plt.subplots(len(channel_names), 1, sharex=True, figsize=figsize)



        ###
        ### PLOT SIGNALS
        ###

        # plot all channels for entire time window
        for i, channel in enumerate(channel_names):
            plt.sca(axes[i])
            sig = get_sig(blk, channel)
            sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
            sig = sig.rescale(channel_units[i])
            plt.plot(sig.times, sig.magnitude, c='0.8', lw=1, zorder=-1)

            if i == 0:
                plt.title(f'({animal}, {food}, {bout_index}): {data_set_name}')

            plt.ylabel(sig.name + ' (' + sig.units.dimensionality.string + ')')
            axes[i].yaxis.set_label_coords(-0.06, 0.5)

            if i < len(channel_names)-1:
                # remove right, top, and bottom plot borders, and remove x-axis
                sns.despine(ax=plt.gca(), bottom=True)
                plt.gca().xaxis.set_visible(False)
            else:
                # remove right and top plot borders, and set x-label
                sns.despine(ax=plt.gca())
                plt.xlabel('Time (s)')

        # plot smoothed force for entire time window
        channel = 'Force'
        sig = get_sig(blk, channel)
        if lazy:
            sig = sig.time_slice(None, None)
        sig = elephant.signal_processing.butter(sig, lowpass_freq = 5*pq.Hz)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale('mN')
        plt.plot(sig.times, sig.magnitude, c='k', lw=1, zorder=0)
        force_smoothed_sig = sig



        # iterate over all swallows
        for j, i in enumerate(df.index):

            behavior_start = df.loc[i, 'Start (s)']*pq.s
            behavior_end   = df.loc[i, 'End (s)']*pq.s

            ###
            ### MOVEMENTS
            ###

            # plot inward food movement
            inward_movement_start = df.loc[i, 'Inward movement start (s)']*pq.s
            inward_movement_end   = df.loc[i, 'Inward movement end (s)']*pq.s
            if np.isfinite(inward_movement_start):
                channel = 'Force'
                ax = axes[channel_names.index(channel)]
                ax.axvspan(
                    inward_movement_start, inward_movement_end,
                    0.99, 1,
                    facecolor='k', edgecolor=None, lw=0)

            # plot outward food movement
            outward_movement_start = df.loc[i, 'Outward movement start (s)']*pq.s
            outward_movement_end   = df.loc[i, 'Outward movement end (s)']*pq.s
            if np.isfinite(outward_movement_start):
                channel = 'Force'
                ax = axes[channel_names.index(channel)]
                ax.axvspan(
                    outward_movement_start, outward_movement_end,
                    0.99, 1,
                    facecolor='#666666', edgecolor=None, lw=0)



            ###
            ### FORCE SEGMENTATION
            ###

            prev_force_drop_end = df.loc[i, 'Previous force drop end (s)']*pq.s   # start of previous "Force drop end" epoch
            force_shoulder_end  = df.loc[i, 'Force shoulder end start (s)']*pq.s  # start of "Force shoulder end" epoch
            force_rise_start    = df.loc[i, 'Force rise start start (s)']*pq.s    # start of "Force rise start" epoch
            force_plateau_start = df.loc[i, 'Force plateau start start (s)']*pq.s # start of "Force plateau start" epoch
            force_plateau_end   = df.loc[i, 'Force plateau end start (s)']*pq.s   # start of "Force plateau end" epoch
            force_drop_end      = df.loc[i, 'Force drop end start (s)']*pq.s      # start of "Force drop end" epoch

            # force rise start, plateau start and end, and drop end are required
            force_is_segmented = np.all(np.isfinite(np.array([
                force_rise_start, force_plateau_start, force_plateau_end, force_drop_end])))

            force_segmentation_times = df.at[i, 'Force segmentation times (s)']

            if force_is_segmented:

                plt.sca(axes[channel_names.index('Force')])
                
                force_min_time = df.loc[i, 'Force minimum time (s)']
                force_min = df.loc[i, 'Force minimum (mN)']
                force_peak_time = df.loc[i, 'Force peak time (s)']
                force_peak = df.loc[i, 'Force peak (mN)']
                force_baseline = df.loc[i, 'Force baseline (mN)']

                force_plateau_start_value = df.loc[i, 'Force plateau start value (mN)']
                force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)']
                force_drop_end_value = df.loc[i, 'Force drop end value (mN)']
                force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)']

                # get smoothed force for whole behavior for remaining force calculations
                sig = force_smoothed_sig
                sig = sig.time_slice(prev_force_drop_end - 0.01*pq.s, force_drop_end + 0.01*pq.s)
                sig = sig.rescale('mN')

                # plot force shoulder in color
                if np.isfinite(force_shoulder_end):
                    sig2 = sig.time_slice(prev_force_drop_end, force_shoulder_end)
                    plt.plot(sig2.times, sig2.magnitude, c=force_colors['shoulder'], lw=2, zorder=1)
                    
                # plot force rise in color
                sig2 = sig.time_slice(force_rise_start, force_plateau_start)
                plt.plot(sig2.times, sig2.magnitude, c=force_colors['rise'], lw=2, zorder=1)

                # plot force plateau in color
                sig2 = sig.time_slice(force_plateau_start, force_plateau_end)
                plt.plot(sig2.times, sig2.magnitude, c=force_colors['plateau'], lw=2, zorder=1)

                # plot force peak, baseline, and plateau values
                plt.plot([force_peak_time],     [force_peak],                marker=CARETDOWN,  markersize=5, color='k')
#                 plt.plot([force_min_time],      [force_min],                 marker=CARETUP,    markersize=5, color='k')
                plt.plot([force_rise_start],    [force_baseline],            marker=CARETUP,    markersize=5, color='k')
                plt.plot([force_plateau_start], [force_plateau_start_value], marker=CARETRIGHT, markersize=5, color='k')
                plt.plot([force_plateau_end],   [force_plateau_end_value],   marker=CARETLEFT,  markersize=5, color='k')

                # plot segmentation boundaries
                for (t, y, c) in [
                        (force_shoulder_end,  force_shoulder_end_value,  force_colors['shoulder']),
                        (force_rise_start,    force_baseline,            force_colors['rise']),
                        (force_plateau_start, force_plateau_start_value, force_colors['plateau']),
                        (force_plateau_end,   force_plateau_end_value,   force_colors['plateau']),
                        (force_drop_end,      force_drop_end_value,      force_colors['drop'])]:
                    if np.isfinite(y):
                        axes[-1].add_artist(patches.ConnectionPatch(
                            xyA=(t, y), xyB=(t, 0),
                            coordsA='data', coordsB=axes[-1].get_xaxis_transform(),
                            axesA=axes[-1], axesB=axes[-1],
                            color='0.8', lw=1, ls=':', zorder=-2))



            ###
            ### SPIKES AND BURSTS
            ###

            for k, unit in enumerate(units):
                st = df.loc[i, f'{unit} spike train']
                if st is not None and st.size > 0:

                    # get the signal for the behavior with 10 seconds cushion for spikes outside the behavior duration
                    channel = st.annotations['channels'][0]
                    sig = get_sig(blk, channel)
                    sig = sig.time_slice(max(sig.t_start, behavior_start - 10*pq.s), min(sig.t_stop, behavior_end + 10*pq.s))
                    sig = sig.rescale(channel_units[channel_names.index(channel)])

                    # plot spikes
                    plt.sca(axes[channel_names.index(channel)])
                    marker = ['.', 'x'][j%2] # alternate markers between behaviors
                    spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
                    plt.scatter(st.times.rescale('s'), spike_amplitudes, marker=marker, c=unit_colors[unit])

                    # plot burst windows
                    discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == st.name), None)
                    if discriminator is None:
                        raise Exception(f'For data set "{data_set_name}", discriminator "{st.name}" could not be found')
                    bottom = pq.Quantity(discriminator['amplitude'][0], discriminator['units']).rescale(sig.units)
                    top = pq.Quantity(discriminator['amplitude'][1], discriminator['units']).rescale(sig.units)
                    height = top-bottom
                    bursts = df.at[i, f'{unit} all bursts']
                    for burst in zip(bursts.times, bursts.durations, bursts.array_annotations['spikes']):
                        time, duration, n_spikes = burst
                        left = time
                        right = time + duration
                        width = right-left
                        linestyle = '-' if is_good_burst(burst) else '--'
                        rect = patches.Rectangle((left, bottom), width, height, ls=linestyle, edgecolor=unit_colors[unit], fill=False)
                        plt.gca().add_patch(rect)

                    # plot markers for edges of bursts
                    burst_start = df.loc[i, f'{unit} burst start (s)']
                    burst_end = df.loc[i, f'{unit} burst end (s)']
                    if top > 0:
                        plt.plot([burst_start], [top],    marker=CARETDOWN, markersize=5, color='k')
                        plt.plot([burst_end],   [top],    marker=CARETDOWN, markersize=5, color='k')
                    else:
                        plt.plot([burst_start], [bottom], marker=CARETUP,   markersize=5, color='k')
                        plt.plot([burst_end],   [bottom], marker=CARETUP,   markersize=5, color='k')



        ###
        ### FINISH FIGURE
        ###

        # optimize plot margins
        plt.subplots_adjust(
            left   = 0.1,
            right  = 0.99,
            top    = 0.96,
            bottom = 0.06,
            hspace = 0.15,
        )

        # export figure
        export_dir2 = os.path.join(export_dir, 'sanity-checks')
        if not os.path.exists(export_dir2):
            os.mkdir(export_dir2)
        plt.gcf().savefig(os.path.join(export_dir2, f'{animal} {food} {bout_index}.png'), dpi=300)

        pbar.update()

    pbar.close()

    end = datetime.datetime.now()
    print('elapsed time:', end-start)

---

# Figures

In [None]:
# # shared plot settings
# inches_per_second = 0.5
# subplot_height_in_inches = 1.1
# top_margin_in_inches = 0.50
# bottom_margin_in_inches = 0.39
# left_margin_in_inches = 0.58
# right_margin_in_inches = 0.8

# shared plot settings
inches_per_second = 0.75
subplot_height_in_inches = 0.5
top_margin_in_inches = 0.50
bottom_margin_in_inches = 0.39
left_margin_in_inches = 0.58
right_margin_in_inches = 0.8

boolean_fig_width = 5
boolean_fig_height = subplot_height_in_inches*4+top_margin_in_inches+bottom_margin_in_inches  # match 4-trace plots

## Swallow

In [None]:
start = datetime.datetime.now()

df = df_all.query('Food == "Swallow"').reset_index()
(data_set_name, time_window) = feeding_bouts[('JG12', 'Swallow', 0)]
animal, _, _ = ('JG12', 'Swallow', 0)

t_start, t_stop = time_window*pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -35,  35], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -70,  70], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
#     {'channel': 'Force',    'units': 'mN', 'ylim': [-100, 400], 'scalebar': 200}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [-20, 400], 'scalebar': 200}, #, 'decimation_factor': 100},
]
plot_names = [p['channel'] for p in plots]
plot_units = [p['units'] for p in plots]

fig_height_in_inches, bottom_fraction, top_fraction = solve_figure_vertical_dimensions(
    len(plots), subplot_height_in_inches, bottom_margin_in_inches, top_margin_in_inches, 0)
fig_width_in_inches, left_fraction, right_fraction = solve_figure_horizontal_dimensions(
    1, inches_per_second*(t_stop-t_start).magnitude, left_margin_in_inches, right_margin_in_inches, 0)

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (fig_width_in_inches, fig_height_in_inches),
    linewidth = 0.5,
    x_scalebar = 1*pq.s,
    layout_settings = dict(
        left   = left_fraction,
        right  = right_fraction,
        bottom = bottom_fraction,
        top    = top_fraction,
        hspace = 0,
    ),
)

# load the metadata
metadata = neurotic.MetadataSelector(metadata_file)
metadata.select(data_set_name)

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the signals
fig, axes = plot_signals_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

unit_burst_boxes = {
    'B38':   [-12, 12],
    'I2':    [-30, 25],
    'B8a/b': [-20, 12],
    'B6/B9': [-20, 15],
    'B3':    [-45, 35],
    'B4/B5': [-60, 55],
}

for j, i in enumerate(df.index):
    for k, unit in enumerate(units):
        st = df.loc[i, f'{unit} spike train']
        if st is not None and st.size > 0:

            # get the neural channel
            channel = st.annotations['channels'][0]
            sig = get_sig(blk, channel)

            # get the signal for the entire bout
            sig = sig.time_slice(t_start, t_stop)
            sig = sig.rescale(plot_units[plot_names.index(channel)])

            # plot spikes
            ax = axes[plot_names.index(channel)]
            spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
            ax.scatter(st.times.rescale('s'), spike_amplitudes, marker='.', s=20, c=unit_colors[unit], zorder=3)

            # plot burst windows
            left = df.loc[i, f'{unit} burst start (s)']
            right = df.loc[i, f'{unit} burst end (s)']
            if np.isfinite(left) and np.isfinite(right):
                width = right-left
                bottom, top = unit_burst_boxes[unit]
                height = top-bottom
                rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor=unit_colors[unit], fill=False, zorder=3, clip_on=False)
                ax.add_patch(rect)

# plot a zero baseline for force
force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
axes[-1].axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

# # plot force phase boundaries
# times = df.loc[0, 'Force segmentation times (s)'][2:2+6]
# for t in times:
#     axes[-1].axvline(x=t, ymin=0.15, lw=1, ls=':', c='gray', zorder=-1)

# # add arabic numerals for force phase boundaries
# axes[-1].annotate('1', xy=(times[0], -0.05), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
# axes[-1].annotate('2', xy=(times[1], -0.05), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
# axes[-1].annotate('3', xy=(times[2], -0.05), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
# axes[-1].annotate('4', xy=(times[3], -0.05), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
# axes[-1].annotate('5', xy=(times[4], -0.05), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
# axes[-1].annotate('1', xy=(times[5], -0.05), xycoords=('data', 'axes fraction'), ha='center', va='bottom')

# # add roman numerals for force phases
# axes[-1].annotate('I',   xy=(times[0:2].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')
# axes[-1].annotate('II',  xy=(times[1:3].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')
# axes[-1].annotate('III', xy=(times[2:4].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')
# axes[-1].annotate('IV',  xy=(times[3:5].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')
# axes[-1].annotate('V',   xy=(times[4:6].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')

# add unit name labels
axes[2].annotate('B38',   xy=(2978.15, 0.70), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B38'])
axes[0].annotate('I2',    xy=(2979.40, 0.75), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['I2'])
axes[1].annotate('B8a/b', xy=(2981.10, 0.80), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B8a/b'])
axes[2].annotate('B6/B9', xy=(2980.10, 0.73), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B6/B9'])
axes[2].annotate('B3',    xy=(2981.85, 0.79), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['B3'])
axes[3].annotate('B4/B5', xy=(2979.10, 0.80), xycoords=('data', 'axes fraction'), ha='right',  c=unit_colors['B4/B5'])

# add protraction box
left, right = df.loc[0, ['I2 burst start (s)', 'I2 burst end (s)']]
# bottom, top = (1, 1.15)
bottom, top = (1.2, 1.5)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', fill=False, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)
# axes[0].annotate('Prot.', xy=((left+right)/2, 1.065), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')
axes[0].annotate('Prot.', xy=((left+right)/2, bottom+0.12), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')

# add retraction box
left, right = df.loc[0, ['I2 burst end (s)', 'End (s)']] # behavior ends with end of B43 burst
# bottom, top = (1, 1.15)
bottom, top = (1.2, 1.5)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', facecolor='k', fill=True, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)
# axes[0].annotate('Retraction', xy=((left+right)/2, 1.065), xycoords=('data', 'axes fraction'), ha='center', va='center', c='w', fontsize='small')
axes[0].annotate('Retraction', xy=((left+right)/2, bottom+0.12), xycoords=('data', 'axes fraction'), ha='center', va='center', c='w', fontsize='small')

# add inward movement box
left, right = df.loc[0, ['Inward movement start (s)', 'Inward movement end (s)']]
# bottom, top = (1.20, 1.35)
bottom, top = (1.6, 1.9)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', facecolor='0.75', fill=True, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)
# axes[0].annotate('Inward', xy=((left+right)/2, 1.260), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')
axes[0].annotate('Inward', xy=((left+right)/2, bottom+0.12), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')

fig.savefig(os.path.join(export_dir, 'swallow.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

In [None]:
df = df_all.query('Food == "Swallow"').reset_index()
(data_set_name, time_window) = feeding_bouts[('JG12', 'Swallow', 0)]
animal, _, _ = ('JG12', 'Swallow', 0)
t_start, t_stop = time_window*pq.s

fig, axes = plt.subplots(len(units), 1, sharex=True, figsize=(boolean_fig_width, boolean_fig_height))

for j, i in enumerate(df.index):
    for k, unit in enumerate(units):
        ax = axes[k]
        
        burst_start, burst_end = df.loc[i, [f'{unit} burst start (s)', f'{unit} burst end (s)']]
        if np.isfinite(burst_start):
            x = np.array([t_start, burst_start, burst_end, t_stop])
            y = np.array([0, 0, 1, 0])
        else:
            x = np.array([t_start, t_stop])
            y = np.array([0, 0])
        
        ax.step(x, y, unit_colors[unit])
        ax.set_ylim([-0.1, 1.1])
        ax.set_ylabel(unit, c=unit_colors[unit], rotation=0, ha='right', va='center')
        
        # hide the box around the subplot
        ax.set_frame_on(False)
        
        # disable tick marks
        ax.tick_params(
            bottom=False,
            left=False,
            labelbottom=False,
            labelleft=False)

add_scalebar(axes[-1],
    sizex=1,
    labelx='1 s',

    loc='upper right',
    bbox_to_anchor=(1, 0),

    borderpad=1,
    sep=5,
    barwidth=2,
)

fig.tight_layout(h_pad=0.5, w_pad=0, pad=0)

fig.savefig(os.path.join(export_dir, 'swallow-boolean.png'), dpi=600)

## Bite

In [None]:
start = datetime.datetime.now()

df = df_all.query('Food == "Bite"').reset_index()
(data_set_name, time_window) = feeding_bouts[('JG12', 'Bite', 0)]
animal, _, _ = ('JG12', 'Bite', 0)

t_start, t_stop = time_window*pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -80,  80], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -30,  30], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -70,  70], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
]
plot_names = [p['channel'] for p in plots]
plot_units = [p['units'] for p in plots]

fig_height_in_inches, bottom_fraction, top_fraction = solve_figure_vertical_dimensions(
    len(plots), subplot_height_in_inches, bottom_margin_in_inches, top_margin_in_inches, 0)
fig_width_in_inches, left_fraction, right_fraction = solve_figure_horizontal_dimensions(
    1, inches_per_second*(t_stop-t_start).magnitude, left_margin_in_inches, right_margin_in_inches, 0)

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (fig_width_in_inches, fig_height_in_inches),
    linewidth = 0.5,
    x_scalebar = 1*pq.s,
    layout_settings = dict(
        left   = left_fraction,
        right  = right_fraction,
        bottom = bottom_fraction,
        top    = top_fraction,
        hspace = 0,
    ),
)

# load the metadata
metadata = neurotic.MetadataSelector(metadata_file)
metadata.select(data_set_name)

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the signals
fig, axes = plot_signals_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

unit_burst_boxes = {
    'B38':   [-12, 12],
    'I2':    [-70, 72],
    'B8a/b': [-28, 19],
    'B6/B9': [-20, 15],
    'B3':    [-45, 35],
    'B4/B5': [-60, 55],
}

for j, i in enumerate(df.index):
    for k, unit in enumerate(units):
        st = df.loc[i, f'{unit} spike train']
        if st is not None and st.size > 0:

            # get the neural channel
            channel = st.annotations['channels'][0]
            sig = get_sig(blk, channel)

            # get the signal for the entire bout
            sig = sig.time_slice(t_start, t_stop)
            sig = sig.rescale(plot_units[plot_names.index(channel)])

            # plot spikes
            ax = axes[plot_names.index(channel)]
            spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
            ax.scatter(st.times.rescale('s'), spike_amplitudes, marker='.', s=20, c=unit_colors[unit], zorder=3)

            # plot burst windows
            left = df.loc[i, f'{unit} burst start (s)']
            right = df.loc[i, f'{unit} burst end (s)']
            if np.isfinite(left) and np.isfinite(right):
                width = right-left
                bottom, top = unit_burst_boxes[unit]
                height = top-bottom
                rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor=unit_colors[unit], fill=False, zorder=3, clip_on=False)
                ax.add_patch(rect)

# # add unit name labels
# axes[2].annotate('B38',   xy=(2978.15, 0.70), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B38'])
axes[0].annotate('I2',    xy=(2172.15, 0.75), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['I2'])
axes[1].annotate('B8a/b', xy=(2172.20, 0.85), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B8a/b'])
axes[2].annotate('B6/B9', xy=(2173.00, 0.20), xycoords=('data', 'axes fraction'), ha='right',  c=unit_colors['B6/B9'])
# axes[2].annotate('B3',    xy=(2981.85, 0.79), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['B3'])
axes[3].annotate('B4/B5', xy=(2171.80, 0.80), xycoords=('data', 'axes fraction'), ha='right',  c=unit_colors['B4/B5'])

# add protraction box
left, right = df.loc[0, ['I2 burst start (s)', 'I2 burst end (s)']]
# bottom, top = (1, 1.15)
bottom, top = (1.2, 1.5)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', fill=False, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)
# axes[0].annotate('Prot.', xy=((left+right)/2, 1.065), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')
axes[0].annotate('Protraction', xy=((left+right)/2, bottom+0.12), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')

# add retraction box
left, right = df.loc[0, ['I2 burst end (s)', 'End (s)']] # behavior ends with end of B43 burst
# bottom, top = (1, 1.15)
bottom, top = (1.2, 1.5)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', facecolor='k', fill=True, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)
# axes[0].annotate('Retraction', xy=((left+right)/2, 1.065), xycoords=('data', 'axes fraction'), ha='center', va='center', c='w', fontsize='small')
axes[0].annotate('Retraction', xy=((left+right)/2, bottom+0.12), xycoords=('data', 'axes fraction'), ha='center', va='center', c='w', fontsize='small')

# # add inward movement box
# left, right = df.loc[0, ['Inward movement start (s)', 'Inward movement end (s)']]
# bottom, top = (1.20, 1.35)
# width = right-left
# height = top-bottom
# rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', facecolor='0.75', fill=True, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
# axes[0].add_patch(rect)
# axes[0].annotate('Inward', xy=((left+right)/2, 1.260), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')

fig.savefig(os.path.join(export_dir, 'bite.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

In [None]:
df = df_all.query('Food == "Bite"').reset_index()
(data_set_name, time_window) = feeding_bouts[('JG12', 'Bite', 0)]
animal, _, _ = ('JG12', 'Bite', 0)
t_start, t_stop = time_window*pq.s

fig, axes = plt.subplots(len(units), 1, sharex=True, figsize=(boolean_fig_width, boolean_fig_height))

for j, i in enumerate(df.index):
    for k, unit in enumerate(units):
        ax = axes[k]
        
        burst_start, burst_end = df.loc[i, [f'{unit} burst start (s)', f'{unit} burst end (s)']]
        if np.isfinite(burst_start):
            x = np.array([t_start, burst_start, burst_end, t_stop])
            y = np.array([0, 0, 1, 0])
        else:
            x = np.array([t_start, t_stop])
            y = np.array([0, 0])
        
        ax.step(x, y, unit_colors[unit])
        ax.set_ylim([-0.1, 1.1])
        ax.set_ylabel(unit, c=unit_colors[unit], rotation=0, ha='right', va='center')
        
        # hide the box around the subplot
        ax.set_frame_on(False)
        
        # disable tick marks
        ax.tick_params(
            bottom=False,
            left=False,
            labelbottom=False,
            labelleft=False)

add_scalebar(axes[-1],
    sizex=1,
    labelx='1 s',

    loc='upper right',
    bbox_to_anchor=(1, 0),

    borderpad=1,
    sep=5,
    barwidth=2,
)

fig.tight_layout(h_pad=0.5, w_pad=0, pad=0)

fig.savefig(os.path.join(export_dir, 'bite-boolean.png'), dpi=600)

## Rejection

In [None]:
start = datetime.datetime.now()

df = df_all.query('Food == "Rejection"').reset_index()
(data_set_name, time_window) = feeding_bouts[('JG12', 'Rejection', 0)]
animal, _, _ = ('JG12', 'Rejection', 0)

# t_start, t_stop = time_window*pq.s
t_start, t_stop = [1429.3, 1438.6]*pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -50,  50], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  35], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -70,  70], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
]
plot_names = [p['channel'] for p in plots]
plot_units = [p['units'] for p in plots]

fig_height_in_inches, bottom_fraction, top_fraction = solve_figure_vertical_dimensions(
    len(plots), subplot_height_in_inches, bottom_margin_in_inches, top_margin_in_inches, 0)
fig_width_in_inches, left_fraction, right_fraction = solve_figure_horizontal_dimensions(
    1, inches_per_second*(t_stop-t_start).magnitude, left_margin_in_inches, right_margin_in_inches, 0)

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (fig_width_in_inches, fig_height_in_inches),
    linewidth = 0.5,
    x_scalebar = 1*pq.s,
    layout_settings = dict(
        left   = left_fraction,
        right  = right_fraction,
        bottom = bottom_fraction,
        top    = top_fraction,
        hspace = 0,
    ),
)

# load the metadata
metadata = neurotic.MetadataSelector(metadata_file)
metadata.select(data_set_name)

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the signals
fig, axes = plot_signals_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

unit_burst_boxes = {
    'B38':   [-12, 12],
    'I2':    [-30, 38],
    'B8a/b': [-23, 15],
    'B6/B9': [-20, 15],
    'B3':    [-40, 29],
    'B4/B5': [-60, 55],
}

for j, i in enumerate(df.index):
    for k, unit in enumerate(units):
        st = df.loc[i, f'{unit} spike train']
        if st is not None and st.size > 0:

            # get the neural channel
            channel = st.annotations['channels'][0]
            sig = get_sig(blk, channel)

            # get the signal for the entire bout
            sig = sig.time_slice(t_start, t_stop)
            sig = sig.rescale(plot_units[plot_names.index(channel)])

            # plot spikes
            ax = axes[plot_names.index(channel)]
            spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
            ax.scatter(st.times.rescale('s'), spike_amplitudes, marker='.', s=20, c=unit_colors[unit], zorder=3)

            # plot burst windows
            left = df.loc[i, f'{unit} burst start (s)']
            right = df.loc[i, f'{unit} burst end (s)']
            if np.isfinite(left) and np.isfinite(right):
                width = right-left
                bottom, top = unit_burst_boxes[unit]
                height = top-bottom
                rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor=unit_colors[unit], fill=False, zorder=3, clip_on=False)
                ax.add_patch(rect)

# add unit name labels
axes[2].annotate('B38',   xy=(1430.20, 0.70), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B38'])
axes[0].annotate('I2',    xy=(1434.35, 0.75), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['I2'])
axes[1].annotate('B8a/b', xy=(1431.60, 0.85), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B8a/b'])
axes[2].annotate('B6/B9', xy=(1436.20, 0.73), xycoords=('data', 'axes fraction'), ha='right',  c=unit_colors['B6/B9'])
axes[2].annotate('B3',    xy=(1437.80, 0.79), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['B3'])
axes[3].annotate('B4/B5', xy=(1432.40, 0.80), xycoords=('data', 'axes fraction'), ha='right',  c=unit_colors['B4/B5'])

# add protraction box
left, right = df.loc[0, ['I2 burst start (s)', 'I2 burst end (s)']]
# bottom, top = (1, 1.15)
bottom, top = (1.2, 1.5)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', fill=False, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)
# axes[0].annotate('Protraction', xy=((left+right)/2, 1.065), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')
axes[0].annotate('Protraction', xy=((left+right)/2, bottom+0.12), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')

# add retraction box
# left, right = df.loc[0, ['I2 burst end (s)', 'End (s)']] # behavior ends with end of B43 burst
left, right = df.loc[0, 'I2 burst end (s)'], t_stop.magnitude # behavior ends with end of B43 burst
# bottom, top = (1, 1.15)
bottom, top = (1.2, 1.5)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', facecolor='k', fill=True, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)
# axes[0].annotate('Retraction', xy=((left+right)/2, 1.065), xycoords=('data', 'axes fraction'), ha='center', va='center', c='w', fontsize='small')
axes[0].annotate('Retraction', xy=((left+right)/2, bottom+0.12), xycoords=('data', 'axes fraction'), ha='center', va='center', c='w', fontsize='small')

# # add inward movement box
# left, right = df.loc[0, ['Inward movement start (s)', 'Inward movement end (s)']]
# bottom, top = (1.20, 1.35)
# width = right-left
# height = top-bottom
# rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', facecolor='0.75', fill=True, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
# axes[0].add_patch(rect)
# axes[0].annotate('Inward', xy=((left+right)/2, 1.260), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')

# add outward movement box
left, right = df.loc[0, ['Outward movement start (s)', 'Outward movement end (s)']]
# bottom, top = (1.20, 1.35)
bottom, top = (1.6, 1.9)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', facecolor='0.75', fill=True, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)
# axes[0].annotate('Outward', xy=((left+right)/2, 1.260), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')
axes[0].annotate('Outward', xy=((left+right)/2, bottom+0.12), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')

fig.savefig(os.path.join(export_dir, 'rejection.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

In [None]:
df = df_all.query('Food == "Rejection"').reset_index()
(data_set_name, time_window) = feeding_bouts[('JG12', 'Rejection', 0)]
animal, _, _ = ('JG12', 'Rejection', 0)
# t_start, t_stop = time_window*pq.s
t_start, t_stop = [1429.3, 1438.6]*pq.s

fig, axes = plt.subplots(len(units), 1, sharex=True, figsize=(boolean_fig_width, boolean_fig_height))

for j, i in enumerate(df.index):
    for k, unit in enumerate(units):
        ax = axes[k]
        
        burst_start, burst_end = df.loc[i, [f'{unit} burst start (s)', f'{unit} burst end (s)']]
        if np.isfinite(burst_start):
            x = np.array([t_start, burst_start, burst_end, t_stop])
            y = np.array([0, 0, 1, 0])
        else:
            x = np.array([t_start, t_stop])
            y = np.array([0, 0])
        
        ax.step(x, y, unit_colors[unit])
        ax.set_ylim([-0.1, 1.1])
        ax.set_ylabel(unit, c=unit_colors[unit], rotation=0, ha='right', va='center')
        
        # hide the box around the subplot
        ax.set_frame_on(False)
        
        # disable tick marks
        ax.tick_params(
            bottom=False,
            left=False,
            labelbottom=False,
            labelleft=False)

add_scalebar(axes[-1],
    sizex=1,
    labelx='1 s',

    loc='upper right',
    bbox_to_anchor=(1, 0),

    borderpad=1,
    sep=5,
    barwidth=2,
)

fig.tight_layout(h_pad=0.5, w_pad=0, pad=0)

fig.savefig(os.path.join(export_dir, 'rejection-boolean.png'), dpi=600)