# Jump to a Figure

- [Figure 3](#[FIGURE-3])
  - [Figure 3A](#🐌-Figure-3A)
  - [Figure 3B](#🐌-Figure-3B)
  - [Figure 3C](#🐌-Figure-3C)
  - [Figure 3D](#🐌-Figure-3D)
- [Figure 4](#[FIGURE-4])
  - [Figure 4C](#🐌-Figure-4C)
  - [Figure 4D](#🐌-Figure-4D)

# Preamble

## Import Packages

In [None]:
import os
import datetime
from tqdm.notebook import tqdm
import numpy as np
import scipy as sp
import quantities as pq
import elephant
import pandas as pd
import statsmodels.api as sm
import neurotic
from neurotic.datasets.data import _detect_spikes, _find_bursts
from modules.utils import BehaviorsDataFrame, DownsampleNeoSignal
from modules import r_stats
from modules.plot_utils import add_scalebar, solve_figure_horizontal_dimensions, solve_figure_vertical_dimensions

pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import matplotlib.lines as mlines
from matplotlib.markers import CARETLEFT, CARETRIGHT, CARETUP, CARETDOWN, CARETUPBASE
from matplotlib.ticker import MultipleLocator
import seaborn as sns

from rpy2.robjects.packages import importr
from rpy2.robjects import numpy2ri
numpy2ri.activate()

In [None]:
import warnings

# don't warn about invalid comparisons to NaN
np.seterr(invalid='ignore')

# np.nanmax raises a warning if all values are NaN and returns NaN, which is the behavior we want
warnings.filterwarnings('ignore', message='All-NaN slice encountered')

# elephant.statistics.instantaneous_rate always complains about negative values
warnings.filterwarnings('ignore', message='Instantaneous firing rate approximation contains '
                                          'negative values, possibly caused due to machine '
                                          'precision errors')

# with matplotlib>=3.1 and seaborn<=0.9.0, deprecation warnings are raised
# whenever tick marks are placed on the right axis but not the left
from matplotlib.cbook import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)

# don't complain about opening too many figures
plt.rcParams.update({'figure.max_open_warning': 0})

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

## Plot Settings

In [None]:
export_dir = 'neurotic-manuscript-figures'
if not os.path.exists(export_dir):
    os.mkdir(export_dir)

In [None]:
# general plot settings
sns.set(
#     context = 'poster',
    style = 'ticks',
    font_scale = 1,
    font = 'Palatino Linotype',
)

In [None]:
# display current color palette
with sns.axes_style('darkgrid'):
    sns.palplot(sns.color_palette(None), size=0.5)

In [None]:
unit_colors = {
    'B38':       '#EFBF46', # yellow
    'I2':        '#DC5151', # red
    'B8a/b':     'C6',      # pink
    'B6/B9':     'C9',      # light blue
    'B3/B6/B9':  '#5A9BC5', # medium blue
    'B3':        '#4F80BD', # dark blue
    'B4/B5':     '#00A86B', # jade green
    'Force':     'k',       # black
}
force_colors = {
    'shoulder':     unit_colors['B38'],
    'dip':          unit_colors['I2'],
    'initial rise': unit_colors['B8a/b'],
    'rise':         unit_colors['B8a/b'],
    'plateau':      unit_colors['B3/B6/B9'],
    'drop':         'gray',
}

In [None]:
# display the selected unit colors
with sns.axes_style('darkgrid'):
    sns.palplot(unit_colors.values(), size=0.5)

In [None]:
# print hex codes for selected unit colors
for unit, color in unit_colors.items():
    print(f'{unit.ljust(10)} {mcolors.to_hex(color).upper()}')

In [None]:
# see simulated colorblindness for selected unit colors
print(
    'https://davidmathlogic.com/colorblind/#' +
    '-'.join([mcolors.to_hex(c).upper().replace('#', '%23') for c in unit_colors.values()]))

## Data Parameters

In [None]:
units = [
    'B38',
    'I2',
    'B8a/b',
    'B6/B9',
    'B3',
    'B4/B5',
]

In [None]:
burst_thresholds_default = {
    'B38':   ( 8,  5)*pq.Hz, # based on McManus et al. 2014
    'I2':    (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
    'B8a/b': ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
    'B6/B9': (10,  5)*pq.Hz, # based on Lu et al. 2015
    'B3':    ( 8,  2)*pq.Hz, # based on Lu et al. 2015
    'B4/B5': ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
}

burst_thresholds_by_animal = {
    'JG07': burst_thresholds_default.copy(),
    'JG08': burst_thresholds_default.copy(),
    'JG11': burst_thresholds_default.copy(),
    'JG12': burst_thresholds_default.copy(),
    'JG14': burst_thresholds_default.copy(),
}

# exceptions
burst_thresholds_by_animal['JG07']['B8a/b'] = ( 2,   2  )*pq.Hz # both thresholds reduced for this animal because B8a/b signal was weak with few spikes
burst_thresholds_by_animal['JG08']['B6/B9'] = (10,   3  )*pq.Hz # end threshold reduced for this animal
burst_thresholds_by_animal['JG11']['B4/B5'] = ( 1.5, 1.5)*pq.Hz # both thresholds reduced for this animal because only one neuron appeared to project
burst_thresholds_by_animal['JG14']['B6/B9'] = ( 4,   2  )*pq.Hz # both thresholds reduced for this animal because B6/B9 always fired slowly

In [None]:
channel_units = [
    'uV', # I2
    'uV', # RN
    'uV', # BN2
    'uV', # BN3
    'mN', # Force
]

In [None]:
channel_names_by_animal = {
    'JG07': ['I2-L', 'RN-L', 'BN2-L', 'BN3-L',    'Force'],
    'JG08': ['I2',   'RN',   'BN2',   'BN3',      'Force'],
    'JG11': ['I2',   'RN',   'BN2',   'BN3-PROX', 'Force'],
    'JG12': ['I2',   'RN',   'BN2',   'BN3-DIST', 'Force'],
    'JG14': ['I2',   'RN',   'BN2',   'BN3-PROX', 'Force'],
}

In [None]:
sig_filters_by_animal = {
    'JG07': [{'channel': 'I2-L', 'lowpass': 100}],
    'JG08': [{'channel': 'I2',   'lowpass': 100}],
    'JG11': [{'channel': 'I2',   'lowpass': 100}],
    'JG12': [{'channel': 'I2',   'lowpass': 100}],
    'JG14': [{'channel': 'I2',   'lowpass': 100}],
}

In [None]:
epoch_types_by_food = {
    'Regular nori': ['Swallow (regular 5-cm nori strip)'],
    'Tape nori':    ['Swallow (tape nori)'],
}

In [None]:
feeding_bouts = {
    # (animal, food, bout_index): (data_set_name, time_window)

    ('JG07', 'Regular nori', 0): ('IN VIVO / JG07 / 2018-05-20 / 002', [1496, 1518]), # 4 swallows
    ('JG07', 'Regular nori', 1): ('IN VIVO / JG07 / 2018-05-20 / 002', [1169, 1191]), # 4 swallows
    ('JG07', 'Regular nori', 2): ('IN VIVO / JG07 / 2018-05-20 / 002', [1582, 1615]), # 5 swallows
    ('JG08', 'Regular nori', 0): ('IN VIVO / JG08 / 2018-06-21 / 002', [ 256,  287]), # 4 swallows
    ('JG08', 'Regular nori', 1): ('IN VIVO / JG08 / 2018-06-21 / 002', [ 454,  481]), # 4 swallows
    ('JG11', 'Regular nori', 0): ('IN VIVO / JG11 / 2019-04-03 / 001', [1791, 1819]), # 5 swallows
    ('JG11', 'Regular nori', 1): ('IN VIVO / JG11 / 2019-04-03 / 004', [ 551,  568]), # 3 swallows
    ('JG12', 'Regular nori', 0): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 147,  165]), # 3 swallows
    ('JG12', 'Regular nori', 1): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 229,  245]), # 3 swallows
    ('JG12', 'Regular nori', 2): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 277,  291]), # 3 swallows
    ('JG14', 'Regular nori', 0): ('IN VIVO / JG14 / 2019-07-30 / 001', [1834, 1865]), # 4 swallows
    ('JG14', 'Regular nori', 1): ('IN VIVO / JG14 / 2019-07-30 / 001', [1910, 1943]), # 5 swallows
    ('JG14', 'Regular nori', 2): ('IN VIVO / JG14 / 2019-07-30 / 001', [2052, 2084]), # 5 swallows
    
    ('JG07', 'Tape nori',    0): ('IN VIVO / JG07 / 2018-05-20 / 002', [2718, 2755]), # 5 swallows
    ('JG08', 'Tape nori',    0): ('IN VIVO / JG08 / 2018-06-21 / 002', [ 147,  208]), # 7 swallows, some bucket and head movement
    ('JG08', 'Tape nori',    1): ('IN VIVO / JG08 / 2018-06-21 / 002', [ 664,  701]), # 5 swallows, large bucket movement
    ('JG08', 'Tape nori',    2): ('IN VIVO / JG08 / 2018-06-21 / 002', [1451, 1477]), # 3 swallows, some bucket movement
    ('JG11', 'Tape nori',    0): ('IN VIVO / JG11 / 2019-04-03 / 004', [1227, 1280]), # 5 swallows, inward food movements had very low amplitude
    ('JG12', 'Tape nori',    0): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 436,  465]), # 4 swallows
    ('JG12', 'Tape nori',    1): ('IN VIVO / JG12 / 2019-05-10 / 002', [2901, 2937]), # 5 swallows
    ('JG14', 'Tape nori',    0): ('IN VIVO / JG14 / 2019-07-29 / 004', [ 829,  870]), # 5 swallows
}

In [None]:
# for example figures only -- not used in majority of analysis because of long inter-swallow intervals

exemplary_bout    = ('JG12', 'Tape nori', 101)
exemplary_swallow = ('JG12', 'Tape nori', 102)

feeding_bouts[exemplary_bout]    = ('IN VIVO / JG12 / 2019-05-10 / 002', [2970.7, 2992.0]) # 3 swallows
feeding_bouts[exemplary_swallow] = ('IN VIVO / JG12 / 2019-05-10 / 002', [2977.3, 2984.5]) # 1 swallow

In [None]:
# these unloaded swallows have unreliable inward movement measurement
# and are excluded from some parts of the analysis
unreliable_inward_movement = [
    ('JG07', 'Regular nori', 0, 2), # inward movement unusually short (maybe strip tore?)
    ('JG07', 'Regular nori', 0, 3), # finished strip mid-retraction
    ('JG07', 'Regular nori', 1, 1), # inward movement unusually short (maybe strip tore?)
    ('JG07', 'Regular nori', 1, 3), # finished strip mid-retraction
    ('JG07', 'Regular nori', 2, 2), # inward movement unusually short (maybe strip tore?)
    ('JG07', 'Regular nori', 2, 4), # finished strip mid-retraction
    ('JG08', 'Regular nori', 0, 3), # finished strip mid-retraction
    ('JG08', 'Regular nori', 1, 3), # finished strip mid-retraction
    ('JG11', 'Regular nori', 0, 4), # finished strip mid-retraction
    ('JG11', 'Regular nori', 1, 2), # finished strip mid-retraction
    ('JG12', 'Regular nori', 0, 2), # finished strip mid-retraction
    # JG12 Regular nori 0 1 is OK
    ('JG12', 'Regular nori', 2, 2), # finished strip mid-retraction
    ('JG14', 'Regular nori', 0, 3), # finished strip mid-retraction
    ('JG14', 'Regular nori', 1, 0), # bite-swallow and inward movement unusually short (maybe strip tore?)
    ('JG14', 'Regular nori', 1, 3), # inward movement unusually short (maybe strip tore?)
    ('JG14', 'Regular nori', 1, 4), # finished strip mid-retraction
    ('JG14', 'Regular nori', 2, 1), # not visible
]

# these behaviors at the start of each unloaded bout are also excluded
# because they are actually bite-swallows instead of pure swallows
bite_swallow_behaviors = [
    ('JG07', 'Regular nori', 0, 0),
    ('JG07', 'Regular nori', 1, 0),
    ('JG07', 'Regular nori', 2, 0),
    ('JG08', 'Regular nori', 0, 0),
    ('JG08', 'Regular nori', 1, 0),
    ('JG11', 'Regular nori', 0, 0),
    ('JG11', 'Regular nori', 1, 0),
    ('JG12', 'Regular nori', 0, 0),
    ('JG12', 'Regular nori', 1, 0),
    ('JG12', 'Regular nori', 2, 0),
    ('JG14', 'Regular nori', 0, 0),
    ('JG14', 'Regular nori', 1, 0),
    ('JG14', 'Regular nori', 2, 0),
]
unreliable_inward_movement += bite_swallow_behaviors

## Helper Functions

In [None]:
def query_union(queries):
    return '|'.join([f'({q})' for q in queries if q is not None])

def label2query(label):
    animal, food = label[:4], label[5:]
    query = f'(Animal == "{animal}") & (Food == "{food}")'
    return query

def contains(series, string):
#     return np.array([string in x for x in series])
    return series.map(lambda x: string in x)

In [None]:
def get_sig(blk, channel):
    sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == channel), None)
    if sig is None:
        raise Exception(f'Channel "{channel}" could not be found')
    else:
        return sig

In [None]:
def get_sig_index(blk, channel):
    index = next((i for i, sig in enumerate(blk.segments[0].analogsignals) if sig.name == channel), None)
    if index is None:
        raise Exception(f'Channel "{channel}" could not be found')
    else:
        return index

In [None]:
def apply_filters(blk, metadata):
    # nearly identical to neurotic's implementation except
    # time_slice ensures proxies are loaded
    
    for sig_filter in metadata['filters']:
        index = get_sig_index(blk, sig_filter['channel'])
        high = sig_filter.get('highpass', None)
        low  = sig_filter.get('lowpass',  None)
        if high:
            high *= pq.Hz
        if low:
            low  *= pq.Hz
        blk.segments[0].analogsignals[index] = elephant.signal_processing.butter(  # may raise a FutureWarning
            signal = blk.segments[0].analogsignals[index].time_slice(None, None),
            highpass_freq = high,
            lowpass_freq  = low,
        )
    
    return blk

In [None]:
def is_good_burst(burst):
    time, duration, n_spikes = burst
    return duration >= 0.5*pq.s and n_spikes > 2

In [None]:
# must first convert args from quantities to simple ndarrays
# (use .rescale('s').magnitude)
def normalize_time(fixed_times, t, extrapolate=True):
    
    if not isinstance(t, np.ndarray):
        if type(t) is list:
            t = np.array(t)
        else:
            t = np.array([t])
    
    assert not isinstance(fixed_times, pq.Quantity), f'fixed_times should not be a quantity (use .magnitude)'
    assert not isinstance(t, pq.Quantity), f't should not be a quantity (use .magnitude)'
    assert np.all(np.diff(fixed_times[~np.isnan(fixed_times)])>=0), f'fixed_times must be sorted: {fixed_times}'
    
    # create a copy in which NaNs are removed
    # - this is done because searchsorted does not work well with NaNs
    # - infinities are retained here
    fixed_times_without_nans = fixed_times[np.where(~np.isnan(fixed_times))[0]]
    
    # find the indexes of the fixed values that are just after each value in t
    indexes = np.searchsorted(fixed_times_without_nans, t)
    
    # adjust indexes to account for the NaNs that were removed
    nan_indexes = np.where(np.isnan(fixed_times))[0]
    for nan_index in nan_indexes:
        indexes[indexes >= nan_index] += 1
    
    # increment/decrement any index equal to 0/N, where N=len(fixed_times)
    # - this is needed for values in t that are less/greater than the min/max fixed
    #   time, and for NaNs in t which get assigned an index of N by searchsorted
    # - for values in t less/greater than the min/max fixed time, this
    #   increment/decrement in index will prepare that value to be normalized
    #   using extrapolation based on the first/last interval in fixed_times
    indexes = np.clip(indexes, 1, len(fixed_times)-1)
    
    # compute the normalized values of t using linear interpolation between the
    # bordering fixed times
    # - normalization of values in t that are less than the min fixed time is
    #   accomplished by extrapolation, as the fraction (after-t)/(after-before)
    #   will be greater than 1 since the t value is in fact earlier than the
    #   "before" fixed time
    # - normalization of values in t that are greater than the max fixed time is
    #   accomplished by extrapolation, as the fraction (after-t)/(after-before)
    #   will be less than 0 since the t value is in fact later than the "after"
    #   fixed time
    before = fixed_times[indexes-1]
    after  = fixed_times[indexes]
    t_normalized = indexes - (after-t)/(after-before)
    
    if not extrapolate:
        # extrapolation is already done, so here we undo it by setting to NaN
        # the normalized value for any t that is less/greater than the min/max
        # fixed time
        with np.errstate(invalid='ignore'): # don't warn about invalid comparisons to NaN
            t_normalized[t < np.nanmin(fixed_times)] = np.nan
            t_normalized[t > np.nanmax(fixed_times)] = np.nan
    
    return t_normalized

In [None]:
# must first convert args from quantities to simple ndarrays
# (use .rescale('s').magnitude)
def unnormalize_time(fixed_times, t_normalized, extrapolate=True):
    
    if not isinstance(t_normalized, np.ndarray):
        if type(t_normalized) is list:
            t_normalized = np.array(t_normalized)
        else:
            t_normalized = np.array([t_normalized])
    
    assert not isinstance(fixed_times, pq.Quantity), f'fixed_times should not be a quantity (use .magnitude)'
    assert not isinstance(t_normalized, pq.Quantity), f't_normalized should not be a quantity (use .magnitude)'
    assert np.all(np.diff(fixed_times[~np.isnan(fixed_times)])>=0), f'fixed_times must be sorted: {fixed_times}'
    
    # get the index of the fixed time that comes after each t
    indexes = np.ceil(t_normalized)
    
    # clip the "after" indexes so that the first or last interval
    # in fixed_times will be used for extrapolation
    indexes = np.clip(indexes, 1, len(fixed_times)-1)
    
    # get the fixed times before and after each t
    before = np.array([fixed_times[int(i)-1] if not np.isnan(i) else np.nan for i in indexes])
    after  = np.array([fixed_times[int(i)]   if not np.isnan(i) else np.nan for i in indexes])
    
    # compute the real values of t
    t = after + (t_normalized-indexes)*(after-before)
    
    if not extrapolate:
        # extrapolation is already done, so here we undo it by setting to NaN
        # the value for any t that is less/greater than the min/max fixed time
        with np.errstate(invalid='ignore'): # don't warn about invalid comparisons to NaN
            t[t < np.nanmin(fixed_times)] = np.nan
            t[t > np.nanmax(fixed_times)] = np.nan
    
    return t

In [None]:
# these arrays are used as interp_times for resample_sig_in_normalized_time
# depending on the segmentation scheme
# - linspace ensures samples will be taken at regular intervals in normalized time
interp_resolution = 1000
video_seg_interp_times = np.linspace(0, 9, interp_resolution)
force_seg_interp_times = np.linspace(0, 9, interp_resolution)

def resample_sig_in_normalized_time(fixed_times, sig, interp_times):
    
    # get normalized times and signal values
    # - normalize_time will put into times_normalized a np.nan wherever a time was not normalizable,
    #   i.e., wherever the time occurred adjacent to a missing fixed time (np.nan in fixed_times)
    times = sig.times.rescale('s').magnitude
    times_normalized = normalize_time(fixed_times.magnitude, times)
    y = sig.magnitude.flatten()

    # drop times that could not be normalized due to missing fixed times
    # - interp1d will erroneously interpolate across the gaps created by this deletion,
    #   but we will then replace those interpolated values with np.nan
    where_normalizable = np.where(~np.isnan(times_normalized))[0]
    times_normalized = times_normalized[where_normalizable]
    y = y[where_normalizable]

    # resample evenly in normalized time
    # - with bounds_error=False and fill_value=np.nan, interp1d will put np.nan in places outside
    #   the min and max of times_normalized
    # - interp1d will interpolate across the regions we deleted, but we want np.nan there,
    #   so we will insert them manually in the next step
    interp_func = sp.interpolate.interp1d(times_normalized, y, kind='linear', bounds_error=False, fill_value=np.nan)
    y_interp = interp_func(interp_times)

    # replace erroneously interpolated values with np.nan
    # - now the points in y_interp which would correspond to times that could not be normalized
    #   have been set to np.nan
    where_not_normalizable = np.where(np.isnan(unnormalize_time(fixed_times.magnitude, interp_times)))[0]
    y_interp[where_not_normalizable] = np.nan

    return y_interp

In [None]:
def print_column_analysis(column):
    unit = column.split('(')[-1].split(')')[0]
    mean = df_all[column].mean()
    std = df_all[column].std()
    cv = std/abs(mean)
    print(f'Overall mean +/- std: {mean:.2f} +/- {std:.2f} {unit}')
    print(f'Overall coefficient of variation CV: {cv:.2f}')

In [None]:
def differences_test(x, y, x_label='GROUP 1', y_label='GROUP 2', measure_label='MEASURE', units='UNITS', alpha=0.05):

    # descriptive statistics
    x_mean = np.mean(x)
    y_mean = np.mean(y)
    x_std = np.std(x, ddof=1)
    y_std = np.std(y, ddof=1)
    x_n = len(x)
    y_n = len(y)
    assert x_n == y_n, 'expected x and y to be paired but they have unequal length'
    print(f'{x_label}: (M = {x_mean:g}, SD = {x_std:g}, N = {x_n:g})')
    print(f'{y_label}: (M = {y_mean:g}, SD = {y_std:g}, N = {y_n:g})')
    print()

    # Shapiro-Wilk test for normality of differences
    # - equivalent R test: shapiro.test(x-y)
    shapiro_W, shapiro_p = sp.stats.shapiro(x-y)
    shapiro_signif = '*' if shapiro_p < alpha else '(n.s.)'
    print(f'H0: Differences have normal distribution, W = {shapiro_W:g},\tp = {shapiro_p:g} {shapiro_signif}')

    if shapiro_p >= 0.05:
        print('- Because the differences can be assumed to be normal, a paired t-test will be used')

        # paired one-tailed T-test for an increase in means
        # - equivalent R test : t.test(x, y, paired=TRUE, alternative="greater")
        ttest_t, ttest_p = sp.stats.ttest_rel(x, y)
        ttest_p = ttest_p/2 # manually divide by 2 because sp.stats.ttest_rel does not have a one-tailed setting
        ttest_signif = '*' if ttest_p < alpha and x_mean > y_mean else '(n.s.)'
        print(f'H0: Difference in means is not positive,  t = {ttest_t:g},\tp = {ttest_p:g} {ttest_signif}')

        # Cohen's d for effect size
        # - equivalent R function: library(effsize); cohen.d(x, y, paired=FALSE)
        # - I don't understand the paper cited for the R function's paired=TRUE case,
        #   so I haven't tried to implement it here. The R function's paired=TRUE
        #   result is a little smaller. It's unclear to me whether that result can
        #   even still be called "Cohen's d", as that name does not appear in the paper.
        pooled_std = np.sqrt(((len(x)-1)*np.var(x, ddof=1) + (len(y)-1)*np.var(y, ddof=1))/(len(x)+len(y)-2))
        cohen_d = (np.mean(x)-np.mean(y))/pooled_std
        print()
        print(f'Effect size: Cohen\'s d = {cohen_d:g}')
        
        if ttest_p < alpha and x_mean > y_mean:
            print()
            print(f'"A paired-samples one-tailed t-test indicated that {measure_label} for the ' \
                  f'[{x_n}] animals were significantly [greater/longer] for ' \
                  f'{x_label} (M = {x_mean:.2f} {units}, SD = {x_std:.2f} {units}) than for ' \
                  f'{y_label} (M = {y_mean:.2f} {units}, SD = {y_std:.2f} {units}), ' \
                  f't = {ttest_t:.3f}, df = {x_n-1}, p = {ttest_p:.3f}, Cohen\'s d = {cohen_d:.2f}."')
        else:
            print()
            print(f'"A paired-samples one-tailed t-test indicated no significant increase in {measure_label} in the [{x_n}] animals between ' \
                  f'{x_label} (M = {x_mean:.2f} {units}, SD = {x_std:.2f} {units}) and ' \
                  f'{y_label} (M = {y_mean:.2f} {units}, SD = {y_std:.2f} {units}), ' \
                  f't = {ttest_t:.3f}, df = {x_n-1}, p = {ttest_p:.3f}, Cohen\'s d = {cohen_d:.2f}."')
        
        return ttest_signif

    else:
        print('- Because the differences cannot be assumed to be normal, a Wilcoxon signed rank test will be used')

        # Wilcoxon signed rank test for non-zero difference in locations (medians?)
        # - equivalent R test: wilcox.test(x, y, paired=TRUE, exact=FALSE)
        # - a warning is raised for small sample sizes (N < 10) becauses SciPy's implementation
        #   always calculates the p-value using a normal approximation of the test statistic
        #   distribution, which is inaccurate for small sample size
        # - use R to get exact p-value, with wilcox.test(x, y, paired=TRUE, exact=TRUE)
        # - upcoming implementation of exact distribution: https://github.com/scipy/scipy/pull/10796
        wilcoxon_W, wilcoxon_p = sp.stats.wilcoxon(x, y, correction=True)
        wilcoxon_signif = '*' if wilcoxon_p < alpha else '(n.s.)'
        print(f'H0: Difference in medians is zero,        W = {wilcoxon_W:g},\tp = {wilcoxon_p:g} {wilcoxon_signif}')
        
        return wilcoxon_signif

## Import and Process the Data

In [None]:
# skip expensive calculations by loading the results from a file?
# - with load_from_files=False, perform data processing from scratch,
#   which takes several minutes
# - with load_from_files=True, load the final results (dataframes)
#   pickled last time the calculations were performed
load_from_files = True

In [None]:
pickled_vars = ['df_all', 'df_exemplary_bout', 'df_exemplary_swallow']
if load_from_files:
    
    for var in pickled_vars:
        exec(f'{var} = pd.read_pickle("{var}.zip")')

    print('calculation results loaded from files')
    
else:
    
    start = datetime.datetime.now()

    # use Neo RawIO lazy loading to load much faster and using less memory
    # - with lazy=True, filtering parameters specified in metadata are ignored
    #     - note: filters are replaced below and applied manually anyway
    # - with lazy=True, loading via time_slice requires neo>=0.8.0
    lazy = True

    # load the metadata containing file paths
    metadata = neurotic.MetadataSelector(file='../../data/metadata.yml')

    # filter epochs for each bout and perform calculations
    df_list = []
    last_data_set_name = None
    pbar = tqdm(total=len(feeding_bouts), unit='feeding bout')
    for (animal, food, bout_index), (data_set_name, time_window) in feeding_bouts.items():

        channel_names = channel_names_by_animal[animal]
        epoch_types = epoch_types_by_food[food]
        burst_thresholds = burst_thresholds_by_animal[animal]

        ###
        ### LOAD DATASET
        ###

        metadata.select(data_set_name)

        if data_set_name is last_data_set_name:
            # skip reloading the data if it's already in memory
            pass
        else:

            # ensure that the right filters are used
            metadata['filters'] = sig_filters_by_animal[animal]

            blk = neurotic.load_dataset(metadata, lazy=lazy)

            if lazy:
                # manually perform filters
                blk = apply_filters(blk, metadata)

        last_data_set_name = data_set_name



        ###
        ### LOCATE BEHAVIOR EPOCHS AND SUBEPOCHS
        ###

        # construct a query for locating behaviors
        behavior_query = f'(Type in {epoch_types}) & ({time_window[0]} <= Start) & (End <= {time_window[1]})'

        # construct queries for locating epochs associated with each behavior
        # - each query should match at most one epoch
        # - dictionary keys are used as prefixes for the names of new columns
        subepoch_queries = {}

        subepoch_queries['B38 activity']            = (f'(Type == "B38 activity") & ' \
                                                       f'(@behavior_start-3 <= End) & (End <= @behavior_start+4)',
                                                       'last') # use last if there are multiple matches
                                                      # must end within a few seconds of behavior start (3 before or 4 after)

        subepoch_queries['I2 protraction activity'] = f'(Type == "I2 protraction activity") & ' \
                                                      f'(@behavior_start-1 <= Start) & (End <= @behavior_end)'
                                                      # must start no earlier than 1 second before behavior and end within it

        subepoch_queries['B8 activity']             = f'(Type == "B8 activity") & ' \
                                                      f'(@behavior_start <= Start) & (End <= @behavior_end)'
                                                      # must be fully contained within behavior

        subepoch_queries['B3/6/9/10 activity']      = f'(Type == "B3/6/9/10 activity") & ' \
                                                      f'(@behavior_start <= Start) & (End <= @behavior_end)'
                                                      # must be fully contained within behavior

        subepoch_queries['B4/B5 activity']          = (f'(Type == "B4/B5 activity") & ' \
                                                       f'(@behavior_start <= Start) & (Start <= @behavior_end)',
                                                       'first') # use first if there are multiple matches
                                                      # must start within behavior
        
        subepoch_queries['Force shoulder end']      = f'(Type == "Force shoulder end") & ' \
                                                      f'(@behavior_start-3 <= End) & (End <= @behavior_start+3)'
                                                      # must end within 3 seconds of behavior start
        
        subepoch_queries['Force rise start']        = f'(Type == "Force rise start") & ' \
                                                      f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                      # must start within behavior

        subepoch_queries['Force plateau start']     = f'(Type == "Force plateau start") & ' \
                                                      f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                      # must start within behavior

        subepoch_queries['Force plateau end']       = f'(Type == "Force plateau end") & ' \
                                                      f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)'
                                                      # must start within 2 seconds of behavior end

        subepoch_queries['Force drop end']          = f'(Type == "Force drop end") & ' \
                                                      f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)'
                                                      # must start within 2 seconds of behavior end
        
        subepoch_queries['Inward movement']         = f'(Type == "Inward movement") & ' \
                                                      f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                      # must start within behavior

        # construct a table in which each row is a behavior and subepoch data
        # is added as columns, e.g. df['B38 activity start (s)']
        df = BehaviorsDataFrame(blk.segments[0].epochs, behavior_query, subepoch_queries)

        # renumber behaviors assuming all behaviors are from a single contiguous sequence
        df = df.sort_values('Start (s)').reset_index(drop=True).rename_axis('Behavior_index')



        ###
        ### START CALCULATIONS
        ###

        # some columns must have type 'object', which
        # can be accomplished by initializing with None or np.nan
        df['Video segmentation times (s)'] = None
        df['Force segmentation times (s)'] = None
        df['Force, video segmented interpolation (mN)'] = None
        df['Force, force segmented interpolation (mN)'] = None
        for unit in units:
            df[f'{unit} spike train'] = None
            df[f'{unit} firing rate (Hz)'] = None
            df[f'{unit} firing rate, video segmented interpolation (Hz)'] = None
            df[f'{unit} firing rate, force segmented interpolation (Hz)'] = None
            df[f'{unit} all bursts'] = None

            # while we're at it, initialize some other things that might otherwise never be given values
#             df[f'{unit} first burst start (s)'] = np.nan
#             df[f'{unit} first burst end (s)'] = np.nan
#             df[f'{unit} first burst duration (s)'] = 0
#             df[f'{unit} first burst spike count'] = 0
#             df[f'{unit} first burst mean frequency (Hz)'] = np.nan
#             df[f'{unit} last burst start (s)'] = np.nan
#             df[f'{unit} last burst end (s)'] = np.nan
#             df[f'{unit} last burst duration (s)'] = 0
#             df[f'{unit} last burst spike count'] = 0
#             df[f'{unit} last burst mean frequency (Hz)'] = np.nan
            df[f'{unit} burst start (s)'] = np.nan
            df[f'{unit} burst end (s)'] = np.nan
            df[f'{unit} burst duration (s)'] = 0
            df[f'{unit} burst spike count'] = 0
            df[f'{unit} burst mean frequency (Hz)'] = 0
            df[f'{unit} burst start (video seg normalized)'] = np.nan
            df[f'{unit} burst end (video seg normalized)'] = np.nan
            df[f'{unit} burst start (force seg normalized)'] = np.nan
            df[f'{unit} burst end (force seg normalized)'] = np.nan
        df['Inward movement start (video seg normalized)'] = np.nan
        df['Inward movement end (video seg normalized)'] = np.nan
        df['Inward movement start (force seg normalized)'] = np.nan
        df['Inward movement end (force seg normalized)'] = np.nan



        # get smoothed force for entire time window
        channel = 'Force'
        sig = get_sig(blk, channel)
        if lazy:
            sig = sig.time_slice(None, None)
        sig = elephant.signal_processing.butter(sig, lowpass_freq = 5*pq.Hz)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale('mN')
        force_smoothed_sig = sig



        df['End to next start (s)'] = np.nan
        df['Start to next start (s)'] = np.nan
        previous_i = None

        # iterate over all swallows
        for j, i in enumerate(df.index):

            behavior_start = df.loc[i, 'Start (s)']*pq.s
            behavior_end = df.loc[i, 'End (s)']*pq.s
            inward_movement_start = df.loc[i, 'Inward movement start (s)']*pq.s
            inward_movement_end = df.loc[i, 'Inward movement end (s)']*pq.s

            # calculate interbehavior intervals assuming all behaviors are from a single contiguous sequence
            if previous_i is not None:
                df.loc[previous_i, 'End to next start (s)']   = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'End (s)']
                df.loc[previous_i, 'Start to next start (s)'] = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'Start (s)']
                df.loc[previous_i, 'Inward movement start to next inward movement start (s)'] = \
                    df.loc[i, 'Inward movement start (s)'] - df.loc[previous_i, 'Inward movement start (s)']
                df.loc[previous_i, 'Inward movement end to next inward movement start (s)'] = \
                    df.loc[i, 'Inward movement start (s)'] - df.loc[previous_i, 'Inward movement end (s)']
            previous_i = i



            ###
            ### VIDEO SEGMENTATION
            ###
            
            # inward movement start and end are required
            video_is_segmented = np.all(np.isfinite(np.array([
                inward_movement_start, inward_movement_end])))
            
            if video_is_segmented:
                inward_movement_duration = df.loc[i, 'Inward movement duration (s)']*pq.s
                
                # get the list of fixed times for normalization
                video_segmentation_times = df.at[i, 'Video segmentation times (s)'] = np.array([
                    inward_movement_start-inward_movement_duration*4,
                    inward_movement_start-inward_movement_duration*3,
                    inward_movement_start-inward_movement_duration*2,
                    inward_movement_start-inward_movement_duration*1,
                    inward_movement_start,
                    inward_movement_end,
                    inward_movement_end+inward_movement_duration*1,
                    inward_movement_end+inward_movement_duration*2,
                    inward_movement_end+inward_movement_duration*3,
                    inward_movement_end+inward_movement_duration*4,
                ])*pq.s # 'at', not 'loc', is important for inserting list into cell

            else: # video is not segmented
                video_segmentation_times = df.at[i, 'Video segmentation times (s)'] = np.array([
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                ])*pq.s # 'at', not 'loc', is important for inserting list into cell
            
            
            
            ###
            ### FORCE SEGMENTATION
            ###

            force_shoulder_end  = df.loc[i, 'Force shoulder end start (s)']*pq.s  # start of "Force shoulder end" epoch
            force_rise_start    = df.loc[i, 'Force rise start start (s)']*pq.s    # start of "Force rise start" epoch
            force_plateau_start = df.loc[i, 'Force plateau start start (s)']*pq.s # start of "Force plateau start" epoch
            force_plateau_end   = df.loc[i, 'Force plateau end start (s)']*pq.s   # start of "Force plateau end" epoch
            force_drop_end      = df.loc[i, 'Force drop end start (s)']*pq.s      # start of "Force drop end" epoch

            # force rise start, plateau start and end, and drop end are required
            force_is_segmented = np.all(np.isfinite(np.array([
                force_rise_start, force_plateau_start, force_plateau_end, force_drop_end])))

            if force_is_segmented:
                # get some times for the previous and next swallow
                epochs_force_shoulder_end  = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force shoulder end'), None)
                epochs_force_rise_start    = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force rise start'), None)
                epochs_force_plateau_start = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force plateau start'), None)
                epochs_force_plateau_end   = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force plateau end'), None)
                epochs_force_drop_end      = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force drop end'), None)
                assert epochs_force_shoulder_end  is not None, 'failed to find "Force shoulder end" epochs'
                assert epochs_force_rise_start    is not None, 'failed to find "Force rise start" epochs'
                assert epochs_force_plateau_start is not None, 'failed to find "Force plateau start" epochs'
                assert epochs_force_plateau_end   is not None, 'failed to find "Force plateau end" epochs'
                assert epochs_force_drop_end      is not None, 'failed to find "Force drop end" epochs'

                try:
                    prev_force_plateau_start = df.loc[i, 'Previous force plateau start (s)'] = epochs_force_plateau_start.time_slice(None, force_rise_start)[-1]
                    assert force_rise_start-prev_force_plateau_start < 16*pq.s, f'for swallow {i}, previous force plateau start is too far away'
                except IndexError:
                    prev_force_plateau_start = df.loc[i, 'Previous force plateau start (s)'] = np.nan

                try:
                    prev_force_plateau_end = df.loc[i, 'Previous force plateau end (s)'] = epochs_force_plateau_end.time_slice(None, force_rise_start)[-1]
                    assert force_rise_start-prev_force_plateau_end < 12*pq.s, f'for swallow {i}, previous force plateau end is too far away'
                except IndexError:
                    prev_force_plateau_end = df.loc[i, 'Previous force plateau end (s)'] = np.nan

                try:
                    prev_force_drop_end = df.loc[i, 'Previous force drop end (s)'] = epochs_force_drop_end.time_slice(None, force_rise_start)[-1]
                    assert force_rise_start-prev_force_drop_end < 12*pq.s, f'for swallow {i}, previous force drop end is too far away'
                except IndexError:
                    prev_force_drop_end = df.loc[i, 'Previous force drop end (s)'] = np.nan

                try:
                    next_force_rise_start = df.loc[i, 'Next force rise start (s)'] = epochs_force_rise_start.time_slice(force_drop_end, None)[0]
                    assert next_force_rise_start-force_drop_end < 12*pq.s, f'for swallow {i}, next force rise start is too far away'
                except IndexError:
                    next_force_rise_start = df.loc[i, 'Next force rise start (s)'] = np.nan

                try:
                    next_force_shoulder_end = df.loc[i, 'Next force shoulder end (s)'] = epochs_force_shoulder_end.time_slice(force_drop_end, None)[0]
                    if next_force_shoulder_end > next_force_rise_start:
                        # next swallow did not have a shoulder and we instead grabbed a later shoulder
                        next_force_shoulder_end = df.loc[i, 'Next force shoulder end (s)'] = np.nan
                except IndexError:
                    next_force_shoulder_end = df.loc[i, 'Next force shoulder end (s)'] = np.nan

                # get the list of fixed times for normalization
                force_segmentation_times = df.at[i, 'Force segmentation times (s)'] = np.array([
                    prev_force_plateau_start,
                    prev_force_plateau_end,
                    prev_force_drop_end,
                    force_shoulder_end,
                    force_rise_start,
                    force_plateau_start,
                    force_plateau_end,
                    force_drop_end,
                    next_force_shoulder_end,
                    next_force_rise_start,
                ])*pq.s # 'at', not 'loc', is important for inserting list into cell

            else: # force is not segmented
                force_segmentation_times = df.at[i, 'Force segmentation times (s)'] = np.array([
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                ])*pq.s # 'at', not 'loc', is important for inserting list into cell



            ###
            ### BEHAVIORAL MARKERS
            ###
            
            if video_is_segmented:
                df.loc[i, 'Inward movement start (video seg normalized)'] = \
                    normalize_time(video_segmentation_times.magnitude, float(inward_movement_start))
                df.loc[i, 'Inward movement end (video seg normalized)'] = \
                    normalize_time(video_segmentation_times.magnitude, float(inward_movement_end))
            
            if force_is_segmented:
                df.loc[i, 'Inward movement start (force seg normalized)'] = \
                    normalize_time(force_segmentation_times.magnitude, float(inward_movement_start))
                df.loc[i, 'Inward movement end (force seg normalized)'] = \
                    normalize_time(force_segmentation_times.magnitude, float(inward_movement_end))



            ###
            ### FORCE QUANTIFICATION
            ###

            if force_is_segmented:

                # get smoothed force for whole behavior for remaining force calculations
                sig = force_smoothed_sig
                if np.isfinite(force_shoulder_end):
                    sig = sig.time_slice(force_shoulder_end - 0.01*pq.s, force_drop_end + 0.01*pq.s)
                else:
                    sig = sig.time_slice(force_rise_start - 1*pq.s, force_drop_end + 0.01*pq.s)
                sig = sig.rescale('mN')

                # find force peak, baseline, and the increase
                force_min_time = df.loc[i, 'Force minimum time (s)'] = elephant.spike_train_generation.peak_detection(sig, 999*pq.mN, sign='below')[0]
                force_min = df.loc[i, 'Force minimum (mN)'] = sig[sig.time_index(force_min_time)][0]
                force_peak_time = df.loc[i, 'Force peak time (s)'] = elephant.spike_train_generation.peak_detection(sig, 0*pq.mN)[0]
                force_peak = df.loc[i, 'Force peak (mN)'] = sig[sig.time_index(force_peak_time)][0]
                force_baseline = df.loc[i, 'Force baseline (mN)'] = sig[sig.time_index(force_rise_start)][0]
                force_increase = df.loc[i, 'Force increase (mN)'] = force_peak-force_baseline

                # find force plateau, drop, and shoulder values
                force_plateau_start_value = df.loc[i, 'Force plateau start value (mN)'] = sig[sig.time_index(force_plateau_start)][0]
                force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)'] = sig[sig.time_index(force_plateau_end)][0]
                force_drop_end_value = df.loc[i, 'Force drop end value (mN)'] = sig[sig.time_index(force_drop_end)][0]
                if np.isfinite(force_shoulder_end):
                    force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)'] = sig[sig.time_index(force_shoulder_end)][0]
                else:
                    force_shoulder_end_value = np.nan

                # find force rise and plateau durations
                force_rise_duration = df.loc[i, 'Force rise duration (s)'] = force_plateau_start - force_rise_start
                force_plateau_duration = df.loc[i, 'Force plateau duration (s)'] = force_plateau_end - force_plateau_start
                force_rise_plateau_duration = df.loc[i, 'Force rise and plateau duration (s)'] = force_plateau_end - force_rise_start

                # find average slope during rising phase
                force_rise_increase = df.loc[i, 'Force rise increase (mN)'] = force_plateau_start_value - force_baseline
                force_slope = df.loc[i, 'Force slope (mN/s)'] = (force_rise_increase/force_rise_duration).rescale('mN/s')



            ###
            ### FORCE NORMALIZATION
            ###

            if video_is_segmented:
                t_start = video_segmentation_times[0]-0.001*pq.s
                t_stop = video_segmentation_times[-1]+0.001*pq.s
                
                channel = 'Force'
                sig = get_sig(blk, channel)
                sig = sig.time_slice(t_start, t_stop)
                sig = sig.rescale(channel_units[channel_names.index(channel)])

                force_video_seg_interp = df.at[i, 'Force, video segmented interpolation (mN)'] = \
                    resample_sig_in_normalized_time(video_segmentation_times, sig, video_seg_interp_times) # 'at', not 'loc', is important for inserting list into cell
            
            if force_is_segmented:
                t_start = force_segmentation_times[0]-0.001*pq.s
                t_stop = force_segmentation_times[-1]+0.001*pq.s
                
                channel = 'Force'
                sig = get_sig(blk, channel)
                sig = sig.time_slice(t_start, t_stop)
                sig = sig.rescale(channel_units[channel_names.index(channel)])

                force_force_seg_interp = df.at[i, 'Force, force segmented interpolation (mN)'] = \
                    resample_sig_in_normalized_time(force_segmentation_times, sig, force_seg_interp_times) # 'at', not 'loc', is important for inserting list into cell



            ###
            ### FIND SPIKE TRAINS
            ###

            if lazy:
                if metadata['amplitude_discriminators'] is not None:
                    for discriminator in metadata['amplitude_discriminators']:
                        sig = get_sig(blk, discriminator['channel'])
                        if sig is not None:
                            sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
                            st = _detect_spikes(sig, discriminator, blk.segments[0].epochs)
                            st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                            st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                            st = st.time_slice(st_epoch_start, st_epoch_end)
                            df.at[i, f'{st.name} spike train'] = st # 'at', not 'loc', is important for inserting list into cell
            else:
                for spiketrain in blk.segments[0].spiketrains:
                    discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == spiketrain.name), None)
                    if discriminator is None:
                        raise Exception(f'For data set "{data_set_name}", discriminator "{spiketrain.name}" could not be found')
                    st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                    st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                    if np.isfinite(st_epoch_start) and np.isfinite(st_epoch_end):
                        st = spiketrain.time_slice(st_epoch_start, st_epoch_end)
                    else:
                        # this unit's discriminator epoch was not located for this swallow
                        st = None
                    df.at[i, f'{spiketrain.name} spike train'] = st # 'at', not 'loc', is important for inserting list into cell



            ###
            ### QUANTIFY SPIKE TRAINS AND BURSTS
            ###

            for k, unit in enumerate(units):
                st = df.loc[i, f'{unit} spike train']
                if st is not None:

                    # get the neural channel
                    channel = st.annotations['channels'][0]
                    sig = get_sig(blk, channel)

                    # create a continuous smoothed firing rate representation
                    # by convolving the spike train with a kernel
                    # - choice of t_start and t_stop here ensures firing rates are
                    #   recorded as zero far from the burst and can be resampled later
                    if video_is_segmented and force_is_segmented:
                        t_start = min(video_segmentation_times[0], force_segmentation_times[0])-0.001*pq.s
                        t_stop = max(video_segmentation_times[-1], force_segmentation_times[-1])+0.001*pq.s
                    elif video_is_segmented:
                        t_start = video_segmentation_times[0]-0.001*pq.s
                        t_stop = video_segmentation_times[-1]+0.001*pq.s
                    elif force_is_segmented:
                        t_start = force_segmentation_times[0]-0.001*pq.s
                        t_stop = force_segmentation_times[-1]+0.001*pq.s
                    else:
                        # no segmentation available
                        t_start = behavior_start-5*pq.s
                        t_stop = behavior_end+5*pq.s
                    smoothing_kernel = elephant.kernels.GaussianKernel(0.2*pq.s) # 200 ms standard deviation
                    firing_rate = df.at[i, f'{unit} firing rate (Hz)'] = elephant.statistics.instantaneous_rate(
                        spiketrain=st,
                        t_start=t_start,
                        t_stop=t_stop,
                        sampling_period=sig.sampling_period,
                        kernel=smoothing_kernel,
                    ) # 'at', not 'loc', is important for inserting list into cell

                    # normalization
                    if video_is_segmented:
                        firing_rate_video_seg_interp = df.at[i, f'{unit} firing rate, video segmented interpolation (Hz)'] = \
                            resample_sig_in_normalized_time(video_segmentation_times, firing_rate, video_seg_interp_times) # 'at', not 'loc', is important for inserting list into cell
                    if force_is_segmented:
                        firing_rate_force_seg_interp = df.at[i, f'{unit} firing rate, force segmented interpolation (Hz)'] = \
                            resample_sig_in_normalized_time(force_segmentation_times, firing_rate, force_seg_interp_times) # 'at', not 'loc', is important for inserting list into cell

                    if st.size > 0:

#                         # get the signal for the behavior with 10 seconds cushion before and after (for better baseline estimation)
#                         sig = sig.time_slice(behavior_start - 10*pq.s, behavior_end + 10*pq.s)
#                         sig = sig.rescale(channel_units[channel_names.index(channel)])

                        # find every sequence of spikes that qualifies as a burst
                        bursts = df.at[i, f'{unit} all bursts'] = _find_bursts(st, burst_thresholds[unit][0], burst_thresholds[unit][1]) # 'at', not 'loc', is important for inserting list into cell

                        first_burst_start = np.nan
#                         first_burst_end = np.nan
#                         first_burst_spike_count = 0
#                         first_burst_mean_freq = 0*pq.Hz
#                         last_burst_start = np.nan
                        last_burst_end = np.nan
#                         last_burst_spike_count = 0
#                         last_burst_mean_freq = 0*pq.Hz
                        if len(bursts) > 0:

                            for burst in zip(bursts.times, bursts.durations, bursts.array_annotations['spikes']):
                                if is_good_burst(burst):
                                    time, duration, n_spikes = burst
                                    first_burst_start = time
#                                     first_burst_start, first_burst_end = burst['Start (s)'], burst['End (s)']
#                                     first_burst_duration = first_burst_end-first_burst_start
#                                     df.loc[i, f'{unit} first burst start (s)'] = first_burst_start.rescale('s')
#                                     df.loc[i, f'{unit} first burst end (s)'] = first_burst_end.rescale('s')
#                                     first_burst_duration = df.loc[i, f'{unit} first burst duration (s)'] = first_burst_duration.rescale('s')
#                                     first_burst_spike_count = df.loc[i, f'{unit} first burst spike count'] = st.time_slice(first_burst_start, first_burst_end).size
#                                     first_burst_mean_freq = df.loc[i, f'{unit} first burst mean frequency (Hz)'] = ((first_burst_spike_count-1)/first_burst_duration).rescale('Hz')

#                                     # find burst RAUC and mean voltage
#                                     first_burst_rauc = df.loc[i, f'{unit} first burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=first_burst_start, t_stop=first_burst_end).rescale('uV*s')
#                                     first_burst_mean_rect_voltage = df.loc[i, f'{unit} first burst mean rectified voltage (μV)'] = first_burst_rauc/first_burst_duration

                                    break # quit after finding first good burst

                            for burst in zip(reversed(bursts.times), reversed(bursts.durations), reversed(bursts.array_annotations['spikes'])):
                                if is_good_burst(burst):
                                    time, duration, n_spikes = burst
                                    last_burst_end = time + duration
#                                     last_burst_start, last_burst_end = burst['Start (s)'], burst['End (s)']
#                                     last_burst_duration = last_burst_end-last_burst_start
#                                     df.loc[i, f'{unit} last burst start (s)'] = last_burst_start.rescale('s')
#                                     df.loc[i, f'{unit} last burst end (s)'] = last_burst_end.rescale('s')
#                                     last_burst_duration = df.loc[i, f'{unit} last burst duration (s)'] = last_burst_duration.rescale('s')
#                                     last_burst_spike_count = df.loc[i, f'{unit} last burst spike count'] = st.time_slice(last_burst_start, last_burst_end).size
#                                     last_burst_mean_freq = df.loc[i, f'{unit} last burst mean frequency (Hz)'] = ((last_burst_spike_count-1)/last_burst_duration).rescale('Hz')

#                                     # find burst RAUC and mean voltage
#                                     last_burst_rauc = df.loc[i, f'{unit} last burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=last_burst_start, t_stop=last_burst_end).rescale('uV*s')
#                                     last_burst_mean_rect_voltage = df.loc[i, f'{unit} last burst mean rectified voltage (μV)'] = last_burst_rauc/last_burst_duration
                                    
                                    break # quit after finding first (actually, last) good burst
                        
                        # merge the first and last good bursts and anything in between
                        if np.isfinite(first_burst_start) and np.isfinite(last_burst_end):
                            st_burst = st.time_slice(first_burst_start, last_burst_end)
                            burst_start = df.loc[i, f'{unit} burst start (s)'] = first_burst_start
                            burst_end = df.loc[i, f'{unit} burst end (s)'] = last_burst_end
                            burst_duration = df.loc[i, f'{unit} burst duration (s)'] = (burst_end-burst_start)
                            burst_spike_count = df.loc[i, f'{unit} burst spike count'] = st_burst.size
                            if burst_spike_count > 1:
                                burst_mean_freq = df.loc[i, f'{unit} burst mean frequency (Hz)'] = ((burst_spike_count-1)/burst_duration).rescale('Hz')
                            if video_is_segmented:
                                df.loc[i, f'{unit} burst start (video seg normalized)'] = normalize_time(video_segmentation_times.magnitude, float(first_burst_start))
                                df.loc[i, f'{unit} burst end (video seg normalized)']   = normalize_time(video_segmentation_times.magnitude, float(last_burst_end))
                            if force_is_segmented:
                                df.loc[i, f'{unit} burst start (force seg normalized)'] = normalize_time(force_segmentation_times.magnitude, float(first_burst_start))
                                df.loc[i, f'{unit} burst end (force seg normalized)']   = normalize_time(force_segmentation_times.magnitude, float(last_burst_end))

            # B3/B6/B9
            df['B3/B6/B9 burst start (s)'] = np.nan
            df['B3/B6/B9 burst end (s)'] = np.nan
            df['B3/B6/B9 burst duration (s)'] = 0
            df['B3/B6/B9 burst spike count'] = 0
            df['B3/B6/B9 burst mean frequency (Hz)'] = 0
            b3b6b9_burst_start = df['B3/B6/B9 burst start (s)'] = df[['B6/B9 burst start (s)', 'B3 burst start (s)']].min(axis=1)
            b3b6b9_burst_end = df['B3/B6/B9 burst end (s)']   = df[['B6/B9 burst end (s)',   'B3 burst end (s)']]  .max(axis=1)
            b3b6b9_burst_duration = df['B3/B6/B9 burst duration (s)'] = b3b6b9_burst_end - b3b6b9_burst_start
            b3b6b9_burst_spike_count = df['B3/B6/B9 burst spike count'] = df['B6/B9 burst spike count'] + df['B3 burst spike count']
            b3b6b9_burst_mean_freq = df['B3/B6/B9 burst mean frequency (Hz)'] = (b3b6b9_burst_spike_count-1)/b3b6b9_burst_duration


        ###
        ### FINISH
        ###

        # index the table on 4 variables so that this dataframe can later be merged with others
        df['Animal'] = animal
        df['Food'] = food
        df['Bout_index'] = bout_index
        df = df.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])

        df_list += [df]
        
        pbar.update()
        
    pbar.close()

    df_all = pd.concat(df_list, sort=False).sort_index()

    if exemplary_swallow in feeding_bouts:
        # move exemplar to separate dataframe
        df_exemplary_swallow = df_all.loc[exemplary_swallow].copy()
        df_all = df_all.drop(exemplary_swallow)

    if exemplary_bout in feeding_bouts:
        # move exemplar to separate dataframe
        df_exemplary_bout = df_all.loc[exemplary_bout].copy()
        df_all = df_all.drop(exemplary_bout)

    # save dataframes to files so that calculations can be skipped in the future
    for var in pickled_vars:
        exec(f'{var}.to_pickle("{var}.zip")')

    end = datetime.datetime.now()
    print('elapsed time:', end-start)

## 🤪 Sanity Checks

In [None]:
skip_sanity_checks = True

In [None]:
if not skip_sanity_checks:
    start = datetime.datetime.now()

    # reconstruct the original df_all, before exemplary_swallow and
    # exemplary_bout were removed
    df_list2 = [df_all]
    if exemplary_swallow in feeding_bouts:
        df_exemplary_swallow2 = df_exemplary_swallow.copy()
        df_exemplary_swallow2['Animal'] = exemplary_swallow[0]
        df_exemplary_swallow2['Food'] = exemplary_swallow[1]
        df_exemplary_swallow2['Bout_index'] = exemplary_swallow[2]
        df_exemplary_swallow2 = df_exemplary_swallow2.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
        df_list2 += [df_exemplary_swallow2]
    if exemplary_bout in feeding_bouts:
        df_exemplary_bout2 = df_exemplary_bout.copy()
        df_exemplary_bout2['Animal'] = exemplary_bout[0]
        df_exemplary_bout2['Food'] = exemplary_bout[1]
        df_exemplary_bout2['Bout_index'] = exemplary_bout[2]
        df_exemplary_bout2 = df_exemplary_bout2.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
        df_list2 += [df_exemplary_bout2]
    df_all2 = pd.concat(df_list2, sort=False).sort_index()

    # use Neo RawIO lazy loading to load much faster and using less memory
    # - with lazy=True, filtering parameters specified in metadata are ignored
    #     - note: filters are replaced below and applied manually anyway
    # - with lazy=True, loading via time_slice requires neo>=0.8.0
    lazy = True

    # load the metadata containing file paths
    metadata = neurotic.MetadataSelector(file='../../data/metadata.yml')

    last_data_set_name = None
    pbar = tqdm(total=len(feeding_bouts), unit='figure')
    for (animal, food, bout_index), (data_set_name, time_window) in feeding_bouts.items():

        channel_names = channel_names_by_animal[animal]
        epoch_types = epoch_types_by_food[food]
        burst_thresholds = burst_thresholds_by_animal[animal]

        df = df_all2.loc[animal, food, bout_index]



        ###
        ### LOAD DATASET
        ###

        metadata.select(data_set_name)

        if data_set_name is last_data_set_name:
            # skip reloading the data if it's already in memory
            pass
        else:

            # ensure that the right filters are used
            metadata['filters'] = sig_filters_by_animal[animal]

            blk = neurotic.load_dataset(metadata, lazy=lazy)

            if lazy:
                # manually perform filters
                blk = apply_filters(blk, metadata)

        last_data_set_name = data_set_name



        ###
        ### START FIGURE
        ###

#             figsize = (9.5, 10) # dimensions for notebook
#             figsize = (11, 8.5) # dimensions for printing
        figsize = (16, 9) # dimensions for filling wide screens
        fig, axes = plt.subplots(len(channel_names), 1, sharex=True, figsize=figsize)



        ###
        ### PLOT SIGNALS
        ###

        # plot all channels for entire time window
        for i, channel in enumerate(channel_names):
            plt.sca(axes[i])
            sig = get_sig(blk, channel)
            sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
            sig = sig.rescale(channel_units[i])
            plt.plot(sig.times, sig.magnitude, c='0.8', lw=1, zorder=-1)

            if i == 0:
                plt.title(f'({animal}, {food}, {bout_index}): {data_set_name}')

            plt.ylabel(sig.name + ' (' + sig.units.dimensionality.string + ')')
            axes[i].yaxis.set_label_coords(-0.06, 0.5)

            if i < len(channel_names)-1:
                # remove right, top, and bottom plot borders, and remove x-axis
                sns.despine(ax=plt.gca(), bottom=True)
                plt.gca().xaxis.set_visible(False)
            else:
                # remove right and top plot borders, and set x-label
                sns.despine(ax=plt.gca())
                plt.xlabel('Time (s)')

        # plot smoothed force for entire time window
        channel = 'Force'
        sig = get_sig(blk, channel)
        if lazy:
            sig = sig.time_slice(None, None)
        sig = elephant.signal_processing.butter(sig, lowpass_freq = 5*pq.Hz)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale('mN')
        plt.plot(sig.times, sig.magnitude, c='k', lw=1, zorder=0)
        force_smoothed_sig = sig



        # iterate over all swallows
        for j, i in enumerate(df.index):

            behavior_start = df.loc[i, 'Start (s)']*pq.s
            behavior_end   = df.loc[i, 'End (s)']*pq.s
            
            ###
            ### MOVEMENTS
            ###
            
            # plot inward food movement
            inward_movement_start = df.loc[i, 'Inward movement start (s)']*pq.s
            inward_movement_end   = df.loc[i, 'Inward movement end (s)']*pq.s
            if np.isfinite(inward_movement_start):
                channel = 'Force'
                ax = axes[channel_names.index(channel)]
                ax.axvspan(
                    inward_movement_start, inward_movement_end,
                    0.99, 1,
                    facecolor='k', edgecolor=None, lw=0)
            
            

            ###
            ### FORCE SEGMENTATION
            ###

            force_shoulder_end  = df.loc[i, 'Force shoulder end start (s)']*pq.s  # start of "Force shoulder end" epoch
            force_rise_start    = df.loc[i, 'Force rise start start (s)']*pq.s    # start of "Force rise start" epoch
            force_plateau_start = df.loc[i, 'Force plateau start start (s)']*pq.s # start of "Force plateau start" epoch
            force_plateau_end   = df.loc[i, 'Force plateau end start (s)']*pq.s   # start of "Force plateau end" epoch
            force_drop_end      = df.loc[i, 'Force drop end start (s)']*pq.s      # start of "Force drop end" epoch

            # force rise start, plateau start and end, and drop end are required
            force_is_segmented = np.all(np.isfinite(np.array([
                force_rise_start, force_plateau_start, force_plateau_end, force_drop_end])))

            force_segmentation_times = df.at[i, 'Force segmentation times (s)']

            if force_is_segmented:

                force_min_time = df.loc[i, 'Force minimum time (s)']
                force_min = df.loc[i, 'Force minimum (mN)']
                force_peak_time = df.loc[i, 'Force peak time (s)']
                force_peak = df.loc[i, 'Force peak (mN)']
                force_baseline = df.loc[i, 'Force baseline (mN)']

                force_plateau_start_value = df.loc[i, 'Force plateau start value (mN)']
                force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)']
                force_drop_end_value = df.loc[i, 'Force drop end value (mN)']
                force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)']

                # get smoothed force for whole behavior for remaining force calculations
                sig = force_smoothed_sig
                if np.isfinite(force_shoulder_end):
                    sig = sig.time_slice(force_shoulder_end - 0.01*pq.s, force_drop_end + 0.01*pq.s)
                else:
                    sig = sig.time_slice(force_rise_start - 1*pq.s, force_drop_end + 0.01*pq.s)
                sig = sig.rescale('mN')

                # plot force rise in color
                plt.sca(axes[channel_names.index('Force')])
                sig2 = sig.time_slice(force_rise_start, force_plateau_start)
                plt.plot(sig2.times, sig2.magnitude, c=force_colors['rise'], lw=2, zorder=1)

                # plot force plateau in color
                sig2 = sig.time_slice(force_plateau_start, force_plateau_end)
                plt.plot(sig2.times, sig2.magnitude, c=force_colors['plateau'], lw=2, zorder=1)

                # plot force shoulder in color
                if np.isfinite(force_shoulder_end):
                    sig2 = sig.time_slice(force_drop_end, force_shoulder_end)
                    plt.plot(sig2.times, sig2.magnitude, c=force_colors['shoulder'], lw=2, zorder=1)

                # plot force peak, baseline, and plateau values
                plt.plot([force_peak_time],     [force_peak],                marker=CARETDOWN,  markersize=5, color='k')
#                 plt.plot([force_min_time],      [force_min],                 marker=CARETUP,    markersize=5, color='k')
                plt.plot([force_rise_start],    [force_baseline],            marker=CARETUP,    markersize=5, color='k')
                plt.plot([force_plateau_start], [force_plateau_start_value], marker=CARETRIGHT, markersize=5, color='k')
                plt.plot([force_plateau_end],   [force_plateau_end_value],   marker=CARETLEFT,  markersize=5, color='k')

                # plot segmentation boundaries across all subplots
                for (t, y, c) in [
                        (force_shoulder_end,  force_shoulder_end_value,  force_colors['shoulder']),
                        (force_rise_start,    force_baseline,            force_colors['rise']),
                        (force_plateau_start, force_plateau_start_value, force_colors['plateau']),
                        (force_plateau_end,   force_plateau_end_value,   force_colors['plateau']),
                        (force_drop_end,      force_drop_end_value,      force_colors['drop'])]:
                    if np.isfinite(y):
                        axes[-1].add_artist(patches.ConnectionPatch(
                            xyA=(t, y), xyB=(t, 1),
                            coordsA='data', coordsB=axes[0].get_xaxis_transform(),
                            axesA=axes[-1], axesB=axes[0],
                            color=c, lw=1, ls=':', zorder=-2))



            ###
            ### SPIKES AND BURSTS
            ###

            for k, unit in enumerate(units):
                st = df.loc[i, f'{unit} spike train']
                if st is not None and st.size > 0:

                    # get the signal for the behavior with 10 seconds cushion for spikes outside the behavior duration
                    channel = st.annotations['channels'][0]
                    sig = get_sig(blk, channel)
                    sig = sig.time_slice(behavior_start - 10*pq.s, behavior_end + 10*pq.s)
                    sig = sig.rescale(channel_units[channel_names.index(channel)])

                    # plot spikes
                    plt.sca(axes[channel_names.index(channel)])
                    marker = ['.', 'x'][j%2] # alternate markers between behaviors
                    spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
                    plt.scatter(st.times.rescale('s'), spike_amplitudes, marker=marker, c=unit_colors[unit])

                    # plot burst windows
                    discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == st.name), None)
                    if discriminator is None:
                        raise Exception(f'For data set "{data_set_name}", discriminator "{st.name}" could not be found')
                    bottom = pq.Quantity(discriminator['amplitude'][0], discriminator['units']).rescale(sig.units)
                    top = pq.Quantity(discriminator['amplitude'][1], discriminator['units']).rescale(sig.units)
                    height = top-bottom
                    bursts = df.at[i, f'{unit} all bursts']
                    for burst in zip(bursts.times, bursts.durations, bursts.array_annotations['spikes']):
                        time, duration, n_spikes = burst
                        left = time
                        right = time + duration
                        width = right-left
                        linestyle = '-' if is_good_burst(burst) else '--'
                        rect = patches.Rectangle((left, bottom), width, height, ls=linestyle, edgecolor=unit_colors[unit], fill=False)
                        plt.gca().add_patch(rect)

                    # plot markers for edges of bursts
                    burst_start = df.loc[i, f'{unit} burst start (s)']
                    burst_end = df.loc[i, f'{unit} burst end (s)']
                    if top > 0:
                        plt.plot([burst_start], [top],    marker=CARETDOWN, markersize=5, color='k')
                        plt.plot([burst_end],   [top],    marker=CARETDOWN, markersize=5, color='k')
                    else:
                        plt.plot([burst_start], [bottom], marker=CARETUP,   markersize=5, color='k')
                        plt.plot([burst_end],   [bottom], marker=CARETUP,   markersize=5, color='k')



        ###
        ### FINISH FIGURE
        ###

        # optimize plot margins
        plt.subplots_adjust(
            left   = 0.1,
            right  = 0.99,
            top    = 0.96,
            bottom = 0.06,
            hspace = 0.15,
        )

        # export figure
        export_dir2 = os.path.join(export_dir, 'sanity-checks')
        if not os.path.exists(export_dir2):
            os.mkdir(export_dir2)
        plt.gcf().savefig(os.path.join(export_dir2, f'{animal} {food} {bout_index}.png'), dpi=300)
        
        pbar.update()
        
    pbar.close()

    if exemplary_swallow in feeding_bouts:
        # rename output file for exemplar
        animal, food, bout_index = exemplary_swallow
        old_path = os.path.join(export_dir2, f'{animal} {food} {bout_index}.png')
        new_path = os.path.join(export_dir2, 'exemplary_swallow.png')
        if os.path.exists(old_path):
            if os.path.exists(new_path):
                os.remove(new_path)
            os.rename(old_path, new_path)

    if exemplary_bout in feeding_bouts:
        # rename output file for exemplar
        animal, food, bout_index = exemplary_bout
        old_path = os.path.join(export_dir2, f'{animal} {food} {bout_index}.png')
        new_path = os.path.join(export_dir2, 'exemplary_bout.png')
        if os.path.exists(old_path):
            if os.path.exists(new_path):
                os.remove(new_path)
            os.rename(old_path, new_path)


    del df_all2, df_list2

    end = datetime.datetime.now()
    print('elapsed time:', end-start)

In [None]:
if not skip_sanity_checks:
    start = datetime.datetime.now()

    if exemplary_swallow in feeding_bouts:
        pbar = tqdm(total=len(feeding_bouts)-1, unit='figure')
    else:
        pbar = tqdm(total=len(feeding_bouts), unit='figure')
    for (animal, food, bout_index), (data_set_name, time_window) in feeding_bouts.items():

        channel_names = channel_names_by_animal[animal]

        if (animal, food, bout_index) == exemplary_swallow:
            # skip the lone swallow
            continue
        elif (animal, food, bout_index) == exemplary_bout:
            df = df_exemplary_bout
        else:
            df = df_all.loc[animal, food, bout_index]

        # load the data
        metadata = neurotic.MetadataSelector('../../data/metadata.yml')
        metadata.select(data_set_name)
        blk = neurotic.load_dataset(metadata, lazy=True)

        # figsize = (9.5, 10) # dimensions for notebook
        # figsize = (11, 8.5) # dimensions for printing
        # figsize = (16, 9) # dimensions for filling wide screens
        figsize = (20, 9)
        fig, axes = plt.subplots(len(units)+1, 3, sharex='col', figsize=figsize)

        for k, unit in enumerate(units):
            # get the subplot axes handles
            ax_real_time, ax_force_seg, ax_video_seg = axes[k]

            # set y-axis label
            ax_real_time.set_ylabel(f'{unit} (Hz)')
            ax_real_time.yaxis.set_label_coords(-0.06, 0.5)

            # remove right, top, and bottom plot borders, and remove x-axis
            sns.despine(ax=ax_real_time, bottom=True)
            sns.despine(ax=ax_force_seg, bottom=True)
            sns.despine(ax=ax_video_seg, bottom=True)
            ax_real_time.xaxis.set_visible(False)
            ax_force_seg.xaxis.set_visible(False)
            ax_video_seg.xaxis.set_visible(False)

            # set time plot ranges
            ax_real_time.set_xlim([time_window[0]-5, time_window[1]+5])
            ax_force_seg.set_xlim([force_seg_interp_times.min(), force_seg_interp_times.max()])
#             ax_video_seg.set_xlim([video_seg_interp_times.min(), video_seg_interp_times.max()])
            ax_video_seg.set_xlim([2, 6.5])

            # elevate the Axes for units and remove background colors so that
            # each vertical ConnectionPatch drawn later is visible behind it
            ax_real_time.set_zorder(1)
            ax_force_seg.set_zorder(1)
            ax_video_seg.set_zorder(1)
            ax_real_time.set_facecolor('none')
            ax_force_seg.set_facecolor('none')
            ax_video_seg.set_facecolor('none')

        # remove right and top plot borders from bottom panels, and set x-label
        ax_real_time, ax_force_seg, ax_video_seg = axes[-1]
        sns.despine(ax=ax_real_time)
        sns.despine(ax=ax_force_seg)
        sns.despine(ax=ax_video_seg)
        ax_real_time.set_xlabel('Time (s)')
        ax_force_seg.set_xlabel('Time (normalized using force segmentation)')
        ax_video_seg.set_xlabel('Time (normalized using video segmentation)')

        # set time plot ranges
        ax_real_time.set_xlim([time_window[0]-5, time_window[1]+5])
        ax_force_seg.set_xlim([force_seg_interp_times.min(), force_seg_interp_times.max()])
#         ax_video_seg.set_xlim([video_seg_interp_times.min(), video_seg_interp_times.max()])
        ax_video_seg.set_xlim([2, 6.5])

        # plot force in real time
        ax_real_time, ax_force_seg, ax_video_seg = axes[-1]
        channel = 'Force'
        sig = get_sig(blk, channel)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale(channel_units[channel_names.index(channel)])
        ax_real_time.plot(sig.times, sig.magnitude, c='0.8', lw=1)
        ax_real_time.set_ylabel(sig.name + ' (' + sig.units.dimensionality.string + ')')
        ax_real_time.yaxis.set_label_coords(-0.06, 0.5)

        all_force_seg_normalized_times_series = {}
        all_video_seg_normalized_times_series = {}
        for unit in units:
            all_force_seg_normalized_times_series[unit] = np.zeros((0, force_seg_interp_times.size))
            all_video_seg_normalized_times_series[unit] = np.zeros((0, video_seg_interp_times.size))
        all_force_seg_normalized_times_series['Force'] = np.zeros((0, force_seg_interp_times.size))
        all_video_seg_normalized_times_series['Force'] = np.zeros((0, video_seg_interp_times.size))
        
        for j, i in enumerate(df.index):
            
            
            # plot inward food movement
            inward_movement_start = df.loc[i, 'Inward movement start (s)']
            inward_movement_end   = df.loc[i, 'Inward movement end (s)']
            if np.isfinite(inward_movement_start):
                ax_real_time, ax_force_seg, ax_video_seg = axes[-1]
                ax_real_time.axvspan(
                    inward_movement_start, inward_movement_end,
                    0.98, 1,
                    facecolor='k', edgecolor=None, lw=0)
            
            
            for k, unit in enumerate(units):
                ax_real_time, ax_force_seg, ax_video_seg = axes[k]


                # raster plot
                st = df.loc[i, f'{unit} spike train']
                if st is not None:
                    ax_real_time.eventplot(positions=st, lineoffsets=-1, colors=unit_colors[unit])


                # plot the firing rates in real time
                firing_rate = df.loc[i, f'{unit} firing rate (Hz)']
                if firing_rate is not None:
                    ax_real_time.plot(firing_rate.times.rescale('s'), firing_rate, c=unit_colors[unit])


                # plot firing rates in normalized time
                firing_rate_force_seg_interp = df.loc[i, f'{unit} firing rate, force segmented interpolation (Hz)']
                if firing_rate_force_seg_interp is not None:
                    all_force_seg_normalized_times_series[unit] = np.concatenate([all_force_seg_normalized_times_series[unit], firing_rate_force_seg_interp[np.newaxis, :]])
                    ax_force_seg.plot(force_seg_interp_times, firing_rate_force_seg_interp, c=unit_colors[unit])
                firing_rate_video_seg_interp = df.loc[i, f'{unit} firing rate, video segmented interpolation (Hz)']
                if firing_rate_video_seg_interp is not None:
                    all_video_seg_normalized_times_series[unit] = np.concatenate([all_video_seg_normalized_times_series[unit], firing_rate_video_seg_interp[np.newaxis, :]])
                    ax_video_seg.plot(video_seg_interp_times, firing_rate_video_seg_interp, c=unit_colors[unit])


            # plot force in normalized time
            ax_real_time, ax_force_seg, ax_video_seg = axes[-1]
            force_force_seg_interp = df.at[i, 'Force, force segmented interpolation (mN)']
            if force_force_seg_interp is not None:
                all_force_seg_normalized_times_series['Force'] = np.concatenate([all_force_seg_normalized_times_series['Force'], force_force_seg_interp[np.newaxis, :]])
                ax_force_seg.plot(force_seg_interp_times, force_force_seg_interp, c='0.8', lw=1)
            force_video_seg_interp = df.at[i, 'Force, video segmented interpolation (mN)']
            if force_video_seg_interp is not None:
                all_video_seg_normalized_times_series['Force'] = np.concatenate([all_video_seg_normalized_times_series['Force'], force_video_seg_interp[np.newaxis, :]])
                ax_video_seg.plot(video_seg_interp_times, force_video_seg_interp, c='0.8', lw=1)


            # plot force phase boundaries in real time
            force_segmentation_times = df.at[i, 'Force segmentation times (s)'].rescale('s').magnitude
            for m, t in enumerate(force_segmentation_times[3:8]): # 3 = end of shoulder, 8 = end of next shoulder
                if np.isfinite(t):
                    if m == 1: # 1 = start of rise
                        color = force_colors['rise']
                    else:
                        color = '0.75'
                    axes[-1][0].add_artist(patches.ConnectionPatch(
                        xyA=(t, 0), xyB=(t, 1),
                        coordsA=axes[-1][0].get_xaxis_transform(), coordsB=axes[0][0].get_xaxis_transform(),
                        axesA=axes[-1][0], axesB=axes[0][0],
                        color=color, lw=1, ls=':'))


        # plot force phase boundaries in normalized time
        for m in range(len(force_segmentation_times)):
            if m == 4: # 4 = start of rise
                color = force_colors['rise']
            else:
                color = '0.75'
            axes[-1][1].add_artist(patches.ConnectionPatch(
                xyA=(m, 0), xyB=(m, 1),
                coordsA=axes[-1][1].get_xaxis_transform(), coordsB=axes[0][1].get_xaxis_transform(),
                axesA=axes[-1][1], axesB=axes[0][1],
                color=color, lw=1, ls=':'))
        
        # plot video phase boundaries in normalized time
        video_segmentation_times = df.at[i, 'Video segmentation times (s)'].rescale('s').magnitude # grab last swallow's as an example
#         for m in range(len(video_segmentation_times)):
        for m in [2, 3, 4, 5, 6]:
            if m == 4: # 4 = start of inward movement
                color = force_colors['rise']
            else:
                color = '0.75'
            axes[-1][2].add_artist(patches.ConnectionPatch(
                xyA=(m, 0), xyB=(m, 1),
                coordsA=axes[-1][2].get_xaxis_transform(), coordsB=axes[0][2].get_xaxis_transform(),
                axesA=axes[-1][2], axesB=axes[0][2],
                color=color, lw=1, ls=':'))


        # plot firing rate distributions
        for k, unit in enumerate(units):
            ax_real_time, ax_force_seg, ax_video_seg = axes[k]

            firing_rate_median = np.nanmedian(all_force_seg_normalized_times_series[unit], axis=0)
            firing_rate_q1 = np.nanquantile(all_force_seg_normalized_times_series[unit], q=0.25, axis=0)
            firing_rate_q3 = np.nanquantile(all_force_seg_normalized_times_series[unit], q=0.75, axis=0)
            ax_force_seg.plot(force_seg_interp_times, firing_rate_median, c='k', lw=2, zorder=3)
            ax_force_seg.plot(force_seg_interp_times, firing_rate_q1, c='k', lw=2, ls='--')
            ax_force_seg.plot(force_seg_interp_times, firing_rate_q3, c='k', lw=2, ls='--')
            
            firing_rate_median = np.nanmedian(all_video_seg_normalized_times_series[unit], axis=0)
            firing_rate_q1 = np.nanquantile(all_video_seg_normalized_times_series[unit], q=0.25, axis=0)
            firing_rate_q3 = np.nanquantile(all_video_seg_normalized_times_series[unit], q=0.75, axis=0)
            ax_video_seg.plot(video_seg_interp_times, firing_rate_median, c='k', lw=2, zorder=3)
            ax_video_seg.plot(video_seg_interp_times, firing_rate_q1, c='k', lw=2, ls='--')
            ax_video_seg.plot(video_seg_interp_times, firing_rate_q3, c='k', lw=2, ls='--')


        # plot force distribution
        ax_real_time, ax_force_seg, ax_video_seg = axes[-1]

        force_median = np.nanmedian(all_force_seg_normalized_times_series['Force'], axis=0)
        force_q1 = np.nanquantile(all_force_seg_normalized_times_series['Force'], q=0.25, axis=0)
        force_q3 = np.nanquantile(all_force_seg_normalized_times_series['Force'], q=0.75, axis=0)
        ax_force_seg.plot(force_seg_interp_times, force_median, c='k', lw=2, zorder=3)
        ax_force_seg.plot(force_seg_interp_times, force_q1, c='k', lw=2, ls='--')
        ax_force_seg.plot(force_seg_interp_times, force_q3, c='k', lw=2, ls='--')
        
        force_median = np.nanmedian(all_video_seg_normalized_times_series['Force'], axis=0)
        force_q1 = np.nanquantile(all_video_seg_normalized_times_series['Force'], q=0.25, axis=0)
        force_q3 = np.nanquantile(all_video_seg_normalized_times_series['Force'], q=0.75, axis=0)
        ax_video_seg.plot(video_seg_interp_times, force_median, c='k', lw=2, zorder=3)
        ax_video_seg.plot(video_seg_interp_times, force_q1, c='k', lw=2, ls='--')
        ax_video_seg.plot(video_seg_interp_times, force_q3, c='k', lw=2, ls='--')


        plt.suptitle(f'({animal}, {food}, {bout_index}): {data_set_name}')
        plt.tight_layout(rect=(0, 0, 1, 0.97))

        # export figure
        export_dir3 = os.path.join(export_dir, 'sanity-checks-firing-rates')
        if not os.path.exists(export_dir3):
            os.mkdir(export_dir3)
        plt.gcf().savefig(os.path.join(export_dir3, f'{animal} {food} {bout_index}.png'), dpi=300)
        
        pbar.update()
    
    pbar.close()

    if exemplary_bout in feeding_bouts:
        # rename output file for exemplar
        animal, food, bout_index = exemplary_bout
        old_path = os.path.join(export_dir3, f'{animal} {food} {bout_index}.png')
        new_path = os.path.join(export_dir3, 'exemplary_bout.png')
        if os.path.exists(old_path):
            if os.path.exists(new_path):
                os.remove(new_path)
            os.rename(old_path, new_path)

    end = datetime.datetime.now()
    print('elapsed time:', end-start)

## Plotting Functions

In [None]:
def boxplot_with_points(x, y, hue, data, ax=None, show_points=False, describe=True):
    if ax is None:
        ax = plt.gca()
    boxcolor = '0.75' if hue is None else None
    pointcolor = 'k' if hue is None else None
    edgecolor = '0.25'
    linewidth = 0 if hue is None else 1
    size = 4
    
    data = data.dropna(subset=[y])
    
    sns.boxplot(x=x, y=y, hue=hue, data=data, ax=ax, color=boxcolor, whis=999) # whiskers span extrema
    
    if show_points:
        sns.swarmplot(x=x, y=y, hue=hue, data=data, ax=ax, color=pointcolor, linewidth=linewidth, edgecolor=edgecolor, size=size, dodge=True)
        
        if hue is not None:
            # avoid duplicate legend entries
            handles, labels = ax.get_legend_handles_labels()
            n = int(len(labels)/2)
            ax.legend(handles[:n], labels[:n], title=hue)
    
    ax.set_xlabel(None)

    if describe:
        by = [x] if hue is None else [x, hue]
        print(y)
#         print(data.groupby(by)[y].describe())
        print(data.groupby(by)[y].apply(lambda y: {
            'N': f'{y.count()}',
            'Median': y.median(),
#             'Q1': y.quantile(0.25),
#             'Q3': y.quantile(0.75),
        }).unstack())
        print()

In [None]:
default_markers = ['.', '+', 'x', '1', '4', 'o', 'D']
default_colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6']

def scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, trend_separately=False, tooltips=False, padding=0.05, colors=default_colors, markers=default_markers):
    
    for j, (label, query) in enumerate(data_subsets.items()):
        if query is not None:
            df = df_all.query(query)
            ax.scatter(df[xlabel], df[ylabel],
                       label=label, marker=markers[j], c=colors[j], clip_on=False)
            
    all_points = df_all.query(query_union(data_subsets.values()))[[xlabel, ylabel]].dropna()
    
    if len(all_points) > 0:

        xrange = np.ptp(all_points.iloc[:, 0])
        xmin = min(all_points.iloc[:, 0]) - xrange * padding
        xmax = max(all_points.iloc[:, 0]) + xrange * padding
        if xlim is None:
            xlim = [xmin, xmax]
        if xlim[0] is None:
            xlim[0] = xmin
        if xlim[1] is None:
            xlim[1] = xmax

        yrange = np.ptp(all_points.iloc[:, 1])
        ymin = min(all_points.iloc[:, 1]) - yrange * padding
        ymax = max(all_points.iloc[:, 1]) + yrange * padding
        if ylim is None:
            ylim = [ymin, ymax]
        if ylim[0] is None:
            ylim[0] = ymin
        if ylim[1] is None:
            ylim[1] = ymax
    
    if trend_separately:
        for j, (label, query) in enumerate(data_subsets.items()):
            if query is not None:
                df = df_all.query(query)[[xlabel, ylabel]].dropna()
                model = sm.OLS(df.iloc[:,1], sm.add_constant(df.iloc[:,0])).fit()
                model_stats = 'R$^2$ = {:.2f}, p = {:.5f}, n = {}'.format(model.rsquared, model.pvalues[1], len(df))
                print(f'{label}:', model_stats)
                model_x = np.linspace(df.iloc[:,0].min()-np.ptp(df.iloc[:,0])*0.1, df.iloc[:,0].max()+np.ptp(df.iloc[:,0])*0.1, 100)
                model_y = model.params[0] + model.params[1] * model_x
                ax.plot(model_x, model_y, color=colors[j])#, label=model_stats)
                
    if trend:
        model = sm.OLS(all_points.iloc[:,1], sm.add_constant(all_points.iloc[:,0])).fit()
        model_stats = 'R$^2$ = {:.2f}, p = {:.5f}, n = {}'.format(model.rsquared, model.pvalues[1], len(all_points))
        print('All points:', model_stats)
        model_x = np.linspace(all_points.iloc[:,0].min()-np.ptp(all_points.iloc[:,0])*0.1, all_points.iloc[:,0].max()+np.ptp(all_points.iloc[:,0])*0.1, 100)
        model_y = model.params[0] + model.params[1] * model_x
        ax.plot(model_x, model_y, color='k')#, label=model_stats)
    
    if tooltips:
        # create a transparent layer containing all points for detecting mouse events
        sc = ax.scatter(all_points.iloc[:, 0], all_points.iloc[:, 1], label=None,
                        alpha=0, # transparent points
                        s=0.001, # small radius of tooltip activation
                       )
        
        # initialize a hidden empty tooltip
        annot = ax.annotate(
            '', xy=(0, 0),                              # initialize empty at origin
            xytext=(0, 15), textcoords='offset points', # position text above point
            color='k', size=10, ha='center',            # small black centered text
            bbox=dict(fc='w', lw=0, alpha=0.6),         # transparent white background
            arrowprops=dict(shrink=0, headlength=7, headwidth=7, width=0, lw=0, color='k'), # small black arrow
        )
        annot.set_visible(False)
        
        # prepare tooltip contents
        tooltip_text = [' '.join([str(x) for x in index]) for index in list(all_points.index)]
        
        # bind the animation of tooltips to a hover event
        def hover(event):
            if event.inaxes == ax:
                cont, ind = sc.contains(event)
                if cont:
                    point_index = ind['ind'][0]
                    pos = sc.get_offsets()[point_index]
                    annot.xy = pos
                    annot.set_text(tooltip_text[point_index])
                    annot.set_visible(True)
                    ax.figure.canvas.draw_idle()
                else:
                    if annot.get_visible():
                        annot.set_visible(False)
                        ax.figure.canvas.draw_idle()
        ax.figure.canvas.mpl_connect('motion_notify_event', hover)
        
    ax.set_xlabel(xlabel)
    ax.set_xlim(xlim)
    ax.set_ylabel(ylabel)#.replace('rectified ',''))
    ax.set_ylim(ylim)

In [None]:
def prettyplot_with_scalebars(
    blk,
    t_start,
    t_stop,
    plots,
    
    outfile_basename=None, # base name of output files
    export_only=False,     # if True, will not render in notebook
    formats=['pdf', 'svg', 'png'], # extensions of output files
    dpi=300,               # resolution (applicable only for PNG)
    
    figsize=(14, 7),       # figure size in inches
    linewidth=1,           # thickness of lines in points
    layout_settings=None,  # positioning of plot edges and the space between plots
    
    x_scalebar=1*pq.s,     # size of the time scale bar in seconds
    ylabel_padding=10,     # space between trace labels and plots
    scalebar_padding=1,    # space between scale bars and plots
    scalebar_sep=5,        # space between scale bars and scale labels
    barwidth=2,            # thickness of scale bars
):
    
    if export_only:
        plt.ioff()
        
    fig, axes = plt.subplots(len(plots), 1, sharex=True, figsize=figsize, squeeze=False)
    axes = axes.flatten()

    for i, p in enumerate(plots):

        # get the subplot axes handle
        ax = axes[i]

        # select and rescale a channel for the subplot
        sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == p['channel']), None)
        assert sig is not None, f"Signal with name {p['channel']} not found"
        sig = sig.time_slice(t_start, t_stop)
        sig = sig.rescale(p['units'])

        # downsample the data
        sig_downsampled = DownsampleNeoSignal(sig, p.get('decimation_factor', 1))

        # specify the x- and y-data for the subplot
        ax.plot(
            sig_downsampled.times,
            sig_downsampled.as_quantity(),
            linewidth=linewidth,
            color=p.get('color', 'k'),
        )
        
        # hide the box around the subplot
        ax.set_frame_on(False)

        # specify the y-axis label
        ylabel = p.get('ylabel', sig.name)
        if ylabel is not None:
            ax.set_ylabel(ylabel, rotation='horizontal', ha='right', va='center', labelpad=ylabel_padding)

        # specify the plot range
        ax.set_xlim([t_start, t_stop])
        ax.set_ylim(p['ylim'])

        # disable tick marks
        ax.tick_params(
            bottom=False,
            left=False,
            labelbottom=False,
            labelleft=False)

        # add y-axis scale bar
        if p['scalebar'] is not None:
            add_scalebar(ax,
                sizey=p['scalebar'],
                labely=f'{p["scalebar"]} {sig.units.dimensionality.string}',

                loc='center left',
                bbox_to_anchor=(1, 0.5),

                borderpad=scalebar_padding,
                sep=scalebar_sep,
                barwidth=barwidth,
            )
        
    # add time scale bar below final plot
    if x_scalebar is not None:
        add_scalebar(axes[-1],
            sizex=x_scalebar.rescale(sig.times.units).magnitude,
            labelx=f'{x_scalebar.magnitude:g} {x_scalebar.units.dimensionality.string}',

            loc='upper right',
            bbox_to_anchor=(1, 0),

            borderpad=scalebar_padding,
            sep=scalebar_sep,
            barwidth=barwidth,
        )

    # adjust the white space around and between the subplots
    if layout_settings is None:
        fig.tight_layout(h_pad=0, w_pad=0, pad=0)
    else:
        plt.subplots_adjust(**layout_settings)

    if outfile_basename is not None:
        # specify file metadata (applicable only for PDF)
        metadata = dict(
            Subject = 'Data file: '  + blk.file_origin + '\n' +
                      'Start time: ' + str(t_start)    + '\n' +
                      'End time: '   + str(t_stop),
        )

        # write the figure to files
        for ext in formats:
            fig.savefig(f'{outfile_basename}.{ext}', metadata=metadata, dpi=dpi)

    if export_only:
        plt.ion()
    
    return fig, axes

In [None]:
# def plot_unloaded_vs_loaded(df, y, ax, color='0.25', show_statistics=True, show_signif=True, alpha=0.05):
    
#     df = df.reset_index()
#     df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
#     df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'
    
#     # plot points
#     sns.stripplot(
#         x='Food', y=y, hue='Animal',
#         data=df.groupby(['Animal', 'Food'])[y].mean().reset_index(),
#         order=['Unloaded', 'Loaded'],
#         jitter=False,
#         palette=[color], # do not desaturate by animal
#         ax=ax,
#         clip_on=False,
#     )
#     ax.legend_.remove()

#     # plot lines connecting points
#     ax.plot([
#         df.query('Food == "Unloaded"').groupby('Animal')[y].mean(),
#         df.query('Food == "Loaded"').groupby('Animal')[y].mean()
#     ], color='0.75', clip_on=False)

#     ax.set_xlim([-0.25, 1.25])
#     ax.set_xlabel(None)
#     sns.despine(ax=ax)

#     if show_statistics:
#         print(df.groupby(['Animal', 'Food'])[y].apply(lambda x: {'Mean': x.mean(), 'Count': x.count()}).unstack([1, 2])[['Unloaded', 'Loaded']])
#         print()
#         signif = differences_test(
#             x=df.query('Food == "Loaded"').groupby('Animal')[y].mean(),
#             y=df.query('Food == "Unloaded"').groupby('Animal')[y].mean(),
#             x_label='loaded swallows',
#             y_label='unloaded swallows',
#             measure_label='mean [' + ' '.join(y.split()[:-1]) + ']',
#             units=y.split()[-1].strip('()'),
#             alpha=alpha,
#         )
        
#         if show_signif and signif == '*':
#             ax.annotate(
#                 '*',
#                 xy=(0.5, 1), xycoords='axes fraction',
#                 xytext=(0, -20), textcoords='offset points',
#                 ha='center', fontsize='xx-large', color='0.5',
#             )

In [None]:
def plot_unloaded_vs_loaded(df, y, ax, color='0.25', show_all=False, show_statistics=True, show_signif=True, bracket_width=1.0, alpha=0.05):
    
    df = df.reset_index()
    df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
    df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'
    
#     # plot means
#     sns.pointplot(
#         x='Animal', y=y, hue='Food',
#         data=df,
#         hue_order=['Unloaded', 'Loaded'],
#         dodge=0.4,
# #         palette=[color], # do not desaturate by food
#         palette=['0.75'], # do not desaturate by food
#         join=False,
#         estimator=np.mean,
#         ci='sd',
#         scale=0.6,
#         capsize=0.2,
#         errwidth=1,
#         ax=ax,
#     )
    
#     # plot individual swallows points
#     if show_all:
#         sns.stripplot(
#             x='Animal', y=y, hue='Food',
#             data=df,
#             hue_order=['Unloaded', 'Loaded'],
#             jitter=False,
#             dodge=True,
#             linewidth=1,
#             marker='_',
#             color='k',
#             ax=ax,
#             clip_on=False,
#         )

    for food, offset in {'Unloaded': -0.2, 'Loaded': 0.2}.items():
        for i, (animal, values) in enumerate(df[df['Food'] == food].groupby('Animal')[y]):
            
            # plot the mean
            ax.scatter(
                [i+offset], values.mean(),
                color=color,
                s=15,
                zorder=2, clip_on=False,
            )
            
            # plot the standard error of the mean
            lower_sem = values.mean()-values.sem()
            upper_sem = values.mean()+values.sem()
            ax.plot(
                [i+offset]*2, [lower_sem, upper_sem],
                color=color, lw=1,
                zorder=2, clip_on=False,
            )
            ax.plot(
                [i+offset-0.1, i+offset+0.1], [lower_sem, lower_sem],
                color=color, lw=1,
                zorder=2, clip_on=False,
            )
            ax.plot(
                [i+offset-0.1, i+offset+0.1], [upper_sem, upper_sem],
                color=color, lw=1,
                zorder=2, clip_on=False,
            )
            
            if show_all:
                # plot individual swallows points
                ax.scatter(
                    [i+offset]*values.size, values,

                    facecolors='none',
                    edgecolors='k',
                    linewidths=1,
                    s=12,

#                     facecolors='k',
#                     marker='_',
#                     s=18,

                    zorder=3,
                )
            ax.axvline(x=i+offset, c='0.9', lw=0.5, zorder=-2)
            ax.annotate(
                food[0],
                xy=(i+offset, 1), xycoords=ax.get_xaxis_transform(),
                ha='center', va='bottom', fontsize='x-small', color='0.5',
            )
            if food == 'Unloaded': # do it just once
                ax.annotate(
                    i+1,
                    xy=(i, 1), xycoords=ax.get_xaxis_transform(),
                    xytext=(0, 10), textcoords='offset points',
                    ha='center', va='bottom', fontsize='x-small', color='0.5',
                )

    # plot lines connecting points
    for i, (y_unloaded, y_loaded) in enumerate(zip(
                df.query('Food == "Unloaded"').groupby('Animal')[y].mean(),
                df.query('Food == "Loaded"').groupby('Animal')[y].mean())):
        ax.plot(
            [i-0.2, i+0.2],
            [y_unloaded, y_loaded],
            color='0.75',
            zorder=-1,
        )

    ax.tick_params(bottom=False, labelbottom=False)
#     ax.legend_.remove()
    ax.set_xlabel(None)
    sns.despine(ax=ax, bottom=True)

    if show_statistics:
        print(df.groupby(['Animal', 'Food'])[y].apply(lambda x: {'Mean': x.mean(), 'Count': x.count()}).unstack([1, 2])[['Unloaded', 'Loaded']])
        print()
        signif = differences_test(
            x=df.query('Food == "Loaded"').groupby('Animal')[y].mean(),
            y=df.query('Food == "Unloaded"').groupby('Animal')[y].mean(),
            x_label='loaded swallows',
            y_label='unloaded swallows',
            measure_label='mean [' + ' '.join(y.split()[:-1]) + ']',
            units=y.split()[-1].strip('()'),
            alpha=alpha,
        )
        
        if show_signif and signif == '*':
            from matplotlib.patches import ArrowStyle
            ax.annotate(
                '*',
                xy=(0.5, -0.01), xycoords='axes fraction',
                xytext=(0, -22), textcoords='offset points',
                arrowprops=dict(arrowstyle=ArrowStyle.BracketB(widthB=bracket_width, lengthB=0.2, angleB=None), color='0.5'),
                ha='center', fontsize='x-large', color='0.5',
            )

---

# Figures

## [FIGURE 3]

In [None]:
df = df_exemplary_swallow
(data_set_name, time_window) = feeding_bouts[exemplary_swallow]
animal, _, _ = exemplary_swallow
behavior = 0

t_start, t_stop = time_window*pq.s
plots = [
    {'channel': 'BN2', 'ylabel': None, 'units': 'uV', 'ylim': [-60, 38], 'scalebar': 50},
]
plot_names = [p['channel'] for p in plots]
plot_units = [p['units'] for p in plots]

kwargs = dict(
    figsize = (5, 2),
    linewidth = 0.5,
    x_scalebar = None, # 1*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

#################################

epochs = [
    {'name': 'B38 activity',       'label': 'B38 activity',      'color': 'B38',      'time': [t_start.magnitude, 2978.89]},
    {'name': 'B3/6/9/10 activity', 'label': 'B3/B6/B9 activity', 'color': 'B3/B6/B9', 'time': [2979.23, 2983.28]},
]

amplitude_discriminators = [
    {'name': 'B38',   'channel': 'BN2', 'epoch': 'B38 activity',       'amplitude': [  6,  15], 'units': 'uV'}, # actual thresholds used: [7, 20]
    {'name': 'B6/B9', 'channel': 'BN2', 'epoch': 'B3/6/9/10 activity', 'amplitude': [-17,  -9], 'units': 'uV'}, # actual thresholds used: [-25, -9]
    {'name': 'B3',    'channel': 'BN2', 'epoch': 'B3/6/9/10 activity', 'amplitude': [ 20,  30], 'units': 'uV'}, # actual thresholds used: [-60, -25]
]

spike_trains = []
for discriminator in amplitude_discriminators:
    sig = get_sig(blk, discriminator['channel'])
    if sig is not None:
        sig = sig.time_slice(t_start, t_stop)
        st = _detect_spikes(sig, discriminator, blk.segments[0].epochs)
        epoch = next((ep for ep in epochs if ep['name'] == discriminator['epoch']))
        st_epoch_start = epoch['time'][0]*pq.s
        st_epoch_end = epoch['time'][1]*pq.s
        st = st.time_slice(st_epoch_start, st_epoch_end)
        spike_trains.append(st)

### 🐌 Figure 3A

In [None]:
# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)
ax = axes[0]

# # add time scale bar
# add_scalebar(ax,
#     sizex=1, labelx='1 s',
#     loc='lower right', bbox_to_anchor=(1, 0),
#     borderpad=0.5, sep=5, barwidth=2,
# )

#################################

fig.tight_layout(h_pad=0, w_pad=0, pad=0)

fig.savefig(os.path.join(export_dir, 'figure-3A.png'), dpi=600)

### 🐌 Figure 3B

In [None]:
# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)
ax = axes[0]

# # add time scale bar
# add_scalebar(ax,
#     sizex=1, labelx='1 s',
#     loc='lower right', bbox_to_anchor=(1, 0),
#     borderpad=0.5, sep=5, barwidth=2,
# )

#################################

# add epoch bars
for d in epochs:
    label = d['label']
    color = d['color']
    left, right = d['time']
    bottom, top = 0.15, 0.18
    width = right-left
    height = top-bottom
    rect = patches.Rectangle((left, bottom), width, height, linewidth=0, facecolor=unit_colors[color], fill=True, clip_on=False, transform=ax.get_xaxis_transform())
    ax.add_patch(rect)
    ax.annotate(label, xy=((right-left)/2+left, bottom-0.02), xycoords=('data', 'axes fraction'), ha='center', va='top', c=unit_colors[color])

#################################

# add amplitude discriminator thresholds
for d in amplitude_discriminators:
    unit = d['name']
    epoch = next((ep for ep in epochs if ep['name'] == d['epoch']))
    left, right = epoch['time']
    bottom, top = d['amplitude']
    ax.hlines(y=bottom, xmin=left, xmax=right, color=unit_colors[unit], ls='--')
    ax.hlines(y=top,    xmin=left, xmax=right, color=unit_colors[unit], ls='--')

###############################

fig.tight_layout(h_pad=0, w_pad=0, pad=0)

fig.savefig(os.path.join(export_dir, 'figure-3B.png'), dpi=600)

### 🐌 Figure 3C

In [None]:
# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)
ax = axes[0]

# # add time scale bar
# add_scalebar(ax,
#     sizex=1, labelx='1 s',
#     loc='lower right', bbox_to_anchor=(1, 0),
#     borderpad=0.5, sep=5, barwidth=2,
# )

#################################

# add epoch bars
for d in epochs:
    label = d['label']
    color = d['color']
    left, right = d['time']
    bottom, top = 0.15, 0.18
    width = right-left
    height = top-bottom
    rect = patches.Rectangle((left, bottom), width, height, linewidth=0, facecolor=unit_colors[color], fill=True, clip_on=False, transform=ax.get_xaxis_transform())
    ax.add_patch(rect)
    ax.annotate(label, xy=((right-left)/2+left, bottom-0.02), xycoords=('data', 'axes fraction'), ha='center', va='top', c=unit_colors[color])

#################################

# add amplitude discriminator thresholds
for d in amplitude_discriminators:
    unit = d['name']
    epoch = next((ep for ep in epochs if ep['name'] == d['epoch']))
    left, right = epoch['time']
    bottom, top = d['amplitude']
    ax.hlines(y=bottom, xmin=left, xmax=right, color=unit_colors[unit], ls='--')
    ax.hlines(y=top,    xmin=left, xmax=right, color=unit_colors[unit], ls='--')

###############################

# add spike markers
for st in spike_trains:
    # get the neural channel
    channel = st.annotations['channels'][0]
    sig = get_sig(blk, channel)

    # get the signal for the entire bout
    sig = sig.time_slice(t_start, t_stop)
    sig = sig.rescale(plot_units[plot_names.index(channel)])

    # plot spikes
    spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
    ax.scatter(st.times.rescale('s'), spike_amplitudes, marker='.', s=30, c=unit_colors[st.name], zorder=3)

#################################

fig.tight_layout(h_pad=0, w_pad=0, pad=0)

fig.savefig(os.path.join(export_dir, 'figure-3C.png'), dpi=600)

### 🐌 Figure 3D

In [None]:
# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)
ax = axes[0]

# add time scale bar
add_scalebar(ax,
    sizex=1, labelx='1 s',
    loc='lower right', bbox_to_anchor=(1, 0),
    borderpad=0.5, sep=5, barwidth=2,
)

#################################

# add spike markers
for st in spike_trains:
    # get the neural channel
    channel = st.annotations['channels'][0]
    sig = get_sig(blk, channel)

    # get the signal for the entire bout
    sig = sig.time_slice(t_start, t_stop)
    sig = sig.rescale(plot_units[plot_names.index(channel)])

    # plot spikes
    spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
    ax.scatter(st.times.rescale('s'), spike_amplitudes, marker='.', s=30, c=unit_colors[st.name], zorder=3)

#################################

unit_burst_boxes = {
    'B38':   [-12, 12],
    'B6/B9': [-20, 15],
    'B3':    [-43, 32],
}

# plot burst windows
for k, unit in enumerate(unit_burst_boxes.keys()):
    left = df.loc[behavior, f'{unit} burst start (s)']
    right = df.loc[behavior, f'{unit} burst end (s)']
    if np.isfinite(left) and np.isfinite(right):
        width = right-left
        bottom, top = unit_burst_boxes[unit]
        height = top-bottom
        rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor=unit_colors[unit], fill=False, zorder=3, clip_on=False)
        ax.add_patch(rect)

# add unit name labels
ax.annotate('B38',   xy=(2978.15, 0.75), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B38'])
ax.annotate('B6/B9', xy=(2980.10, 0.79), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B6/B9'])
ax.annotate('B3',    xy=(2981.85, 0.87), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['B3'])

#################################

fig.tight_layout(h_pad=0, w_pad=0, pad=0)

fig.savefig(os.path.join(export_dir, 'figure-3D.png'), dpi=600)

## [FIGURE 4]

### 🐌 Figure 4C

In [None]:
exemplary_swallow_unloaded = ('JG12', 'Regular nori', 1, 1)
animal, food, bout, behavior = exemplary_swallow_unloaded
df = df_all.loc[(animal, food, bout)]
(data_set_name, time_window) = 'IN VIVO / JG12 / 2019-05-10 / 002', [234.94, 242.14] # t = 237.1, twidth = 7.2 ==> [234.94, 242.14]

t_start, t_stop = time_window*pq.s
plots = [
#     {'channel': 'BN2', 'ylabel': None, 'units': 'uV', 'ylim': [-60, 38], 'scalebar': 50},
    {'channel': 'BN2', 'ylabel': None, 'units': 'uV', 'ylim': [-60, 45], 'scalebar': 50},
]
plot_names = [p['channel'] for p in plots]
plot_units = [p['units'] for p in plots]

kwargs = dict(
    figsize = (5, 2),
    linewidth = 0.5,
    x_scalebar = None, # 1*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

#################################

epochs = [
#     {'name': 'B38 activity',       'label': 'B38',      'color': 'B38',      'time': [t_start.magnitude, 2978.89]},
    {'name': 'B3/6/9/10 activity', 'label': 'B3/B6/B9', 'color': 'B3/B6/B9', 'time': [236.55, 239.53]},
]

amplitude_discriminators = [
#     {'name': 'B38',   'channel': 'BN2', 'epoch': 'B38 activity',       'amplitude': [  6,  15], 'units': 'uV'}, # actual thresholds used: [7, 20]
    {'name': 'B6/B9', 'channel': 'BN2', 'epoch': 'B3/6/9/10 activity', 'amplitude': [-17,  -9], 'units': 'uV'}, # actual thresholds used: [-25, -9]
    {'name': 'B3',    'channel': 'BN2', 'epoch': 'B3/6/9/10 activity', 'amplitude': [ 20,  45], 'units': 'uV'}, # actual thresholds used: [-60, -25]
]

spike_trains = []
for discriminator in amplitude_discriminators:
    sig = get_sig(blk, discriminator['channel'])
    if sig is not None:
        sig = sig.time_slice(t_start, t_stop)
        st = _detect_spikes(sig, discriminator, blk.segments[0].epochs)
        epoch = next((ep for ep in epochs if ep['name'] == discriminator['epoch']))
        st_epoch_start = epoch['time'][0]*pq.s
        st_epoch_end = epoch['time'][1]*pq.s
        st = st.time_slice(st_epoch_start, st_epoch_end)
        spike_trains.append(st)

In [None]:
# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)
ax = axes[0]

# add time scale bar
add_scalebar(ax,
    sizex=1, labelx='1 s',
    loc='lower right', bbox_to_anchor=(1, 0),
    borderpad=0.5, sep=5, barwidth=2,
)

#################################

# add spike markers
for st in spike_trains:
    # get the neural channel
    channel = st.annotations['channels'][0]
    sig = get_sig(blk, channel)

    # get the signal for the entire bout
    sig = sig.time_slice(t_start, t_stop)
    sig = sig.rescale(plot_units[plot_names.index(channel)])

    # plot spikes
    spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
    ax.scatter(st.times.rescale('s'), spike_amplitudes, marker='.', s=20, c=unit_colors[st.name], zorder=3)

#################################

unit_burst_boxes = {
#     'B38':   [-12, 12],
#     'B6/B9': [-20, 15],
    'B6/B9': [-20, 20],
    'B3':    [-43, 32],
}

# plot burst windows
for k, unit in enumerate(unit_burst_boxes.keys()):
    left = df.loc[behavior, f'{unit} burst start (s)']
    right = df.loc[behavior, f'{unit} burst end (s)']
    if np.isfinite(left) and np.isfinite(right):
        width = right-left
        bottom, top = unit_burst_boxes[unit]
        height = top-bottom
        rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor=unit_colors[unit], fill=False, zorder=3, clip_on=False)
        ax.add_patch(rect)

# add unit name labels
# ax.annotate('B38',   xy=(2978.15, 0.75), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B38'])
ax.annotate('B6/B9', xy=(237.75, 0.80), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B6/B9'])
ax.annotate('B3',    xy=(239.35, 0.90), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['B3'])

#################################

fig.tight_layout(h_pad=0, w_pad=0, pad=0)

fig.savefig(os.path.join(export_dir, 'figure-4C.png'), dpi=600)

### 🐌 Figure 4D

In [None]:
df = df_exemplary_swallow
(data_set_name, time_window) = feeding_bouts[exemplary_swallow]
animal, _, _ = exemplary_swallow
behavior = 0

t_start, t_stop = time_window*pq.s
plots = [
#     {'channel': 'BN2', 'ylabel': None, 'units': 'uV', 'ylim': [-60, 38], 'scalebar': 50},
    {'channel': 'BN2', 'ylabel': None, 'units': 'uV', 'ylim': [-60, 45], 'scalebar': 50},
]
plot_names = [p['channel'] for p in plots]
plot_units = [p['units'] for p in plots]

kwargs = dict(
    figsize = (5, 2),
    linewidth = 0.5,
    x_scalebar = None, # 1*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

#################################

epochs = [
#     {'name': 'B38 activity',       'label': 'B38',      'color': 'B38',      'time': [t_start.magnitude, 2978.89]},
    {'name': 'B3/6/9/10 activity', 'label': 'B3/B6/B9', 'color': 'B3/B6/B9', 'time': [2979.23, 2983.28]},
]

amplitude_discriminators = [
#     {'name': 'B38',   'channel': 'BN2', 'epoch': 'B38 activity',       'amplitude': [  6,  15], 'units': 'uV'}, # actual thresholds used: [7, 20]
    {'name': 'B6/B9', 'channel': 'BN2', 'epoch': 'B3/6/9/10 activity', 'amplitude': [-17,  -9], 'units': 'uV'}, # actual thresholds used: [-25, -9]
    {'name': 'B3',    'channel': 'BN2', 'epoch': 'B3/6/9/10 activity', 'amplitude': [ 20,  30], 'units': 'uV'}, # actual thresholds used: [-60, -25]
]

spike_trains = []
for discriminator in amplitude_discriminators:
    sig = get_sig(blk, discriminator['channel'])
    if sig is not None:
        sig = sig.time_slice(t_start, t_stop)
        st = _detect_spikes(sig, discriminator, blk.segments[0].epochs)
        epoch = next((ep for ep in epochs if ep['name'] == discriminator['epoch']))
        st_epoch_start = epoch['time'][0]*pq.s
        st_epoch_end = epoch['time'][1]*pq.s
        st = st.time_slice(st_epoch_start, st_epoch_end)
        spike_trains.append(st)

In [None]:
# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)
ax = axes[0]

# add time scale bar
add_scalebar(ax,
    sizex=1, labelx='1 s',
    loc='lower right', bbox_to_anchor=(1, 0),
    borderpad=0.5, sep=5, barwidth=2,
)

#################################

# add spike markers
for st in spike_trains:
    # get the neural channel
    channel = st.annotations['channels'][0]
    sig = get_sig(blk, channel)

    # get the signal for the entire bout
    sig = sig.time_slice(t_start, t_stop)
    sig = sig.rescale(plot_units[plot_names.index(channel)])

    # plot spikes
    spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
    ax.scatter(st.times.rescale('s'), spike_amplitudes, marker='.', s=20, c=unit_colors[st.name], zorder=3)

#################################

unit_burst_boxes = {
#     'B38':   [-12, 12],
#     'B6/B9': [-20, 15],
    'B6/B9': [-20, 20],
    'B3':    [-43, 32],
}

# plot burst windows
for k, unit in enumerate(unit_burst_boxes.keys()):
    left = df.loc[behavior, f'{unit} burst start (s)']
    right = df.loc[behavior, f'{unit} burst end (s)']
    if np.isfinite(left) and np.isfinite(right):
        width = right-left
        bottom, top = unit_burst_boxes[unit]
        height = top-bottom
        rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor=unit_colors[unit], fill=False, zorder=3, clip_on=False)
        ax.add_patch(rect)

# add unit name labels
# ax.annotate('B38',   xy=(2978.15, 0.75), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B38'])
ax.annotate('B6/B9', xy=(2980.10, 0.79), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B6/B9'])
ax.annotate('B3',    xy=(2981.85, 0.82), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['B3'])

#################################

fig.tight_layout(h_pad=0, w_pad=0, pad=0)

fig.savefig(os.path.join(export_dir, 'figure-4D.png'), dpi=600)