# Jump to a Figure

- [Figure 1](#[FIGURE-1])
  - [Figure 1A](#🐌-Figure-1A)
  - [Figures 1B & 1C](#🐌-Figures-1B-&-1C)
- [Figure 2](#[FIGURE-2])
  - [Figure 2A](#🐌-Figure-2A)
  - [Figure 2B](#🐌-Figure-2B)
  - [Figure 2C](#🐌-Figure-2C)
  - [Figure 2D](#🐌-Figure-2D)
  - [Figure 2E](#🐌-Figure-2E)
- [Figure 3](#[FIGURE-3])
  - [Figure 3A](#🐌-Figure-3A)
  - [Figure 3B](#🐌-Figure-3B)
  - [Figure 3C](#🐌-Figure-3C)
  - [Figure 3D](#🐌-Figure-3D)
  - [Figure 3E](#🐌-Figure-3E)
  - [Figure 3F](#🐌-Figure-3F)
  - [Figure 3G](#🐌-Figure-3G)
- [Figure 4](#[FIGURE-4])
  - [Figure 4A](#🐌-Figure-4A)
  - [Figure 4B](#🐌-Figure-4B)
- [Figure 5](#[FIGURE-5])
  - [Figure 5A](#🐌-Figure-5A)
  - [Figure 5B](#🐌-Figure-5B)

# Preamble

## Import Packages

In [None]:
import os
import datetime
import pickle
import numpy as np
import scipy as sp
import quantities as pq
import elephant
import pandas as pd
import statsmodels.api as sm
import neurotic
from neurotic.datasets.data import _detect_spikes
from utils import BehaviorsDataFrame, DownsampleNeoSignal

pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import matplotlib.lines as mlines
from matplotlib.markers import CARETLEFT, CARETRIGHT, CARETUP, CARETDOWN, CARETUPBASE
from matplotlib.ticker import MultipleLocator
import seaborn as sns

In [None]:
import warnings

# np.nanmax raises a warning if all values are NaN and returns NaN, which is the behavior we want
warnings.filterwarnings('ignore', message='All-NaN slice encountered')

# elephant.statistics.instantaneous_rate always complains about negative values
warnings.filterwarnings('ignore', message='Instantaneous firing rate approximation contains '
                                          'negative values, possibly caused due to machine '
                                          'precision errors')

# with matplotlib>=3.1 and seaborn<=0.9.0, deprecation warnings are raised
# whenever tick marks are placed on the right axis but not the left
from matplotlib.cbook import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)

# don't complain about opening too many figures
plt.rcParams.update({'figure.max_open_warning': 0})

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

## Plot Settings

In [None]:
export_dir = 'manuscript-figures'
if not os.path.exists(export_dir):
    os.mkdir(export_dir)

In [None]:
# general plot settings
sns.set(
#     context = 'poster',
    style = 'ticks',
    font_scale = 1,
    font = 'Palatino Linotype',
)

In [None]:
# display current color palette
with sns.axes_style('darkgrid'):
    sns.palplot(sns.color_palette(None), size=0.5)

In [None]:
unit_colors = {
    'I2 spikes': 'C9', # light blue
    'I2':        'C9', # light blue
    'B8a/b':     'C6', # pink
    'B6/B9':     'C2', # green
    'B3':        'C3', # red
    'B38':       'C1', # orange
    'B4/B5':     'C0', # dark blue
}
force_colors = {
    'dip':          unit_colors['I2 spikes'],
    'initial rise': unit_colors['B8a/b'],
    'rise':         unit_colors['B8a/b'],
    'plateau':      unit_colors['B6/B9'],
    'drop':         'gray',
    'shoulder':     unit_colors['B38'],
}

In [None]:
for unit, color_id in unit_colors.items():
    if unit != 'I2 spikes':
        color_index = int(color_id[1:])
        color_tuple = sns.color_palette(None)[color_index]
        color_hex = mcolors.to_hex(color_tuple)
        print(f'{unit}\t{color_hex}')

## Data Parameters

In [None]:
burst_thresholds_default = {
    'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
    'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
    'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
    'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
    'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
    'B4/B5':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
}

burst_thresholds_by_animal = {
    'JG07': burst_thresholds_default.copy(),
    'JG08': burst_thresholds_default.copy(),
    'JG11': burst_thresholds_default.copy(),
    'JG12': burst_thresholds_default.copy(),
    'JG14': burst_thresholds_default.copy(),
}

# exceptions
burst_thresholds_by_animal['JG08']['B6/B9'] = (10,   3  )*pq.Hz # end threshold reduced for this animal
burst_thresholds_by_animal['JG11']['B4/B5'] = ( 1.5, 1.5)*pq.Hz # both thresholds reduced for this animal because only one neuron appeared to project
burst_thresholds_by_animal['JG14']['B6/B9'] = ( 4,   2  )*pq.Hz # both thresholds reduced for this animal because B6/B9 always fired slowly

In [None]:
channel_units = ['uV', 'uV', 'uV', 'uV', 'mN']

In [None]:
channel_names_by_animal = {
    'JG07': ['I2-L', 'RN-L', 'BN2-L', 'BN3-L',    'Force'],
    'JG08': ['I2',   'RN',   'BN2',   'BN3',      'Force'],
    'JG11': ['I2',   'RN',   'BN2',   'BN3-PROX', 'Force'],
    'JG12': ['I2',   'RN',   'BN2',   'BN3-DIST', 'Force'],
    'JG14': ['I2',   'RN',   'BN2',   'BN3-PROX', 'Force'],
}

In [None]:
sig_filters_by_animal = {
    'JG07': [{'channel': 'I2-L', 'lowpass': 100}],
    'JG08': [{'channel': 'I2',   'lowpass': 100}],
    'JG11': [{'channel': 'I2',   'lowpass': 100}],
    'JG12': [{'channel': 'I2',   'lowpass': 100}],
    'JG14': [{'channel': 'I2',   'lowpass': 100}],
}

In [None]:
epoch_types_by_food = {
    'Regular nori': ['Swallow (regular 5-cm nori strip)'],
    'Tape nori':    ['Swallow (tape nori)'],
}

In [None]:
feeding_bouts = {
    # (animal, food, bout_index): (data_set_name, time_window)

    ('JG07', 'Regular nori', 0): ('IN VIVO / JG07 / 2018-05-20 / 002', [1496, 1518]), # 4 swallows, last 2 inward food movement epochs not representative of retraction (strip broke, then finished strip mid-retraction)
    ('JG08', 'Regular nori', 0): ('IN VIVO / JG08 / 2018-06-21 / 002', [ 256,  287]), # 4 swallows
    ('JG08', 'Regular nori', 1): ('IN VIVO / JG08 / 2018-06-21 / 002', [ 454,  481]), # 4 swallows
    ('JG11', 'Regular nori', 0): ('IN VIVO / JG11 / 2019-04-03 / 001', [1791, 1819]), # 5 swallows
    ('JG11', 'Regular nori', 1): ('IN VIVO / JG11 / 2019-04-03 / 004', [ 551,  568]), # 3 swallows
    ('JG12', 'Regular nori', 0): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 147,  165]), # 3 swallows, last 1 inward food movement epoch not representative of retraction (finished strip mid-retraction)
    ('JG12', 'Regular nori', 1): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 229,  245]), # 3 swallows
    ('JG12', 'Regular nori', 2): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 277,  291]), # 3 swallows, last 1 inward food movement epoch not representative of retraction (finished strip mid-retraction)
    ('JG14', 'Regular nori', 0): ('IN VIVO / JG14 / 2019-07-30 / 001', [1834, 1865]), # 4 swallows
    ('JG14', 'Regular nori', 1): ('IN VIVO / JG14 / 2019-07-30 / 001', [1910, 1943]), # 5 swallows
    ('JG14', 'Regular nori', 2): ('IN VIVO / JG14 / 2019-07-30 / 001', [2052, 2084]), # 5 swallows
    
    ('JG07', 'Tape nori',    0): ('IN VIVO / JG07 / 2018-05-20 / 002', [2718, 2755]), # 5 swallows
    ('JG08', 'Tape nori',    0): ('IN VIVO / JG08 / 2018-06-21 / 002', [ 147,  208]), # 7 swallows, some bucket and head movement
    ('JG08', 'Tape nori',    1): ('IN VIVO / JG08 / 2018-06-21 / 002', [ 664,  701]), # 5 swallows, large bucket movement
    ('JG08', 'Tape nori',    2): ('IN VIVO / JG08 / 2018-06-21 / 002', [1451, 1477]), # 3 swallows, some bucket movement
    ('JG11', 'Tape nori',    0): ('IN VIVO / JG11 / 2019-04-03 / 004', [1227, 1280]), # 5 swallows
    ('JG12', 'Tape nori',    0): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 436,  465]), # 4 swallows
    ('JG12', 'Tape nori',    1): ('IN VIVO / JG12 / 2019-05-10 / 002', [2901, 2937]), # 5 swallows
    ('JG14', 'Tape nori',    0): ('IN VIVO / JG14 / 2019-07-29 / 004', [ 829,  870]), # 5 swallows
}

In [None]:
# for example figures only -- not used in majority of analysis because of long inter-swallow intervals

exemplary_bout    = ('JG12', 'Tape nori', 101)
exemplary_swallow = ('JG12', 'Tape nori', 102)

feeding_bouts[exemplary_bout]    = ('IN VIVO / JG12 / 2019-05-10 / 002', [2970.7, 2992.0]) # 3 swallows
feeding_bouts[exemplary_swallow] = ('IN VIVO / JG12 / 2019-05-10 / 002', [2977.3, 2984.5]) # 1 swallow

## Helper Functions

In [None]:
def query_union(queries):
    return '|'.join([f'({q})' for q in queries if q is not None])

def label2query(label):
    animal, food = label[:4], label[5:]
    query = f'(Animal == "{animal}") & (Food == "{food}")'
    return query

def contains(series, string):
#     return np.array([string in x for x in series])
    return series.map(lambda x: string in x)

In [None]:
def get_sig(blk, channel):
    sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == channel), None)
    if sig is None:
        raise Exception(f'Channel "{channel}" could not be found')
    else:
        return sig

In [None]:
def get_sig_index(blk, channel):
    index = next((i for i, sig in enumerate(blk.segments[0].analogsignals) if sig.name == channel), None)
    if index is None:
        raise Exception(f'Channel "{channel}" could not be found')
    else:
        return index

In [None]:
def apply_filters(blk, metadata):
    # nearly identical to neurotic's implementation except
    # time_slice ensures proxies are loaded
    
    for sig_filter in metadata['filters']:
        index = get_sig_index(blk, sig_filter['channel'])
        high = sig_filter.get('highpass', None)
        low  = sig_filter.get('lowpass',  None)
        if high:
            high *= pq.Hz
        if low:
            low  *= pq.Hz
        blk.segments[0].analogsignals[index] = elephant.signal_processing.butter(  # may raise a FutureWarning
            signal = blk.segments[0].analogsignals[index].time_slice(None, None),
            highpass_freq = high,
            lowpass_freq  = low,
        )
    
    return blk

In [None]:
# def finite_min(x, y):
#     '''Workaround for Quantities warning about comparison to NaN'''
#     if np.isfinite(x) and np.isfinite(y):
#         return min(x, y)
#     elif np.isfinite(x):
#         return x
#     elif np.isfinite(y):
#         return y
#     else:
#         assert x.units == y.units
#         return np.nan * x.units

# def finite_max(x, y):
#     '''Workaround for Quantities warning about comparison to NaN'''
#     if np.isfinite(x) and np.isfinite(y):
#         return max(x, y)
#     elif np.isfinite(x):
#         return x
#     elif np.isfinite(y):
#         return y
#     else:
#         assert x.units == y.units
#         return np.nan * x.units

In [None]:
def find_bursts(st, burst_thresholds):
    '''Find every sequence of spikes that qualifies as a burst'''
    
    isi = elephant.statistics.isi(st).rescale('s')
    iff = 1/isi

    start_freq, end_freq = burst_thresholds
    start_mask = iff > start_freq
    end_mask = iff < end_freq

    bursts = []
    scan_index = -1
    while scan_index < iff.size:
        start_index = None
        end_index = None

        start_mask_indexes = np.where(start_mask)[0]
        start_mask_indexes = start_mask_indexes[start_mask_indexes > scan_index]
        if start_mask_indexes.size == 0:
            break

        start_index = start_mask_indexes[0] # first time that iff rises above start threshold

        end_mask_indexes = np.where(end_mask)[0]
        end_mask_indexes = end_mask_indexes[end_mask_indexes > start_index]
        if end_mask_indexes.size > 0:
            end_index = end_mask_indexes[0] # first time after start that iff drops below end theshold
        else:
            end_index = -1 # end of spike train (include all spikes after start)

        burst = {
            'Start (s)': st[start_index].rescale('s'),
            'End (s)': st[end_index].rescale('s'),
            'Duration (s)': (st[end_index] - st[start_index]).rescale('s'),
            'Number of spikes': end_index-start_index+1 if end_index > 0 else st.size-start_index
        }
        bursts.append(burst)
        if end_index == -1:
            break
        else:
            scan_index = end_index
    
    return bursts

In [None]:
def is_good_burst(burst):
    return burst['Duration (s)'] >= 0.5*pq.s and burst['Number of spikes'] > 2

In [None]:
# this function runs much faster if args are first converted
# from quantities to simple ndarrays (use .rescale('s').magnitude)
def normalize_time(fixed_times, t):
    if not isinstance(t, np.ndarray):
        if type(t) is list:
            t = np.array(t)
        else:
            t = np.array([t])
    
    assert np.all(np.diff(fixed_times[~np.isnan(fixed_times)])>=0), f'fixed_times must be sorted: {fixed_times}'
    assert np.all(np.diff(t[~np.isnan(t)])>=0), f't must be sorted: {t}'
    
    t_min = fixed_times[~np.isnan(fixed_times)].min()
    t_max = fixed_times[~np.isnan(fixed_times)].max()
    
    result = []
    last_found_i = 0
    for ti in t:
        
        found = False
        
        if np.isnan(ti) or ti < t_min or t_max < ti:
#             print(f'time {ti:.3f} was out of bounds for normalization')
            result.append(np.nan)
            continue

        # use a manual O(n) loop instead of using np.searchsorted's much faster binary
        # search O(log(n)) algorithm because NaNs inside fixed_times can fool searchsorted
        for i in range(last_found_i, len(fixed_times)-1):
            before = fixed_times[i]
            after = fixed_times[i+1]
            if np.isfinite(before) and np.isfinite(after):
                if before <= ti <= after:
                    found = True
                    last_found_i = i
                    result.append((ti-before)/(after-before) + i)
                    break

        # if we haven't returned already, then there must be a NaN bordering where t would go
        if not found:
#             print(f'time {ti:.3f} would fall next to an undefined boundary')
            result.append(np.nan)
            continue
    
    return np.array(result)

In [None]:
# this function runs much faster if args are first converted
# from quantities to simple ndarrays (use .rescale('s').magnitude)
def unnormalize_time(fixed_times, t_normalized):
    if not isinstance(t_normalized, np.ndarray):
        if type(t_normalized) is list:
            t_normalized = np.array(t_normalized)
        else:
            t_normalized = np.array([t_normalized])
    
    assert np.all(np.diff(fixed_times[~np.isnan(fixed_times)])>=0), f'fixed_times must be sorted: {fixed_times}'
    
    t_normalized_min = np.where(~np.isnan(fixed_times))[0].min()
    t_normalized_max = np.where(~np.isnan(fixed_times))[0].max()
    
    result = []
    for ti in t_normalized:
        
        if np.isnan(ti) or ti < t_normalized_min or t_normalized_max < ti:
    #         print(f'normalized time {ti:.3f} was out of bounds for un-normalization')
            result.append(np.nan)
            continue

        if np.isclose(ti, t_normalized_max):
            # workaround for numerical imprecision issue
            ti -= 0.000001

        i = int(np.floor(ti))
        before = fixed_times[i]
        after = fixed_times[i+1]
        if np.isfinite(before) and np.isfinite(after):
            result.append((ti - i)*(after-before) + before)
            continue
        else:
            # there is a NaN bordering where t would go
#             print(f'normalized time {ti:.3f} would fall next to an undefined boundary')
            result.append(np.nan)
            continue
    
    return np.array(result)

In [None]:
# times_interp is the array of normalized times at which samples will be taken
# - samples will be taken at regular intervals in normalized time
times_interp = np.linspace(0, 9, 4000) # [0, 9] is the range of normalized times

def resample_sig_in_normalized_time(fixed_times, sig, times_interp=times_interp):
        
        # get normalized times and signal values
        # - normalize_time will put into times_normalized a np.nan wherever a time was not normalizable,
        #   i.e., wherever the time occurred adjacent to a missing fixed time (np.nan in fixed_times)
        times = sig.times.rescale('s').magnitude
        times_normalized = normalize_time(fixed_times.magnitude, times)
        y = sig.magnitude.flatten()
        
        # drop times that could not be normalized due to missing fixed times
        # - interp1d will erroneously interpolate across the gaps created by this deletion,
        #   but we will then replace those interpolated values with np.nan
        where_normalizable = np.where(~np.isnan(times_normalized))[0]
        times_normalized = times_normalized[where_normalizable]
        y = y[where_normalizable]
        
        # resample evenly in normalized time
        # - fill_value will only put np.nan in places outside the min and max of times_normalized
        # - interp1d will interpolate across the regions we deleted, but we want np.nan there,
        #   so we will insert them manually in the next step
        interp_func = sp.interpolate.interp1d(times_normalized, y, kind='linear', bounds_error=False, fill_value=np.nan)
        y_interp = interp_func(times_interp)
        
        # replace erroneously interpolated values with np.nan
        # - now the points in y_interp which would correspond to times that could not be normalized
        #   have been set to np.nan
        where_not_normalizable = np.where(np.isnan(unnormalize_time(fixed_times.magnitude, times_interp)))[0]
        y_interp[where_not_normalizable] = np.nan
        
        return y_interp

In [None]:
def print_column_analysis(column):
    unit = column.split('(')[-1].split(')')[0]
    mean = df_all[column].mean()
    std = df_all[column].std()
    cv = std/abs(mean)
    print(f'Overall mean +/- std: {mean:.2f} +/- {std:.2f} {unit}')
    print(f'Overall coefficient of variation CV: {cv:.2f}')

In [None]:
def differences_test(x, y):
    # Shapiro-Wilk test for normality of differences
    # - equivalent R test: shapiro.test(x-y)
    shapiro_W, shapiro_p = sp.stats.shapiro(x-y)
    print(f'H0: Differences have normal distribution, W = {shapiro_W:g},\tp = {shapiro_p:g}')

    if shapiro_p >= 0.05:
        print('- Because the differences can be assumed to be normal, a paired t-test will be used')

        # paired T-test for non-zero difference in means
        # - equivalent R test: t.test(x, y, paired=TRUE)
        ttest_t, ttest_p = sp.stats.ttest_rel(x, y)
        print(f'H0: Difference in means is zero,          t = {ttest_t:g},\tp = {ttest_p:g}')
    
    else:
        print('- Because the differences cannot be assumed to be normal, a Wilcoxon signed rank test will be used')

        # Wilcoxon signed rank test for non-zero difference in locations (medians?)
        # - equivalent R test: wilcox.test(x, y, paired=TRUE, exact=FALSE)
        # - a warning is raised for small sample sizes (N < 10) becauses SciPy's implementation
        #   always calculates the p-value using a normal approximation of the test statistic
        #   distribution, which is inaccurate for small sample size
        # - use R to get exact p-value, with wilcox.test(x, y, paired=TRUE, exact=TRUE)
        # - upcoming implementation of exact distribution: https://github.com/scipy/scipy/pull/10796
        wilcoxon_W, wilcoxon_p = sp.stats.wilcoxon(x, y, correction=True)
        print(f'H0: Difference in medians is zero,        W = {wilcoxon_W:g},\tp = {wilcoxon_p:g}')

In [None]:
def lighten_color(color, amount=0.5):
    '''
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    
    https://stackoverflow.com/a/49601444/3314376
    '''
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])

## Import and Process the Data

In [None]:
# skip expensive calculations by loading the results from a file?
# - with load_from_files=False, perform data processing from scratch,
#   which takes several minutes
# - with load_from_files=True, load the final results (dataframes)
#   pickled last time the calculations were performed
load_from_files = False

In [None]:
pickled_vars = ['df_all', 'df_exemplary_bout', 'df_exemplary_swallow']
if load_from_files:
    
    # TODO: why does unpickling generate this warning?
    #     RuntimeWarning: invalid value encountered in greater
    #         return self.magnitude > other
    for var in pickled_vars:
        filename = f'{var}.pickle'
        with open(filename, 'rb') as f:
            exec(f'{var} = pickle.load(f)')

    print('calculation results loaded from files')
    
else:
    
    start = datetime.datetime.now()

    # use Neo RawIO lazy loading to load much faster and using less memory
    # - with lazy=True, filtering parameters specified in metadata are ignored
    #     - note: filters are replaced below and applied manually anyway
    # - with lazy=True, loading via time_slice requires neo>=0.8.0
    lazy = True

    # load the metadata containing file paths
    metadata = neurotic.MetadataSelector(file='../../data/metadata.yml')

    # filter epochs for each bout and perform calculations
    df_list = []
    last_data_set_name = None
    for (animal, food, bout_index), (data_set_name, time_window) in feeding_bouts.items():

        channel_names = channel_names_by_animal[animal]
        epoch_types = epoch_types_by_food[food]
        burst_thresholds = burst_thresholds_by_animal[animal]

        ###
        ### LOAD DATASET
        ###

        metadata.select(data_set_name)

        if data_set_name is last_data_set_name:
            # skip reloading the data if it's already in memory
            pass
        else:

            # ensure that the right filters are used
            metadata['filters'] = sig_filters_by_animal[animal]

            blk = neurotic.load_dataset(metadata, lazy=lazy)

            if lazy:
                # manually perform filters
                blk = apply_filters(blk, metadata)

        last_data_set_name = data_set_name



        ###
        ### LOCATE BEHAVIOR EPOCHS AND SUBEPOCHS
        ###

        # construct a query for locating behaviors
        behavior_query = f'(Type in {epoch_types}) & ({time_window[0]} <= Start) & (End <= {time_window[1]})'

        # construct queries for locating epochs associated with each behavior
        # - each query should match at most one epoch
        # - dictionary keys are used as prefixes for the names of new columns
        subepoch_queries = {}

        subepoch_queries['I2 protraction activity'] = f'(Type == "I2 protraction activity") & ' \
                                                      f'(@behavior_start-1 <= Start) & (End <= @behavior_end)'
                                                      # must start no earlier than 1 second before behavior and end within it

        subepoch_queries['B8 activity']             = f'(Type == "B8 activity") & ' \
                                                      f'(@behavior_start <= Start) & (End <= @behavior_end)'
                                                      # must be fully contained within behavior

        subepoch_queries['B3/6/9/10 activity']      = f'(Type == "B3/6/9/10 activity") & ' \
                                                      f'(@behavior_start <= Start) & (End <= @behavior_end)'
                                                      # must be fully contained within behavior

        subepoch_queries['B38 activity']            = (f'(Type == "B38 activity") & ' \
                                                       f'(@behavior_start-3 <= End) & (End <= @behavior_start+4)',
                                                       'last') # use last if there are multiple matches
                                                      # must end within a few seconds of behavior start (3 before or 4 after)

        subepoch_queries['B4/B5 activity']          = (f'(Type == "B4/B5 activity") & ' \
                                                       f'(@behavior_start <= Start) & (Start <= @behavior_end)',
                                                       'first') # use first if there are multiple matches
                                                      # must start within behavior
        
        subepoch_queries['Force shoulder end']      = f'(Type == "Force shoulder end") & ' \
                                                      f'(@behavior_start-3 <= End) & (End <= @behavior_start+3)'
                                                      # must end within 3 seconds of behavior start
        
        subepoch_queries['Force rise start']        = f'(Type == "Force rise start") & ' \
                                                      f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                      # must start within behavior

        subepoch_queries['Force plateau start']     = f'(Type == "Force plateau start") & ' \
                                                      f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                      # must start within behavior

        subepoch_queries['Force plateau end']       = f'(Type == "Force plateau end") & ' \
                                                      f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)'
                                                      # must start within 2 seconds of behavior end

        subepoch_queries['Force drop end']          = f'(Type == "Force drop end") & ' \
                                                      f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)'
                                                      # must start within 2 seconds of behavior end
        
        subepoch_queries['Inward movement']         = f'(Type == "Inward movement") & ' \
                                                      f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                      # must start within behavior

        # construct a table in which each row is a behavior and subepoch data
        # is added as columns, e.g. df['B38 activity start (s)']
        df = BehaviorsDataFrame(blk.segments[0].epochs, behavior_query, subepoch_queries)

        # renumber behaviors assuming all behaviors are from a single contiguous sequence
        df = df.sort_values('Start (s)').reset_index(drop=True).rename_axis('Behavior_index')



        ###
        ### START CALCULATIONS
        ###

        # some columns must have type 'object', which
        # can be accomplished by initializing with None or np.nan
        df['Normalization fixed times (s)'] = None
        df['Force, normalized time interpolation (mN)'] = None
        units = [
            'I2 spikes',
            'B8a/b',
            'B3',
            'B6/B9',
            'B38',
            'B4/B5',
        ]
        for unit in units:
            df[unit+' spike train'] = None
            df[unit+' firing rate (Hz)'] = None
            df[unit+' firing rate, normalized time interpolation (Hz)'] = None
            df[unit+' inter-spike intervals (s)'] = None
            df[unit+' all bursts (s)'] = None

            # while we're at it, initialize some other things that might otherwise never be given values
            df[unit+' first burst start (s)'] = np.nan
            df[unit+' first burst end (s)'] = np.nan
            df[unit+' first burst duration (s)'] = 0
            df[unit+' first burst spike count'] = 0
            df[unit+' first burst mean frequency (Hz)'] = np.nan
            df[unit+' last burst start (s)'] = np.nan
            df[unit+' last burst end (s)'] = np.nan
            df[unit+' last burst duration (s)'] = 0
            df[unit+' last burst spike count'] = 0
            df[unit+' last burst mean frequency (Hz)'] = np.nan



        # get smoothed force for entire time window
        channel = 'Force'
        sig = get_sig(blk, channel)
        if lazy:
            sig = sig.time_slice(None, None)
        sig = elephant.signal_processing.butter(sig, lowpass_freq = 5*pq.Hz)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale('mN')
        force_smoothed_sig = sig



        df['End to next start (s)'] = np.nan
        df['Start to next start (s)'] = np.nan
        previous_i = None

        # iterate over all swallows
        for j, i in enumerate(df.index):

            behavior_start = df.loc[i, 'Start (s)']*pq.s
            behavior_end = df.loc[i, 'End (s)']*pq.s

            # calculate interbehavior intervals assuming all behaviors are from a single contiguous sequence
            if previous_i is not None:
                df.loc[previous_i, 'End to next start (s)']   = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'End (s)']
                df.loc[previous_i, 'Start to next start (s)'] = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'Start (s)']
            previous_i = i



            ###
            ### FORCE SEGMENTATION
            ###

            force_shoulder_end  = df.loc[i, 'Force shoulder end start (s)']*pq.s  # start of "Force shoulder end" epoch
            force_rise_start    = df.loc[i, 'Force rise start start (s)']*pq.s    # start of "Force rise start" epoch
            force_plateau_start = df.loc[i, 'Force plateau start start (s)']*pq.s # start of "Force plateau start" epoch
            force_plateau_end   = df.loc[i, 'Force plateau end start (s)']*pq.s   # start of "Force plateau end" epoch
            force_drop_end      = df.loc[i, 'Force drop end start (s)']*pq.s      # start of "Force drop end" epoch

            # force rise start, plateau start and end, and drop end are required
            force_is_segmented = np.all(np.isfinite(np.array([
                force_rise_start, force_plateau_start, force_plateau_end, force_drop_end])))

            if force_is_segmented:
                # get some times for the previous and next swallow
                epochs_force_shoulder_end  = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force shoulder end'), None)
                epochs_force_rise_start    = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force rise start'), None)
                epochs_force_plateau_start = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force plateau start'), None)
                epochs_force_plateau_end   = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force plateau end'), None)
                epochs_force_drop_end      = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force drop end'), None)
                assert epochs_force_shoulder_end  is not None, 'failed to find "Force shoulder end" epochs'
                assert epochs_force_rise_start    is not None, 'failed to find "Force rise start" epochs'
                assert epochs_force_plateau_start is not None, 'failed to find "Force plateau start" epochs'
                assert epochs_force_plateau_end   is not None, 'failed to find "Force plateau end" epochs'
                assert epochs_force_drop_end      is not None, 'failed to find "Force drop end" epochs'

                try:
                    prev_force_plateau_start = df.loc[i, 'Previous force plateau start (s)'] = epochs_force_plateau_start.time_slice(None, force_rise_start)[-1]
                    assert force_rise_start-prev_force_plateau_start < 16*pq.s, f'for swallow {i}, previous force plateau start is too far away'
                except IndexError:
                    prev_force_plateau_start = df.loc[i, 'Previous force plateau start (s)'] = np.nan

                try:
                    prev_force_plateau_end = df.loc[i, 'Previous force plateau end (s)'] = epochs_force_plateau_end.time_slice(None, force_rise_start)[-1]
                    assert force_rise_start-prev_force_plateau_end < 12*pq.s, f'for swallow {i}, previous force plateau end is too far away'
                except IndexError:
                    prev_force_plateau_end = df.loc[i, 'Previous force plateau end (s)'] = np.nan

                try:
                    prev_force_drop_end = df.loc[i, 'Previous force drop end (s)'] = epochs_force_drop_end.time_slice(None, force_rise_start)[-1]
                    assert force_rise_start-prev_force_drop_end < 12*pq.s, f'for swallow {i}, previous force drop end is too far away'
                except IndexError:
                    prev_force_drop_end = df.loc[i, 'Previous force drop end (s)'] = np.nan

                try:
                    next_force_rise_start = df.loc[i, 'Next force rise start (s)'] = epochs_force_rise_start.time_slice(force_drop_end, None)[0]
                    assert next_force_rise_start-force_drop_end < 12*pq.s, f'for swallow {i}, next force rise start is too far away'
                except IndexError:
                    next_force_rise_start = df.loc[i, 'Next force rise start (s)'] = np.nan

                try:
                    next_force_shoulder_end = df.loc[i, 'Next force shoulder end (s)'] = epochs_force_shoulder_end.time_slice(force_drop_end, None)[0]
                    if next_force_shoulder_end > next_force_rise_start:
                        # next swallow did not have a shoulder and we instead grabbed a later shoulder
                        next_force_shoulder_end = df.loc[i, 'Next force shoulder end (s)'] = np.nan
                except IndexError:
                    next_force_shoulder_end = df.loc[i, 'Next force shoulder end (s)'] = np.nan

                # get the list of fixed times for normalization
                normalization_fixed_times = df.at[i, 'Normalization fixed times (s)'] = np.array([
                    prev_force_plateau_start,
                    prev_force_plateau_end,
                    prev_force_drop_end,
                    force_shoulder_end,
                    force_rise_start,
                    force_plateau_start,
                    force_plateau_end,
                    force_drop_end,
                    next_force_shoulder_end,
                    next_force_rise_start,
                ])*pq.s # 'at', not 'loc', is important for inserting list into cell

            else: # force is not segmented
                normalization_fixed_times = df.at[i, 'Normalization fixed times (s)'] = np.array([
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                ])*pq.s # 'at', not 'loc', is important for inserting list into cell



            ###
            ### FORCE QUANTIFICATION
            ###

            if force_is_segmented:

                # get smoothed force for whole behavior for remaining force calculations
                sig = force_smoothed_sig
                if np.isfinite(force_shoulder_end):
                    sig = sig.time_slice(force_shoulder_end - 0.01*pq.s, force_drop_end + 0.01*pq.s)
                else:
                    sig = sig.time_slice(force_rise_start - 1*pq.s, force_drop_end + 0.01*pq.s)
                sig = sig.rescale('mN')

                # find force peak, baseline, and the increase
                force_min_time = df.loc[i, 'Force minimum time (s)'] = elephant.spike_train_generation.peak_detection(sig, 999*pq.mN, sign='below')[0]
                force_min = df.loc[i, 'Force minimum (mN)'] = sig[sig.time_index(force_min_time)][0]
                force_peak_time = df.loc[i, 'Force peak time (s)'] = elephant.spike_train_generation.peak_detection(sig, 0*pq.mN)[0]
                force_peak = df.loc[i, 'Force peak (mN)'] = sig[sig.time_index(force_peak_time)][0]
                force_baseline = df.loc[i, 'Force baseline (mN)'] = sig[sig.time_index(force_rise_start)][0]
                force_increase = df.loc[i, 'Force increase (mN)'] = force_peak-force_baseline

                # find force plateau, drop, and shoulder values
                force_plateau_start_value = df.loc[i, 'Force plateau start value (mN)'] = sig[sig.time_index(force_plateau_start)][0]
                force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)'] = sig[sig.time_index(force_plateau_end)][0]
                force_drop_end_value = df.loc[i, 'Force drop end value (mN)'] = sig[sig.time_index(force_drop_end)][0]
                if np.isfinite(force_shoulder_end):
                    force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)'] = sig[sig.time_index(force_shoulder_end)][0]
                else:
                    force_shoulder_end_value = np.nan

                # find force rise and plateau durations
                force_rise_duration = df.loc[i, 'Force rise duration (s)'] = force_plateau_start - force_rise_start
                force_plateau_duration = df.loc[i, 'Force plateau duration (s)'] = force_plateau_end - force_plateau_start
                force_rise_plateau_duration = df.loc[i, 'Force rise and plateau duration (s)'] = force_plateau_end - force_rise_start

                # find average slope during rising phase
                force_rise_increase = df.loc[i, 'Force rise increase (mN)'] = force_plateau_start_value - force_baseline
                force_slope = df.loc[i, 'Force slope (mN/s)'] = (force_rise_increase/force_rise_duration).rescale('mN/s')



                ###
                ### FORCE NORMALIZATION
                ###

                channel = 'Force'
                sig = get_sig(blk, channel)
                sig = sig.time_slice(prev_force_plateau_start+0.001*pq.s, next_force_rise_start-0.001*pq.s)
                sig = sig.rescale(channel_units[channel_names.index(channel)])

                force_interp = df.at[i, 'Force, normalized time interpolation (mN)'] = \
                    resample_sig_in_normalized_time(normalization_fixed_times, sig) # 'at', not 'loc', is important for inserting list into cell



            ###
            ### FIND SPIKE TRAINS
            ###

            if lazy:
                if metadata['amplitude_discriminators'] is not None:
                    for discriminator in metadata['amplitude_discriminators']:
                        sig = get_sig(blk, discriminator['channel'])
                        if sig is not None:
                            sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
                            st = _detect_spikes(sig, discriminator, blk.segments[0].epochs)
                            st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                            st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                            st = st.time_slice(st_epoch_start, st_epoch_end)
                            df.at[i, st.name+' spike train'] = st # 'at', not 'loc', is important for inserting list into cell
            else:
                for spiketrain in blk.segments[0].spiketrains:
                    discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == spiketrain.name), None)
                    if discriminator is None:
                        raise Exception(f'For data set "{data_set_name}", discriminator "{spiketrain.name}" could not be found')
                    st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                    st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                    if np.isfinite(st_epoch_start) and np.isfinite(st_epoch_end):
                        st = spiketrain.time_slice(st_epoch_start, st_epoch_end)
                    else:
                        # this unit's discriminator epoch was not located for this swallow
                        st = None
                    df.at[i, spiketrain.name+' spike train'] = st # 'at', not 'loc', is important for inserting list into cell



            ###
            ### QUANTIFY SPIKE TRAINS AND BURSTS
            ###

            for k, unit in enumerate(units):
                st = df.loc[i, unit+' spike train']
                if st is not None:
                    df.loc[i, unit+' spike count'] = st.size

                    # get the neural channel
                    channel = st.annotations['channels'][0]
                    sig = get_sig(blk, channel)

                    # create a continuous smoothed firing rate representation
                    # by convolving the spike train with a kernel
                    smoothing_kernel = elephant.kernels.GaussianKernel(0.2*pq.s) # 200 ms standard deviation
#                     smoothing_kernel = elephant.kernels.RectangularKernel(0.2*pq.s / (2*np.sqrt(3))) # 200 ms width, 2*sqrt(3) undoes elephant's scaling
                    if force_is_segmented:
                        # choice of t_start and t_stop here ensures firing rates are recorded as zero far from the burst
                        t_start = prev_force_plateau_start+0.001*pq.s
                        t_stop = next_force_rise_start-0.001*pq.s
                    else:
                        # force segmentation not available
                        t_start = behavior_start-5*pq.s
                        t_stop = behavior_end+5*pq.s
                    firing_rate = df.at[i, unit+' firing rate (Hz)'] = elephant.statistics.instantaneous_rate(
                        spiketrain=st,
                        t_start=t_start,
                        t_stop=t_stop,
                        sampling_period=sig.sampling_period,
                        kernel=smoothing_kernel,
                    ) # 'at', not 'loc', is important for inserting list into cell

                    # normalization
                    if force_is_segmented:
                        firing_rate_interp = df.at[i, unit+' firing rate, normalized time interpolation (Hz)'] = \
                            resample_sig_in_normalized_time(normalization_fixed_times, firing_rate) # 'at', not 'loc', is important for inserting list into cell

                    if st.size > 0:

                        # get the signal for the behavior with 10 seconds cushion before and after (for better baseline estimation)
                        sig = sig.time_slice(behavior_start - 10*pq.s, behavior_end + 10*pq.s)
                        sig = sig.rescale(channel_units[channel_names.index(channel)])

                        # find every sequence of spikes that qualifies as a burst
                        bursts = df.at[i, unit+' all bursts (s)'] = find_bursts(st, burst_thresholds[unit]) # 'at', not 'loc', is important for inserting list into cell

                        first_burst_start = np.nan
                        first_burst_end = np.nan
                        first_burst_spike_count = 0
                        first_burst_mean_freq = 0*pq.Hz
                        last_burst_start = np.nan
                        last_burst_end = np.nan
                        last_burst_spike_count = 0
                        last_burst_mean_freq = 0*pq.Hz
                        if len(bursts) > 0:

                            for burst in bursts:
                                if is_good_burst(burst):
                                    first_burst_start, first_burst_end = burst['Start (s)'], burst['End (s)']
                                    first_burst_duration = first_burst_end-first_burst_start
                                    df.loc[i, unit+' first burst start (s)'] = first_burst_start.rescale('s')
                                    df.loc[i, unit+' first burst end (s)'] = first_burst_end.rescale('s')
                                    first_burst_duration = df.loc[i, unit+' first burst duration (s)'] = first_burst_duration.rescale('s')
                                    first_burst_spike_count = df.loc[i, unit+' first burst spike count'] = st.time_slice(first_burst_start, first_burst_end).size
                                    first_burst_mean_freq = df.loc[i, unit+' first burst mean frequency (Hz)'] = ((first_burst_spike_count-1)/first_burst_duration).rescale('Hz')

                                    # find burst RAUC and mean voltage
                                    first_burst_rauc = df.loc[i, unit+' first burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=first_burst_start, t_stop=first_burst_end).rescale('uV*s')
                                    first_burst_mean_rect_voltage = df.loc[i, unit+' first burst mean rectified voltage (μV)'] = first_burst_rauc/first_burst_duration
                                    
                                    # normalization
                                    if force_is_segmented:
                                        df.loc[i, unit+' first burst start (normalized)'] = normalize_time(normalization_fixed_times.magnitude, float(first_burst_start.rescale('s')))

                                    break # quit after finding first good burst

                            for burst in reversed(bursts):
                                if is_good_burst(burst):
                                    last_burst_start, last_burst_end = burst['Start (s)'], burst['End (s)']
                                    last_burst_duration = last_burst_end-last_burst_start
                                    df.loc[i, unit+' last burst start (s)'] = last_burst_start.rescale('s')
                                    df.loc[i, unit+' last burst end (s)'] = last_burst_end.rescale('s')
                                    last_burst_duration = df.loc[i, unit+' last burst duration (s)'] = last_burst_duration.rescale('s')
                                    last_burst_spike_count = df.loc[i, unit+' last burst spike count'] = st.time_slice(last_burst_start, last_burst_end).size
                                    last_burst_mean_freq = df.loc[i, unit+' last burst mean frequency (Hz)'] = ((last_burst_spike_count-1)/last_burst_duration).rescale('Hz')

                                    # find burst RAUC and mean voltage
                                    last_burst_rauc = df.loc[i, unit+' last burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=last_burst_start, t_stop=last_burst_end).rescale('uV*s')
                                    last_burst_mean_rect_voltage = df.loc[i, unit+' last burst mean rectified voltage (μV)'] = last_burst_rauc/last_burst_duration

                                    # normalization
                                    if force_is_segmented:
                                        df.loc[i, unit+' last burst end (normalized)'] = normalize_time(normalization_fixed_times.magnitude, float(last_burst_end.rescale('s')))
                                    
                                    break # quit after finding first (actually, last) good burst



#             ###
#             ### TIMING DELAYS
#             ###

#             i2_burst_start      = df.loc[i, 'I2 spikes first burst start (s)']*pq.s
#             i2_burst_end        = df.loc[i, 'I2 spikes last burst end (s)']*pq.s
#             i2_burst_duration   = df.loc[i, 'I2 spikes all bursts duration (s)'] = i2_burst_end - i2_burst_start
#             b8_burst_start      = df.loc[i, 'B8a/b first burst start (s)']*pq.s
#             b8_burst_end        = df.loc[i, 'B8a/b last burst end (s)']*pq.s
#             b8_burst_duration   = df.loc[i, 'B8a/b all bursts duration (s)'] = b8_burst_end - b8_burst_start
#             b6b9_burst_start    = df.loc[i, 'B6/B9 first burst start (s)']*pq.s
#             b6b9_burst_end      = df.loc[i, 'B6/B9 last burst end (s)']*pq.s
#             b6b9_burst_duration = df.loc[i, 'B6/B9 all bursts duration (s)'] = b6b9_burst_end - b6b9_burst_start
#             b3_burst_start      = df.loc[i, 'B3 first burst start (s)']*pq.s
#             b3_burst_end        = df.loc[i, 'B3 last burst end (s)']*pq.s
#             b3_burst_duration   = df.loc[i, 'B3 all bursts duration (s)'] = b3_burst_end - b3_burst_start
#             b38_burst_start     = df.loc[i, 'B38 first burst start (s)']*pq.s
#             b38_burst_end       = df.loc[i, 'B38 last burst end (s)']*pq.s
#             b38_burst_duration  = df.loc[i, 'B38 all bursts duration (s)'] = b38_burst_end - b38_burst_start

#             df.loc[i, 'Next I2 spikes first burst start (s)'] = np.nan # will be set on next iteration
#             df.loc[i, 'Next I2 spikes last burst end (s)'] = np.nan # will be set on next iteration
#             df.loc[i, 'Next I2 spikes all bursts duration (s)'] = np.nan # will be set on next iteration
#             if j != 0:
#                 df.loc[df.index[j-1], 'Next I2 spikes first burst start (s)'] = i2_burst_start
#                 df.loc[df.index[j-1], 'Next I2 spikes last burst end (s)'] = i2_burst_end
#                 df.loc[df.index[j-1], 'Next I2 spikes all bursts duration (s)'] = i2_burst_end - i2_burst_start

#             # consider B3/B6/B9 bursting if either B3 or B6/B9 is bursting
#             b3b6b9_burst_start    = df.loc[i, 'B3/B6/B9 burst start (s)']    = finite_min(b6b9_burst_start, b3_burst_start)
#             b3b6b9_burst_end      = df.loc[i, 'B3/B6/B9 burst end (s)']      = finite_max(b6b9_burst_end,   b3_burst_end)
#             b3b6b9_burst_duration = df.loc[i, 'B3/B6/B9 burst duration (s)'] = b3b6b9_burst_end - b3b6b9_burst_start

#             # consider bursting only if B8a/b and B3/B6/B9 are both bursting
#             b8_or_b3b6b9_burst_end = df.loc[i, 'B8a/b and B3/B6/B9 conjunction end (s)'] = \
#                                                finite_min(b8_burst_end, b3b6b9_burst_end)

#             # delays from neural to force
#             i2_force_rise_start_delay        = df.loc[i, 'Delay from I2 end to force rise start (s)'] = \
#                                                          force_rise_start - i2_burst_end

#             b8_force_rise_start_delay        = df.loc[i, 'Delay from B8a/b start to force rise start (s)'] = \
#                                                          force_rise_start - b8_burst_start
#             b8_force_plateau_start_delay     = df.loc[i, 'Delay from B8a/b start to force plateau start (s)'] = \
#                                                          force_plateau_start - b8_burst_start
#             b8_force_plateau_end_delay       = df.loc[i, 'Delay from B8a/b end to force plateau end (s)'] = \
#                                                          force_plateau_end - b8_burst_end

#             b6b9_force_rise_start_delay      = df.loc[i, 'Delay from B6/B9 start to force rise start (s)'] = \
#                                                          force_rise_start - b6b9_burst_start
#             b6b9_force_plateau_start_delay   = df.loc[i, 'Delay from B6/B9 start to force plateau start (s)'] = \
#                                                          force_plateau_start - b6b9_burst_start
#             b6b9_force_plateau_end_delay     = df.loc[i, 'Delay from B6/B9 end to force plateau end (s)'] = \
#                                                          force_plateau_end - b6b9_burst_end

#             b3b6b9_force_rise_start_delay    = df.loc[i, 'Delay from B3/B6/B9 start to force rise start (s)'] = \
#                                                          force_rise_start - b3b6b9_burst_start
#             b3b6b9_force_plateau_start_delay = df.loc[i, 'Delay from B3/B6/B9 start to force plateau start (s)'] = \
#                                                          force_plateau_start - b3b6b9_burst_start
#             b3b6b9_force_plateau_end_delay   = df.loc[i, 'Delay from B3/B6/B9 end to force plateau end (s)'] = \
#                                                          force_plateau_end - b3b6b9_burst_end

#             b3_force_plateau_start_delay     = df.loc[i, 'Delay from B3 start to force plateau start (s)'] = \
#                                                          force_plateau_start - b3_burst_start
#             b3_force_plateau_end_delay       = df.loc[i, 'Delay from B3 end to force plateau end (s)'] = \
#                                                          force_plateau_end - b3_burst_end
#             b8_or_b3b6b9_force_plateau_end_delay = \
#                                                df.loc[i, 'Delay from either B8a/b or B3/B6/B9 end to force plateau end (s)'] = \
#                                                          force_plateau_end - b8_or_b3b6b9_burst_end

#             b38_force_shoulder_end_delay     = df.loc[i, 'Delay from B38 end to force shoulder end (s)'] = \
#                                                          force_shoulder_end - b38_burst_end



#             ###
#             ### B8 ACTIVITY BEFORE B3/B6/B9
#             ###

#             st = df.loc[i, 'B8a/b spike train']
#             if st is not None:
#                 b8_preb3b6b9_burst_duration    = df.loc[i, 'B8a/b pre-B3/B6/B9 burst duration (s)'] = \
#                                                            b3b6b9_burst_start - b8_burst_start
#                 b8_preb3b6b9_burst_spike_count = df.loc[i, 'B8a/b pre-B3/B6/B9 burst spike count'] = \
#                                                            st.time_slice(b8_burst_start, b3b6b9_burst_start).size
#                 b8_preb3b6b9_burst_mean_freq   = df.loc[i, 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)'] = \
#                                                            ((b8_preb3b6b9_burst_spike_count-1)/b8_preb3b6b9_burst_duration).rescale('Hz')

#                 # get the neural channel
#                 channel = st.annotations['channels'][0]
#                 sig = get_sig(blk, channel)

#                 # get the signal for the behavior with 5 seconds cushion before and after (for better baseline estimation)
#                 sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
#                 sig = sig.rescale('uV')

#                 # for B8a/b before B3/B6/B9 start...
#                 if np.isfinite(b8_burst_start) and np.isfinite(b3b6b9_burst_start):
#                     # find RAUC and mean voltage
#                     b8_preb3b6b9_rauc = df.loc[i, 'B8a/b pre-B3/B6/B9 burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=b8_burst_start, t_stop=b3b6b9_burst_start).rescale('uV*s')
#                     b8_preb3b6b9_mean_rect_voltage = df.loc[i, 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)'] = b8_preb3b6b9_rauc/b8_preb3b6b9_burst_duration

#                     # get the peak smoothed frequency
#                     firing_rate = elephant.statistics.instantaneous_rate(
#                         spiketrain=st,
#                         t_start=st.t_start,
#                         sampling_period=sig.sampling_period,
#                         kernel=smoothing_kernel,
#                     )
#                     b8_preb3b6b9_burst_peak_smoothed_freq = df.loc[i, 'B8a/b pre-B3/B6/B9 burst peak smoothed frequency (Hz)'] = \
#                                                                       firing_rate.time_slice(b8_burst_start, b3b6b9_burst_start).max().rescale('Hz')


#                     if force_is_segmented:
#                         # get force during rise and plateau
#                         sig = get_sig(blk, 'Force')
#                         sig = sig.time_slice(force_rise_start, force_plateau_end)
#                         sig = sig.rescale('mN')

#                         # get force at end of B8-only burst, offset by delay
#                         force_b8_only_rise_end = df.loc[i, 'Force delayed B8-only rise end (s)'] = force_rise_start + b8_preb3b6b9_burst_duration
#                         force_b8_only_rise_height = df.loc[i, 'Force at delayed B8-only rise end (mN)'] = sig[sig.time_index(force_b8_only_rise_end)][0]

#                         # find average slope during initial rising phase (before B3/B6/B9 begin, offset by delay)
#                         force_initial_increase = df.loc[i, 'Force initial increase (mN)'] = (force_b8_only_rise_height-force_baseline).rescale('mN')
#                         force_initial_slope = df.loc[i, 'Force initial slope (mN/s)'] = (force_initial_increase/b8_preb3b6b9_burst_duration).rescale('mN/s')



#         # perform the following after having gone through all behaviors once
#         for j, i in enumerate(df.index):

#             ###
#             ### NORMALIZED TIMES
#             ###

#             normalization_fixed_times = df.loc[i, 'Normalization fixed times (s)']
#             normalization_fixed_times = normalization_fixed_times.magnitude
#             force_rise_start, force_plateau_start, force_plateau_end, force_drop_end = normalization_fixed_times[[4, 5, 6, 7]]
#             force_is_segmented = np.all(np.isfinite(np.array([force_rise_start, force_plateau_start, force_plateau_end, force_drop_end])))

#             if force_is_segmented:

#                 i2_burst_start   = df.loc[i, 'I2 spikes first burst start (s)']
#                 i2_burst_end     = df.loc[i, 'I2 spikes last burst end (s)']
#                 b8_burst_start   = df.loc[i, 'B8a/b first burst start (s)']
#                 b8_burst_end     = df.loc[i, 'B8a/b last burst end (s)']
#                 b6b9_burst_start = df.loc[i, 'B6/B9 first burst start (s)']
#                 b6b9_burst_end   = df.loc[i, 'B6/B9 last burst end (s)']
#                 b3_burst_start   = df.loc[i, 'B3 first burst start (s)']
#                 b3_burst_end     = df.loc[i, 'B3 last burst end (s)']
#                 b38_burst_start  = df.loc[i, 'B38 first burst start (s)']
#                 b38_burst_end    = df.loc[i, 'B38 last burst end (s)']
#                 b4b5_burst_start = df.loc[i, 'B4/B5 first burst start (s)']
#                 b4b5_burst_end   = df.loc[i, 'B4/B5 last burst end (s)']
# #                 next_i2_burst_start = df.loc[i, 'Next I2 spikes first burst start (s)']
# #                 next_i2_burst_end   = df.loc[i, 'Next I2 spikes last burst end (s)']

#                 i2_burst_start_normalized      = df.loc[i, 'I2 first burst start (normalized)']      = normalize_time(normalization_fixed_times,
#                                                            i2_burst_start)
#                 i2_burst_end_normalized        = df.loc[i, 'I2 last burst end (normalized)']         = normalize_time(normalization_fixed_times,
#                                                            i2_burst_end)

#                 b8_burst_start_normalized      = df.loc[i, 'B8a/b first burst start (normalized)']   = normalize_time(normalization_fixed_times,
#                                                            b8_burst_start)
#                 b8_burst_end_normalized        = df.loc[i, 'B8a/b last burst end (normalized)']      = normalize_time(normalization_fixed_times,
#                                                            b8_burst_end)

#                 b6b9_burst_start_normalized    = df.loc[i, 'B6/B9 first burst start (normalized)']   = normalize_time(normalization_fixed_times,
#                                                            b6b9_burst_start)
#                 b6b9_burst_end_normalized      = df.loc[i, 'B6/B9 last burst end (normalized)']      = normalize_time(normalization_fixed_times,
#                                                            b6b9_burst_end)

#                 b3_burst_start_normalized      = df.loc[i, 'B3 first burst start (normalized)']      = normalize_time(normalization_fixed_times,
#                                                            b3_burst_start)
#                 b3_burst_end_normalized        = df.loc[i, 'B3 last burst end (normalized)']         = normalize_time(normalization_fixed_times,
#                                                            b3_burst_end)

#                 b38_burst_start_normalized     = df.loc[i, 'B38 first burst start (normalized)']     = normalize_time(normalization_fixed_times,
#                                                            b38_burst_start)
#                 b38_burst_end_normalized       = df.loc[i, 'B38 last burst end (normalized)']        = normalize_time(normalization_fixed_times,
#                                                            b38_burst_end)

#                 b4b5_burst_start_normalized    = df.loc[i, 'B4/B5 first burst start (normalized)']   = normalize_time(normalization_fixed_times,
#                                                            b4b5_burst_start)
#                 b4b5_burst_end_normalized      = df.loc[i, 'B4/B5 last burst end (normalized)']      = normalize_time(normalization_fixed_times,
#                                                            b4b5_burst_end)

# #                 next_i2_burst_start_normalized = df.loc[i, 'Next I2 first burst start (normalized)'] = normalize_time(normalization_fixed_times,
# #                                                            next_i2_burst_start)
# #                 next_i2_burst_end_normalized   = df.loc[i, 'Next I2 last burst end (normalized)']    = normalize_time(normalization_fixed_times,
# #                                                            next_i2_burst_end)



        ###
        ### FINISH
        ###

        # index the table on 4 variables so that this dataframe can later be merged with others
        df['Animal'] = animal
        df['Food'] = food
        df['Bout_index'] = bout_index
        df = df.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])

        df_list += [df]

    df_all = pd.concat(df_list, sort=False).sort_index()

    if exemplary_swallow in feeding_bouts:
        # move exemplar to separate dataframe
        df_exemplary_swallow = df_all.loc[exemplary_swallow].copy()
        df_all = df_all.drop(exemplary_swallow)

    if exemplary_bout in feeding_bouts:
        # move exemplar to separate dataframe
        df_exemplary_bout = df_all.loc[exemplary_bout].copy()
        df_all = df_all.drop(exemplary_bout)

    # save dataframes to files so that calculations can be skipped in the future
    for var in pickled_vars:
        filename = f'{var}.pickle'
        with open(filename, 'wb') as f:
            exec(f'pickle.dump({var}, f)')

    end = datetime.datetime.now()
    print('elapsed time:', end-start)

## 🤪 Sanity Checks

In [None]:
skip_sanity_checks = True

In [None]:
if not skip_sanity_checks:
    start = datetime.datetime.now()

    # reconstruct the original df_all, before exemplary_swallow and
    # exemplary_bout were removed
    df_list2 = [df_all]
    if exemplary_swallow in feeding_bouts:
        df_exemplary_swallow2 = df_exemplary_swallow.copy()
        df_exemplary_swallow2['Animal'] = exemplary_swallow[0]
        df_exemplary_swallow2['Food'] = exemplary_swallow[1]
        df_exemplary_swallow2['Bout_index'] = exemplary_swallow[2]
        df_exemplary_swallow2 = df_exemplary_swallow2.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
        df_list2 += [df_exemplary_swallow2]
    if exemplary_bout in feeding_bouts:
        df_exemplary_bout2 = df_exemplary_bout.copy()
        df_exemplary_bout2['Animal'] = exemplary_bout[0]
        df_exemplary_bout2['Food'] = exemplary_bout[1]
        df_exemplary_bout2['Bout_index'] = exemplary_bout[2]
        df_exemplary_bout2 = df_exemplary_bout2.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
        df_list2 += [df_exemplary_bout2]
    df_all2 = pd.concat(df_list2, sort=False).sort_index()

    # use Neo RawIO lazy loading to load much faster and using less memory
    # - with lazy=True, filtering parameters specified in metadata are ignored
    #     - note: filters are replaced below and applied manually anyway
    # - with lazy=True, loading via time_slice requires neo>=0.8.0
    lazy = True

    # load the metadata containing file paths
    metadata = neurotic.MetadataSelector(file='../../data/metadata.yml')

    last_data_set_name = None
    for (animal, food, bout_index), (data_set_name, time_window) in feeding_bouts.items():

        channel_names = channel_names_by_animal[animal]
        epoch_types = epoch_types_by_food[food]
        burst_thresholds = burst_thresholds_by_animal[animal]

        df = df_all2.loc[animal, food, bout_index]



        ###
        ### LOAD DATASET
        ###

        metadata.select(data_set_name)

        if data_set_name is last_data_set_name:
            # skip reloading the data if it's already in memory
            pass
        else:

            # ensure that the right filters are used
            metadata['filters'] = sig_filters_by_animal[animal]

            blk = neurotic.load_dataset(metadata, lazy=lazy)

            if lazy:
                # manually perform filters
                blk = apply_filters(blk, metadata)

        last_data_set_name = data_set_name



        ###
        ### START FIGURE
        ###

#             figsize = (9.5, 10) # dimensions for notebook
#             figsize = (11, 8.5) # dimensions for printing
        figsize = (16, 9) # dimensions for filling wide screens
        fig, axes = plt.subplots(len(channel_names), 1, sharex=True, figsize=figsize)



        ###
        ### PLOT SIGNALS
        ###

        # plot all channels for entire time window
        for i, channel in enumerate(channel_names):
            plt.sca(axes[i])
            sig = get_sig(blk, channel)
            sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
            sig = sig.rescale(channel_units[i])
            plt.plot(sig.times, sig.magnitude, c='0.8', lw=1, zorder=-1)

            if i == 0:
                plt.title(f'({animal}, {food}, {bout_index}): {data_set_name}')

            plt.ylabel(sig.name + ' (' + sig.units.dimensionality.string + ')')
            axes[i].yaxis.set_label_coords(-0.06, 0.5)

            if i < len(channel_names)-1:
                # remove right, top, and bottom plot borders, and remove x-axis
                sns.despine(ax=plt.gca(), bottom=True)
                plt.gca().xaxis.set_visible(False)
            else:
                # remove right and top plot borders, and set x-label
                sns.despine(ax=plt.gca())
                plt.xlabel('Time (s)')

        # plot smoothed force for entire time window
        channel = 'Force'
        sig = get_sig(blk, channel)
        if lazy:
            sig = sig.time_slice(None, None)
        sig = elephant.signal_processing.butter(sig, lowpass_freq = 5*pq.Hz)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale('mN')
        plt.plot(sig.times, sig.magnitude, c='k', lw=1, zorder=0)
        force_smoothed_sig = sig



        # iterate over all swallows
        for j, i in enumerate(df.index):

            behavior_start = df.loc[i, 'Start (s)']*pq.s
            behavior_end   = df.loc[i, 'End (s)']*pq.s
            
            ###
            ### MOVEMENTS
            ###
            
            # plot inward food movement
            inward_movement_start = df.loc[i, 'Inward movement start (s)']*pq.s
            inward_movement_end   = df.loc[i, 'Inward movement end (s)']*pq.s
            if np.isfinite(inward_movement_start):
                channel = 'Force'
                ax = axes[channel_names.index(channel)]
                ax.axvspan(
                    inward_movement_start, inward_movement_end,
                    0.99, 1,
                    facecolor='k', edgecolor=None, lw=0)
            
            

            ###
            ### FORCE SEGMENTATION
            ###

            force_shoulder_end  = df.loc[i, 'Force shoulder end start (s)']*pq.s  # start of "Force shoulder end" epoch
            force_rise_start    = df.loc[i, 'Force rise start start (s)']*pq.s    # start of "Force rise start" epoch
            force_plateau_start = df.loc[i, 'Force plateau start start (s)']*pq.s # start of "Force plateau start" epoch
            force_plateau_end   = df.loc[i, 'Force plateau end start (s)']*pq.s   # start of "Force plateau end" epoch
            force_drop_end      = df.loc[i, 'Force drop end start (s)']*pq.s      # start of "Force drop end" epoch

            # force rise start, plateau start and end, and drop end are required
            force_is_segmented = np.all(np.isfinite(np.array([
                force_rise_start, force_plateau_start, force_plateau_end, force_drop_end])))

            normalization_fixed_times = df.at[i, 'Normalization fixed times (s)']

            if force_is_segmented:

                force_min_time = df.loc[i, 'Force minimum time (s)']
                force_min = df.loc[i, 'Force minimum (mN)']
                force_peak_time = df.loc[i, 'Force peak time (s)']
                force_peak = df.loc[i, 'Force peak (mN)']
                force_baseline = df.loc[i, 'Force baseline (mN)']

                force_plateau_start_value = df.loc[i, 'Force plateau start value (mN)']
                force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)']
                force_drop_end_value = df.loc[i, 'Force drop end value (mN)']
                force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)']

                # get smoothed force for whole behavior for remaining force calculations
                sig = force_smoothed_sig
                if np.isfinite(force_shoulder_end):
                    sig = sig.time_slice(force_shoulder_end - 0.01*pq.s, force_drop_end + 0.01*pq.s)
                else:
                    sig = sig.time_slice(force_rise_start - 1*pq.s, force_drop_end + 0.01*pq.s)
                sig = sig.rescale('mN')

                # plot force rise in color
                plt.sca(axes[channel_names.index('Force')])
                sig2 = sig.time_slice(force_rise_start, force_plateau_start)
                plt.plot(sig2.times, sig2.magnitude, c=force_colors['rise'], lw=2, zorder=1)

                # plot force plateau in color
                sig2 = sig.time_slice(force_plateau_start, force_plateau_end)
                plt.plot(sig2.times, sig2.magnitude, c=force_colors['plateau'], lw=2, zorder=1)

                # plot force shoulder in color
                if np.isfinite(force_shoulder_end):
                    sig2 = sig.time_slice(force_drop_end, force_shoulder_end)
                    plt.plot(sig2.times, sig2.magnitude, c=force_colors['shoulder'], lw=2, zorder=1)

                # plot force peak, baseline, and plateau values
                plt.plot([force_peak_time],     [force_peak],                marker=CARETDOWN,  markersize=5, color='k')
#                 plt.plot([force_min_time],      [force_min],                 marker=CARETUP,    markersize=5, color='k')
                plt.plot([force_rise_start],    [force_baseline],            marker=CARETUP,    markersize=5, color='k')
                plt.plot([force_plateau_start], [force_plateau_start_value], marker=CARETRIGHT, markersize=5, color='k')
                plt.plot([force_plateau_end],   [force_plateau_end_value],   marker=CARETLEFT,  markersize=5, color='k')

                # plot segmentation boundaries across all subplots
                for (t, y, c) in [
                        (force_shoulder_end,  force_shoulder_end_value,  force_colors['shoulder']),
                        (force_rise_start,    force_baseline,            force_colors['rise']),
                        (force_plateau_start, force_plateau_start_value, force_colors['plateau']),
                        (force_plateau_end,   force_plateau_end_value,   force_colors['plateau']),
                        (force_drop_end,      force_drop_end_value,      force_colors['drop'])]:
                    if np.isfinite(y):
                        axes[-1].add_artist(patches.ConnectionPatch(
                            xyA=(t, y), xyB=(t, 1),
                            coordsA='data', coordsB=axes[0].get_xaxis_transform(),
                            axesA=axes[-1], axesB=axes[0],
                            color=c, lw=1, ls=':', zorder=-2))



            ###
            ### SPIKES AND BURSTS
            ###

            units = [
                'I2 spikes',
                'B8a/b',
                'B3',
                'B6/B9',
                'B38',
                'B4/B5',
            ]
            for k, unit in enumerate(units):
                st = df.loc[i, unit+' spike train']
                if st is not None and st.size > 0:

                    # get the signal for the behavior with 10 seconds cushion for spikes outside the behavior duration
                    channel = st.annotations['channels'][0]
                    sig = get_sig(blk, channel)
                    sig = sig.time_slice(behavior_start - 10*pq.s, behavior_end + 10*pq.s)
                    sig = sig.rescale(channel_units[channel_names.index(channel)])

                    # get every sequence of spikes that qualifies as a burst
                    bursts = df.at[i, unit+' all bursts (s)']

                    # find the first and last good bursts
                    first_burst_start = np.nan
                    first_burst_end = np.nan
                    last_burst_start = np.nan
                    last_burst_end = np.nan
                    if len(bursts) > 0:

                        for burst in bursts:
                            if is_good_burst(burst):
                                first_burst_start, first_burst_end = burst['Start (s)'], burst['End (s)']
                                break # quit after finding first good burst

                        for burst in reversed(bursts):
                            if is_good_burst(burst):
                                last_burst_start, last_burst_end = burst['Start (s)'], burst['End (s)']
                                break # quit after finding first (actually, last) good burst

                    # plot spikes
                    plt.sca(axes[channel_names.index(channel)])
                    marker = ['.', 'x'][j%2] # alternate markers between behaviors
                    spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
                    plt.scatter(st.times.rescale('s'), spike_amplitudes, marker=marker, c=unit_colors[unit])

                    # plot burst windows
                    discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == st.name), None)
                    if discriminator is None:
                        raise Exception(f'For data set "{data_set_name}", discriminator "{st.name}" could not be found')
                    bottom = pq.Quantity(discriminator['amplitude'][0], discriminator['units']).rescale(sig.units)
                    top = pq.Quantity(discriminator['amplitude'][1], discriminator['units']).rescale(sig.units)
                    height = top-bottom
                    for burst in bursts:
                        left = burst['Start (s)']
                        right = burst['End (s)']
                        width = right-left
                        linestyle = '-' if is_good_burst(burst) else '--'
                        rect = patches.Rectangle((left, bottom), width, height, ls=linestyle, edgecolor=unit_colors[unit], fill=False)
                        plt.gca().add_patch(rect)

                    # plot markers for edges of bursts
                    if top > 0:
                        plt.plot([first_burst_start], [top],    marker=CARETDOWN, markersize=5, color='k')
                        plt.plot([last_burst_end],    [top],    marker=CARETDOWN, markersize=5, color='k')
                    else:
                        plt.plot([first_burst_start], [bottom], marker=CARETUP,   markersize=5, color='k')
                        plt.plot([last_burst_end],    [bottom], marker=CARETUP,   markersize=5, color='k')



        ###
        ### FINISH FIGURE
        ###

        # optimize plot margins
        plt.subplots_adjust(
            left   = 0.1,
            right  = 0.99,
            top    = 0.96,
            bottom = 0.06,
            hspace = 0.15,
        )

        # export figure
        export_dir2 = os.path.join(export_dir, 'sanity-checks')
        if not os.path.exists(export_dir2):
            os.mkdir(export_dir2)
        plt.gcf().savefig(os.path.join(export_dir2, f'{animal} {food} {bout_index}.png'), dpi=300)

    if exemplary_swallow in feeding_bouts:
        # rename output file for exemplar
        animal, food, bout_index = exemplary_swallow
        old_path = os.path.join(export_dir2, f'{animal} {food} {bout_index}.png')
        new_path = os.path.join(export_dir2, 'exemplary_swallow.png')
        if os.path.exists(old_path):
            if os.path.exists(new_path):
                os.remove(new_path)
            os.rename(old_path, new_path)

    if exemplary_bout in feeding_bouts:
        # rename output file for exemplar
        animal, food, bout_index = exemplary_bout
        old_path = os.path.join(export_dir2, f'{animal} {food} {bout_index}.png')
        new_path = os.path.join(export_dir2, 'exemplary_bout.png')
        if os.path.exists(old_path):
            if os.path.exists(new_path):
                os.remove(new_path)
            os.rename(old_path, new_path)


    del df_all2, df_list2

    end = datetime.datetime.now()
    print('elapsed time:', end-start)

In [None]:
if not skip_sanity_checks:
    start = datetime.datetime.now()

    for (animal, food, bout_index), (data_set_name, time_window) in feeding_bouts.items():

        channel_names = channel_names_by_animal[animal]

        if (animal, food, bout_index) == exemplary_swallow:
            # skip the lone swallow
            continue
        elif (animal, food, bout_index) == exemplary_bout:
            df = df_exemplary_bout
        else:
            df = df_all.loc[animal, food, bout_index]

        # load the data
        metadata = neurotic.MetadataSelector('../../data/metadata.yml')
        metadata.select(data_set_name)
        blk = neurotic.load_dataset(metadata, lazy=True)

        units = [
            'I2 spikes',
            'B8a/b',
            'B3',
            'B6/B9',
            'B38',
            'B4/B5',
        ]

        # figsize = (9.5, 10) # dimensions for notebook
        # figsize = (11, 8.5) # dimensions for printing
        figsize = (16, 9) # dimensions for filling wide screens
        fig, axes = plt.subplots(len(units)+1, 2, sharex='col', figsize=figsize)

        for k, unit in enumerate(units):
            # get the subplot axes handles
            ax_left, ax_right = axes[k]

            # set y-axis label
            ax_left.set_ylabel(unit+' (Hz)')
            ax_left.yaxis.set_label_coords(-0.06, 0.5)

            # remove right, top, and bottom plot borders, and remove x-axis
            sns.despine(ax=ax_left, bottom=True)
            sns.despine(ax=ax_right, bottom=True)
            ax_left.xaxis.set_visible(False)
            ax_right.xaxis.set_visible(False)

            # set normalized time plot range
            ax_right.set_xlim([0, 9])

            # elevate the Axes for units and remove background colors so that
            # each vertical ConnectionPatch drawn later is visible behind it
            ax_left.set_zorder(1)
            ax_right.set_zorder(1)
            ax_left.set_facecolor('none')
            ax_right.set_facecolor('none')

        # remove right and top plot borders from bottom panel, and set x-label
        ax_left, ax_right = axes[-1]
        sns.despine(ax=ax_left)
        sns.despine(ax=ax_right)
        ax_left.set_xlabel('Time (s)')
        ax_right.set_xlabel('Time (normalized)')

        # set normalized time plot range
        ax_right.set_xlim([0, 9])

        # plot force in real time
        ax_left, ax_right = axes[-1]
        channel = 'Force'
        sig = get_sig(blk, channel)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale(channel_units[channel_names.index(channel)])
        ax_left.plot(sig.times, sig.magnitude, c='0.8', lw=1)
        ax_left.set_ylabel(sig.name + ' (' + sig.units.dimensionality.string + ')')
        ax_left.yaxis.set_label_coords(-0.06, 0.5)

        all_normalized_times_series = {}
        for unit in units:
            all_normalized_times_series[unit] = np.zeros((0, times_interp.size))
        all_normalized_times_series['Force'] = np.zeros((0, times_interp.size))

        for j, i in enumerate(df.index):
            
            
            # plot inward food movement
            inward_movement_start = df.loc[i, 'Inward movement start (s)']
            inward_movement_end   = df.loc[i, 'Inward movement end (s)']
            if np.isfinite(inward_movement_start):
                ax_left, ax_right = axes[-1]
                ax_left.axvspan(
                    inward_movement_start, inward_movement_end,
                    0.98, 1,
                    facecolor='k', edgecolor=None, lw=0)
            
            
            for k, unit in enumerate(units):
                ax_left, ax_right = axes[k]


                # raster plot
                st = df.loc[i, unit+' spike train']
                if st is not None:
                    ax_left.eventplot(positions=st, lineoffsets=-1, colors=unit_colors[unit])


                # plot the firing rates in real time
                firing_rate = df.loc[i, unit+' firing rate (Hz)']
                if firing_rate is not None:
                    ax_left.plot(firing_rate.times.rescale('s'), firing_rate, c=unit_colors[unit])


                # plot firing rates in normalized time
                firing_rate_interp = df.loc[i, unit+' firing rate, normalized time interpolation (Hz)']
                if firing_rate_interp is not None:
                    all_normalized_times_series[unit] = np.concatenate([all_normalized_times_series[unit], firing_rate_interp[np.newaxis, :]])
                    ax_right.plot(times_interp, firing_rate_interp, c=lighten_color(unit_colors[unit], amount=0.7))


            # plot force in normalized time
            ax_left, ax_right = axes[-1]
            force_interp = df.at[i, 'Force, normalized time interpolation (mN)']
            if force_interp is not None:
                all_normalized_times_series['Force'] = np.concatenate([all_normalized_times_series['Force'], force_interp[np.newaxis, :]])
                ax_right.plot(times_interp, force_interp, c='0.8', lw=1)


            # plot force phase boundaries in real time
            normalization_fixed_times = df.at[i, 'Normalization fixed times (s)'].rescale('s').magnitude
            for m, t in enumerate(normalization_fixed_times[3:8]): # 3 = end of shoulder, 8 = end of next shoulder
                if np.isfinite(t):
                    if m == 1: # 1 = start of rise
                        color = force_colors['rise']
                    else:
                        color = '0.75'
                    axes[-1][0].add_artist(patches.ConnectionPatch(
                        xyA=(t, 0), xyB=(t, 1),
                        coordsA=axes[-1][0].get_xaxis_transform(), coordsB=axes[0][0].get_xaxis_transform(),
                        axesA=axes[-1][0], axesB=axes[0][0],
                        color=color, lw=1, ls=':'))


        # plot force phase boundaries in normalized time
        for m in range(len(normalization_fixed_times)):
            if m == 4: # 4 = start of rise
                color = force_colors['rise']
            else:
                color = '0.75'
            axes[-1][1].add_artist(patches.ConnectionPatch(
                xyA=(m, 0), xyB=(m, 1),
                coordsA=axes[-1][1].get_xaxis_transform(), coordsB=axes[0][1].get_xaxis_transform(),
                axesA=axes[-1][1], axesB=axes[0][1],
                color=color, lw=1, ls=':'))


        # plot firing rate distributions
        for k, unit in enumerate(units):
            ax_left, ax_right = axes[k]

            firing_rate_median = np.nanmedian(all_normalized_times_series[unit], axis=0)
            firing_rate_q1 = np.nanquantile(all_normalized_times_series[unit], q=0.25, axis=0)
            firing_rate_q3 = np.nanquantile(all_normalized_times_series[unit], q=0.75, axis=0)
            ax_right.plot(times_interp, firing_rate_median, c='k', lw=2, zorder=3)
            ax_right.plot(times_interp, firing_rate_q1, c='gray', lw=1)
            ax_right.plot(times_interp, firing_rate_q3, c='gray', lw=1)


        # plot force distribution
        ax_left, ax_right = axes[-1]

        force_median = np.nanmedian(all_normalized_times_series['Force'], axis=0)
        force_q1 = np.nanquantile(all_normalized_times_series['Force'], q=0.25, axis=0)
        force_q3 = np.nanquantile(all_normalized_times_series['Force'], q=0.75, axis=0)
        ax_right.plot(times_interp, force_median, c='k', lw=2, zorder=3)
        ax_right.plot(times_interp, force_q1, c='gray', lw=1)
        ax_right.plot(times_interp, force_q3, c='gray', lw=1)


        plt.suptitle(f'({animal}, {food}, {bout_index}): {data_set_name}')
        plt.tight_layout(rect=(0, 0, 1, 0.97))

        # export figure
        export_dir3 = os.path.join(export_dir, 'sanity-checks-firing-rates')
        if not os.path.exists(export_dir3):
            os.mkdir(export_dir3)
        plt.gcf().savefig(os.path.join(export_dir3, f'{animal} {food} {bout_index}.png'), dpi=300)

    if exemplary_bout in feeding_bouts:
        # rename output file for exemplar
        animal, food, bout_index = exemplary_bout
        old_path = os.path.join(export_dir3, f'{animal} {food} {bout_index}.png')
        new_path = os.path.join(export_dir3, 'exemplary_bout.png')
        if os.path.exists(old_path):
            if os.path.exists(new_path):
                os.remove(new_path)
            os.rename(old_path, new_path)

    end = datetime.datetime.now()
    print('elapsed time:', end-start)

## Plotting Functions

In [None]:
def boxplot_with_points(x, y, hue, data, ax=None, show_points=False, describe=True):
    if ax is None:
        ax = plt.gca()
    boxcolor = '0.75' if hue is None else None
    pointcolor = 'k' if hue is None else None
    edgecolor = '0.25'
    linewidth = 0 if hue is None else 1
    size = 4
    
    data = data.dropna(subset=[y])
    
    sns.boxplot(x=x, y=y, hue=hue, data=data, ax=ax, color=boxcolor, whis=999) # whiskers span extrema
    
    if show_points:
        sns.swarmplot(x=x, y=y, hue=hue, data=data, ax=ax, color=pointcolor, linewidth=linewidth, edgecolor=edgecolor, size=size, dodge=True)
        
        if hue is not None:
            # avoid duplicate legend entries
            handles, labels = ax.get_legend_handles_labels()
            n = int(len(labels)/2)
            ax.legend(handles[:n], labels[:n], title=hue)
    
    ax.set_xlabel(None)

    if describe:
        by = [x] if hue is None else [x, hue]
        print(y)
#         print(data.groupby(by)[y].describe())
        print(data.groupby(by)[y].apply(lambda y: {
            'N': f'{y.count()}',
            'Median': y.median(),
#             'Q1': y.quantile(0.25),
#             'Q3': y.quantile(0.75),
        }).unstack())
        print()

In [None]:
default_markers = ['.', '+', 'x', '1', '4', 'o', 'D']
default_colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6']

def scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, trend_separately=False, tooltips=False, padding=0.05, colors=default_colors, markers=default_markers):
    
    for j, (label, query) in enumerate(data_subsets.items()):
        if query is not None:
            df = df_all.query(query)
            ax.scatter(df[xlabel], df[ylabel],
                       label=label, marker=markers[j], c=colors[j], clip_on=False)
            
    all_points = df_all.query(query_union(data_subsets.values()))[[xlabel, ylabel]].dropna()
    
    if len(all_points) > 0:

        xrange = np.ptp(all_points.iloc[:, 0])
        xmin = min(all_points.iloc[:, 0]) - xrange * padding
        xmax = max(all_points.iloc[:, 0]) + xrange * padding
        if xlim is None:
            xlim = [xmin, xmax]
        if xlim[0] is None:
            xlim[0] = xmin
        if xlim[1] is None:
            xlim[1] = xmax

        yrange = np.ptp(all_points.iloc[:, 1])
        ymin = min(all_points.iloc[:, 1]) - yrange * padding
        ymax = max(all_points.iloc[:, 1]) + yrange * padding
        if ylim is None:
            ylim = [ymin, ymax]
        if ylim[0] is None:
            ylim[0] = ymin
        if ylim[1] is None:
            ylim[1] = ymax
    
    if trend_separately:
        for j, (label, query) in enumerate(data_subsets.items()):
            if query is not None:
                df = df_all.query(query)[[xlabel, ylabel]].dropna()
                model = sm.OLS(df.iloc[:,1], sm.add_constant(df.iloc[:,0])).fit()
                model_stats = 'R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(df))
                print(label+':', model_stats)
                model_x = np.linspace(xlim[0], xlim[1], 100)
                model_y = model.params[0] + model.params[1] * model_x
                ax.plot(model_x, model_y, color=colors[j])#, label=model_stats)
                
    if trend:
        model = sm.OLS(all_points.iloc[:,1], sm.add_constant(all_points.iloc[:,0])).fit()
        model_stats = 'R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(all_points))
        print('All points:', model_stats)
        model_x = np.linspace(xlim[0], xlim[1], 100)
        model_y = model.params[0] + model.params[1] * model_x
        ax.plot(model_x, model_y, color='gray')#, label=model_stats)
    
    if tooltips:
        # create a transparent layer containing all points for detecting mouse events
        sc = ax.scatter(all_points.iloc[:, 0], all_points.iloc[:, 1], label=None,
                        alpha=0, # transparent points
                        s=0.001, # small radius of tooltip activation
                       )
        
        # initialize a hidden empty tooltip
        annot = ax.annotate(
            '', xy=(0, 0),                              # initialize empty at origin
            xytext=(0, 15), textcoords='offset points', # position text above point
            color='k', size=10, ha='center',            # small black centered text
            bbox=dict(fc='w', lw=0, alpha=0.6),         # transparent white background
            arrowprops=dict(shrink=0, headlength=7, headwidth=7, width=0, lw=0, color='k'), # small black arrow
        )
        annot.set_visible(False)
        
        # prepare tooltip contents
        tooltip_text = [' '.join([str(x) for x in index]) for index in list(all_points.index)]
        
        # bind the animation of tooltips to a hover event
        def hover(event):
            if event.inaxes == ax:
                cont, ind = sc.contains(event)
                if cont:
                    point_index = ind['ind'][0]
                    pos = sc.get_offsets()[point_index]
                    annot.xy = pos
                    annot.set_text(tooltip_text[point_index])
                    annot.set_visible(True)
                    ax.figure.canvas.draw_idle()
                else:
                    if annot.get_visible():
                        annot.set_visible(False)
                        ax.figure.canvas.draw_idle()
        ax.figure.canvas.mpl_connect('motion_notify_event', hover)
        
    ax.set_xlabel(xlabel)
    ax.set_xlim(xlim)
    ax.set_ylabel(ylabel)#.replace('rectified ',''))
    ax.set_ylim(ylim)

In [None]:
from matplotlib.offsetbox import AnchoredOffsetbox
class AnchoredScaleBar(AnchoredOffsetbox):
    def __init__(self, transform, sizex=0, sizey=0, labelx=None, labely=None, loc=4,
                 pad=0.1, borderpad=0.1, sep=2, prop=None, barcolor="black", barwidth=None, 
                 **kwargs):
        """
        Draw a horizontal and/or vertical  bar with the size in data coordinate
        of the give axes. A label will be drawn underneath (center-aligned).
        - transform : the coordinate frame (typically axes.transData)
        - sizex,sizey : width of x,y bar, in data units. 0 to omit
        - labelx,labely : labels for x,y bars; None to omit
        - loc : position in containing axes
        - pad, borderpad : padding, in fraction of the legend font size (or prop)
        - sep : separation between labels and bars in points.
        - **kwargs : additional arguments passed to base class constructor
        """
        from matplotlib.patches import Rectangle
        from matplotlib.offsetbox import AuxTransformBox, VPacker, HPacker, TextArea, DrawingArea
        bars = AuxTransformBox(transform)
        if sizex:
#             bars.add_artist(Rectangle((0,0), sizex, 0, ec=barcolor, lw=barwidth, fc="none"))
            bars.add_artist(Rectangle((0,0), -sizex, 0, ec=barcolor, lw=barwidth, fc="none"))
        if sizey:
            bars.add_artist(Rectangle((0,0), 0, sizey, ec=barcolor, lw=barwidth, fc="none"))

        if sizex and labelx:
            self.xlabel = TextArea(labelx, minimumdescent=False)
            bars = VPacker(children=[bars, self.xlabel], align="center", pad=0, sep=sep)
        if sizey and labely:
            self.ylabel = TextArea(labely)
#             bars = HPacker(children=[self.ylabel, bars], align="center", pad=0, sep=sep)
            bars = HPacker(children=[bars, self.ylabel], align="center", pad=0, sep=sep)

        AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad,
                                   child=bars, prop=prop, frameon=False, **kwargs)

In [None]:
def prettyplot_with_scalebars(
    blk,
    t_start,
    t_stop,
    plots,
    
    outfile_basename=None, # base name of output files
    export_only=False,     # if True, will not render in notebook
    formats=['pdf', 'svg', 'png'], # extensions of output files
    dpi=300,               # resolution (applicable only for PNG)
    
    figsize=(14, 7),       # figure size in inches
    linewidth=1,           # thickness of lines in points
    layout_settings=None,  # positioning of plot edges and the space between plots
    
    x_scalebar=1*pq.s,     # size of the time scale bar in seconds
    ylabel_padding=10,     # space between trace labels and plots
    scalebar_padding=1,    # space between scale bars and plots
    scalebar_sep=5,        # space between scale bars and scale labels
    barwidth=2,            # thickness of scale bars
):
    
    if export_only:
        plt.ioff()
        
    fig, axes = plt.subplots(len(plots), 1, sharex=True, figsize=figsize, squeeze=False)
    axes = axes.flatten()

    for i, p in enumerate(plots):

        # get the subplot axes handle
        ax = axes[i]

        # select and rescale a channel for the subplot
        sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == p['channel']), None)
        assert sig is not None, f"Signal with name {p['channel']} not found"
        sig = sig.time_slice(t_start, t_stop)
        sig = sig.rescale(p['units'])

        # downsample the data
        sig_downsampled = DownsampleNeoSignal(sig, p.get('decimation_factor', 1))

        # specify the x- and y-data for the subplot
        ax.plot(
            sig_downsampled.times,
            sig_downsampled.as_quantity(),
            linewidth=linewidth,
            color=p.get('color', 'k'),
        )
        
        # hide the box around the subplot
        ax.set_frame_on(False)

        # specify the y-axis label
        ylabel = p.get('ylabel', sig.name)
        if ylabel is not None:
            ax.set_ylabel(ylabel, rotation='horizontal', ha='right', va='center', labelpad=ylabel_padding)

        # specify the plot range
        ax.set_xlim([t_start, t_stop])
        ax.set_ylim(p['ylim'])

        # disable tick marks
        ax.tick_params(
            bottom=False,
            left=False,
            labelbottom=False,
            labelleft=False)

        # add y-axis scale bar
        if p['scalebar'] is not None:
            ax.add_artist(AnchoredScaleBar(
                ax.transData,
                sizey=p['scalebar'],
                labely=f'{p["scalebar"]} {sig.units.dimensionality.string}',

                loc='center left',
                bbox_to_anchor=(1, 0.5),
                bbox_transform=ax.transAxes,

                pad=0,
                borderpad=scalebar_padding,
                sep=scalebar_sep,
                barwidth=barwidth,
            ))
        
    # add time scale bar below final plot
    if x_scalebar is not None:
        axes[-1].add_artist(AnchoredScaleBar(
            axes[-1].transData,
            sizex=x_scalebar.rescale(sig.times.units).magnitude,
            labelx=f'{x_scalebar.magnitude:g} {x_scalebar.units.dimensionality.string}',

            loc='upper right',
            bbox_to_anchor=(1, 0),
            bbox_transform=axes[-1].transAxes,

            pad=0,
            borderpad=scalebar_padding,
            sep=scalebar_sep,
            barwidth=barwidth,
        ))

    # adjust the white space around and between the subplots
    if layout_settings is None:
        fig.tight_layout(h_pad=0, w_pad=0, pad=0)
    else:
        plt.subplots_adjust(**layout_settings)

    if outfile_basename is not None:
        # specify file metadata (applicable only for PDF)
        metadata = dict(
            Subject = 'Data file: '  + blk.file_origin + '\n' +
                      'Start time: ' + str(t_start)    + '\n' +
                      'End time: '   + str(t_stop),
        )

        # write the figure to files
        for ext in formats:
            fig.savefig(outfile_basename+'.'+ext, metadata=metadata, dpi=dpi)

    if export_only:
        plt.ion()
    
    return fig, axes

---

# Figures

## [FIGURE 1]

### 🐌 Figure 1A

In [None]:
start = datetime.datetime.now()

df = df_exemplary_swallow
(data_set_name, time_window) = feeding_bouts[exemplary_swallow]
animal, _, _ = exemplary_swallow

t_start, t_stop = time_window*pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -35,  35], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -70,  70], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [-100, 400], 'scalebar': 200}, #, 'decimation_factor': 100},
]
plot_names = [p['channel'] for p in plots]
plot_units = [p['units'] for p in plots]

kwargs = dict(
    figsize = (7, 9),
    linewidth = 0.5,
    x_scalebar = 1*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
    'I2 spikes',
    'B8a/b',
    'B6/B9',
    'B3',
    'B38',
    'B4/B5',
]
unit_burst_boxes = {
    'I2 spikes': [-30, 25],
    'B8a/b':     [-20, 12],
    'B6/B9':     [-20, 15],
    'B3':        [-45, 35],
    'B38':       [-12, 12],
    'B4/B5':     [-60, 55],
}

for j, i in enumerate(df.index):
    for k, unit in enumerate(units):
        st = df.loc[i, unit+' spike train']
        if st is not None and st.size > 0:

            # get the neural channel
            channel = st.annotations['channels'][0]
            sig = get_sig(blk, channel)

            # get the signal for the entire bout
            sig = sig.time_slice(t_start, t_stop)
            sig = sig.rescale(plot_units[plot_names.index(channel)])

            # plot spikes
            ax = axes[plot_names.index(channel)]
            spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
            ax.scatter(st.times.rescale('s'), spike_amplitudes, marker='.', s=20, c=unit_colors[unit], zorder=3)

            # plot burst windows
            bursts = df.at[i, unit+' all bursts (s)']
            bottom, top = unit_burst_boxes[unit]
            height = top-bottom
            left = right = np.nan
            for burst in bursts:
                if is_good_burst(burst):
                    if np.isnan(left) or burst['End (s)'] < left:
                        left = burst['Start (s)']
                    if np.isnan(right) or burst['End (s)'] > right:
                        right = burst['End (s)']
            if np.isfinite(left) and np.isfinite(right):
                width = right-left
                rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor=unit_colors[unit], fill=False, zorder=3, clip_on=False)
                ax.add_patch(rect)

# plot a zero baseline for force
force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
axes[-1].axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

# plot force phase boundaries
times = df.loc[0, 'Normalization fixed times (s)'][2:2+6]
for t in times:
    axes[-1].axvline(x=t, ymin=0.15, lw=1, ls=':', c='gray', zorder=-1, clip_on=False)
#     axes[-1].add_artist(patches.ConnectionPatch(
#         xyA=(t, 0), xyB=(t, 1),
#         coordsA=axes[-1].get_xaxis_transform(), coordsB=axes[0].get_xaxis_transform(),
#         axesA=axes[-1], axesB=axes[0],
#         color='gray', lw=1, ls=':', zorder=-1))

# add arabic numerals for force phase boundaries
axes[-1].annotate('1', xy=(times[0], 0), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
axes[-1].annotate('2', xy=(times[1], 0), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
axes[-1].annotate('3', xy=(times[2], 0), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
axes[-1].annotate('4', xy=(times[3], 0), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
axes[-1].annotate('5', xy=(times[4], 0), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
axes[-1].annotate('1', xy=(times[5], 0), xycoords=('data', 'axes fraction'), ha='center', va='bottom')

# add roman numerals for force phases
axes[-1].annotate('I',   xy=(times[0:2].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')
axes[-1].annotate('II',  xy=(times[1:3].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')
axes[-1].annotate('III', xy=(times[2:4].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')
axes[-1].annotate('IV',  xy=(times[3:5].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')
axes[-1].annotate('V',   xy=(times[4:6].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')

# add unit name labels
axes[0].annotate('I2',    xy=(2979.40, 0.78), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['I2 spikes'])
axes[1].annotate('B8a/b', xy=(2981.10, 0.80), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B8a/b'])
axes[2].annotate('B6/B9', xy=(2980.04, 0.73), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['B6/B9'])
axes[2].annotate('B3',    xy=(2981.85, 0.82), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['B3'])
axes[2].annotate('B38',   xy=(2978.15, 0.70), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B38'])
axes[3].annotate('B4/B5', xy=(2979.10, 0.82), xycoords=('data', 'axes fraction'), ha='right',  c=unit_colors['B4/B5'])

# add markers for video frame times
video_times = {
    'B': 2979.8,
    'C': 2982.5,
}
for label, video_time in video_times.items():
    axes[-1].plot([video_time], [-0.01], marker=CARETUP, markersize=8, color='k', transform=axes[-1].get_xaxis_transform(), clip_on=False)
    axes[-1].annotate(label, xy=(video_time, -0.1), xycoords=('data', 'axes fraction'), ha='center', va='top')

# add protraction box
left, right = df.loc[0, ['I2 spikes first burst start (s)', 'I2 spikes last burst end (s)']]
bottom, top = (1, 1.15)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', fill=False, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)

# add retraction box
left, right = df.loc[0, ['I2 spikes last burst end (s)', 'End (s)']] # behavior ends with end of B43 burst
bottom, top = (1, 1.15)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', facecolor='k', fill=True, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)

fig.tight_layout(h_pad=0, w_pad=0, pad=0)

fig.savefig(os.path.join(export_dir, 'figure-1A.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figures 1B & 1C

Video frames (see code for Figure 1A for times of video frames)

## [FIGURE 2]

### 🐌 Figure 2A

NOTE: The plot output by this code is much longer than it will be in the final figure and must be cropped manually. It is rendered with the same amount of time showing as Fig 2B so that time scales are identical.

In [None]:
start = datetime.datetime.now()

data_set_name = 'IN VIVO / JG12 / 2019-05-10 / 002'
animal = 'JG12'
t_start, t_stop = [223.4, 391.1] * pq.s # t=273.71, twidth=167.7, 1 bite + a few regular nori strip swallows
t_crop = 264.4 * pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 400},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 400},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 400},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 400},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -50, 450], 'scalebar': 300}, #, 'decimation_factor': 100},
]

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (14, 5),
    linewidth = 0.5,
    x_scalebar = 10*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# zero the signals after the time where the figure should be manually cropped
for i, sig in enumerate(blk.segments[0].analogsignals):
    sig = sig.time_slice(t_start, t_stop) # load needed if lazy
    sig[sig.time_index(t_crop):] = 0*sig.units
    blk.segments[0].analogsignals[i] = sig

# plot the data
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

fig.savefig(os.path.join(export_dir, 'figure-2A-needs-cropped.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 2B

In [None]:
start = datetime.datetime.now()

data_set_name = 'IN VIVO / JG12 / 2019-05-10 / 002'
animal = 'JG12'
t_start, t_stop = [2875.3, 3043] * pq.s # t=2925.61, twidth=167.7, 1 bite + 19 tape nori swallows
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 400},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 400},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 400},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 400},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -50, 450], 'scalebar': 300}, #, 'decimation_factor': 100},
]

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (14, 5),
    linewidth = 0.5,
    x_scalebar = 10*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the data
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

# plot a zero baseline for force
ax = axes[-1]
force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
ax.axhline(force_zero, color='0.5', lw=1, ls='--', zorder=-1)

# plot a gray rectangle highlighting portion of the record expanded in other figures
exemplar_plot_range = feeding_bouts[exemplary_swallow][1]
axes[-1].set_zorder(-1) # needed for trick below
axes[-1].axvspan(
    exemplar_plot_range[0], exemplar_plot_range[1],
    0, 999, clip_on=False, # trick for getting to span entire figure
    facecolor='0.8', edgecolor=None, lw=0, zorder=-2)

fig.savefig(os.path.join(export_dir, 'figure-2B.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 2C

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 5))

df = df_all.reset_index()
df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'

x = 'Food'
y = 'Start to next start (s)'
hue = 'Animal'

sns.stripplot(
    x=x, y=y, hue=hue,
    data=df.groupby(['Animal', 'Food'])[y].mean().reset_index(),
    order=['Unloaded', 'Loaded'],
    jitter=False,
    color='0.25',
)
ax.legend_.remove()

ax.plot([
    df.query('Food == "Unloaded"').groupby('Animal')[y].mean(),
    df.query('Food == "Loaded"').groupby('Animal')[y].mean()
], color='0.5')

ax.set_xlim([-0.5, 1.5])
ax.set_ylim([0, 10])
ax.set_xlabel(None)
ax.set_ylabel('Time from one motor pattern to the next (s)')
plt.tight_layout()

plt.gcf().savefig(os.path.join(export_dir, 'figure-2C.png'), dpi=300)

df.groupby(['Animal', 'Food'])[y].mean()

In [None]:
y = 'Start to next start (s)'
y_tape_nori = df_all.query('Food == "Tape nori"').groupby('Animal')[y].mean()
y_reg_nori = df_all.query('Food == "Regular nori"').groupby('Animal')[y].mean()

differences_test(y_tape_nori, y_reg_nori)

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))

# df = df_all.reset_index()
# df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
# df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'

# x = 'Food'
# y = 'Start to next start (s)'
# hue = 'Animal'
# show_points = True

# boxplot_with_points(x, y, hue, df, ax, show_points)
# ax.set_ylim([0, 10])
# ax.set_ylabel('Time from one motor pattern to the next (s)')
# plt.tight_layout()

### 🐌 Figure 2D

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 5))

df = df_all.reset_index()
df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'

x = 'Food'
y = 'Duration (s)'
hue = 'Animal'

sns.stripplot(
    x=x, y=y, hue=hue,
    data=df.groupby(['Animal', 'Food'])[y].mean().reset_index(),
    order=['Unloaded', 'Loaded'],
    jitter=False,
    color='0.25',
)
ax.legend_.remove()

ax.plot([
    df.query('Food == "Unloaded"').groupby('Animal')[y].mean(),
    df.query('Food == "Loaded"').groupby('Animal')[y].mean()
], color='0.5')

ax.set_xlim([-0.5, 1.5])
ax.set_ylim([0, 10])
ax.set_xlabel(None)
ax.set_ylabel('Motor pattern duration (s)')
plt.tight_layout()

plt.gcf().savefig(os.path.join(export_dir, 'figure-2D.png'), dpi=300)

df.groupby(['Animal', 'Food'])[y].mean()

In [None]:
y = 'Duration (s)'
y_tape_nori = df_all.query('Food == "Tape nori"').groupby('Animal')[y].mean()
y_reg_nori = df_all.query('Food == "Regular nori"').groupby('Animal')[y].mean()

differences_test(y_tape_nori, y_reg_nori)

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))

# df = df_all.reset_index()
# df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
# df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'

# x = 'Food'
# y = 'Duration (s)'
# hue = 'Animal'
# show_points = True

# boxplot_with_points(x, y, hue, df, ax, show_points)
# ax.set_ylim([0, 10])
# ax.set_ylabel('Motor pattern duration (s)')
# plt.tight_layout()

### 🐌 Figure 2E

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 5))

df = df_all.reset_index()
df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'

x = 'Food'
y = 'End to next start (s)'
hue = 'Animal'

sns.stripplot(
    x=x, y=y, hue=hue,
    data=df.groupby(['Animal', 'Food'])[y].mean().reset_index(),
    order=['Unloaded', 'Loaded'],
    jitter=False,
    color='0.25',
)
ax.legend_.remove()

ax.plot([
    df.query('Food == "Unloaded"').groupby('Animal')[y].mean(),
    df.query('Food == "Loaded"').groupby('Animal')[y].mean()
], color='0.5')

ax.axhline(y=0, ls=':', c='gray', zorder=-1)
ax.set_xlim([-0.5, 1.5])
ax.set_ylim([-1, 2])
ax.set_xlabel(None)
ax.set_ylabel('Time between motor patterns (s)')
plt.tight_layout()

plt.gcf().savefig(os.path.join(export_dir, 'figure-2E.png'), dpi=300)

df.groupby(['Animal', 'Food'])[y].mean()

In [None]:
y = 'End to next start (s)'
y_tape_nori = df_all.query('Food == "Tape nori"').groupby('Animal')[y].mean()
y_reg_nori = df_all.query('Food == "Regular nori"').groupby('Animal')[y].mean()

differences_test(y_tape_nori, y_reg_nori)

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))

# df = df_all.reset_index()
# df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
# df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'

# x = 'Food'
# y = 'End to next start (s)'
# hue = 'Animal'
# show_points = True

# boxplot_with_points(x, y, hue, df, ax, show_points)
# ax.axhline(y=0, ls=':', c='gray', zorder=-1)
# ax.set_ylabel('Time between motor patterns (s)')
# plt.tight_layout()

---

## [FIGURE 3]

### 🐌 Figure 3A

Biomechanics schematic

### 🐌 Figure 3B

In [None]:
start = datetime.datetime.now()

df = df_exemplary_swallow
(data_set_name, time_window) = feeding_bouts[exemplary_swallow]
animal, _, _ = exemplary_swallow

t_start, t_stop = time_window*pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -50,  50], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
#     {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300}, #, 'decimation_factor': 100},
]

kwargs = dict(
    figsize = (12, 1), #1.5),
    linewidth = 0.5,
    x_scalebar = None, #5*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
    'I2 spikes',
#     'B8a/b',
#     'B6/B9',
#     'B3',
#     'B38',
]

# # plot a zero baseline for force
# ax = axes[-1]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    
    # plot force markers
    force_shoulder_end = df.loc[i, 'Force shoulder end start (s)']*pq.s
    force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)']*pq.mN
    force_min_time = df.loc[i, 'Force minimum time (s)']*pq.s
    force_min = df.loc[i, 'Force minimum (mN)']*pq.mN
    ax = axes[-1]
    ax.plot([force_shoulder_end], [force_shoulder_end_value], marker=CARETDOWN, markersize=8, color='k')
    ax.plot([force_min_time],     [force_min],                marker=CARETUP,   markersize=8, color='k')
    
    for k, unit in enumerate(units):
                
        # plot burst windows
        burst_start = df.loc[i, unit+' first burst start (s)']
        if np.isfinite(burst_start) and burst_start < t_stop:
            burst_end = df.loc[i, unit+' last burst end (s)']
            axes[-1].set_zorder(-1) # needed for trick below
            axes[-1].axvspan(
                burst_start, burst_end,
                0, 999, clip_on=False, # trick for getting to span entire figure
                facecolor=lighten_color(unit_colors[unit], amount=0.7), edgecolor=None, lw=0)

fig.savefig(os.path.join(export_dir, 'figure-3B.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 3C

In [None]:
start = datetime.datetime.now()

df = df_exemplary_swallow
(data_set_name, time_window) = feeding_bouts[exemplary_swallow]
animal, _, _ = exemplary_swallow

t_start, t_stop = time_window*pq.s
plots = [
#     {'channel': 'I2',       'units': 'uV', 'ylim': [ -50,  50], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300}, #, 'decimation_factor': 100},
]

kwargs = dict(
    figsize = (12, 1), #1.5),
    linewidth = 0.5,
    x_scalebar = None, #5*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
#     'I2 spikes',
#     'B8a/b',
#     'B6/B9',
#     'B3',
    'B38',
]

# # plot a zero baseline for force
# ax = axes[-1]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    
    # plot force markers
    force_shoulder_end = df.loc[i, 'Force shoulder end start (s)']*pq.s
    force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)']*pq.mN
    ax = axes[-1]
    ax.plot([force_shoulder_end], [force_shoulder_end_value], marker=CARETDOWN, markersize=8, color='k')
    
    for k, unit in enumerate(units):
                
        # plot burst windows
        burst_start = df.loc[i, unit+' first burst start (s)']
        if np.isfinite(burst_start) and burst_start < t_stop:
            burst_end = df.loc[i, unit+' last burst end (s)']
            burst_end = min(burst_end, t_stop) # hack to prevent unclipped rect entering margin
            axes[-1].set_zorder(-1) # needed for trick below
            axes[-1].axvspan(
                burst_start, burst_end,
                0, 999, clip_on=False, # trick for getting to span entire figure
                facecolor=lighten_color(unit_colors[unit], amount=0.7), edgecolor=None, lw=0)

fig.savefig(os.path.join(export_dir, 'figure-3C.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 3D

In [None]:
start = datetime.datetime.now()

df = df_exemplary_swallow
(data_set_name, time_window) = feeding_bouts[exemplary_swallow]
animal, _, _ = exemplary_swallow

t_start, t_stop = time_window*pq.s
plots = [
#     {'channel': 'I2',       'units': 'uV', 'ylim': [ -50,  50], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
#     {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300}, #, 'decimation_factor': 100},
]

kwargs = dict(
    figsize = (12, 1), #1.5),
    linewidth = 0.5,
    x_scalebar = None, #5*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
#     'I2 spikes',
    'B8a/b',
#     'B6/B9',
#     'B3',
#     'B38',
]

# # plot a zero baseline for force
# ax = axes[-1]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    
    # plot force markers
    force_rise_start = df.loc[i, 'Force rise start start (s)']*pq.s
    force_baseline = df.loc[i, 'Force baseline (mN)']*pq.mN
    force_plateau_end = df.loc[i, 'Force plateau end start (s)']*pq.s
    force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)']*pq.mN
    ax = axes[-1]
    ax.plot([force_rise_start],  [force_baseline],          marker=CARETUP,   markersize=8, color='k')
    ax.plot([force_plateau_end], [force_plateau_end_value], marker=CARETLEFT, markersize=8, color='k')
    
    for k, unit in enumerate(units):
                
        # plot burst windows
        burst_start = df.loc[i, unit+' first burst start (s)']
        if np.isfinite(burst_start) and burst_start < t_stop:
            burst_end = df.loc[i, unit+' last burst end (s)']
            burst_end = min(burst_end, t_stop) # hack to prevent unclipped rect entering margin
            axes[-1].set_zorder(-1) # needed for trick below
            axes[-1].axvspan(
                burst_start, burst_end,
                0, 999, clip_on=False, # trick for getting to span entire figure
                facecolor=lighten_color(unit_colors[unit], amount=0.7), edgecolor=None, lw=0)

fig.savefig(os.path.join(export_dir, 'figure-3D.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 3E

In [None]:
start = datetime.datetime.now()

df = df_exemplary_swallow
(data_set_name, time_window) = feeding_bouts[exemplary_swallow]
animal, _, _ = exemplary_swallow

t_start, t_stop = time_window*pq.s
plots = [
#     {'channel': 'I2',       'units': 'uV', 'ylim': [ -50,  50], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
#     {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300}, #, 'decimation_factor': 100},
]

kwargs = dict(
    figsize = (12, 1), #1.5),
    linewidth = 0.5,
    x_scalebar = None, #5*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
#     'I2 spikes',
#     'B8a/b',
#     'B6/B9',
#     'B3',
#     'B38',
    'B4/B5',
]

# # plot a zero baseline for force
# ax = axes[-1]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    
    # plot force markers
    force_rise_start = df.loc[i, 'Force rise start start (s)']*pq.s
    force_baseline = df.loc[i, 'Force baseline (mN)']*pq.mN
    ax = axes[-1]
    ax.plot([force_rise_start], [force_baseline], marker=CARETUP, markersize=8, color='k')
    
    for k, unit in enumerate(units):
                
        # plot burst windows
        burst_start = df.loc[i, unit+' first burst start (s)']
        if np.isfinite(burst_start) and burst_start < t_stop:
            burst_end = df.loc[i, unit+' last burst end (s)']
            burst_end = min(burst_end, t_stop) # hack to prevent unclipped rect entering margin
            axes[-1].set_zorder(-1) # needed for trick below
            axes[-1].axvspan(
                burst_start, burst_end,
                0, 999, clip_on=False, # trick for getting to span entire figure
                facecolor=lighten_color(unit_colors[unit], amount=0.7), edgecolor=None, lw=0)

fig.savefig(os.path.join(export_dir, 'figure-3E.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 3F

In [None]:
start = datetime.datetime.now()

df = df_exemplary_swallow
(data_set_name, time_window) = feeding_bouts[exemplary_swallow]
animal, _, _ = exemplary_swallow

t_start, t_stop = time_window*pq.s
plots = [
#     {'channel': 'I2',       'units': 'uV', 'ylim': [ -50,  50], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300}, #, 'decimation_factor': 100},
]

kwargs = dict(
    figsize = (12, 1), #1.5),
    linewidth = 0.5,
    x_scalebar = None, #5*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
#     'I2 spikes',
#     'B8a/b',
    'B6/B9',
#     'B3',
#     'B38',
]

# # plot a zero baseline for force
# ax = axes[-1]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    
    # plot force markers
    force_rise_start = df.loc[i, 'Force rise start start (s)']*pq.s
    force_baseline = df.loc[i, 'Force baseline (mN)']*pq.mN
    force_plateau_end = df.loc[i, 'Force plateau end start (s)']*pq.s
    force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)']*pq.mN
    ax = axes[-1]
    ax.plot([force_rise_start],  [force_baseline],          marker=CARETUP,   markersize=8, color='k')
    ax.plot([force_plateau_end], [force_plateau_end_value], marker=CARETLEFT, markersize=8, color='k')
    
    for k, unit in enumerate(units):
                
        # plot burst windows
        burst_start = df.loc[i, unit+' first burst start (s)']
        if np.isfinite(burst_start) and burst_start < t_stop:
            burst_end = df.loc[i, unit+' last burst end (s)']
            burst_end = min(burst_end, t_stop) # hack to prevent unclipped rect entering margin
            axes[-1].set_zorder(-1) # needed for trick below
            axes[-1].axvspan(
                burst_start, burst_end,
                0, 999, clip_on=False, # trick for getting to span entire figure
                facecolor=lighten_color(unit_colors[unit], amount=0.7), edgecolor=None, lw=0)

fig.savefig(os.path.join(export_dir, 'figure-3F.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 3G

In [None]:
start = datetime.datetime.now()

df = df_exemplary_swallow
(data_set_name, time_window) = feeding_bouts[exemplary_swallow]
animal, _, _ = exemplary_swallow

t_start, t_stop = time_window*pq.s
plots = [
#     {'channel': 'I2',       'units': 'uV', 'ylim': [ -50,  50], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300}, #, 'decimation_factor': 100},
]

kwargs = dict(
    figsize = (12, 1.5), #1),
    linewidth = 0.5,
    x_scalebar = 5*pq.s, #None,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
#     'I2 spikes',
#     'B8a/b',
#     'B6/B9',
    'B3',
#     'B38',
]

# # plot a zero baseline for force
# ax = axes[-1]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    
    # plot force markers
    force_plateau_start = df.loc[i, 'Force plateau start start (s)']*pq.s
    force_plateau_start_value = df.loc[i, 'Force plateau start value (mN)']*pq.mN
    force_plateau_end = df.loc[i, 'Force plateau end start (s)']*pq.s
    force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)']*pq.mN
    ax = axes[-1]
    ax.plot([force_plateau_start], [force_plateau_start_value], marker=CARETRIGHT, markersize=8, color='k')
    ax.plot([force_plateau_end],   [force_plateau_end_value],   marker=CARETLEFT,  markersize=8, color='k')
    
    for k, unit in enumerate(units):
                
        # plot burst windows
        burst_start = df.loc[i, unit+' first burst start (s)']
        if np.isfinite(burst_start) and burst_start < t_stop:
            burst_end = df.loc[i, unit+' last burst end (s)']
            burst_end = min(burst_end, t_stop) # hack to prevent unclipped rect entering margin
            axes[-1].set_zorder(-1) # needed for trick below
            axes[-1].axvspan(
                burst_start, burst_end,
                0, 999, clip_on=False, # trick for getting to span entire figure
                facecolor=lighten_color(unit_colors[unit], amount=0.7), edgecolor=None, lw=0)

fig.savefig(os.path.join(export_dir, 'figure-3G.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

## [FIGURE 4]

In [None]:
# normalization_fixed_times_labels = [
#     'prev_force_plateau_start',
#     'prev_force_plateau_end',
#     'prev_force_drop_end',
#     'force_shoulder_end',
#     'force_rise_start',
#     'force_plateau_start',
#     'force_plateau_end',
#     'force_drop_end',
#     'next_force_shoulder_end',
#     'next_force_rise_start'
# ]

In [None]:
# use only tape nori swallows
df = df_all.query('Food == "Tape nori"')

# get all normalization times from previous force drop end to current drop end
t = np.array([times.magnitude[2:8] for times in df['Normalization fixed times (s)']])

# get all phase durations
all_phase_durations = np.diff(t).T

# find median phase durations
# - 0: partial force maintenance
# - 1: force dip
# - 2: force rise
# - 3: force maintenance
# - 4: major force drop
median_phase_durations = np.nanmedian(all_phase_durations, axis=1)

# copy phases to represent "previous" and "next" swallows
# - 0: previous force maintenance
# - 1: previous major force drop
# - 2: partial force maintenance
# - 3: force dip
# - 4: force rise
# - 5: force maintenance
# - 6: major force drop
# - 7: next partial force maintenance
# - 8: next force dip
median_phase_durations = np.concatenate([
    median_phase_durations[-2:],
    median_phase_durations,
    median_phase_durations[:2]
])

# convert durations into boundary timings
median_phase_boundaries = np.concatenate([[0], median_phase_durations]).cumsum()

# phase_labels = [
#     'IV\nPrevious force\nmaintenance',
#     'V\nPrevious major\nforce drop',
#     'I\nPartial force\nmaintenance',
#     'II\nForce dip',
#     'III\nForce rise',
#     'IV\nForce\nmaintenance',
#     'V\nMajor\nforce drop',
#     'I\nNext partial\nforce maintenance',
#     'II\nNext\nforce dip',
# ]
phase_labels = [
    '',
    'V',
    'I',
    'II',
    'III',
    'IV',
    'V',
    '',
    '',
]

# shared settings for Figures 4A and 4B
figure_4_xlim = [1.2, 10.3]
figure_4_unit_fontsize = 12

In [None]:
plt.figure()

plt.boxplot(
    [a[np.isfinite(a)] for a in list(all_phase_durations)],
    labels=phase_labels[2:7],
    showmeans=True,
)

plt.ylabel('Duration (s)')
plt.tight_layout()

In [None]:
for t, l in zip(all_phase_durations, phase_labels[2:7]):
    l = l.replace('\n', ' ')
    print(f'{l}:\tmedian {np.nanmedian(t):g}, mean {np.nanmean(t):g} (n={t[np.isfinite(t)].size})')

### 🐌 Figure 4A

In [None]:
# plt.figure(figsize=(9, 5))
plt.figure(figsize=(7, 4))
ax = plt.gca()

units = ['I2 spikes', 'B8a/b', 'B6/B9', 'B3', 'B38', 'B4/B5']

# use only tape nori swallows
df = df_all.query('Food == "Tape nori"')

for i, unit in enumerate(units):
    t0_data   = df[f'{unit} first burst start (normalized)'].dropna()
    t0_data[:] = unnormalize_time(median_phase_boundaries, t0_data.values)
    t0_median = t0_data.median()
    t0_q1     = t0_data.quantile(0.25)
    t0_q3     = t0_data.quantile(0.75)
    
    t1_data   = df[f'{unit} last burst end (normalized)'].dropna()
    t1_data[:] = unnormalize_time(median_phase_boundaries, t1_data.values)
    t1_median = t1_data.median()
    t1_q1     = t1_data.quantile(0.25)
    t1_q3     = t1_data.quantile(0.75)
    
    # plot boxes using medians
    height = 0.8
    lw = 1
    rect = patches.Rectangle((t0_median, i-height/2), t1_median-t0_median, height, facecolor=unit_colors[unit], edgecolor='k', lw=lw, clip_on=False)
    ax.add_patch(rect)
    
    # plot quartiles (25% and 75% quantiles)
    ax.add_line(mlines.Line2D([t0_q1, t0_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t0_q3, t0_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q1, t1_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q3, t1_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t0_q1, t0_q3], [i, i], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q1, t1_q3], [i, i], color='k', lw=lw, clip_on=False))

###
### LABELS AND ANNOTATIONS
###

# drop some phases
trimmed_median_phase_boundaries = median_phase_boundaries[1:8]
trimmed_phase_labels = phase_labels[1:8]

for x in trimmed_median_phase_boundaries:
    plt.axvline(x=x, ls=':', lw=1, c='gray', zorder=-1, clip_on=False)

plt.xlabel(None)
plt.ylabel(None)

sns.despine(bottom=True, left=True)
plt.tick_params(bottom=False, left=False) # disable tick marks
plt.xticks(trimmed_median_phase_boundaries[:-1]+np.diff(trimmed_median_phase_boundaries)/2, trimmed_phase_labels)
units2 = units
units2[0] = 'I2'
plt.yticks(range(len(units)), units2, fontsize=figure_4_unit_fontsize)

plt.xlim(figure_4_xlim)
plt.ylim(5.9, -0.5)

ax.tick_params(
    axis='x',
    labelsize='small',
)

ax.add_artist(AnchoredScaleBar(
    ax.transData,
    sizex=1,
    labelx='1 s',

    loc='lower right',
    bbox_to_anchor=(1, 0),
    bbox_transform=ax.transAxes,

    pad=0,
    borderpad=1,
    sep=5,
    barwidth=2,
))

# plt.tight_layout(h_pad=0, pad=0)
plt.subplots_adjust(
    left   = 0.08,
    right  = 0.90,
    bottom = 0.10,
    top    = 0.99,
    wspace = 0.2,
    hspace = 0.2,
)

plt.gcf().savefig(os.path.join(export_dir, 'figure-4A.png'), dpi=600)

In [None]:
# use only tape nori swallows
df = df_all.query('Food == "Tape nori"')

for i, unit in enumerate(['I2 spikes', 'B8a/b', 'B6/B9', 'B3', 'B38', 'B4/B5']):
    t0_data   = df[f'{unit} first burst start (normalized)'].dropna()
    t0_data[:] = unnormalize_time(median_phase_boundaries, t0_data.values)
    t0_median = t0_data.median()
    t0_N      = t0_data.size
    
    t1_data   = df[f'{unit} last burst end (normalized)'].dropna()
    t1_data[:] = unnormalize_time(median_phase_boundaries, t1_data.values)
    t1_median = t1_data.median()
    t1_N      = t1_data.size

    print(f'{unit}:\t[{t0_median:.2f} (n={t0_N}), {t1_median:.2f} (n={t1_N})], duration: {t1_median-t0_median:.2f}')

### 🐌 Figure 4B

In [None]:
times_interp_unnormalized = unnormalize_time(median_phase_boundaries, times_interp)

# find the number of data points in one cycle
n = np.where((4 <= times_interp) & (times_interp <= 9))[0].size # 4 = start of rise, 9 = start of next rise

units = [
    'I2 spikes',
    'B8a/b',
    'B6/B9',
    'B3',
    'B38',
    'B4/B5',
]

# figsize = (9, 8)
figsize = (7, 6)
# figsize = (9.5, 10) # dimensions for notebook
# figsize = (11, 8.5) # dimensions for printing
# figsize = (16, 9) # dimensions for filling wide screens
fig, axes = plt.subplots(len(units)+1, 1, sharex='col', figsize=figsize)

###
### UNITS
###

# use only tape nori swallows
df = df_all.query('Food == "Tape nori"')

for i, unit in enumerate(units):
    ax = axes[i]
    
    # elevate the Axes for units and remove background colors so that
    # each vertical ConnectionPatch drawn later is visible behind it
    ax.set_zorder(1)
    ax.set_facecolor('none')
    
    # find the firing rate median and quartiles
    firing_rate_data = np.array(list(df[unit+' firing rate, normalized time interpolation (Hz)']))
    firing_rate_median = np.nanmedian(firing_rate_data, axis=0)
    firing_rate_q1 = np.nanquantile(firing_rate_data, q=0.25, axis=0)
    firing_rate_q3 = np.nanquantile(firing_rate_data, q=0.75, axis=0)

    # plot the firing rate median and quartiles (median last so it's on top)
    ax.plot(times_interp_unnormalized, firing_rate_q1,     c=lighten_color(unit_colors[unit], amount=0.7), lw=1, zorder=2)
    ax.plot(times_interp_unnormalized, firing_rate_q3,     c=lighten_color(unit_colors[unit], amount=0.7), lw=1, zorder=2)
    ax.plot(times_interp_unnormalized, firing_rate_median, c=unit_colors[unit], lw=2, zorder=2)

    
    ax.set_ylim([0, None])
    if unit == 'I2 spikes':
        ax.set_ylabel('I2', rotation='horizontal', ha='right', va='center', labelpad=10, fontsize=figure_4_unit_fontsize)
    else:
        ax.set_ylabel(unit, rotation='horizontal', ha='right', va='center', labelpad=10, fontsize=figure_4_unit_fontsize)
#     ax.yaxis.set_label_coords(-0.06, 0.5)

# remove right and top plot borders, and remove x-axis
#     sns.despine(ax=ax)#, bottom=True)
    sns.despine(ax=ax, left=True, right=False)#, trim=True)
#     ax.xaxis.set_visible(False)

#     ax.tick_params(bottom=False) # disable tick marks
    # disable tick marks
    ax.tick_params(
        bottom=False,
        left=False,
        right=True,
        labelbottom=False,
        labelleft=False)

#     ax.set_yticks([0, 10, 20])
#     ax.set_yticklabels([0, 10, '20 Hz'])
#     ax.set_yticks(ax.get_yticks()) # trick to prevent tight_layout from changing ticks
#     yticklabels = [f'{y:g}' for y in ax.get_yticks()]
#     yticklabels[-1] += ' Hz'
#     ax.set_yticklabels(yticklabels)
    


#     # add freq scale bar
#     ax.add_artist(AnchoredScaleBar(
#         ax.transData,
#         sizey=10,
#         labely='10 Hz',

#         loc='center left',
#         bbox_to_anchor=(1, 0.5),
#         bbox_transform=ax.transAxes,

#         pad=0,
#         borderpad=1,
#         sep=5,
#         barwidth=2,
#     ))

###
### FORCE
###

ax = axes[-1]

force_data = np.array(list(df['Force, normalized time interpolation (mN)']))

# def tile_array(arr, i_start, i_stop, axis=1):
#     arr = arr.copy()
#     n = i_stop - i_start
    
#     ind_before, ind_after = [slice(None)]*arr.ndim, [slice(None)]*arr.ndim
#     ind_before[axis] = slice(None, i_start)
#     ind_after[axis] = slice(i_stop+1, None)
    
#     # replace values outside of range with NaN
#     arr[tuple(ind_before)] = np.nan
#     arr[tuple(ind_after)] = np.nan
    
#     # construct versions of arr shifted to the left and right
#     arr_nan = np.full_like(arr, np.nan)
#     arr_nan = arr_nan
#     arr_nan = np.full((arr.shape[0], n), np.nan)
#     left = np.concatenate((arr[:, n:], arr_nan), axis=1)
#     right = np.concatenate((arr_nan, arr[:, :-n]), axis=1)
    
#     # merge arrays
#     arr = np.nanmax([
#         arr,
#         left,
#         right
#     ], axis=0) # can axis be generalized?
    
#     return arr

# # replace the "previous" swallow force data with a repeat of the "current" swallow
# force_data = tile_array(force_data, i_start, i_stop)

# find the force median and quartiles
force_median = np.nanmedian(force_data, axis=0)
force_q1 = np.nanquantile(force_data, q=0.25, axis=0)
force_q3 = np.nanquantile(force_data, q=0.75, axis=0)

# plot the force median and quartiles (median last so it's on top)
ax.plot(times_interp_unnormalized, force_q1,     c=lighten_color('k', amount=0.7), lw=1, zorder=2)
ax.plot(times_interp_unnormalized, force_q3,     c=lighten_color('k', amount=0.7), lw=1, zorder=2)
ax.plot(times_interp_unnormalized, force_median, c='k', lw=2, zorder=2)

ax.set_ylim([0, None])
ax.set_ylabel('Force', rotation='horizontal', ha='right', va='center', labelpad=10, fontsize=figure_4_unit_fontsize)
# ax.yaxis.set_label_coords(-0.06, 0.5)

# drop some phases
trimmed_median_phase_boundaries = median_phase_boundaries[1:8]
trimmed_phase_labels = phase_labels[1:8]

# remove right and top plot borders from bottom panel, and set x-label
sns.despine(ax=ax, left=True, right=False)#, trim=True)
ax.set_xlim(figure_4_xlim)
# ax.tick_params(bottom=False) # disable tick marks
# disable tick marks
ax.tick_params(
    bottom=False,
    left=False,
    right=True,
#     labelbottom=False,
    labelleft=False)
ax.tick_params(
    axis='x',
    labelsize='small',
)

ax.set_xticks(trimmed_median_phase_boundaries[:-1]+np.diff(trimmed_median_phase_boundaries)/2)
ax.set_xticklabels(trimmed_phase_labels)

# ax.set_yticks([0, 100, 200])
# ax.set_yticklabels([0, 100, '200 mN'])
# ax.set_yticks(ax.get_yticks()) # trick to prevent tight_layout from changing ticks
# yticklabels = [f'{y:g}' for y in ax.get_yticks()]
# yticklabels[-1] += ' mN'
# ax.set_yticklabels(yticklabels)
       
# plot force phase boundaries in normalized time
for t in trimmed_median_phase_boundaries:
    ax.add_artist(patches.ConnectionPatch(
        xyA=(t, 0), xyB=(t, 1),
        coordsA=ax.get_xaxis_transform(), coordsB=axes[0].get_xaxis_transform(),
        axesA=ax, axesB=axes[0],
        color='gray', lw=1, ls=':', zorder=0))

# add time scale bar
ax.add_artist(AnchoredScaleBar(
    ax.transData,
    sizex=1,
    labelx='1 s',

    loc='upper right',
    bbox_to_anchor=(1, 1),
    bbox_transform=ax.transAxes,

    pad=0,
    borderpad=1,
    sep=5,
    barwidth=2,
))

# fig.tight_layout(h_pad=0, pad=0) # first tight_layout removes excess margins and sets reasonable ylims
# fig.tight_layout(pad=0)

# ylims = [
#     [0, 20], # I2
#     [0, 40], # B8a/b
#     [0, 50], # B6/B9
#     [0, 10], # B3
#     [0, 15], # B38
#     [0, 20], # B4/B5
#     [0, 300], # Force
# ]
# for i, ax in enumerate(axes):
#     ax.set_ylim(ylims[i])

yticks = [
    [0,  10,  20], # I2
    [0,  20,  40], # B8a/b
    [0,  25,  50], # B6/B9
    [0,   5,  10], # B3
    [0,  10,  20], # B38
    [0,  10,  20], # B4/B5
    [0, 150, 300], # Force
]
for i, ax in enumerate(axes):
    ax.set_ylim([yticks[i][0], yticks[i][-1]])

for i, unit in enumerate(units):
    ax = axes[i]
    ax.grid(axis='y', clip_on=False)
#     ax.set_yticks(ax.get_yticks()) # trick to prevent tight_layout from changing ticks
    ax.set_yticks(yticks[i])
    yticklabels = [f'{y:g}' for y in ax.get_yticks()]
    yticklabels[-1] += ' Hz'
    ax.set_yticklabels(yticklabels)
ax = axes[-1]
ax.grid(axis='y', clip_on=False)
# ax.set_yticks(ax.get_yticks()) # trick to prevent tight_layout from changing ticks
ax.set_yticks(yticks[-1])
yticklabels = [f'{y:g}' for y in ax.get_yticks()]
yticklabels[-1] += ' mN'
ax.set_yticklabels(yticklabels)

# fig.tight_layout(h_pad=0, pad=0) # second tight_layout makes room for units added to y tick labels
# fig.tight_layout(pad=0)

plt.subplots_adjust(
    left   = 0.08,
    right  = 0.90,
    bottom = 0.05,
    top    = 0.99,
    wspace = 0.2,
    hspace = 0.2,
)

fig.savefig(os.path.join(export_dir, 'figure-4B.png'), dpi=600)

## [FIGURE 5]

### 🐌 Figure 5A

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4))

df = df_all.reset_index()
df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'

x = 'Food'
y = 'B3/B6/B9 burst duration (s)'
hue = 'Animal'

sns.stripplot(
    x=x, y=y, hue=hue,
    data=df.groupby(['Animal', 'Food'])[y].mean().reset_index(),
    order=['Unloaded', 'Loaded'],
    jitter=False,
    color='0.25',
)
ax.legend_.remove()

ax.plot([
    df.query('Food == "Unloaded"').groupby('Animal')[y].mean(),
    df.query('Food == "Loaded"').groupby('Animal')[y].mean()
], color='0.5')

ax.set_xlim([-0.5, 1.5])
ax.set_ylim([0, 4])
ax.set_xlabel(None)
plt.tight_layout()

plt.gcf().savefig(os.path.join(export_dir, 'figure-5A.png'), dpi=300)

df.groupby(['Animal', 'Food'])[y].mean()

In [None]:
y = 'B3/B6/B9 burst duration (s)'
y_tape_nori = df_all.query('Food == "Tape nori"').groupby('Animal')[y].mean()
y_reg_nori = df_all.query('Food == "Regular nori"').groupby('Animal')[y].mean()

differences_test(y_tape_nori, y_reg_nori)

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))

# df = df_all.reset_index()
# df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
# df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'

# x = 'Food'
# y = 'B3/B6/B9 burst duration (s)'
# hue = 'Animal'
# show_points = True

# boxplot_with_points(x, y, hue, df, ax, show_points)
# ax.set_ylim([0, None])
# plt.tight_layout()

### 🐌 Figure 5B

In [None]:
df_all['B3/B6/B9 burst duration (s)'] = \
    df_all[['B6/B9 last burst end (s)',    'B3 last burst end (s)']]   .max(axis=1) - \
    df_all[['B6/B9 first burst start (s)', 'B3 first burst start (s)']].min(axis=1)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3 all bursts duration (s)', [0, 6]
# xlabel, xlim = 'B6/B9 all bursts duration (s)', [0, 6]
xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 6]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 6]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

plt.gcf().savefig(os.path.join(export_dir, 'figure-5B.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
#     'JG08 Tape nori',
#     'JG11 Tape nori',
#     'JG12 Tape nori',
#     'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3 all bursts duration (s)', [0, 6]
# xlabel, xlim = 'B6/B9 all bursts duration (s)', [0, 6]
xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 6]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 6]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend_separately=True, tooltips=False, colors=['k'])
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
# ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

plt.gcf().savefig(os.path.join(export_dir, 'figure-5B-JG07.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
#     'JG07 Tape nori',
    'JG08 Tape nori',
#     'JG11 Tape nori',
#     'JG12 Tape nori',
#     'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3 all bursts duration (s)', [0, 6]
# xlabel, xlim = 'B6/B9 all bursts duration (s)', [0, 6]
xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 6]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 6]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend_separately=True, tooltips=False, colors=['k'])
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
# ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

plt.gcf().savefig(os.path.join(export_dir, 'figure-5B-JG08.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
#     'JG07 Tape nori',
#     'JG08 Tape nori',
    'JG11 Tape nori',
#     'JG12 Tape nori',
#     'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3 all bursts duration (s)', [0, 6]
# xlabel, xlim = 'B6/B9 all bursts duration (s)', [0, 6]
xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 6]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 6]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend_separately=True, tooltips=False, colors=['k'])
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
# ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

plt.gcf().savefig(os.path.join(export_dir, 'figure-5B-JG11.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
#     'JG07 Tape nori',
#     'JG08 Tape nori',
#     'JG11 Tape nori',
    'JG12 Tape nori',
#     'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3 all bursts duration (s)', [0, 6]
# xlabel, xlim = 'B6/B9 all bursts duration (s)', [0, 6]
xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 6]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 6]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend_separately=True, tooltips=False, colors=['k'])
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
# ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

plt.gcf().savefig(os.path.join(export_dir, 'figure-5B-JG12.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
#     'JG07 Tape nori',
#     'JG08 Tape nori',
#     'JG11 Tape nori',
#     'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3 all bursts duration (s)', [0, 6]
# xlabel, xlim = 'B6/B9 all bursts duration (s)', [0, 6]
xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 6]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 6]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend_separately=True, tooltips=False, colors=['k'])
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
# ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

plt.gcf().savefig(os.path.join(export_dir, 'figure-5B-JG14.png'), dpi=300)

---

## Random old figures

In [None]:
df_all['Delay from I2 end to force rise start (s)'] = \
    df_all['Force rise start start (s)'] - \
    df_all['I2 spikes last burst end (s)']

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from I2 end to force min (s)'
column = 'Delay from I2 end to force rise start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
df_all['Delay from B3/B6/B9 start to force rise start (s)'] = \
    df_all['Force rise start start (s)'] - \
    df_all[['B6/B9 first burst start (s)', 'B3 first burst start (s)']].min(axis=1)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from B3/B6/B9 start to force start (s)'
column = 'Delay from B3/B6/B9 start to force rise start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
df_all['Delay from B8a/b start to force rise start (s)'] = \
    df_all['Force rise start start (s)'] - \
    df_all['B8a/b first burst start (s)']

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from B8a/b start to force start (s)'
column = 'Delay from B8a/b start to force rise start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
# plt.figure(figsize=(5, 5))
# ax = plt.gca()

# data_subsets = [
#     'JG07 Tape nori',
#     'JG08 Tape nori',
#     'JG11 Tape nori',
#     'JG12 Tape nori',
#     'JG14 Tape nori',
# ]
# data_subsets = {label:label2query(label) for label in data_subsets}

# # xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
# # xlabel, xlim = 'B8a/b first burst mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)', [0, None]
# # xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)', [0, None]
# # ylabel, ylim = 'Force slope (mN/s)', [0, None]
# ylabel, ylim = 'Force initial slope (mN/s)', [0, None]
# # ylabel, ylim = 'Force initial increase (mN)', [0, None]

# scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
# ax.legend()
# sns.despine(ax=ax, offset=20, trim=True)
# plt.tight_layout()

# # plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
# plt.figure(figsize=(5, 5))
# ax = plt.gca()

# data_subsets = [
#     'JG07 Tape nori',
#     'JG08 Tape nori',
#     'JG11 Tape nori',
#     'JG12 Tape nori',
#     'JG14 Tape nori',
# ]
# data_subsets = {label:label2query(label) for label in data_subsets}

# # xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
# # xlabel, xlim = 'B8a/b first burst mean rectified voltage (μV)', [0, None]
# # xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)', [0, None]
# # ylabel, ylim = 'Force slope (mN/s)', [0, None]
# ylabel, ylim = 'Force initial slope (mN/s)', [0, None]
# # ylabel, ylim = 'Force initial increase (mN)', [0, None]

# scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
# ax.legend()
# sns.despine(ax=ax, offset=20, trim=True)
# plt.tight_layout()

In [None]:
# plt.figure(figsize=(5, 5))
# ax = plt.gca()

# data_subsets = [
#     'JG07 Tape nori',
#     'JG08 Tape nori',
#     'JG11 Tape nori',
#     'JG12 Tape nori',
#     'JG14 Tape nori',
# ]
# data_subsets = {label:label2query(label) for label in data_subsets}

# # xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
# # xlabel, xlim = 'B8a/b first burst mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)', [0, None]
# # xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)', [0, None]
# # ylabel, ylim = 'Force slope (mN/s)', [0, None]
# # ylabel, ylim = 'Force initial slope (mN/s)', [0, None]
# ylabel, ylim = 'Force initial increase (mN)', [0, None]

# scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
# ax.legend()
# sns.despine(ax=ax, offset=20, trim=True)
# plt.tight_layout()

In [None]:
# plt.figure(figsize=(5, 5))
# ax = plt.gca()

# data_subsets = [
#     'JG07 Tape nori',
#     'JG08 Tape nori',
#     'JG11 Tape nori',
#     'JG12 Tape nori',
#     'JG14 Tape nori',
# ]
# data_subsets = {label:label2query(label) for label in data_subsets}

# # xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
# # xlabel, xlim = 'B8a/b first burst mean rectified voltage (μV)', [0, None]
# # xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)', [0, None]
# # ylabel, ylim = 'Force slope (mN/s)', [0, None]
# ylabel, ylim = 'Force increase (mN)', [0, None]
# # ylabel, ylim = 'Force initial slope (mN/s)', [0, None]
# # ylabel, ylim = 'Force initial increase (mN)', [0, None]

# scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
# ax.legend()
# sns.despine(ax=ax, offset=20, trim=True)
# plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8 activity duration (s)', [0, 8]
xlabel, xlim = 'B8a/b first burst duration (s)', [0, 8]
ylabel, ylim = 'Force rise and plateau duration (s)', [0, 8]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
df_all['Delay from B3/B6/B9 start to force plateau start (s)'] = \
    df_all['Force plateau start start (s)'] - \
    df_all[['B6/B9 first burst start (s)', 'B3 first burst start (s)']].min(axis=1)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from B3/B6/B9 start to force 50%-height (s)'
column = 'Delay from B3/B6/B9 start to force plateau start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
df_all['Delay from either B8a/b or B3/B6/B9 end to force plateau end (s)'] = \
    df_all['Force plateau end start (s)'] - \
    pd.concat([
        df_all['B8a/b last burst end (s)'],
        df_all[['B6/B9 last burst end (s)', 'B3 last burst end (s)']].max(axis=1)
    ], axis=1).min(axis=1)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from either B8a/b or B3/B6/B9 end to force 80%-height end (s)'
column = 'Delay from either B8a/b or B3/B6/B9 end to force plateau end (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
df_all['B6/B9 all bursts duration (s)'] = \
    df_all['B6/B9 last burst end (s)'] - df_all['B6/B9 first burst start (s)']

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3/6/9/10 activity duration (s)', [0, 8]
# xlabel, xlim = 'B6/B9 first burst duration (s)', [0, 8]
xlabel, xlim = 'B6/B9 all bursts duration (s)', [0, 8]
ylabel, ylim = 'Force rise and plateau duration (s)', [0, 8]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
df_all['Delay from B3 start to force plateau start (s)'] = \
    df_all['Force plateau start start (s)'] - \
    df_all['B3 first burst start (s)']

df_all['Delay from B3 end to force plateau end (s)'] = \
    df_all['Force plateau end start (s)'] - \
    df_all['B3 last burst end (s)']

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
df = df.rename(columns={
#     'Delay from B3 start to force 80%-height start (s)': 'Start of burst',
#     'Delay from B3 end to force 80%-height end (s)':     'End of burst'})
    'Delay from B3 start to force plateau start (s)': 'Start of burst',
    'Delay from B3 end to force plateau end (s)':     'End of burst'})
df = pd.melt(df,
             id_vars=['Animal', 'Food', 'Bout_index', 'Behavior_index'],
             value_vars=['Start of burst', 'End of burst'],
             var_name='When?',
             value_name='Delay from B3 to plateau (s)')
sns.boxplot(hue='Animal', x='Delay from B3 to plateau (s)', y='When?', data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(hue='Animal', x='Delay from B3 to plateau (s)', y='When?', data=df)

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

In [None]:
# plt.figure(figsize=(5, 5))
# ax = plt.gca()

# data_subsets = [
#     'JG07 Tape nori',
#     'JG08 Tape nori',
#     'JG11 Tape nori',
#     'JG12 Tape nori',
#     'JG14 Tape nori',
# ]
# data_subsets = {label:label2query(label) for label in data_subsets}

# # xlabel, xlim = 'B3/6/9/10 activity mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B6/B9 first burst mean rectified voltage (μV)', [0, None]
# # ylabel, ylim = 'Force slope (mN/s)', [0, None]
# # ylabel, ylim = 'Force 80%-height (mN)', [0, None]
# ylabel, ylim = 'Force increase (mN)', [0, None]

# scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
# ax.legend()
# sns.despine(ax=ax, offset=20, trim=True)
# plt.tight_layout()

In [None]:
# plt.figure(figsize=(5, 5))
# ax = plt.gca()

# data_subsets = [
#     'JG07 Tape nori',
#     'JG08 Tape nori',
#     'JG11 Tape nori',
#     'JG12 Tape nori',
#     'JG14 Tape nori',
# ]
# data_subsets = {label:label2query(label) for label in data_subsets}

# # xlabel, xlim = 'B3/6/9/10 activity mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B6/B9 first burst mean frequency (Hz)', [0, None]
# # ylabel, ylim = 'Force slope (mN/s)', [0, None]
# ylabel, ylim = 'Force plateau start value (mN)', [0, None]
# # ylabel, ylim = 'Force increase (mN)', [0, None]
# # ylabel, ylim = 'Force rise increase (mN)', [0, None]

# scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
# ax.legend()
# sns.despine(ax=ax, offset=20, trim=True)
# plt.tight_layout()

In [None]:
df_all['Delay from B38 end to force shoulder end (s)'] = \
    df_all['Force shoulder end start (s)'] - \
    df_all['B38 last burst end (s)']

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from B38 end to shoulder end (s)'
column = 'Delay from B38 end to force shoulder end (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
# plt.figure(figsize=(5, 5))
# ax = plt.gca()

# data_subsets = [
#     'JG07 Tape nori',
#     'JG08 Tape nori',
#     'JG11 Tape nori',
#     'JG12 Tape nori',
#     'JG14 Tape nori',
# ]
# data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst peak smoothed frequency (Hz)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]

# scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
# ax.legend()
# sns.despine(ax=ax, offset=20, trim=True)
# plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'Force rise duration (s)', [0, None]
ylabel, ylim = 'Force rise increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
# start = datetime.datetime.now()

# df = df_exemplary_bout
# (data_set_name, time_window) = feeding_bouts[exemplary_bout]

# # t_start, t_stop = time_window*pq.s
# t_start, t_stop = [2944.5, 3010]*pq.s # 7 swallows

# plots = [
#     {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
#     {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
#     {'channel': 'Force',    'units': 'mN', 'ylim': [ -50, 450], 'scalebar': 300, 'decimation_factor': 100},
# ]
# plot_names = [p['channel'] for p in plots]
# plot_units = [p['units'] for p in plots]

# kwargs = dict(
#     figsize = (12, 6),
#     linewidth = 0.5,
#     x_scalebar = 10*pq.s,
# )

# # load the data
# metadata = neurotic.MetadataSelector('../../data/metadata.yml')
# metadata.select(data_set_name)
# blk = neurotic.load_dataset(metadata) # not lazy so that I2 filter is applied

# # plot the signal
# fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

# units = [
#     'I2 spikes',
#     'B8a/b',
#     'B6/B9',
#     'B3',
#     'B38',
# ]
# unit_burst_boxes = {
#     'I2 spikes': [-35, 30],
#     'B8a/b':     [-20, 12],
#     'B6/B9':     [-15, 12],
#     'B3':        [-45, 35],
#     'B38':       [-12, 12],
# }

# # plot a zero baseline for force
# ax = axes[-1]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

# for j, i in enumerate(df.index):
#     for k, unit in enumerate(units):
#         st = df.loc[i, unit+' spike train']
#         if st is not None and st.size > 0:
                
#             # get the neural channel
#             channel = st.annotations['channels'][0]
#             sig = get_sig(blk, channel)

#             # get the signal for the entire bout
#             sig = sig.time_slice(t_start, t_stop)
#             sig = sig.rescale(plot_units[plot_names.index(channel)])

#             # plot spikes
#             ax = axes[plot_names.index(channel)]
#             spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
#             ax.scatter(st.times.rescale('s'), spike_amplitudes, marker='.', s=4, c=unit_colors[unit], zorder=3)

#             # plot burst windows
#             bursts = df.at[i, unit+' all bursts (s)']
#             bottom, top = unit_burst_boxes[unit]
#             height = top-bottom
#             for burst in bursts:
#                 if is_good_burst(burst):
#                     left = burst['Start (s)']
#                     right = burst['End (s)']
#                     width = right-left
#                     rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor=unit_colors[unit], fill=False, zorder=3, clip_on=False)
#                     ax.add_patch(rect)

# # fig.savefig(os.path.join(export_dir, 'figure-exemplary-bout-export.png'), dpi=600)

# end = datetime.datetime.now()
# print('render time:', end-start)

## Model stuff

In [None]:
from utils import CausalAlphaKernel

swallow_id = ('JG07', 'Tape nori', 0, 0)

(data_set_name, time_window) = feeding_bouts[swallow_id[:3]]

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata, lazy=True)
sig = get_sig(blk, 'Force')
sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
sig = sig.rescale('mN')
sig = elephant.signal_processing.butter(sig, lowpass_freq = 10*pq.Hz)

st_b6b9 = df_all.loc[swallow_id]['B6/B9 spike train']
st_b3 = df_all.loc[swallow_id]['B3 spike train']

weight_b6b9, tau_b6b9 = 1.75, 1
weight_b3,   tau_b3   = 1.75, 0.2
model_scale           = 100
model_baseline        = 85
u_to_y_constant       = 0.005

rate_b6b9 = elephant.statistics.instantaneous_rate(
    spiketrain=st_b6b9,
    sampling_period=0.0002*pq.s,
    kernel=CausalAlphaKernel(tau_b6b9*pq.s),
)

rate_b3 = elephant.statistics.instantaneous_rate(
    spiketrain=st_b3,
    sampling_period=0.0002*pq.s,
    kernel=CausalAlphaKernel(tau_b3*pq.s),
)

# derivation and assumptions of this *very* crude and simple model:
# - force (which is approximately isometric) is a linear function
#   of grasper position, x
# - the grasper is always spherical, and therefore its position is
#   related to the I1/I3 torus contact altitude, y, by the equation
#   for a circle, x^2 + y^2 = r, where r is the grasper radius
#   (here r=1 with arbitrary units)
# - the contact altitude y is approimately equivalent to the major
#   radius of the torus (i.e., the minor radius is negligible)
# - the major radius of the I1/I3 torus decreases from its max radius,
#   ymax (here ymax=1 with arbitrary units), to its min radius (here
#   set to 0) as muscle activation, u, increases
#     - this may be modeled as an asymptotical approach from ymax to 0,
#       e.g., y = ymax*exp(-c*u), or as a piecewise linear function,
#       e.g., y = ymax-c*u with floor 0 and ceiling ymax
# - the muscle activation u is a weighted sum of the synaptic potentials
#   generated by the relevant motor neurons, modeled as alpha functions
#   with fixed time constants and size (i.e., changes in size due to
#   changing driving force as the muscle depolarizes are ignored)
u = rate_total = rate_b6b9 * weight_b6b9 + rate_b3 * weight_b3
u = np.clip(u.magnitude.flatten(), 0, None) # replace with 0 any negative values (caused by numerical imprecision)
# y = np.exp(-u_to_y_constant*u)
y = np.clip(1-u_to_y_constant*u, 0, 1)
x = np.sqrt(1-y**2)
model_force = x * model_scale + model_baseline


plt.figure(figsize=(8,4))
plt.plot(rate_total.times.rescale('s'), rate_total)
plt.plot(rate_total.times.rescale('s'), model_force)
plt.plot(sig.times.rescale('s'), sig.magnitude, color='0.75', zorder=-1)
plt.xlim([df_all.loc[swallow_id]['Start (s)'], df_all.loc[swallow_id]['End (s)']])

plt.title(f'Scale: {model_scale} | Baseline: {model_baseline} | B6/B9: ({weight_b6b9}, {tau_b6b9}) | B3: ({weight_b3}, {tau_b3})')

export_dir4 = os.path.join(export_dir, 'firing-rate-models')
if not os.path.exists(export_dir4):
    os.mkdir(export_dir4)
plt.gcf().savefig(os.path.join(export_dir4, f'S {model_scale} BL {model_baseline} B6B9 {weight_b6b9} {tau_b6b9} B3 {weight_b3} {tau_b3}.png'), dpi=300)