\>\>\> CLICK [HERE](#Figures) TO JUMP DOWN TO MANUSCRIPT FIGURES <<<

# Preamble

## Import Packages

In [None]:
import datetime
import numpy as np
import quantities as pq
import elephant
import pandas as pd
import statsmodels.api as sm
import neurotic
from neurotic.datasets.data import _detect_spikes
from utils import BehaviorsDataFrame, DownsampleNeoSignal

pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.ticker import MultipleLocator
import seaborn as sns

plt.rcParams.update({'figure.max_open_warning': 0})

## IPython Magics

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

## Plot Settings

In [None]:
# general plot settings
sns.set(
#     context = 'poster',
    style = 'ticks',
    font_scale = 1,
    font = 'Palatino Linotype',
)

In [None]:
markers = ['.', '+', 'x', '1', '4', 'o', 'D']
colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6']

unit_colors = {
    'I2 spikes': 'C9', # cyan
    'B8a/b':     'C6', # pink  #'C4', # purple
    'B3':        'C3', # red
    'B6/B9':     'C2', # green
    'B38':       'C1', # orange
}
force_colors = {
    'dip': unit_colors['I2 spikes'],
    'initial rise': unit_colors['B8a/b'],
    'rise': unit_colors['B6/B9'],
    'plateau': unit_colors['B3'],
    'shoulder': unit_colors['B38'],
}

## Data Parameters

In [None]:
feeding_bouts = {
    # (animal, food, bout_index): (
    #     data_set_name,
    #     channels_to_keep,
    #     time_window,
    #     epoch_types_to_keep,
    #     burst_thresholds,
    # )

    ('JG07', 'Tape nori', 0): (
        'IN VIVO / JG07 / 2018-05-20 / 002',
        ['I2-L', 'RN-L', 'BN2-L', 'BN3-L', 'Force'],
        [2718, 2755], # 5 swallows
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
        },
    ),
    
    ('JG08', 'Tape nori', 0): (
        'IN VIVO / JG08 / 2018-06-21 / 002',
        ['I2', 'RN', 'BN2', 'BN3', 'Force'],
        [148, 208], # 7 swallows, some bucket and head movement
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
#             'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  3)*pq.Hz, # based on Lu et al. 2015 (end threshold reduced for this animal)
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
        },
    ),

    ('JG08', 'Tape nori', 1): (
        'IN VIVO / JG08 / 2018-06-21 / 002',
        ['I2', 'RN', 'BN2', 'BN3', 'Force'],
        [664, 701], # 5 swallows, large bucket movement
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
#             'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  3)*pq.Hz, # based on Lu et al. 2015 (end threshold reduced for this animal)
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
        },
    ),

    ('JG08', 'Tape nori', 2): (
        'IN VIVO / JG08 / 2018-06-21 / 002',
        ['I2', 'RN', 'BN2', 'BN3', 'Force'],
        [1452, 1477], # 3 swallows, some bucket movement
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
#             'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  3)*pq.Hz, # based on Lu et al. 2015 (end threshold reduced for this animal)
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
        },
    ),

    ('JG11', 'Tape nori', 0): (
        'IN VIVO / JG11 / 2019-04-03 / 004',
        ['I2', 'RN', 'BN2', 'BN3-PROX', 'Force'],
        [1233, 1280], # 5 swallows
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
        },
    ),

    ('JG12', 'Tape nori', 0): (
        'IN VIVO / JG12 / 2019-05-10 / 002',
        ['I2', 'RN', 'BN2', 'BN3-DIST', 'Force'],
        [437, 465], # 4 swallows
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
        },
    ),
    
    ('JG12', 'Tape nori', 1): (
        'IN VIVO / JG12 / 2019-05-10 / 002',
        ['I2', 'RN', 'BN2', 'BN3-DIST', 'Force'],
        [2901, 2937], # 5 swallows
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
        },
    ),

    ('JG14', 'Tape nori', 0): (
        'IN VIVO / JG14 / 2019-07-29 / 004',
        ['I2', 'RN', 'BN2', 'BN3-PROX', 'Force'],
        [831, 870], # 5 swallows
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
        },
    ),
}

## Helper Functions

In [None]:
def query_union(queries):
    return '|'.join([f'({q})' for q in queries if q is not None])

def label2query(label):
    animal, food = label[:4], label[5:]
    query = f'(Animal == "{animal}") & (Food == "{food}")'
    return query

def contains(series, string):
#     return np.array([string in x for x in series])
    return series.map(lambda x: string in x)

In [None]:
def update_filter(metadata, new_filter):
    i = next((i for i, f in enumerate(metadata['filters']) if f['channel'] == new_filter['channel']), None)
    if i is not None:
        metadata['filters'][i] = new_filter
    else:
        metadata['filters'].append(new_filter)

In [None]:
def get_sig(blk, channel):
    sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == channel), None)
    if sig is None:
        raise Exception(f'Channel "{channel}" could not be found')
    else:
        return sig

In [None]:
def finite_min(x, y):
    '''Workaround for Quantities warning about comparison to NaN'''
    if np.isfinite(x) and np.isfinite(y):
        return min(x, y)
    elif np.isfinite(x):
        return x
    elif np.isfinite(y):
        return y
    else:
        return np.nan

def finite_max(x, y):
    '''Workaround for Quantities warning about comparison to NaN'''
    if np.isfinite(x) and np.isfinite(y):
        return max(x, y)
    elif np.isfinite(x):
        return x
    elif np.isfinite(y):
        return y
    else:
        return np.nan

In [None]:
def find_bursts(st, burst_thresholds):
    '''Find every sequence of spikes that qualifies as a burst'''
    
    isi = elephant.statistics.isi(st).rescale('s')
    iff = 1/isi

    start_freq, end_freq = burst_thresholds
    start_mask = iff > start_freq
    end_mask = iff < end_freq

    bursts = []
    scan_index = -1
    while scan_index < iff.size:
        start_index = None
        end_index = None

        start_mask_indexes = np.where(start_mask)[0]
        start_mask_indexes = start_mask_indexes[start_mask_indexes > scan_index]
        if start_mask_indexes.size == 0:
            break

        start_index = start_mask_indexes[0] # first time that iff rises above start threshold

        end_mask_indexes = np.where(end_mask)[0]
        end_mask_indexes = end_mask_indexes[end_mask_indexes > start_index]
        if end_mask_indexes.size > 0:
            end_index = end_mask_indexes[0] # first time after start that iff drops below end theshold
        else:
            end_index = -1 # end of spike train (include all spikes after start)

        burst = {
            'Start (s)': st[start_index].rescale('s'),
            'End (s)': st[end_index].rescale('s'),
            'Duration (s)': (st[end_index] - st[start_index]).rescale('s'),
            'Number of spikes': end_index-start_index+1 if end_index > 0 else st.size-start_index
        }
        bursts.append(burst)
        if end_index == -1:
            break
        else:
            scan_index = end_index
    
    return bursts

In [None]:
def is_good_burst(burst):
    return burst['Duration (s)'] >= 0.5*pq.s and burst['Number of spikes'] > 2

In [None]:
def print_column_analysis(column):
    unit = column.split('(')[-1].split(')')[0]
    mean = df_all[column].mean()
    std = df_all[column].std()
    cv = std/abs(mean)
    print(f'Overall mean +/- std: {mean:.2f} +/- {std:.2f} {unit}')
    print(f'Overall coefficient of variation CV: {cv:.2f}')

In [None]:
def plot_vertical_lines_with_delay(axes, t, delay, force_y, color):

    # plot vertical line in force plot at time t
    axes[-1].add_artist(patches.ConnectionPatch(
        xyA=(t, force_y), xyB=(t, 1),
        coordsA='data', coordsB=axes[-1].get_xaxis_transform(),
        axesA=axes[-1], axesB=axes[-1],
        color=color, lw=1, ls=':'))
    
    # plot vertical line through all neural plots at time t-delay
    axes[-1].add_artist(patches.ConnectionPatch(
        xyA=(t-delay, 0), xyB=(t-delay, 1),
        coordsA=axes[-2].get_xaxis_transform(), coordsB=axes[0].get_xaxis_transform(),
        axesA=axes[-2], axesB=axes[0],
        color=color, lw=1, ls=':'))
    
    # connect the two lines
    axes[-1].add_artist(patches.ConnectionPatch(
        xyA=(t-delay, 0), xyB=(t, 1),
        coordsA=axes[-2].get_xaxis_transform(), coordsB=axes[-1].get_xaxis_transform(),
        axesA=axes[-2], axesB=axes[-1],
        color=color, lw=1, ls=':'))

In [None]:
def lighten_color(color, amount=0.5):
    '''
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    
    https://stackoverflow.com/a/49601444/3314376
    '''
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])

## Import and Process the Data

### Main dataframe used for most figures

In [None]:
start = datetime.datetime.now()

# use Neo RawIO lazy loading to load much faster and using less memory
# - with lazy=True, filtering parameters specified in metadata are ignored
# - with lazy=True, loading via time_slice requires neo>=0.8.0.dev
# - IMPORTANT: force and I2 filters affect smoothness and possibly threshold crossings and spike detection
lazy = False

# load the metadata containing file paths
metadata = neurotic.MetadataSelector(file='../../data/metadata.yml')

# filter epochs for each bout and perform calculations
df_list = []
last_data_set_name = None
for (animal, food, bout_index), (data_set_name, channels_to_keep, time_window, epoch_types_to_keep, burst_thresholds) in feeding_bouts.items():
    
    ###
    ### LOCATE BEHAVIOR EPOCHS AND SUBEPOCHS
    ###
    
    metadata.select(data_set_name)
    if data_set_name is last_data_set_name:
        # skip reloading the data if it's already in memory
        pass
    else:
        blk = neurotic.load_dataset(metadata, lazy=lazy)
    last_data_set_name = data_set_name
    
    # construct a query for locating behaviors
    behavior_query = f'(Type in {epoch_types_to_keep}) & ({time_window[0]} <= Start) & (End <= {time_window[1]})'
    
    # construct queries for locating epochs associated with each behavior
    # - each query should match at most one epoch
    # - dictionary keys are used as prefixes for the names of new columns
    subepoch_queries = {}
    
    subepoch_queries['I2 protraction activity'] = f'(Type == "I2 protraction activity") & ' \
                                                  f'(@behavior_start-1 <= Start) & (End <= @behavior_end)' # must start no earlier than 1 second before behavior and end within it
    
    subepoch_queries['B8 activity']             = f'(Type == "B8 activity") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)' # must be fully contained within behavior
    
    subepoch_queries['B3/6/9/10 activity']      = f'(Type == "B3/6/9/10 activity") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)' # must be fully contained within behavior

    subepoch_queries['B38 activity']            = f'(Type == "B38 activity") & ' \
                                                  f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)' # must start within 2 seconds of behavior end

    subepoch_queries['Force dip']               = f'(Type == "Force dip") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)' # must be fully contained within behavior
    
    subepoch_queries['Force shoulder']          = f'(Type == "Force shoulder") & ' \
                                                  f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)' # must start within 2 seconds of behavior end
    
    # construct a table in which each row is a behavior and subepoch data
    # is added as columns, e.g. df['Force start (s)']
    df = BehaviorsDataFrame(blk.segments[0].epochs, behavior_query, subepoch_queries)

    
    
    ###
    ### PERFORM CALCULATIONS
    ###

    # spike train columns must have type 'object', which
    # can be accomplished by initializing with None
    units = [
        'I2 spikes',
        'B8a/b',
        'B3',
        'B6/B9',
        'B38',
    ]
    for unit in units:
        df[unit+' spike train'] = None
        df[unit+' inter-spike intervals (s)'] = None
        df[unit+' all bursts (s)'] = None

        # while we're at it, initialize some other things that might otherwise never be given values
        df[unit+' first burst start (s)'] = np.nan
        df[unit+' first burst end (s)'] = np.nan
        df[unit+' first burst duration (s)'] = 0
        df[unit+' first burst spike count'] = 0
        df[unit+' first burst mean frequency (Hz)'] = np.nan
        df[unit+' last burst start (s)'] = np.nan
        df[unit+' last burst end (s)'] = np.nan
        df[unit+' last burst duration (s)'] = 0
        df[unit+' last burst spike count'] = 0
        df[unit+' last burst mean frequency (Hz)'] = np.nan


    # sanity check: plot all channels for entire time window
    fig, axes = plt.subplots(len(channels_to_keep), 1, sharex=True, figsize=(9.5, 10)) # dimensions for notebook
#     fig, axes = plt.subplots(len(channels_to_keep), 1, sharex=True, figsize=(11, 8.5)) # dimensions for printing
    channel_units = ['uV', 'uV', 'uV', 'uV', 'mN']
    for i, channel in enumerate(channels_to_keep):
        plt.sca(axes[i])
        sig = get_sig(blk, channel)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale(channel_units[i])
        plt.plot(sig.times, sig.magnitude, c='0.8', lw=1, zorder=-1)
        
        if i == 0:
            plt.title(f'({animal}, {food}, {bout_index}): {data_set_name}')
            
        plt.ylabel(sig.name + ' (' + sig.units.dimensionality.string + ')')
        axes[i].yaxis.set_label_coords(-0.06, 0.5)
        
        if i < len(channels_to_keep)-1:
            # remove right, top, and bottom plot borders, and remove x-axis
            sns.despine(ax=plt.gca(), bottom=True)
            plt.gca().xaxis.set_visible(False)
        else:
            # remove right and top plot borders, and set x-label
            sns.despine(ax=plt.gca())
            plt.xlabel('Time (s)')
                

    
    # sanity check: plot smoothed force for entire time window
    channel = 'Force'
    sig = get_sig(blk, channel)
    if lazy:
        sig = sig.time_slice(None, None)
    sig = elephant.signal_processing.butter(sig, lowpass_freq = 5*pq.Hz)
    sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
    sig = sig.rescale('mN')
    plt.plot(sig.times, sig.magnitude, c='k', lw=1, zorder=0)
    force_smoothed_sig = sig
    
    
    
    # iterate over all swallows
    for j, i in enumerate(df.index):
        
        behavior_start = df.loc[i, 'Start (s)']*pq.s
        behavior_end = df.loc[i, 'End (s)']*pq.s
        
        ###
        ### FORCE
        ###
        
        # quantify force in each behavior
        
        force_dip_start = df.loc[i, 'Force dip start (s)']*pq.s
        force_rise_start = df.loc[i, 'Force rise start (s)'] = df.loc[i, 'Force dip end (s)']*pq.s
        force_shoulder_start = df.loc[i, 'Force shoulder start (s)']*pq.s
        force_shoulder_end = df.loc[i, 'Force shoulder end (s)']*pq.s
        
        # get dip min using smoothed force
        sig = force_smoothed_sig
        sig = sig.time_slice(force_dip_start, force_rise_start + 0.5*pq.s)
        force_min_time = df.loc[i, 'Force min time (s)'] = elephant.spike_train_generation.peak_detection(sig, 1000*pq.mN, sign='below')[0]
        force_min = df.loc[i, 'Force min (mN)'] = sig[sig.time_index(force_min_time)][0]

        # get smoothed force for whole behavior for remaining force calculations
        sig = force_smoothed_sig
        if np.isfinite(force_shoulder_end):
            sig = sig.time_slice(force_dip_start - 0.01*pq.s, force_shoulder_end + 0.01*pq.s)
        else:
            sig = sig.time_slice(force_dip_start - 0.01*pq.s, behavior_end + 0.01*pq.s)
        sig = sig.rescale('mN')

        # find force peak, force baseline, and the force at 50% and 80% of peak height relative to baseline
        force_peak_time = elephant.spike_train_generation.peak_detection(sig, 0*pq.mN)[0]
        force_peak = df.loc[i, 'Force peak (mN)'] = sig[sig.time_index(force_peak_time)][0]
        force_baseline = df.loc[i, 'Force baseline (mN)'] = sig[sig.time_index(force_rise_start)][0]
        force_increase = df.loc[i, 'Force increase (mN)'] = force_peak-force_baseline
        force_50percent = df.loc[i, 'Force 50%-height (mN)'] = force_baseline + 0.5*force_increase
        force_80percent = df.loc[i, 'Force 80%-height (mN)'] = force_baseline + 0.8*force_increase

        # find time when force first rises above the 50%-height threshold
        crossings = elephant.spike_train_generation.threshold_detection(sig, force_50percent, 'above')
        force_50percent_start = df.loc[i, 'Force 50%-height start (s)'] = crossings[np.where(crossings > force_rise_start)[0][0]].rescale('s')
        
        # find time when force first rises above the 80%-height threshold
        crossings = elephant.spike_train_generation.threshold_detection(sig, force_80percent, 'above')
        force_80percent_start = df.loc[i, 'Force 80%-height start (s)'] = crossings[np.where(crossings > force_rise_start)[0][0]].rescale('s')
        
        # find time when force drops below the 80%-height threshold after peaking
        crossings = elephant.spike_train_generation.threshold_detection(sig, force_80percent, 'below')
        force_80percent_end = df.loc[i, 'Force 80%-height end (s)'] = crossings[np.where(crossings > force_peak_time)[0][0]].rescale('s')
        
        # find time when force drops below the 50%-height threshold after peaking
        crossings = elephant.spike_train_generation.threshold_detection(sig, force_50percent, 'below')
        force_50percent_end = df.loc[i, 'Force 50%-height end (s)'] = crossings[np.where(crossings > force_peak_time)[0][0]].rescale('s')
        
        # find force rise and plateau durations
        force_rise_plateau_duration = df.loc[i, 'Force rise and plateau duration (s)'] = force_80percent_end - force_rise_start
        force_80percent_duration = df.loc[i, 'Force 80%-height duration (s)'] = force_80percent_end - force_80percent_start
        
        # find average slope during rising phase
        force_slope = df.loc[i, 'Force slope (mN/s)'] = ((force_80percent-force_baseline)/(force_80percent_start-force_rise_start)).rescale('mN/s')

        
        # sanity check: plot 50%-height and 80%-height force thresholds
        plt.sca(axes[channels_to_keep.index('Force')])
        plt.plot([force_50percent_start, force_50percent_end], [force_50percent, force_50percent], c='gray', lw=1, ls=':')
        plt.plot([force_80percent_start, force_80percent_end], [force_80percent, force_80percent], c='gray', lw=1, ls=':')
        
        # sanity check: plot force dip
        sig2 = sig.time_slice(force_dip_start, force_rise_start)
        plt.plot(sig2.times, sig2.magnitude, c=force_colors['dip'], lw=2, zorder=1)
        
        # sanity check: plot force rise
        sig2 = sig.time_slice(force_rise_start, force_80percent_start)
        plt.plot(sig2.times, sig2.magnitude, c=force_colors['rise'], lw=2, zorder=1)
        
        # sanity check: plot force plateau
        sig2 = sig.time_slice(force_80percent_start, force_80percent_end)
        plt.plot(sig2.times, sig2.magnitude, c=force_colors['plateau'], lw=2, zorder=1)
        
        # sanity check: plot force shoulder
        if np.isfinite(force_shoulder_start):
            sig2 = sig.time_slice(force_shoulder_start, force_shoulder_end)
            plt.plot(sig2.times, sig2.magnitude, c=force_colors['shoulder'], lw=2, zorder=1)

        # sanity check: plot force min, baseline, peak, and threshold crossings
        plt.plot([force_min_time],        [force_min],       marker=6, markersize=5, color='k')
        plt.plot([force_peak_time],       [force_peak],      marker=7, markersize=5, color='k')
#         plt.plot([force_rise_start],      [force_baseline],  marker=6, markersize=5, color='k')
#         plt.plot([force_50percent_start], [force_50percent], marker=5, markersize=5, color='k')
#         plt.plot([force_50percent_end],   [force_50percent], marker=4, markersize=5, color='k')
#         plt.plot([force_80percent_start], [force_80percent], marker=5, markersize=5, color='k')
#         plt.plot([force_80percent_end],   [force_80percent], marker=4, markersize=5, color='k')

        
        
        ###
        ### FIND SPIKE TRAINS
        ###
        
        if lazy:
            if metadata['amplitude_discriminators'] is not None:
                for discriminator in metadata['amplitude_discriminators']:
                    sig = get_sig(blk, discriminator['channel'])
                    if sig is not None:
                        sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
                        st = _detect_spikes(sig, discriminator, blk.segments[0].epochs)
                        st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                        st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                        st = st.time_slice(st_epoch_start, st_epoch_end)
                        df.at[i, st.name+' spike train'] = st # 'at', not 'loc', is important for inserting list into cell
        else:
            for spiketrain in blk.segments[0].spiketrains:
                discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == spiketrain.name), None)
                if discriminator is None:
                    raise Exception(f'For data set "{data_set_name}", discriminator "{spiketrain.name}" could not be found')
                st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                st = spiketrain.time_slice(st_epoch_start, st_epoch_end)
                df.at[i, st.name+' spike train'] = st # 'at', not 'loc', is important for inserting list into cell
    
    
            
        ###
        ### QUANTIFY SPIKE TRAINS AND BURSTS
        ###
        
        for k, unit in enumerate(units):
            st = df.loc[i, unit+' spike train']
            if st is not None:
                df.loc[i, unit+' spike count'] = st.size
                if st.size > 0:
                    
                    # get the neural channel
                    channel = st.annotations['channels'][0]
                    sig = get_sig(blk, channel)

                    # get the signal for the behavior with 10 seconds cushion before and after (for better baseline estimation)
                    sig = sig.time_slice(behavior_start - 10*pq.s, behavior_end + 10*pq.s)
                    sig = sig.rescale(channel_units[channels_to_keep.index(channel)])
                    
                    # find every sequence of spikes that qualifies as a burst
                    bursts = df.at[i, unit+' all bursts (s)'] = find_bursts(st, burst_thresholds[unit]) # 'at', not 'loc', is important for inserting list into cell
                    
                    first_burst_start = np.nan
                    first_burst_end = np.nan
                    first_burst_spike_count = 0
                    first_burst_mean_freq = 0*pq.Hz
                    last_burst_start = np.nan
                    last_burst_end = np.nan
                    last_burst_spike_count = 0
                    last_burst_mean_freq = 0*pq.Hz
                    if len(bursts) > 0:

                        for burst in bursts:
                            if is_good_burst(burst):
                                first_burst_start, first_burst_end = burst['Start (s)'], burst['End (s)']
                                first_burst_duration = first_burst_end-first_burst_start
                                df.loc[i, unit+' first burst start (s)'] = first_burst_start.rescale('s')
                                df.loc[i, unit+' first burst end (s)'] = first_burst_end.rescale('s')
                                first_burst_duration = df.loc[i, unit+' first burst duration (s)'] = first_burst_duration.rescale('s')
                                first_burst_spike_count = df.loc[i, unit+' first burst spike count'] = st.time_slice(first_burst_start, first_burst_end).size
                                first_burst_mean_freq = df.loc[i, unit+' first burst mean frequency (Hz)'] = ((first_burst_spike_count-1)/first_burst_duration).rescale('Hz')
                                
                                # find burst RAUC and mean voltage
                                first_burst_rauc = df.loc[i, unit+' first burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=first_burst_start, t_stop=first_burst_end).rescale('uV*s')
                                first_burst_mean_rect_voltage = df.loc[i, unit+' first burst mean rectified voltage (μV)'] = first_burst_rauc/first_burst_duration

                                break # quit after finding first good burst

                        for burst in reversed(bursts):
                            if is_good_burst(burst):
                                last_burst_start, last_burst_end = burst['Start (s)'], burst['End (s)']
                                last_burst_duration = last_burst_end-last_burst_start
                                df.loc[i, unit+' last burst start (s)'] = last_burst_start.rescale('s')
                                df.loc[i, unit+' last burst end (s)'] = last_burst_end.rescale('s')
                                last_burst_duration = df.loc[i, unit+' last burst duration (s)'] = last_burst_duration.rescale('s')
                                last_burst_spike_count = df.loc[i, unit+' last burst spike count'] = st.time_slice(last_burst_start, last_burst_end).size
                                last_burst_mean_freq = df.loc[i, unit+' last burst mean frequency (Hz)'] = ((last_burst_spike_count-1)/last_burst_duration).rescale('Hz')

                                # find burst RAUC and mean voltage
                                last_burst_rauc = df.loc[i, unit+' last burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=last_burst_start, t_stop=last_burst_end).rescale('uV*s')
                                last_burst_mean_rect_voltage = df.loc[i, unit+' last burst mean rectified voltage (μV)'] = last_burst_rauc/last_burst_duration
    
                                break # quit after finding first (actually, last) good burst

                                    
                    # sanity check: plot spikes
                    plt.sca(axes[channels_to_keep.index(channel)])
                    marker = ['.', 'x'][j%2] # alternate markers between behaviors
                    
                    spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
                    plt.scatter(st.times.rescale('s'), spike_amplitudes, marker=marker, c=unit_colors[unit])

                    # sanity check: plot burst windows
                    discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == st.name), None)
                    if discriminator is None:
                        raise Exception(f'For data set "{data_set_name}", discriminator "{st.name}" could not be found')
                    bottom = pq.Quantity(discriminator['amplitude'][0], discriminator['units']).rescale(sig.units)
                    top = pq.Quantity(discriminator['amplitude'][1], discriminator['units']).rescale(sig.units)
                    height = top-bottom
                    for burst in bursts:
                        left = burst['Start (s)']
                        right = burst['End (s)']
                        width = right-left
                        linestyle = '-' if is_good_burst(burst) else '--'
                        rect = patches.Rectangle((left, bottom), width, height, ls=linestyle, edgecolor=unit_colors[unit], fill=False)
                        plt.gca().add_patch(rect)

                    # sanity check: plot markers for edges of bursts
                    if top > 0:
                        plt.plot([first_burst_start], [top], marker=7, markersize=5, color='k')#unit_colors[unit])
                        plt.plot([last_burst_end], [top], marker=7, markersize=5, color='k')#unit_colors[unit])
                    else:
                        plt.plot([first_burst_start], [bottom], marker=6, markersize=5, color='k')#unit_colors[unit])
                        plt.plot([last_burst_end], [bottom], marker=6, markersize=5, color='k')#unit_colors[unit])



        ###
        ### INSTANTANEOUS FIRING FREQUENCIES
        ###
        
        for unit in ['B6/B9']:
            st = df.loc[i, unit+' spike train']
            if st is not None:
                if st.size > 0:
                    channel = st.annotations['channels'][0]
                    plt.sca(axes[channels_to_keep.index(channel)])
                    
                    times = st.times.rescale('s')
                    times = np.concatenate([[behavior_start], times, [behavior_end]])*pq.s
                    iff = 1/elephant.statistics.isi(st)
                    iff = np.concatenate([[0], iff.rescale('1/s'), [0, 0]])/pq.s
                    
                    # arbitrary rescaling to fit in plot
                    shift = -6 * np.abs(discriminator['amplitude']).max()
                    iff = iff.magnitude/5+shift
                    
                    plt.plot(times, iff, drawstyle='steps-post', c=lighten_color(unit_colors[unit], amount=0.7), zorder=0)
        
        
        
        ###
        ### TIMING DELAYS
        ###
        
        i2_burst_start   = df.loc[i, 'I2 spikes first burst start (s)']*pq.s
        i2_burst_end     = df.loc[i, 'I2 spikes last burst end (s)']*pq.s
        b8_burst_start   = df.loc[i, 'B8a/b first burst start (s)']*pq.s
        b8_burst_end     = df.loc[i, 'B8a/b last burst end (s)']*pq.s
        b6b9_burst_start = df.loc[i, 'B6/B9 first burst start (s)']*pq.s
        b6b9_burst_end   = df.loc[i, 'B6/B9 last burst end (s)']*pq.s
        b3_burst_start   = df.loc[i, 'B3 first burst start (s)']*pq.s
        b3_burst_end     = df.loc[i, 'B3 last burst end (s)']*pq.s
        b38_burst_start  = df.loc[i, 'B38 first burst start (s)']*pq.s
        b38_burst_end    = df.loc[i, 'B38 last burst end (s)']*pq.s
        
        # consider B3/B6/B9 bursting if either B3 or B6/B9 is bursting
        b3b6b9_burst_start    = df.loc[i, 'B3/B6/B9 burst start (s)']    = finite_min(b6b9_burst_start, b3_burst_start)
        b3b6b9_burst_end      = df.loc[i, 'B3/B6/B9 burst end (s)']      = finite_max(b6b9_burst_end,   b3_burst_end)
        b3b6b9_burst_duration = df.loc[i, 'B3/B6/B9 burst duration (s)'] = b3b6b9_burst_end - b3b6b9_burst_start
        
        # consider bursting only if B8a/b and B3/B6/B9 are both bursting
        b8_or_b3b6b9_burst_end = df.loc[i, 'B8a/b and B3/B6/B9 conjunction end (s)'] = \
                                           finite_min(b8_burst_end, b3b6b9_burst_end)
        
        # delays from neural to force
        i2_force_min_delay             = df.loc[i, 'Delay from I2 end to force min (s)'] = \
                                                   force_min_time - i2_burst_end

        b8_force_start_delay           = df.loc[i, 'Delay from B8a/b start to force start (s)'] = \
                                                   force_rise_start - b8_burst_start
        b8_force_end_delay             = df.loc[i, 'Delay from B8a/b end to force end (s)'] = \
                                                   force_80percent_end - b8_burst_end
        
        b6b9_force_start_delay         = df.loc[i, 'Delay from B6/B9 start to force start (s)'] = \
                                                   force_rise_start - b6b9_burst_start
        b6b9_force_end_delay           = df.loc[i, 'Delay from B6/B9 end to force end (s)'] = \
                                                   force_80percent_end - b6b9_burst_end

        b3b6b9_force_start_delay       = df.loc[i, 'Delay from B3/B6/B9 start to force start (s)'] = \
                                                   force_rise_start - b3b6b9_burst_start
        b3b6b9_force_end_delay         = df.loc[i, 'Delay from B3/B6/B9 end to force end (s)'] = \
                                                   force_80percent_end - b3b6b9_burst_end
        b3b6b9_force_50percent_delay   = df.loc[i, 'Delay from B3/B6/B9 start to force 50%-height (s)'] = \
                                                   force_50percent_start - b3b6b9_burst_start
        b3b6b9_force_80percent_delay   = df.loc[i, 'Delay from B3/B6/B9 start to force 80%-height (s)'] = \
                                                   force_80percent_start - b3b6b9_burst_start
        
        b3_force_80percent_start_delay = df.loc[i, 'Delay from B3 start to force 80%-height start (s)'] = \
                                                   force_80percent_start - b3_burst_start
        b3_force_80percent_end_delay   = df.loc[i, 'Delay from B3 end to force 80%-height end (s)'] = \
                                                   force_80percent_end - b3_burst_end
        b8_or_b3b6b9_force_80percent_end_delay = \
                                         df.loc[i, 'Delay from either B8a/b or B3/B6/B9 end to force 80%-height end (s)'] = \
                                                   force_80percent_end - b8_or_b3b6b9_burst_end
        
        b38_shoulder_start_delay       = df.loc[i, 'Delay from B38 start to shoulder start (s)'] = \
                                                   force_shoulder_start - b38_burst_start
        b38_shoulder_end_delay         = df.loc[i, 'Delay from B38 end to shoulder end (s)'] = \
                                                   force_shoulder_end - b38_burst_end

    

        ###
        ### MISC
        ###
        
        st = df.loc[i, 'B8a/b spike train']
        b8_preb3b6b9_burst_duration    = df.loc[i, 'B8a/b pre-B3/B6/B9 burst duration (s)'] = \
                                                   b3b6b9_burst_start - b8_burst_start
        b8_preb3b6b9_burst_spike_count = df.loc[i, 'B8a/b pre-B3/B6/B9 burst spike count'] = \
                                                   st.time_slice(b8_burst_start, b3b6b9_burst_start).size
        b8_preb3b6b9_burst_mean_freq   = df.loc[i, 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)'] = \
                                                   ((b8_preb3b6b9_burst_spike_count-1)/b8_preb3b6b9_burst_duration).rescale('Hz')
        
        # get the neural channel
        channel = st.annotations['channels'][0]
        sig = get_sig(blk, channel)
        
        # get the signal for the behavior with 5 seconds cushion before and after (for better baseline estimation)
        sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
        sig = sig.rescale('uV')

        # find RAUC and mean voltage for B8a/b before B3/B6/B9 start in each behavior
        if np.isfinite(b8_burst_start) and np.isfinite(b3b6b9_burst_start):
            b8_preb3b6b9_rauc = df.loc[i, 'B8a/b pre-B3/B6/B9 burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=b8_burst_start, t_stop=b3b6b9_burst_start).rescale('uV*s')
            b8_preb3b6b9_mean_rect_voltage = df.loc[i, 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)'] = b8_preb3b6b9_rauc/b8_preb3b6b9_burst_duration
        else:
            print(f'Missing either B8a/b burst and/or B3/B6/B9 burst in data set "{data_set_name}" for swallow spanning times ({behavior_start}, {behavior_end})')

            
        # get force during rise and plateau
        sig = get_sig(blk, 'Force')
        sig = sig.time_slice(force_rise_start, force_80percent_end)
        sig = sig.rescale('mN')
        
        # get force at end of B8-only burst, offset by delay
        force_b8_only_rise_end = df.loc[i, 'Force delayed B8-only rise end (s)'] = force_rise_start + b8_preb3b6b9_burst_duration
        force_b8_only_rise_height = df.loc[i, 'Force at delayed B8-only rise end (mN)'] = sig[sig.time_index(force_b8_only_rise_end)][0]

        # find average slope during initial rising phase (before B3/B6/B9 begin, offset by delay)
        force_initial_increase = df.loc[i, 'Force initial increase (mN)'] = (force_b8_only_rise_height-force_baseline).rescale('mN')
        force_initial_slope = df.loc[i, 'Force initial slope (mN/s)'] = (force_initial_increase/b8_preb3b6b9_burst_duration).rescale('mN/s')

        
        
        # sanity check: plot force initial rise
#         sig = force_smoothed_sig
#         sig = sig.time_slice(force_rise_start, force_b8_only_rise_end)
#         sig = sig.rescale('mN')
#         plt.sca(axes[channels_to_keep.index('Force')])
#         plt.plot(sig.times, sig.magnitude, c=force_colors['initial rise'], lw=2, zorder=1)
        
        
        
        # sanity check: plot important times across all subplots, with delays set by I2 end and force min
        if j == 0:
            muscle_delay = i2_force_min_delay
            axes[-1].text(
                force_min_time, 1.05, f"{muscle_delay.rescale('ms'):.0f} ms delay",
                horizontalalignment='left', verticalalignment='center', transform=axes[-1].get_xaxis_transform(),
                fontsize=8)
        plot_vertical_lines_with_delay(axes, force_min_time, muscle_delay, force_min, force_colors['dip'])
        plot_vertical_lines_with_delay(axes, force_80percent_start, muscle_delay, force_80percent, force_colors['plateau'])
        plot_vertical_lines_with_delay(axes, force_80percent_end, muscle_delay, force_80percent, force_colors['plateau'])
        if np.isfinite(force_shoulder_end):
            sig = force_smoothed_sig
            plot_vertical_lines_with_delay(axes, force_shoulder_end, muscle_delay, sig[sig.time_index(force_shoulder_end)][0], force_colors['shoulder'])


        
    # optimize plot margins
    plt.subplots_adjust(
        left   = 0.1,
        right  = 0.99,
        top    = 0.96,
        bottom = 0.06,
        hspace = 0.15,
    )
    
    # export figure
    plt.gcf().savefig(f'sanity-checks/{animal} {food} {bout_index}.png', dpi=300)

    # renumber behaviors assuming all behaviors are from a single contiguous sequence
    df = df.sort_values('Start (s)').reset_index(drop=True).rename_axis('Behavior_index')
    
    # index the table on 4 variables
    df['Animal'] = animal
    df['Food'] = food
    df['Bout_index'] = bout_index
    df = df.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
    
    df_list += [df]
    
df_all = pd.concat(df_list, sort=False).sort_index()

end = datetime.datetime.now()
print('elapsed time:', end-start)

### Plots with "keeper" markers only

In [None]:
start = datetime.datetime.now()

# use Neo RawIO lazy loading to load much faster and using less memory
# - with lazy=True, filtering parameters specified in metadata are ignored
# - with lazy=True, loading via time_slice requires neo>=0.8.0.dev
# - IMPORTANT: force and I2 filters affect smoothness and possibly threshold crossings and spike detection
lazy = False

# load the metadata containing file paths
metadata = neurotic.MetadataSelector(file='../../data/metadata.yml')

# filter epochs for each bout and perform calculations
df_list = []
last_data_set_name = None
for (animal, food, bout_index), (data_set_name, channels_to_keep, time_window, epoch_types_to_keep, burst_thresholds) in feeding_bouts.items():
    
    ###
    ### LOCATE BEHAVIOR EPOCHS AND SUBEPOCHS
    ###
    
    metadata.select(data_set_name)
    if data_set_name is last_data_set_name:
        # skip reloading the data if it's already in memory
        pass
    else:
        blk = neurotic.load_dataset(metadata, lazy=lazy)
    last_data_set_name = data_set_name
    
    # construct a query for locating behaviors
    behavior_query = f'(Type in {epoch_types_to_keep}) & ({time_window[0]} <= Start) & (End <= {time_window[1]})'
    
    # construct queries for locating epochs associated with each behavior
    # - each query should match at most one epoch
    # - dictionary keys are used as prefixes for the names of new columns
    subepoch_queries = {}
    
    subepoch_queries['I2 protraction activity'] = f'(Type == "I2 protraction activity") & ' \
                                                  f'(@behavior_start-1 <= Start) & (End <= @behavior_end)' # must start no earlier than 1 second before behavior and end within it
    
    subepoch_queries['B8 activity']             = f'(Type == "B8 activity") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)' # must be fully contained within behavior
    
    subepoch_queries['B3/6/9/10 activity']      = f'(Type == "B3/6/9/10 activity") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)' # must be fully contained within behavior

    subepoch_queries['B38 activity']            = f'(Type == "B38 activity") & ' \
                                                  f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)' # must start within 2 seconds of behavior end

    subepoch_queries['Force dip']               = f'(Type == "Force dip") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)' # must be fully contained within behavior
    
    subepoch_queries['Force shoulder']          = f'(Type == "Force shoulder") & ' \
                                                  f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)' # must start within 2 seconds of behavior end
    
    # construct a table in which each row is a behavior and subepoch data
    # is added as columns, e.g. df['Force start (s)']
    df = BehaviorsDataFrame(blk.segments[0].epochs, behavior_query, subepoch_queries)

    
    
    ###
    ### PERFORM CALCULATIONS
    ###

    # spike train columns must have type 'object', which
    # can be accomplished by initializing with None
    units = [
        'I2 spikes',
        'B8a/b',
        'B3',
        'B6/B9',
        'B38',
    ]
    for unit in units:
        df[unit+' spike train'] = None
        df[unit+' inter-spike intervals (s)'] = None
        df[unit+' all bursts (s)'] = None

        # while we're at it, initialize some other things that might otherwise never be given values
        df[unit+' first burst start (s)'] = np.nan
        df[unit+' first burst end (s)'] = np.nan
        df[unit+' first burst duration (s)'] = 0
        df[unit+' first burst spike count'] = 0
        df[unit+' first burst mean frequency (Hz)'] = np.nan
        df[unit+' last burst start (s)'] = np.nan
        df[unit+' last burst end (s)'] = np.nan
        df[unit+' last burst duration (s)'] = 0
        df[unit+' last burst spike count'] = 0
        df[unit+' last burst mean frequency (Hz)'] = np.nan


    # sanity check: plot all channels for entire time window
    fig, axes = plt.subplots(len(channels_to_keep), 1, sharex=True, figsize=(9.5, 10)) # dimensions for notebook
#     fig, axes = plt.subplots(len(channels_to_keep), 1, sharex=True, figsize=(11, 8.5)) # dimensions for printing
    channel_units = ['uV', 'uV', 'uV', 'uV', 'mN']
    for i, channel in enumerate(channels_to_keep):
        plt.sca(axes[i])
        sig = get_sig(blk, channel)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale(channel_units[i])
        plt.plot(sig.times, sig.magnitude, c='0.8', lw=1, zorder=-1)
        
        if i == 0:
            plt.title(f'({animal}, {food}, {bout_index}): {data_set_name}')
            
        plt.ylabel(sig.name + ' (' + sig.units.dimensionality.string + ')')
        axes[i].yaxis.set_label_coords(-0.06, 0.5)
        
        if i < len(channels_to_keep)-1:
            # remove right, top, and bottom plot borders, and remove x-axis
            sns.despine(ax=plt.gca(), bottom=True)
            plt.gca().xaxis.set_visible(False)
        else:
            # remove right and top plot borders, and set x-label
            sns.despine(ax=plt.gca())
            plt.xlabel('Time (s)')
                

    
    # sanity check: plot smoothed force for entire time window
    channel = 'Force'
    sig = get_sig(blk, channel)
    if lazy:
        sig = sig.time_slice(None, None)
    sig = elephant.signal_processing.butter(sig, lowpass_freq = 5*pq.Hz)
    sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
    sig = sig.rescale('mN')
    plt.plot(sig.times, sig.magnitude, c='k', lw=1, zorder=0)
    force_smoothed_sig = sig
    
    
    
    # iterate over all swallows
    for j, i in enumerate(df.index):
        
        behavior_start = df.loc[i, 'Start (s)']*pq.s
        behavior_end = df.loc[i, 'End (s)']*pq.s
        
        ###
        ### FORCE
        ###
        
        # quantify force in each behavior
        
        force_dip_start = df.loc[i, 'Force dip start (s)']*pq.s
        force_rise_start = df.loc[i, 'Force rise start (s)'] = df.loc[i, 'Force dip end (s)']*pq.s
        force_shoulder_start = df.loc[i, 'Force shoulder start (s)']*pq.s
        force_shoulder_end = df.loc[i, 'Force shoulder end (s)']*pq.s
        
        # get dip min using smoothed force
        sig = force_smoothed_sig
        sig = sig.time_slice(force_dip_start, force_rise_start + 0.5*pq.s)
        force_min_time = df.loc[i, 'Force min time (s)'] = elephant.spike_train_generation.peak_detection(sig, 1000*pq.mN, sign='below')[0]
        force_min = df.loc[i, 'Force min (mN)'] = sig[sig.time_index(force_min_time)][0]

        # get smoothed force for whole behavior for remaining force calculations
        sig = force_smoothed_sig
        if np.isfinite(force_shoulder_end):
            sig = sig.time_slice(force_dip_start - 0.01*pq.s, force_shoulder_end + 0.01*pq.s)
        else:
            sig = sig.time_slice(force_dip_start - 0.01*pq.s, behavior_end + 0.01*pq.s)
        sig = sig.rescale('mN')

        # find force peak, force baseline, and the force at 50% and 80% of peak height relative to baseline
        force_peak_time = elephant.spike_train_generation.peak_detection(sig, 0*pq.mN)[0]
        force_peak = df.loc[i, 'Force peak (mN)'] = sig[sig.time_index(force_peak_time)][0]
        force_baseline = df.loc[i, 'Force baseline (mN)'] = sig[sig.time_index(force_rise_start)][0]
        force_increase = df.loc[i, 'Force increase (mN)'] = force_peak-force_baseline
        force_50percent = df.loc[i, 'Force 50%-height (mN)'] = force_baseline + 0.5*force_increase
        force_80percent = df.loc[i, 'Force 80%-height (mN)'] = force_baseline + 0.8*force_increase

        # find time when force first rises above the 50%-height threshold
        crossings = elephant.spike_train_generation.threshold_detection(sig, force_50percent, 'above')
        force_50percent_start = df.loc[i, 'Force 50%-height start (s)'] = crossings[np.where(crossings > force_rise_start)[0][0]].rescale('s')
        
        # find time when force first rises above the 80%-height threshold
        crossings = elephant.spike_train_generation.threshold_detection(sig, force_80percent, 'above')
        force_80percent_start = df.loc[i, 'Force 80%-height start (s)'] = crossings[np.where(crossings > force_rise_start)[0][0]].rescale('s')
        
        # find time when force drops below the 80%-height threshold after peaking
        crossings = elephant.spike_train_generation.threshold_detection(sig, force_80percent, 'below')
        force_80percent_end = df.loc[i, 'Force 80%-height end (s)'] = crossings[np.where(crossings > force_peak_time)[0][0]].rescale('s')
        
        # find time when force drops below the 50%-height threshold after peaking
        crossings = elephant.spike_train_generation.threshold_detection(sig, force_50percent, 'below')
        force_50percent_end = df.loc[i, 'Force 50%-height end (s)'] = crossings[np.where(crossings > force_peak_time)[0][0]].rescale('s')
        
        # find force rise and plateau durations
        force_rise_plateau_duration = df.loc[i, 'Force rise and plateau duration (s)'] = force_80percent_end - force_rise_start
        force_80percent_duration = df.loc[i, 'Force 80%-height duration (s)'] = force_80percent_end - force_80percent_start
        
        # find average slope during rising phase
        force_slope = df.loc[i, 'Force slope (mN/s)'] = ((force_80percent-force_baseline)/(force_80percent_start-force_rise_start)).rescale('mN/s')

        
        # sanity check: plot 50%-height and 80%-height force thresholds
        plt.sca(axes[channels_to_keep.index('Force')])
#         plt.plot([force_50percent_start, force_50percent_end], [force_50percent, force_50percent], c='gray', lw=1, ls=':')
#         plt.plot([force_80percent_start, force_80percent_end], [force_80percent, force_80percent], c='gray', lw=1, ls=':')
        
        # sanity check: plot force dip
#         sig2 = sig.time_slice(force_dip_start, force_rise_start)
#         plt.plot(sig2.times, sig2.magnitude, c=force_colors['dip'], lw=2, zorder=1)
        
        # sanity check: plot force rise
#         sig2 = sig.time_slice(force_rise_start, force_80percent_start)
#         plt.plot(sig2.times, sig2.magnitude, c=force_colors['rise'], lw=2, zorder=1)
        
        # sanity check: plot force plateau
#         sig2 = sig.time_slice(force_80percent_start, force_80percent_end)
#         plt.plot(sig2.times, sig2.magnitude, c=force_colors['plateau'], lw=2, zorder=1)
        
        # sanity check: plot force shoulder
#         if np.isfinite(force_shoulder_start):
#             sig2 = sig.time_slice(force_shoulder_start, force_shoulder_end)
#             plt.plot(sig2.times, sig2.magnitude, c=force_colors['shoulder'], lw=2, zorder=1)

        # sanity check: plot force min, baseline, peak, and threshold crossings
#         plt.plot([force_min_time],        [force_min],       marker=6, markersize=5, color='k')
#         plt.plot([force_peak_time],       [force_peak],      marker=7, markersize=5, color='k')
#         plt.plot([force_rise_start],      [force_baseline],  marker=6, markersize=5, color='k')
#         plt.plot([force_50percent_start], [force_50percent], marker=5, markersize=5, color='k')
#         plt.plot([force_50percent_end],   [force_50percent], marker=4, markersize=5, color='k')
#         plt.plot([force_80percent_start], [force_80percent], marker=5, markersize=5, color='k')
#         plt.plot([force_80percent_end],   [force_80percent], marker=4, markersize=5, color='k')

        
        
        ###
        ### FIND SPIKE TRAINS
        ###
        
        if lazy:
            if metadata['amplitude_discriminators'] is not None:
                for discriminator in metadata['amplitude_discriminators']:
                    sig = get_sig(blk, discriminator['channel'])
                    if sig is not None:
                        sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
                        st = _detect_spikes(sig, discriminator, blk.segments[0].epochs)
                        st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                        st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                        st = st.time_slice(st_epoch_start, st_epoch_end)
                        df.at[i, st.name+' spike train'] = st # 'at', not 'loc', is important for inserting list into cell
        else:
            for spiketrain in blk.segments[0].spiketrains:
                discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == spiketrain.name), None)
                if discriminator is None:
                    raise Exception(f'For data set "{data_set_name}", discriminator "{spiketrain.name}" could not be found')
                st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                st = spiketrain.time_slice(st_epoch_start, st_epoch_end)
                df.at[i, st.name+' spike train'] = st # 'at', not 'loc', is important for inserting list into cell
    
    
            
        ###
        ### QUANTIFY SPIKE TRAINS AND BURSTS
        ###
        
        for k, unit in enumerate(units):
            st = df.loc[i, unit+' spike train']
            if st is not None:
                df.loc[i, unit+' spike count'] = st.size
                if st.size > 0:
                    
                    # get the neural channel
                    channel = st.annotations['channels'][0]
                    sig = get_sig(blk, channel)

                    # get the signal for the behavior with 10 seconds cushion before and after (for better baseline estimation)
                    sig = sig.time_slice(behavior_start - 10*pq.s, behavior_end + 10*pq.s)
                    sig = sig.rescale(channel_units[channels_to_keep.index(channel)])
                    
                    # find every sequence of spikes that qualifies as a burst
                    bursts = df.at[i, unit+' all bursts (s)'] = find_bursts(st, burst_thresholds[unit]) # 'at', not 'loc', is important for inserting list into cell
                    
                    first_burst_start = np.nan
                    first_burst_end = np.nan
                    first_burst_spike_count = 0
                    first_burst_mean_freq = 0*pq.Hz
                    last_burst_start = np.nan
                    last_burst_end = np.nan
                    last_burst_spike_count = 0
                    last_burst_mean_freq = 0*pq.Hz
                    if len(bursts) > 0:

                        for burst in bursts:
                            if is_good_burst(burst):
                                first_burst_start, first_burst_end = burst['Start (s)'], burst['End (s)']
                                first_burst_duration = first_burst_end-first_burst_start
                                df.loc[i, unit+' first burst start (s)'] = first_burst_start.rescale('s')
                                df.loc[i, unit+' first burst end (s)'] = first_burst_end.rescale('s')
                                first_burst_duration = df.loc[i, unit+' first burst duration (s)'] = first_burst_duration.rescale('s')
                                first_burst_spike_count = df.loc[i, unit+' first burst spike count'] = st.time_slice(first_burst_start, first_burst_end).size
                                first_burst_mean_freq = df.loc[i, unit+' first burst mean frequency (Hz)'] = ((first_burst_spike_count-1)/first_burst_duration).rescale('Hz')
                                
                                # find burst RAUC and mean voltage
                                first_burst_rauc = df.loc[i, unit+' first burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=first_burst_start, t_stop=first_burst_end).rescale('uV*s')
                                first_burst_mean_rect_voltage = df.loc[i, unit+' first burst mean rectified voltage (μV)'] = first_burst_rauc/first_burst_duration

                                break # quit after finding first good burst

                        for burst in reversed(bursts):
                            if is_good_burst(burst):
                                last_burst_start, last_burst_end = burst['Start (s)'], burst['End (s)']
                                last_burst_duration = last_burst_end-last_burst_start
                                df.loc[i, unit+' last burst start (s)'] = last_burst_start.rescale('s')
                                df.loc[i, unit+' last burst end (s)'] = last_burst_end.rescale('s')
                                last_burst_duration = df.loc[i, unit+' last burst duration (s)'] = last_burst_duration.rescale('s')
                                last_burst_spike_count = df.loc[i, unit+' last burst spike count'] = st.time_slice(last_burst_start, last_burst_end).size
                                last_burst_mean_freq = df.loc[i, unit+' last burst mean frequency (Hz)'] = ((last_burst_spike_count-1)/last_burst_duration).rescale('Hz')

                                # find burst RAUC and mean voltage
                                last_burst_rauc = df.loc[i, unit+' last burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=last_burst_start, t_stop=last_burst_end).rescale('uV*s')
                                last_burst_mean_rect_voltage = df.loc[i, unit+' last burst mean rectified voltage (μV)'] = last_burst_rauc/last_burst_duration
    
                                break # quit after finding first (actually, last) good burst

                                    
                    # sanity check: plot spikes
                    plt.sca(axes[channels_to_keep.index(channel)])
                    marker = ['.', 'x'][j%2] # alternate markers between behaviors
                    
                    spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
                    plt.scatter(st.times.rescale('s'), spike_amplitudes, marker=marker, c=unit_colors[unit])

                    # sanity check: plot burst windows
                    discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == st.name), None)
                    if discriminator is None:
                        raise Exception(f'For data set "{data_set_name}", discriminator "{st.name}" could not be found')
                    bottom = pq.Quantity(discriminator['amplitude'][0], discriminator['units']).rescale(sig.units)
                    top = pq.Quantity(discriminator['amplitude'][1], discriminator['units']).rescale(sig.units)
                    height = top-bottom
                    for burst in bursts:
                        left = burst['Start (s)']
                        right = burst['End (s)']
                        width = right-left
                        linestyle = '-' if is_good_burst(burst) else '--'
                        rect = patches.Rectangle((left, bottom), width, height, ls=linestyle, edgecolor=unit_colors[unit], fill=False)
                        plt.gca().add_patch(rect)

                    # sanity check: plot markers for edges of bursts
                    if top > 0:
                        plt.plot([first_burst_start], [top], marker=7, markersize=5, color='k')#unit_colors[unit])
                        plt.plot([last_burst_end], [top], marker=7, markersize=5, color='k')#unit_colors[unit])
                    else:
                        plt.plot([first_burst_start], [bottom], marker=6, markersize=5, color='k')#unit_colors[unit])
                        plt.plot([last_burst_end], [bottom], marker=6, markersize=5, color='k')#unit_colors[unit])



        ###
        ### INSTANTANEOUS FIRING FREQUENCIES
        ###
        
        for unit in ['B6/B9']:
            st = df.loc[i, unit+' spike train']
            if st is not None:
                if st.size > 0:
                    channel = st.annotations['channels'][0]
                    plt.sca(axes[channels_to_keep.index(channel)])
                    
                    times = st.times.rescale('s')
                    times = np.concatenate([[behavior_start], times, [behavior_end]])*pq.s
                    iff = 1/elephant.statistics.isi(st)
                    iff = np.concatenate([[0], iff.rescale('1/s'), [0, 0]])/pq.s
                    
                    # arbitrary rescaling to fit in plot
                    shift = -6 * np.abs(discriminator['amplitude']).max()
                    iff = iff.magnitude/5+shift
                    
                    plt.plot(times, iff, drawstyle='steps-post', c=lighten_color(unit_colors[unit], amount=0.7), zorder=0)
        
        
        
        ###
        ### TIMING DELAYS
        ###
        
        i2_burst_start   = df.loc[i, 'I2 spikes first burst start (s)']*pq.s
        i2_burst_end     = df.loc[i, 'I2 spikes last burst end (s)']*pq.s
        b8_burst_start   = df.loc[i, 'B8a/b first burst start (s)']*pq.s
        b8_burst_end     = df.loc[i, 'B8a/b last burst end (s)']*pq.s
        b6b9_burst_start = df.loc[i, 'B6/B9 first burst start (s)']*pq.s
        b6b9_burst_end   = df.loc[i, 'B6/B9 last burst end (s)']*pq.s
        b3_burst_start   = df.loc[i, 'B3 first burst start (s)']*pq.s
        b3_burst_end     = df.loc[i, 'B3 last burst end (s)']*pq.s
        b38_burst_start  = df.loc[i, 'B38 first burst start (s)']*pq.s
        b38_burst_end    = df.loc[i, 'B38 last burst end (s)']*pq.s
        
        # consider B3/B6/B9 bursting if either B3 or B6/B9 is bursting
        b3b6b9_burst_start    = df.loc[i, 'B3/B6/B9 burst start (s)']    = finite_min(b6b9_burst_start, b3_burst_start)
        b3b6b9_burst_end      = df.loc[i, 'B3/B6/B9 burst end (s)']      = finite_max(b6b9_burst_end,   b3_burst_end)
        b3b6b9_burst_duration = df.loc[i, 'B3/B6/B9 burst duration (s)'] = b3b6b9_burst_end - b3b6b9_burst_start
        
        # consider bursting only if B8a/b and B3/B6/B9 are both bursting
        b8_or_b3b6b9_burst_end = df.loc[i, 'B8a/b and B3/B6/B9 conjunction end (s)'] = \
                                           finite_min(b8_burst_end, b3b6b9_burst_end)
        
        # delays from neural to force
        i2_force_min_delay             = df.loc[i, 'Delay from I2 end to force min (s)'] = \
                                                   force_min_time - i2_burst_end

        b8_force_start_delay           = df.loc[i, 'Delay from B8a/b start to force start (s)'] = \
                                                   force_rise_start - b8_burst_start
        b8_force_end_delay             = df.loc[i, 'Delay from B8a/b end to force end (s)'] = \
                                                   force_80percent_end - b8_burst_end
        
        b6b9_force_start_delay         = df.loc[i, 'Delay from B6/B9 start to force start (s)'] = \
                                                   force_rise_start - b6b9_burst_start
        b6b9_force_end_delay           = df.loc[i, 'Delay from B6/B9 end to force end (s)'] = \
                                                   force_80percent_end - b6b9_burst_end

        b3b6b9_force_start_delay       = df.loc[i, 'Delay from B3/B6/B9 start to force start (s)'] = \
                                                   force_rise_start - b3b6b9_burst_start
        b3b6b9_force_end_delay         = df.loc[i, 'Delay from B3/B6/B9 end to force end (s)'] = \
                                                   force_80percent_end - b3b6b9_burst_end
        b3b6b9_force_50percent_delay   = df.loc[i, 'Delay from B3/B6/B9 start to force 50%-height (s)'] = \
                                                   force_50percent_start - b3b6b9_burst_start
        b3b6b9_force_80percent_delay   = df.loc[i, 'Delay from B3/B6/B9 start to force 80%-height (s)'] = \
                                                   force_80percent_start - b3b6b9_burst_start
        
        b3_force_80percent_start_delay = df.loc[i, 'Delay from B3 start to force 80%-height start (s)'] = \
                                                   force_80percent_start - b3_burst_start
        b3_force_80percent_end_delay   = df.loc[i, 'Delay from B3 end to force 80%-height end (s)'] = \
                                                   force_80percent_end - b3_burst_end
        b8_or_b3b6b9_force_80percent_end_delay = \
                                         df.loc[i, 'Delay from either B8a/b or B3/B6/B9 end to force 80%-height end (s)'] = \
                                                   force_80percent_end - b8_or_b3b6b9_burst_end
        
        b38_shoulder_start_delay       = df.loc[i, 'Delay from B38 start to shoulder start (s)'] = \
                                                   force_shoulder_start - b38_burst_start
        b38_shoulder_end_delay         = df.loc[i, 'Delay from B38 end to shoulder end (s)'] = \
                                                   force_shoulder_end - b38_burst_end

    

        ###
        ### MISC
        ###
        
        st = df.loc[i, 'B8a/b spike train']
        b8_preb3b6b9_burst_duration    = df.loc[i, 'B8a/b pre-B3/B6/B9 burst duration (s)'] = \
                                                   b3b6b9_burst_start - b8_burst_start
        b8_preb3b6b9_burst_spike_count = df.loc[i, 'B8a/b pre-B3/B6/B9 burst spike count'] = \
                                                   st.time_slice(b8_burst_start, b3b6b9_burst_start).size
        b8_preb3b6b9_burst_mean_freq   = df.loc[i, 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)'] = \
                                                   ((b8_preb3b6b9_burst_spike_count-1)/b8_preb3b6b9_burst_duration).rescale('Hz')
        
        # get the neural channel
        channel = st.annotations['channels'][0]
        sig = get_sig(blk, channel)
        
        # get the signal for the behavior with 5 seconds cushion before and after (for better baseline estimation)
        sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
        sig = sig.rescale('uV')

        # find RAUC and mean voltage for B8a/b before B3/B6/B9 start in each behavior
        if np.isfinite(b8_burst_start) and np.isfinite(b3b6b9_burst_start):
            b8_preb3b6b9_rauc = df.loc[i, 'B8a/b pre-B3/B6/B9 burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=b8_burst_start, t_stop=b3b6b9_burst_start).rescale('uV*s')
            b8_preb3b6b9_mean_rect_voltage = df.loc[i, 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)'] = b8_preb3b6b9_rauc/b8_preb3b6b9_burst_duration
        else:
            print(f'Missing either B8a/b burst and/or B3/B6/B9 burst in data set "{data_set_name}" for swallow spanning times ({behavior_start}, {behavior_end})')

            
        # get force during rise and plateau
        sig = get_sig(blk, 'Force')
        sig = sig.time_slice(force_rise_start, force_80percent_end)
        sig = sig.rescale('mN')
        
        # get force at end of B8-only burst, offset by delay
        force_b8_only_rise_end = df.loc[i, 'Force delayed B8-only rise end (s)'] = force_rise_start + b8_preb3b6b9_burst_duration
        force_b8_only_rise_height = df.loc[i, 'Force at delayed B8-only rise end (mN)'] = sig[sig.time_index(force_b8_only_rise_end)][0]

        # find average slope during initial rising phase (before B3/B6/B9 begin, offset by delay)
        force_initial_increase = df.loc[i, 'Force initial increase (mN)'] = (force_b8_only_rise_height-force_baseline).rescale('mN')
        force_initial_slope = df.loc[i, 'Force initial slope (mN/s)'] = (force_initial_increase/b8_preb3b6b9_burst_duration).rescale('mN/s')

        
        
        # sanity check: plot force initial rise
#         sig = force_smoothed_sig
#         sig = sig.time_slice(force_rise_start, force_b8_only_rise_end)
#         sig = sig.rescale('mN')
#         plt.sca(axes[channels_to_keep.index('Force')])
#         plt.plot(sig.times, sig.magnitude, c=force_colors['initial rise'], lw=2, zorder=1)
        
        
        
#         # sanity check: plot important times across all subplots, with delays set by I2 end and force min
#         if j == 0:
#             muscle_delay = i2_force_min_delay
#             axes[-1].text(
#                 force_min_time, 1.05, f"{muscle_delay.rescale('ms'):.0f} ms delay",
#                 horizontalalignment='left', verticalalignment='center', transform=axes[-1].get_xaxis_transform(),
#                 fontsize=8)
#         plot_vertical_lines_with_delay(axes, force_min_time, muscle_delay, force_min, force_colors['dip'])
#         plot_vertical_lines_with_delay(axes, force_80percent_start, muscle_delay, force_80percent, force_colors['plateau'])
#         plot_vertical_lines_with_delay(axes, force_80percent_end, muscle_delay, force_80percent, force_colors['plateau'])
#         if np.isfinite(force_shoulder_end):
#             sig = force_smoothed_sig
#             plot_vertical_lines_with_delay(axes, force_shoulder_end, muscle_delay, sig[sig.time_index(force_shoulder_end)][0], force_colors['shoulder'])
    
    
    
    # sanity check: plot vertical lines at "keepers"
    sig = force_smoothed_sig
    ep_poi = next((ep for ep in blk.segments[0].epochs if ep.name == 'Keeper'), None)
    if ep_poi is not None:
        ep_poi = ep_poi.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        if ep_poi.size > 0:
            first_force_min_time = ep_poi.times[0]
            first_i2_burst_end = df.loc[df.index[0], 'I2 spikes first burst end (s)']*pq.s
            muscle_delay = first_force_min_time - first_i2_burst_end

            axes[-1].text(
                first_force_min_time, 1.05, f"{muscle_delay.rescale('ms'):.0f} ms delay",
                horizontalalignment='right', verticalalignment='center', transform=axes[-1].get_xaxis_transform(),
                fontsize=8)
            for t in ep_poi.times:
                plot_vertical_lines_with_delay(axes, t, muscle_delay, sig[sig.time_index(t)][0], 'gray')


        
    # optimize plot margins
    plt.subplots_adjust(
        left   = 0.1,
        right  = 0.99,
        top    = 0.96,
        bottom = 0.06,
        hspace = 0.15,
    )
    
    # export figure
    plt.gcf().savefig(f'sanity-checks/{animal} {food} {bout_index}.png', dpi=300)

    # renumber behaviors assuming all behaviors are from a single contiguous sequence
    df = df.sort_values('Start (s)').reset_index(drop=True).rename_axis('Behavior_index')
    
    # index the table on 4 variables
    df['Animal'] = animal
    df['Food'] = food
    df['Bout_index'] = bout_index
    df = df.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
    
    df_list += [df]
    
# df_all = pd.concat(df_list, sort=False).sort_index()

end = datetime.datetime.now()
print('elapsed time:', end-start)

### Special dataframe used for Fig 2C only

In [None]:
feeding_bouts = {
    # (animal, food, bout_index): (data_set_name, time_window, epoch_types_to_keep)
        
    ('JG12', 'Regular nori', 0): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 147,  165], ['Swallow (regular 5-cm nori strip)']),
    ('JG12', 'Regular nori', 1): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 229,  245], ['Swallow (regular 5-cm nori strip)']),
    ('JG12', 'Regular nori', 2): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 277,  291], ['Swallow (regular 5-cm nori strip)']),

    ('JG12', 'Tape nori',   99): ('IN VIVO / JG12 / 2019-05-10 / 002', [2890, 3080], ['Swallow (tape nori)']),
}

# filter epochs for each feeding condition and perform calculations
df_list = []
last_data_set_name = None
for (animal, food, bout_index), (data_set_name, time_window, epoch_types_to_keep) in feeding_bouts.items():
    
    metadata.select(data_set_name)
    if data_set_name is last_data_set_name:
        # skip reloading the data if it's already in memory
        pass
    else:
        blk = neurotic.load_dataset(metadata, lazy=True)
    last_data_set_name = data_set_name
    
    # construct a query for locating behaviors
    behavior_query = f'(Type in {epoch_types_to_keep}) & ({time_window[0]} <= Start) & (End <= {time_window[1]})'
    
    # construct queries for locating epochs associated with each behavior
    # - each query should match at most one epoch
    # - dictionary keys are used as prefixes for the names of new columns
    subepoch_queries = {}
    
    # construct a table in which each row is a behavior and subepoch data
    # is added as columns, e.g. df['Force start (s)']
    df = BehaviorsDataFrame(blk.segments[0].epochs, behavior_query, subepoch_queries)

    # calculate interbehavior interval assuming all behaviors are from a single contiguous sequence
    df['Interval after (s)'] = np.nan
    previous_i = None
    for i in df.index:
        if previous_i is not None:
            df.loc[previous_i, 'Interval after (s)']  = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'End (s)']
        previous_i = i

    # renumber behaviors assuming all behaviors are from a single contiguous sequence
    df = df.sort_values('Start (s)').reset_index(drop=True).rename_axis('Behavior_index')
    
    # index the table on 4 variables
    df['Animal'] = animal
    df['Food'] = food
    df['Bout_index'] = bout_index
    df = df.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
    
    df_list += [df]
    
df_durations_intervals = pd.concat(df_list, sort=False).sort_index()

## Plotting Functions

In [None]:
def scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, trend_separately=False, tooltips=False, padding=0.05):
    
    for j, (label, query) in enumerate(data_subsets.items()):
        if query is not None:
            df = df_all.query(query)
            ax.scatter(df[xlabel], df[ylabel],
                       label=label, marker=markers[j], c=colors[j], clip_on=False)
            
    all_points = df_all.query(query_union(data_subsets.values()))[[xlabel, ylabel]].dropna()
    
    if len(all_points) > 0:

        xrange = np.ptp(all_points.iloc[:, 0])
        xmin = min(all_points.iloc[:, 0]) - xrange * padding
        xmax = max(all_points.iloc[:, 0]) + xrange * padding
        if xlim is None:
            xlim = [xmin, xmax]
        if xlim[0] is None:
            xlim[0] = xmin
        if xlim[1] is None:
            xlim[1] = xmax

        yrange = np.ptp(all_points.iloc[:, 1])
        ymin = min(all_points.iloc[:, 1]) - yrange * padding
        ymax = max(all_points.iloc[:, 1]) + yrange * padding
        if ylim is None:
            ylim = [ymin, ymax]
        if ylim[0] is None:
            ylim[0] = ymin
        if ylim[1] is None:
            ylim[1] = ymax
    
    if trend_separately:
        for j, (label, query) in enumerate(data_subsets.items()):
            if query is not None:
                df = df_all.query(query)[[xlabel, ylabel]].dropna()
                model = sm.OLS(df.iloc[:,1], sm.add_constant(df.iloc[:,0])).fit()
                model_stats = 'R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(df))
                print(label+':', model_stats)
                model_x = np.linspace(xlim[0], xlim[1], 100)
                model_y = model.params[0] + model.params[1] * model_x
                ax.plot(model_x, model_y, color=colors[j])#, label=model_stats)
                
    if trend:
        model = sm.OLS(all_points.iloc[:,1], sm.add_constant(all_points.iloc[:,0])).fit()
        model_stats = 'R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(all_points))
        print('All points:', model_stats)
        model_x = np.linspace(xlim[0], xlim[1], 100)
        model_y = model.params[0] + model.params[1] * model_x
        ax.plot(model_x, model_y, color='gray')#, label=model_stats)
    
    if tooltips:
        # create a transparent layer containing all points for detecting mouse events
        sc = ax.scatter(all_points.iloc[:, 0], all_points.iloc[:, 1], label=None,
                        alpha=0, # transparent points
                        s=0.001, # small radius of tooltip activation
                       )
        
        # initialize a hidden empty tooltip
        annot = ax.annotate(
            '', xy=(0, 0),                              # initialize empty at origin
            xytext=(0, 15), textcoords='offset points', # position text above point
            color='k', size=10, ha='center',            # small black centered text
            bbox=dict(fc='w', lw=0, alpha=0.6),         # transparent white background
            arrowprops=dict(shrink=0, headlength=7, headwidth=7, width=0, lw=0, color='k'), # small black arrow
        )
        annot.set_visible(False)
        
        # prepare tooltip contents
        tooltip_text = [' '.join([str(x) for x in index]) for index in list(all_points.index)]
        
        # bind the animation of tooltips to a hover event
        def hover(event):
            if event.inaxes == ax:
                cont, ind = sc.contains(event)
                if cont:
                    point_index = ind['ind'][0]
                    pos = sc.get_offsets()[point_index]
                    annot.xy = pos
                    annot.set_text(tooltip_text[point_index])
                    annot.set_visible(True)
                    ax.figure.canvas.draw_idle()
                else:
                    if annot.get_visible():
                        annot.set_visible(False)
                        ax.figure.canvas.draw_idle()
        ax.figure.canvas.mpl_connect('motion_notify_event', hover)
        
    ax.set_xlabel(xlabel)
    ax.set_xlim(xlim)
    ax.set_ylabel(ylabel)#.replace('rectified ',''))
    ax.set_ylim(ylim)

In [None]:
def prettyplot(
    blk,
    t_start,
    t_stop,
    plots,
    outfile_basename=None, # base name of output files
    export_only=False,     # if True, will not render in notebook
    formats=['pdf', 'svg', 'png'], # extensions of output files
    dpi=300,               # resolution (applicable only for PNG)
    figsize=(14, 7),       # figure size in inches
    linewidth=1,           # thickness of lines in points
    majorticks=5,          # spacing of labeled x-axis ticks in seconds
    minorticks=1,          # spacing of unlabeled x-axis ticks in seconds
    ylabel_offset=-0.06,   # horizontal positioning of y-axis labels
    layout_settings=None,  # positioning of plot edges and the space between plots
):
    
    if export_only:
        plt.ioff()
        
    plt.figure(figsize=figsize)

    num_subplots = len(plots)
    for i, p in enumerate(plots):

        # switch to the appropriate subplot in the figure
        if i==0:
            ax = plt.subplot(num_subplots, 1, i+1)
        else:
            plt.subplot(num_subplots, 1, i+1, sharex=ax)

        # select and rescale a channel for the subplot
        sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == p['channel']), None)
        assert sig is not None, f"Signal with name {p['channel']} not found"
        sig = sig.time_slice(t_start, t_stop)
        sig = sig.rescale(p['units'])

        # downsample the data
        sig_downsampled = DownsampleNeoSignal(sig, p.get('decimation_factor', 1))

        # specify the x- and y-data for the subplot
        plt.plot(
            sig_downsampled.times,
            sig_downsampled.as_quantity(),
            linewidth=linewidth,
            color=p.get('color', 'k'),
        )

        # specify the y-axis label
        plt.ylabel(p.get('ylabel', sig.name+' ('+sig.units.dimensionality.string+')'))

        # position the y-axis label so that all subplot y-axis labels are aligned
        plt.gca().yaxis.set_label_coords(ylabel_offset, 0.5)

        # specify the plot range
        plt.xlim([t_start, t_stop])
        plt.ylim(p['ylim'])

        if i == num_subplots-1:
            # turn on minor (frequent and unlabeled) ticks for the bottom x-axis
            plt.gca().xaxis.set_minor_locator(MultipleLocator(minorticks))

            # turn on major (infrequent and labeled) ticks for the bottom x-axis
            plt.gca().xaxis.set_major_locator(MultipleLocator(majorticks))

            # disable scientific notation for major tick labels
            # plt.gca().xaxis.get_major_formatter().set_useOffset(False) # not necessary?

            # specify the bottom x-axis label
            plt.xlabel('Time ('+sig.times.units.dimensionality.string+')')

            # offset axes from plot
            sns.despine(ax=plt.gca(), offset=10)#, trim=True)
        else:
            # offset axes and remove x-axis
            sns.despine(ax=plt.gca(), offset=10, trim=True, bottom=True)
            plt.gca().xaxis.set_visible(False)

    # adjust the white space around and between the subplots
    if layout_settings is None:
        plt.gcf().tight_layout()
    else:
        plt.subplots_adjust(**layout_settings)

    if outfile_basename is not None:
        # specify file metadata (applicable only for PDF)
        metadata = dict(
            Subject = 'Data file: '  + blk.file_origin + '\n' +
                      'Start time: ' + str(t_start)    + '\n' +
                      'End time: '   + str(t_stop),
        )

        # write the figure to files
        for ext in formats:
            plt.gcf().savefig(outfile_basename+'.'+ext, metadata=metadata, dpi=dpi)

    if export_only:
        plt.ion()

---

# Figures

## [FIGURE 1] Biomechanics schematic ❌

Schematic illustration of biomechanics of _Aplysia_ swallowing
- Synthesis of Cullins et al. 2015a, Fig. 6, and McManus et al. 2014, Fig. 10.
- Show grasper protraction/retraction, grasper closing/opening, anterior jaws closing

---

## [FIGURE 2] Motor pattern and force examples

### Fig 2A ✅

Short sequence of swallows on regular nori, 4 channels

- TODO: Annotations? (e.g., "Bite/swallow", "Swallow", "Swallow", "Bite")

The plot output by this code is much longer than it will be in the final figure and must be cropped manually. It is rendered with the same amount of time showing as Fig 2B so that time scales are identical.

Labels are intentionally overlapping since they will be manually replaced in Inkscape for a different style after rendering.

In [None]:
start = datetime.datetime.now()

outfile_basename = 'figure-2A-export'
data_set_name = 'IN VIVO / JG12 / 2019-05-10 / 002'
t_start, t_stop = [228.8, 392.8] * pq.s # t=278, twidth=164, first few are regular nori strip swallows
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60]}, #, 'decimation_factor': 400},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25]}, #, 'decimation_factor': 400},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45]}, #, 'decimation_factor': 400},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'ylabel': 'BN3 (μV)'}, #, 'decimation_factor': 400},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -50, 450], 'decimation_factor': 100},
]

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (14, 5),
    majorticks = 10,
    minorticks = 5,
    linewidth = 0.5,
    ylabel_offset = -0.01,
    layout_settings = dict(
        left   = 0.04,
        right  = 0.99,
        top    = 0.97,
        bottom = 0.08,
        hspace = -0.1,
    ),
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata)

# plot the data
# with sns.plotting_context('poster', font_scale=0.5):
with sns.plotting_context('notebook', font_scale=1):
    prettyplot(blk, t_start, t_stop, plots, outfile_basename, **kwargs)

end = datetime.datetime.now()
print('render time:', end-start)

---

### Fig 2B ✅

Long sequence of swallows on tape nori exemplar

- TODO: Mark which swallows are used in later analysis?

Labels are intentionally overlapping since they will be manually replaced in Inkscape for a different style after rendering.

In [None]:
start = datetime.datetime.now()

outfile_basename = 'figure-2B-export'
data_set_name = 'IN VIVO / JG12 / 2019-05-10 / 002'
t_start, t_stop = [2879, 3043] * pq.s # t=2928.2, twidth=164, 19 tape nori swallows
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60]}, #, 'decimation_factor': 400},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25]}, #, 'decimation_factor': 400},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45]}, #, 'decimation_factor': 400},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'ylabel': 'BN3 (μV)'}, #, 'decimation_factor': 400},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -50, 450], 'decimation_factor': 100},
]

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (14, 5),
    majorticks = 10,
    minorticks = 5,
    linewidth = 0.5,
    ylabel_offset = -0.01,
    layout_settings = dict(
        left   = 0.04,
        right  = 0.99,
        top    = 0.97,
        bottom = 0.08,
        hspace = -0.1,
    ),
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata)

# plot the data
# with sns.plotting_context('poster', font_scale=0.5):
with sns.plotting_context('notebook', font_scale=1):
    prettyplot(blk, t_start, t_stop, plots, outfile_basename, **kwargs)

end = datetime.datetime.now()
print('render time:', end-start)

---

### Fig 2C ✅

Plot swallow duration and inter-swallow interval differences between tape nori and reg nori

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_durations_intervals.reset_index()
df = df.rename(columns={
    'Duration (s)':       'Swallow duration',
    'Interval after (s)': 'Inter-swallow interval'})
df = pd.melt(df,
             id_vars=['Animal', 'Food', 'Bout_index', 'Behavior_index'],
             value_vars=['Swallow duration', 'Inter-swallow interval'],
             var_name='',
             value_name='Duration/interval (s)')
sns.boxplot(hue='Food', y='Duration/interval (s)', x='', data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(hue='Food', y='Duration/interval (s)', x='', data=df)

plt.tight_layout()

---

### Fig 2D ✅

Exemplar force plot for one swallow with labeled features (rise, plateau, shoulder, etc.)

Labels are intentionally overlapping since they will be manually replaced in Inkscape for a different style after rendering.

In [None]:
start = datetime.datetime.now()

outfile_basename = 'figure-2D-export'
data_set_name = 'IN VIVO / JG12 / 2019-05-10 / 002'
t_start, t_stop = [3001.3, 3009.75] * pq.s # t=3003.835, twidth=8.45
plots = [
    {'channel': 'Force',    'units': 'mN', 'ylim': [   0, 250], 'decimation_factor': 100},
]

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (7, 5),
    majorticks = 1,
    minorticks = 1,
    linewidth = 2,
    ylabel_offset = -0.04,
    layout_settings = dict(
        left   = 0.1,
        right  = 0.99,
        top    = 0.97,
        bottom = 0.08,
        hspace = -0.1,
    ),
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# add/update the force filter
new_force_filter = {'channel': 'Force', 'lowpass': 10}
update_filter(metadata, new_force_filter)

# load the data
blk = neurotic.load_dataset(metadata)

# plot the data
# with sns.plotting_context('poster', font_scale=0.5):
with sns.plotting_context('notebook', font_scale=1):
    prettyplot(blk, t_start, t_stop, plots, outfile_basename, **kwargs)
    
end = datetime.datetime.now()
print('render time:', end-start)

---

## [FIGURE 3] First phase of swallowing: peak of protraction and min force

### Fig 3A ✅

Example motor patterns, lines connecting I2 muscle activity ending and force minimum

* TODO: Place event markers programattically so they are completely accurate

Labels are intentionally overlapping since they will be manually replaced in Inkscape for a different style after rendering.

In [None]:
start = datetime.datetime.now()

outfile_basename = 'figure-3A-export'
data_set_name = 'IN VIVO / JG12 / 2019-05-10 / 002'
t_start, t_stop = [2944.5, 3002.5] * pq.s # t=2961.9, twidth=58, 6 tape nori swallows
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -35,  35]}, #, 'decimation_factor': 400},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25]}, #, 'decimation_factor': 400},
#     {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45]}, #, 'decimation_factor': 400},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'ylabel': 'BN3 (μV)'}, #, 'decimation_factor': 400},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -50, 450], 'decimation_factor': 100},
]

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (12, 4),
    majorticks = 10,
    minorticks = 5,
    linewidth = 0.5,
    ylabel_offset = -0.01,
    layout_settings = dict(
        left   = 0.04,
        right  = 0.99,
        top    = 0.97,
        bottom = 0.08,
        hspace = 0,
    ),
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata)

# plot the data
# with sns.plotting_context('poster', font_scale=0.5):
with sns.plotting_context('notebook', font_scale=1):
    prettyplot(blk, t_start, t_stop, plots, outfile_basename, **kwargs)

end = datetime.datetime.now()
print('render time:', end-start)

---

### Fig 3B ✅

Plot delay between I2 muscle activity ending and force drop

- **The Good**: Almost all points are to the right, consistent with peak protraction causing force drop

- **The Bad**: Within-animal variability is high, and some extreme delays are too long

- **Verdict**: Very good for showing I2 protraction idea is plausible. Data from different animals can probably be lumped together without issue.

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
column = 'Delay from I2 end to force min (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

In [None]:
column = 'Delay from I2 end to force min (s)'
print_column_analysis(column)

In [None]:
# points that aren't where we expected
column = 'Delay from I2 end to force min (s)'
df_all[df_all[column] < 0][[column]]

---

## [FIGURE 4] Second phase of swallowing: grasper closing and initial force rise

### Fig 4A ❌

Example motor patterns, lines connecting B8a/b motor neuron activity starting and force beginning to rise

---

### Fig 4B ❗

Plot delay between B3/B6/B9 motor neuron activity starting and force beginning to rise. Takeaway: Force occurs first, evidence that these motor neurons don't start the rise

- **The Good**: In most swallows, force starts before B3/B6/B9

- **The Bad**: However, this is not true in about 1/3 of swallows. Variability is high, so this timing relationship is not reliable.

- This suggests that although there are swallows where B3/B6/B9 might have initialiated force rise, there are even more swallows where it couldn't have because force rose first.

- **Verdict**: Because the plot isn't compelling, instead of showing it we could just cite the statistic ("In X% of swallows, B3/B6/B9 bursts started after force generation had already begun") to motivate the B8a/b analysis

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
column = 'Delay from B3/B6/B9 start to force start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

In [None]:
column = 'Delay from B3/B6/B9 start to force start (s)'
print_column_analysis(column)

In [None]:
# points that aren't where we expected
column = 'Delay from B3/B6/B9 start to force start (s)'
df_all[df_all[column] > 0][[column]]

---

### Fig 4C ✅

Plot delay between B8a/b motor neuron activity starting and force beginning to rise. Takeaway: B8a/b have the right timing for being the cause of initial force rise (by closing the grasper)

- **The Good**: B8a/b always precedes force start. With the exception of JG11, variability is low. Means are similar enough that we might be able to propose a "typical" delay.

- **The Bad**: JG11's variability

- **Verdict**: Very supportive of the argument. Check JG11's outliers for errors. This figure is clean enough that swallows from multiple animals could probably be lumped together without issue.

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
column = 'Delay from B8a/b start to force start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

In [None]:
column = 'Delay from B8a/b start to force start (s)'
print_column_analysis(column)

In [None]:
# points that aren't where we expected
column = 'Delay from B8a/b start to force start (s)'
df_all[df_all[column] < 0][[column]]

---

### Fig 4D ❌

Plot mean rectified voltage on radula nerve (B8a/b burst) preceding the start of B3/B6/B9 activity against initial force slope or delta force. Takeaway: Show that intensity of B8a/b activity correlates with intensity of force rise

- **The Good**: 4 of 5 animals have positive slopes

- **The Bad**: Highly variable. No significant correlations.

- The boundaries of the "B8a/b pre-B3/B6/B9" period are hightly sensitive to several threshold parameters, so it may not be possible to locate it reliably.

- Measurements based on mean rectified voltage probably should not be lumped together from different animals without normalization.

- **Verdict**: 

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b first burst mean rectified voltage (μV)', [0, None]
xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]
ylabel, ylim = 'Force initial slope (mN/s)', [0, None]
# ylabel, ylim = 'Force initial increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

---

### Fig 4 Alternatives ❓

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b first burst mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)', [0, None]
xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]
ylabel, ylim = 'Force initial slope (mN/s)', [0, None]
# ylabel, ylim = 'Force initial increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b first burst mean rectified voltage (μV)', [0, None]
xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]
# ylabel, ylim = 'Force initial slope (mN/s)', [0, None]
ylabel, ylim = 'Force initial increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b first burst mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)', [0, None]
xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]
# ylabel, ylim = 'Force initial slope (mN/s)', [0, None]
ylabel, ylim = 'Force initial increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8 activity duration (s)', [0, 8]
xlabel, xlim = 'B8a/b first burst duration (s)', [0, 8]
ylabel, ylim = 'Force rise and plateau duration (s)', [0, 8]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

---

## [FIGURE 5] Third phase of swallowing: maintaining high force

### Fig 5A ✅

Example motor patterns, lines connecting B3/B6/B9 motor neuron activity to force plateau

* TODO: Place event markers programmatically

Labels are intentionally overlapping since they will be manually replaced in Inkscape for a different style after rendering.

In [None]:
start = datetime.datetime.now()

outfile_basename = 'figure-5A-export'
data_set_name = 'IN VIVO / JG12 / 2019-05-10 / 002'
t_start, t_stop = [2944.5, 3002.5] * pq.s # t=2961.9, twidth=58, 6 tape nori swallows
plots = [
#     {'channel': 'I2',       'units': 'uV', 'ylim': [ -35,  35]}, #, 'decimation_factor': 400},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25]}, #, 'decimation_factor': 400},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45]}, #, 'decimation_factor': 400},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'ylabel': 'BN3 (μV)'}, #, 'decimation_factor': 400},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -50, 450], 'decimation_factor': 100},
]

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (12, 4),
    majorticks = 10,
    minorticks = 5,
    linewidth = 0.5,
    ylabel_offset = -0.01,
    layout_settings = dict(
        left   = 0.04,
        right  = 0.99,
        top    = 0.97,
        bottom = 0.08,
        hspace = 0,
    ),
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata)

# plot the data
# with sns.plotting_context('poster', font_scale=0.5):
with sns.plotting_context('notebook', font_scale=1):
    prettyplot(blk, t_start, t_stop, plots, outfile_basename, **kwargs)

end = datetime.datetime.now()
print('render time:', end-start)

---

### Fig 5B ✅

Plot delay between B3/B6/B9 motor neuron activity starting and force reaching 50% max height. Takeaway: B3/B6/B9 have the right timing for keeping force high

- **The Good**: All delays are positive.

- **The Bad**: Delays are highly variable.

- Perhaps 50%-height isn't the right comparison point. Maybe B3/B6/B9 have effects earlier at 20%-height, or maybe later.

- **Verdict**: Not bad, but perhaps could be better. Delay might be more consistent with a better comparison point.

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
column = 'Delay from B3/B6/B9 start to force 50%-height (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

In [None]:
column = 'Delay from B3/B6/B9 start to force 50%-height (s)'
print_column_analysis(column)

In [None]:
# points that aren't where we expected
column = 'Delay from B3/B6/B9 start to force 50%-height (s)'
df_all[df_all[column] < 0][[column]]

---

### Fig 5C ✅

Plot B3/B6/B9 burst duration against time force is above 80% max height

- **The Good**: All same positive trends, most are significant.

- **The Bad**: Would be more satisfying if the points lay on the diagonal.

- **Verdict**: Good result. Might be improved by choosing a lower plateau threshold with longer duration, such as 75%.

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3 first burst duration (s)', [0, 5]
# xlabel, xlim = 'B6/B9 first burst duration (s)', [0, 5]
xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 5]
ylabel, ylim = 'Force 80%-height duration (s)', [0, 5]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

---

### Fig 5D ❌

Plot delay between EITHER B8a/b OR B3/B6/B9 activity ending (whichever comes first) and end of force plateau. Takeaway: Show that simultaneous B8a/b and B3/B6/B9 activity are needed to sustain force, i.e., force drops when either grasper opens or retraction stops

- **The Good**: There are more points on the right than on the left

- **The Bad**: Very large variability, can't really determine sign of the timing delay

- The end of the "simultaneous B8a/b and B3/B6/B9" period is sensitive to several threshold parameters, so it may not be possible to locate it reliably.

- Perhaps the hard rule of using 80%-height for the end of the plateau is not generally applicable.

- **Verdict**: Not usable in current form.

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
column = 'Delay from either B8a/b or B3/B6/B9 end to force 80%-height end (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

In [None]:
column = 'Delay from either B8a/b or B3/B6/B9 end to force 80%-height end (s)'
print_column_analysis(column)

In [None]:
# points that aren't where we expected
column = 'Delay from either B8a/b or B3/B6/B9 end to force 80%-height end (s)'
df_all[df_all[column] < 0][[column]]

---

### Fig 5 Alternatives ❓

__Other ideas:__

* Plot mean voltage on BN2 during B3/6/9 acitivty vs slope or relative rise in force ?

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3/6/9/10 activity duration (s)', [0, 8]
xlabel, xlim = 'B6/B9 first burst duration (s)', [0, 8]
ylabel, ylim = 'Force rise and plateau duration (s)', [0, 8]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
df = df.rename(columns={
    'Delay from B3 start to force 80%-height start (s)': 'Start of burst',
    'Delay from B3 end to force 80%-height end (s)':     'End of burst'})
df = pd.melt(df,
             id_vars=['Animal', 'Food', 'Bout_index', 'Behavior_index'],
             value_vars=['Start of burst', 'End of burst'],
             var_name='When?',
             value_name='Delay from B3 to plateau (s)')
sns.boxplot(hue='Animal', x='Delay from B3 to plateau (s)', y='When?', data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(hue='Animal', x='Delay from B3 to plateau (s)', y='When?', data=df)

plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3/6/9/10 activity mean rectified voltage (μV)', [0, None]
xlabel, xlim = 'B6/B9 first burst mean rectified voltage (μV)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]
# ylabel, ylim = 'Force 80%-height (mN)', [0, None]
ylabel, ylim = 'Force increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3/6/9/10 activity mean rectified voltage (μV)', [0, None]
xlabel, xlim = 'B6/B9 first burst mean frequency (Hz)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]
# ylabel, ylim = 'Force 80%-height (mN)', [0, None]
ylabel, ylim = 'Force increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

---

## [FIGURE 6] Fourth phase of swallowing: force shoulder during initial protraction

### Fig 6A ❌

Example motor patterns, lines connecting B38 motor neuron activity and force shoulder

---

### Fig 6B ❌

Plot delay between B38 motor neuron activity ending and force shoulder ending

- **The Good**: 

- **The Bad**: Terrible variability. Certainly bad shoulders are contributing to this.

- **Verdict**: This figure needs to use only the best hand-selected swallows with well-defined shoulders AND well-defined B38 burst

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
column = 'Delay from B38 end to shoulder end (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

In [None]:
column = 'Delay from B38 end to shoulder end (s)'
print_column_analysis(column)

In [None]:
# points that aren't where we expected
column = 'Delay from B38 end to shoulder end (s)'
df_all[df_all[column] < 0][[column]]

---

### Fig 6 Alternatives ❓

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B38 activity duration (s)', [0, 6]
xlabel, xlim = 'B38 last burst duration (s)', [0, 6]
ylabel, ylim = 'Force shoulder duration (s)', [0, 6]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B38 last burst mean frequency (Hz)', [0, None]
ylabel, ylim = 'Force shoulder duration (s)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

---

## [FIGURE 7] Summary ❌

Firing rate model (time permitting), or schematic summary of all phases (boxes for motor neuron activity, idealized force, etc.)