- [Figure 1](#[FIGURE-1])
  - [Figure 1A](#🐌-Figure-1A)
- [Figure 2](#[FIGURE-2])
- [Figure 3](#[FIGURE-3])
  - [Figure 3A](#🐌-Figure-3A)
  - [Figure 3B](#🐌-Figure-3B)
  - [Figure 3C](#🐌-Figure-3C)
  - [Figure 3D](#🐌-Figure-3D)
  - [Figure 3E](#🐌-Figure-3E)
  - [Figure 3F](#🐌-Figure-3F)
  - [Figure 3G](#🐌-Figure-3G)
  - [Figure 3H](#🐌-Figure-3H)
  - [Figure 3I ?](#🐌-Figure-3I-?)
- [Figure 4](#[FIGURE-4])
- [Figure 5](#[FIGURE-5])
  - [Figure 5A](#🐌-Figure-5A)
  - [Figure 5B](#🐌-Figure-5B)
- [Figure 6](#[FIGURE-6])

# Preamble

## Import Packages

In [None]:
import os
import datetime
import numpy as np
from scipy import interpolate
import quantities as pq
import elephant
import pandas as pd
import statsmodels.api as sm
import neurotic
from neurotic.datasets.data import _detect_spikes
from utils import BehaviorsDataFrame, CausalAlphaKernel, DownsampleNeoSignal

pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.lines as mlines
from matplotlib.ticker import MultipleLocator
import seaborn as sns

In [None]:
import warnings

# np.nanmax raises a warning if all values are NaN and returns NaN, which is the behavior we want
warnings.filterwarnings('ignore', message='All-NaN slice encountered')

# elephant.statistics.instantaneous_rate always complains about negative values
warnings.filterwarnings('ignore', message='Instantaneous firing rate approximation contains '
                                          'negative values, possibly caused due to machine '
                                          'precision errors')

# with matplotlib>=3.1 and seaborn<=0.9.0, deprecation warnings are raised
# whenever tick marks are placed on the right axis but not the left
from matplotlib.cbook import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)

# don't complain about opening too many figures
plt.rcParams.update({'figure.max_open_warning': 0})

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

## Plot Settings

In [None]:
export_dir = 'manuscript-figures'
if not os.path.exists(export_dir):
    os.mkdir(export_dir)

In [None]:
# general plot settings
sns.set(
#     context = 'poster',
    style = 'ticks',
    font_scale = 1,
    font = 'Palatino Linotype',
)

In [None]:
# display current color palette
with sns.axes_style("darkgrid"):
    sns.palplot(sns.color_palette(None), size=0.5)

In [None]:
markers = ['.', '+', 'x', '1', '4', 'o', 'D']
colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6']

unit_colors = {
    'I2 spikes': 'C9', # light blue
    'I2':        'C9', # light blue
    'B8a/b':     'C6', # pink
    'B3':        'C3', # red
    'B6/B9':     'C2', # green
    'B38':       'C1', # orange
    'B4/B5':     'C0', # dark blue
}
force_colors = {
    'dip':          unit_colors['I2 spikes'],
    'initial rise': unit_colors['B8a/b'],
    'rise':         unit_colors['B8a/b'],
    'plateau':      unit_colors['B6/B9'],
    'drop':         'gray',
    'shoulder':     unit_colors['B38'],
}

## Data Parameters

In [None]:
feeding_bouts = {
    # (animal, food, bout_index): (
    #     data_set_name,
    #     channel_names,
    #     time_window,
    #     epoch_types_to_keep,
    #     burst_thresholds,
    # )

    ('JG07', 'Tape nori', 0): (
        'IN VIVO / JG07 / 2018-05-20 / 002',
        ['I2-L', 'RN-L', 'BN2-L', 'BN3-L', 'Force'],
        [2718, 2755], # 5 swallows
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
            'B4/B5':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
        },
    ),
    
    ('JG08', 'Tape nori', 0): (
        'IN VIVO / JG08 / 2018-06-21 / 002',
        ['I2', 'RN', 'BN2', 'BN3', 'Force'],
        [148, 208], # 7 swallows, some bucket and head movement
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
#             'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  3)*pq.Hz, # based on Lu et al. 2015 (end threshold reduced for this animal)
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
            'B4/B5':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
        },
    ),

    ('JG08', 'Tape nori', 1): (
        'IN VIVO / JG08 / 2018-06-21 / 002',
        ['I2', 'RN', 'BN2', 'BN3', 'Force'],
        [664, 701], # 5 swallows, large bucket movement
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
#             'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  3)*pq.Hz, # based on Lu et al. 2015 (end threshold reduced for this animal)
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
            'B4/B5':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
        },
    ),

    ('JG08', 'Tape nori', 2): (
        'IN VIVO / JG08 / 2018-06-21 / 002',
        ['I2', 'RN', 'BN2', 'BN3', 'Force'],
        [1452, 1477], # 3 swallows, some bucket movement
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
#             'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  3)*pq.Hz, # based on Lu et al. 2015 (end threshold reduced for this animal)
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
            'B4/B5':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
        },
    ),

    ('JG11', 'Tape nori', 0): (
        'IN VIVO / JG11 / 2019-04-03 / 004',
        ['I2', 'RN', 'BN2', 'BN3-PROX', 'Force'],
        [1233, 1280], # 5 swallows
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
#             'B4/B5':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
            'B4/B5':     (1.5, 1.5)*pq.Hz, # threshold reduced for this animal because only one neuron appeared to project
        },
    ),

    ('JG12', 'Tape nori', 0): (
        'IN VIVO / JG12 / 2019-05-10 / 002',
        ['I2', 'RN', 'BN2', 'BN3-DIST', 'Force'],
        [437, 465], # 4 swallows
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
            'B4/B5':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
        },
    ),
    
    ('JG12', 'Tape nori', 1): (
        'IN VIVO / JG12 / 2019-05-10 / 002',
        ['I2', 'RN', 'BN2', 'BN3-DIST', 'Force'],
        [2901, 2937], # 5 swallows
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
            'B4/B5':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
        },
    ),

    ('JG14', 'Tape nori', 0): (
        'IN VIVO / JG14 / 2019-07-29 / 004',
        ['I2', 'RN', 'BN2', 'BN3-PROX', 'Force'],
        [831, 870], # 5 swallows
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
            'B4/B5':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
        },
    ),


    
    # for example figures only -- not used in majority of analysis because of long inter-swallow intervals
    ('JG12', 'Tape nori', 101): (
        'IN VIVO / JG12 / 2019-05-10 / 002',
        ['I2', 'RN', 'BN2', 'BN3-DIST', 'Force'],
        [2944.5, 3010], # 6+1 swallows (last one included for plotting its I2 burst)
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
            'B4/B5':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
        },
    ),
    
    # for example figures only -- not used in majority of analysis because of long inter-swallow intervals
    ('JG12', 'Tape nori', 102): (
        'IN VIVO / JG12 / 2019-05-10 / 002',
        ['I2', 'RN', 'BN2', 'BN3-DIST', 'Force'],
        [2999.3, 3010], # 1 swallow
        ['Swallow (tape nori)'],
        {
            'I2 spikes': (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
            'B8a/b':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
            'B3':        ( 8,  2)*pq.Hz, # based on Lu et al. 2015
            'B6/B9':     (10,  5)*pq.Hz, # based on Lu et al. 2015
            'B38':       ( 8,  5)*pq.Hz, # based on McManus et al. 2014
            'B4/B5':     ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
        },
    ),
}


exemplary_bout = ('JG12', 'Tape nori', 101)
exemplary_bout_plot_range = [2944.5, 2969.5] # first 3 swallows
exemplary_swallow = ('JG12', 'Tape nori', 102)

## Helper Functions

In [None]:
def query_union(queries):
    return '|'.join([f'({q})' for q in queries if q is not None])

def label2query(label):
    animal, food = label[:4], label[5:]
    query = f'(Animal == "{animal}") & (Food == "{food}")'
    return query

def contains(series, string):
#     return np.array([string in x for x in series])
    return series.map(lambda x: string in x)

In [None]:
def update_filter(metadata, new_filter):
    i = next((i for i, f in enumerate(metadata['filters']) if f['channel'] == new_filter['channel']), None)
    if i is not None:
        metadata['filters'][i] = new_filter
    else:
        metadata['filters'].append(new_filter)

In [None]:
def get_sig(blk, channel):
    sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == channel), None)
    if sig is None:
        raise Exception(f'Channel "{channel}" could not be found')
    else:
        return sig

In [None]:
def finite_min(x, y):
    '''Workaround for Quantities warning about comparison to NaN'''
    if np.isfinite(x) and np.isfinite(y):
        return min(x, y)
    elif np.isfinite(x):
        return x
    elif np.isfinite(y):
        return y
    else:
        return np.nan

def finite_max(x, y):
    '''Workaround for Quantities warning about comparison to NaN'''
    if np.isfinite(x) and np.isfinite(y):
        return max(x, y)
    elif np.isfinite(x):
        return x
    elif np.isfinite(y):
        return y
    else:
        return np.nan

In [None]:
def find_bursts(st, burst_thresholds):
    '''Find every sequence of spikes that qualifies as a burst'''
    
    isi = elephant.statistics.isi(st).rescale('s')
    iff = 1/isi

    start_freq, end_freq = burst_thresholds
    start_mask = iff > start_freq
    end_mask = iff < end_freq

    bursts = []
    scan_index = -1
    while scan_index < iff.size:
        start_index = None
        end_index = None

        start_mask_indexes = np.where(start_mask)[0]
        start_mask_indexes = start_mask_indexes[start_mask_indexes > scan_index]
        if start_mask_indexes.size == 0:
            break

        start_index = start_mask_indexes[0] # first time that iff rises above start threshold

        end_mask_indexes = np.where(end_mask)[0]
        end_mask_indexes = end_mask_indexes[end_mask_indexes > start_index]
        if end_mask_indexes.size > 0:
            end_index = end_mask_indexes[0] # first time after start that iff drops below end theshold
        else:
            end_index = -1 # end of spike train (include all spikes after start)

        burst = {
            'Start (s)': st[start_index].rescale('s'),
            'End (s)': st[end_index].rescale('s'),
            'Duration (s)': (st[end_index] - st[start_index]).rescale('s'),
            'Number of spikes': end_index-start_index+1 if end_index > 0 else st.size-start_index
        }
        bursts.append(burst)
        if end_index == -1:
            break
        else:
            scan_index = end_index
    
    return bursts

In [None]:
def is_good_burst(burst):
    return burst['Duration (s)'] >= 0.5*pq.s and burst['Number of spikes'] > 2

In [None]:
# this function runs much faster if args are first converted
# from quantities to simple ndarrays (use .rescale('s').magnitude)
def normalize_time(fixed_times, t):
    if not isinstance(t, np.ndarray):
        if type(t) is list:
            t = np.array(t)
        else:
            t = np.array([t])
    
    assert np.all(np.diff(fixed_times[~np.isnan(fixed_times)])>=0), f'fixed_times must be sorted: {fixed_times}'
    assert np.all(np.diff(t[~np.isnan(t)])>=0), f't must be sorted: {t}'
    
    t_min = fixed_times[~np.isnan(fixed_times)].min()
    t_max = fixed_times[~np.isnan(fixed_times)].max()
    
    result = []
    last_found_i = 0
    for ti in t:
        
        found = False
        
        if np.isnan(ti) or ti < t_min or t_max < ti:
#             print(f'time {ti:.3f} was out of bounds for normalization')
            result.append(np.nan)
            continue

        # use a manual O(n) loop instead of using np.searchsorted's much faster binary
        # search O(log(n)) algorithm because NaNs inside fixed_times can fool searchsorted
        for i in range(last_found_i, len(fixed_times)-1):
            before = fixed_times[i]
            after = fixed_times[i+1]
            if np.isfinite(before) and np.isfinite(after):
                if before <= ti <= after:
                    found = True
                    last_found_i = i
                    result.append((ti-before)/(after-before) + i)
                    break

        # if we haven't returned already, then there must be a NaN bordering where t would go
        if not found:
#             print(f'time {ti:.3f} would fall next to an undefined boundary')
            result.append(np.nan)
            continue
    
    return np.array(result)

In [None]:
# this function runs much faster if args are first converted
# from quantities to simple ndarrays (use .rescale('s').magnitude)
def unnormalize_time(fixed_times, t_normalized):
    if not isinstance(t_normalized, np.ndarray):
        if type(t_normalized) is list:
            t_normalized = np.array(t_normalized)
        else:
            t_normalized = np.array([t_normalized])
    
    assert np.all(np.diff(fixed_times[~np.isnan(fixed_times)])>=0), f'fixed_times must be sorted: {fixed_times}'
    
    t_normalized_min = np.where(~np.isnan(fixed_times))[0].min()
    t_normalized_max = np.where(~np.isnan(fixed_times))[0].max()
    
    result = []
    for ti in t_normalized:
        
        if np.isnan(ti) or ti < t_normalized_min or t_normalized_max < ti:
    #         print(f'normalized time {ti:.3f} was out of bounds for un-normalization')
            result.append(np.nan)
            continue

        if np.isclose(ti, t_normalized_max):
            # workaround for numerical imprecision issue
            ti -= 0.000001

        i = int(np.floor(ti))
        before = fixed_times[i]
        after = fixed_times[i+1]
        if np.isfinite(before) and np.isfinite(after):
            result.append((ti - i)*(after-before) + before)
            continue
        else:
            # there is a NaN bordering where t would go
#             print(f'normalized time {ti:.3f} would fall next to an undefined boundary')
            result.append(np.nan)
            continue
    
    return np.array(result)

In [None]:
# times_interp is the array of normalized times at which samples will be taken
# - samples will be taken at regular intervals in normalized time
times_interp = np.linspace(0, 9, 4000) # [0, 9] is the range of normalized times

def resample_sig_in_normalized_time(fixed_times, sig, times_interp=times_interp):
        
        # get normalized times and signal values
        # - normalize_time will put into times_normalized a np.nan wherever a time was not normalizable,
        #   i.e., wherever the time occurred adjacent to a missing fixed time (np.nan in fixed_times)
        times = sig.times.rescale('s').magnitude
        times_normalized = normalize_time(fixed_times.magnitude, times)
        y = sig.magnitude.flatten()
        
        # drop times that could not be normalized due to missing fixed times
        # - interp1d will erroneously interpolate across the gaps created by this deletion,
        #   but we will then replace those interpolated values with np.nan
        where_normalizable = np.where(~np.isnan(times_normalized))[0]
        times_normalized = times_normalized[where_normalizable]
        y = y[where_normalizable]
        
        # resample evenly in normalized time
        # - fill_value will only put np.nan in places outside the min and max of times_normalized
        # - interp1d will interpolate across the regions we deleted, but we want np.nan there,
        #   so we will insert them manually in the next step
        interp_func = interpolate.interp1d(times_normalized, y, kind='linear', bounds_error=False, fill_value=np.nan)
        y_interp = interp_func(times_interp)
        
        # replace erroneously interpolated values with np.nan
        # - now the points in y_interp which would correspond to times that could not be normalized
        #   have been set to np.nan
        where_not_normalizable = np.where(np.isnan(unnormalize_time(fixed_times.magnitude, times_interp)))[0]
        y_interp[where_not_normalizable] = np.nan
        
        return y_interp

In [None]:
def print_column_analysis(column):
    unit = column.split('(')[-1].split(')')[0]
    mean = df_all[column].mean()
    std = df_all[column].std()
    cv = std/abs(mean)
    print(f'Overall mean +/- std: {mean:.2f} +/- {std:.2f} {unit}')
    print(f'Overall coefficient of variation CV: {cv:.2f}')

In [None]:
# def plot_vertical_lines_with_delay(axes, t, delay, force_y, color, clip_on=False):

#     # plot vertical line in force plot at time t
#     axes[-1].add_artist(patches.ConnectionPatch(
#         xyA=(t, force_y), xyB=(t, 1),
#         coordsA='data', coordsB=axes[-1].get_xaxis_transform(),
#         axesA=axes[-1], axesB=axes[-1],
#         color=color, lw=1, ls=':', clip_on=clip_on))
    
#     # plot vertical line through all neural plots at time t-delay
#     axes[-1].add_artist(patches.ConnectionPatch(
#         xyA=(t-delay, 0), xyB=(t-delay, 1),
#         coordsA=axes[-2].get_xaxis_transform(), coordsB=axes[0].get_xaxis_transform(),
#         axesA=axes[-2], axesB=axes[0],
#         color=color, lw=1, ls=':', clip_on=clip_on))
    
#     # connect the two lines
#     axes[-1].add_artist(patches.ConnectionPatch(
#         xyA=(t-delay, 0), xyB=(t, 1),
#         coordsA=axes[-2].get_xaxis_transform(), coordsB=axes[-1].get_xaxis_transform(),
#         axesA=axes[-2], axesB=axes[-1],
#         color=color, lw=1, ls=':', clip_on=clip_on))

In [None]:
def lighten_color(color, amount=0.5):
    '''
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    
    https://stackoverflow.com/a/49601444/3314376
    '''
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])

## Import and Process the Data

### Main dataframe used for most figures

In [None]:
start = datetime.datetime.now()

# use Neo RawIO lazy loading to load much faster and using less memory
# - with lazy=True, filtering parameters specified in metadata are ignored
# - with lazy=True, loading via time_slice requires neo>=0.8.0.dev
# - IMPORTANT: force and I2 filters affect smoothness and possibly threshold crossings and spike detection
lazy = False

# load the metadata containing file paths
metadata = neurotic.MetadataSelector(file='../../data/metadata.yml')

# filter epochs for each bout and perform calculations
df_list = []
last_data_set_name = None
for (animal, food, bout_index), (data_set_name, channel_names, time_window, epoch_types_to_keep, burst_thresholds) in feeding_bouts.items():
    
    ###
    ### LOCATE BEHAVIOR EPOCHS AND SUBEPOCHS
    ###
    
    metadata.select(data_set_name)
    if data_set_name is last_data_set_name:
        # skip reloading the data if it's already in memory
        pass
    else:
        blk = neurotic.load_dataset(metadata, lazy=lazy)
    last_data_set_name = data_set_name
    
    # construct a query for locating behaviors
    behavior_query = f'(Type in {epoch_types_to_keep}) & ({time_window[0]} <= Start) & (End <= {time_window[1]})'
    
    # construct queries for locating epochs associated with each behavior
    # - each query should match at most one epoch
    # - dictionary keys are used as prefixes for the names of new columns
    subepoch_queries = {}
    
    subepoch_queries['I2 protraction activity'] = f'(Type == "I2 protraction activity") & ' \
                                                  f'(@behavior_start-1 <= Start) & (End <= @behavior_end)'
                                                  # must start no earlier than 1 second before behavior and end within it
    
    subepoch_queries['B8 activity']             = f'(Type == "B8 activity") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)'
                                                  # must be fully contained within behavior
    
    subepoch_queries['B3/6/9/10 activity']      = f'(Type == "B3/6/9/10 activity") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)'
                                                  # must be fully contained within behavior

    subepoch_queries['B38 activity']            = f'(Type == "B38 activity") & ' \
                                                  f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)'
                                                  # must start within 2 seconds of behavior end

    subepoch_queries['B4/B5 activity']          = f'(Type == "B4/B5 activity") & ' \
                                                  f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                  # must start within behavior

    subepoch_queries['Force rise start']        = f'(Type == "Force rise start") & ' \
                                                  f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                  # must start within behavior
    
    subepoch_queries['Force plateau start']     = f'(Type == "Force plateau start") & ' \
                                                  f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                  # must start within behavior
    
    subepoch_queries['Force plateau end']       = f'(Type == "Force plateau end") & ' \
                                                  f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)'
                                                  # must start within 2 seconds of behavior end
    
    subepoch_queries['Force drop end']          = f'(Type == "Force drop end") & ' \
                                                  f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)'
                                                  # must start within 2 seconds of behavior end
    
    subepoch_queries['Force shoulder end']      = f'(Type == "Force shoulder end") & ' \
                                                  f'(@behavior_end <= Start) & (Start <= @behavior_end+4)'
                                                  # must start less than 4 seconds after behavior end
    
    # construct a table in which each row is a behavior and subepoch data
    # is added as columns, e.g. df['B38 activity start (s)']
    df = BehaviorsDataFrame(blk.segments[0].epochs, behavior_query, subepoch_queries)

    # renumber behaviors assuming all behaviors are from a single contiguous sequence
    df = df.sort_values('Start (s)').reset_index(drop=True).rename_axis('Behavior_index')

    
    
    ###
    ### START CALCULATIONS
    ###

    # some columns must have type 'object', which
    # can be accomplished by initializing with None or np.nan
    df['Normalization fixed times (s)'] = None
    df['Force, normalized time interpolation (mN)'] = None
    units = [
        'I2 spikes',
        'B8a/b',
        'B3',
        'B6/B9',
        'B38',
        'B4/B5',
    ]
    for unit in units:
        df[unit+' spike train'] = None
        df[unit+' firing rate (Hz)'] = None
        df[unit+' firing rate, normalized time interpolation (Hz)'] = None
        df[unit+' inter-spike intervals (s)'] = None
        df[unit+' all bursts (s)'] = None

        # while we're at it, initialize some other things that might otherwise never be given values
        df[unit+' first burst start (s)'] = np.nan
        df[unit+' first burst end (s)'] = np.nan
        df[unit+' first burst duration (s)'] = 0
        df[unit+' first burst spike count'] = 0
        df[unit+' first burst mean frequency (Hz)'] = np.nan
        df[unit+' last burst start (s)'] = np.nan
        df[unit+' last burst end (s)'] = np.nan
        df[unit+' last burst duration (s)'] = 0
        df[unit+' last burst spike count'] = 0
        df[unit+' last burst mean frequency (Hz)'] = np.nan


    ### SANITY CHECK: plot all channels for entire time window
#     figsize = (9.5, 10) # dimensions for notebook
#     figsize = (11, 8.5) # dimensions for printing
    figsize = (16, 9) # dimensions for wide screens
    fig, axes = plt.subplots(len(channel_names), 1, sharex=True, figsize=figsize)
    channel_units = ['uV', 'uV', 'uV', 'uV', 'mN']
    for i, channel in enumerate(channel_names):
        plt.sca(axes[i])
        sig = get_sig(blk, channel)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale(channel_units[i])
        plt.plot(sig.times, sig.magnitude, c='0.8', lw=1, zorder=-1)
        
        if i == 0:
            plt.title(f'({animal}, {food}, {bout_index}): {data_set_name}')
            
        plt.ylabel(sig.name + ' (' + sig.units.dimensionality.string + ')')
        axes[i].yaxis.set_label_coords(-0.06, 0.5)
        
        if i < len(channel_names)-1:
            # remove right, top, and bottom plot borders, and remove x-axis
            sns.despine(ax=plt.gca(), bottom=True)
            plt.gca().xaxis.set_visible(False)
        else:
            # remove right and top plot borders, and set x-label
            sns.despine(ax=plt.gca())
            plt.xlabel('Time (s)')
                

    
    ### SANITY CHECK: plot smoothed force for entire time window
    channel = 'Force'
    sig = get_sig(blk, channel)
    if lazy:
        sig = sig.time_slice(None, None)
    sig = elephant.signal_processing.butter(sig, lowpass_freq = 5*pq.Hz)
    sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
    sig = sig.rescale('mN')
    plt.plot(sig.times, sig.magnitude, c='k', lw=1, zorder=0)
    force_smoothed_sig = sig
    
    
    
    # iterate over all swallows
    for j, i in enumerate(df.index):
        
        behavior_start = df.loc[i, 'Start (s)']*pq.s
        behavior_end = df.loc[i, 'End (s)']*pq.s
        
        ###
        ### FORCE
        ###
        
        # quantify force in each behavior
        
        force_rise_start = df.loc[i, 'Force rise start start (s)']*pq.s # start of "Force rise start" epoch
        force_plateau_start = df.loc[i, 'Force plateau start start (s)']*pq.s # start of "Force plateau start" epoch
        force_plateau_end = df.loc[i, 'Force plateau end start (s)']*pq.s # start of "Force plateau end" epoch
        force_drop_end = df.loc[i, 'Force drop end start (s)']*pq.s # start of "Force drop end" epoch
        force_shoulder_end = df.loc[i, 'Force shoulder end start (s)']*pq.s # start of "Force shoulder end" epoch
        
        # get the drop time for the previous swallow and the rise time for the next swallow
        epochs_force_rise_start = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force rise start'), None)
        epochs_force_plateau_start = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force plateau start'), None)
        epochs_force_plateau_end = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force plateau end'), None)
        epochs_force_drop_end = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force drop end'), None)
        epochs_force_shoulder_end = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force shoulder end'), None)
        assert epochs_force_rise_start is not None, 'failed to find "Force rise start" epochs'
        assert epochs_force_plateau_start is not None, 'failed to find "Force plateau start" epochs'
        assert epochs_force_plateau_end is not None, 'failed to find "Force plateau end" epochs'
        assert epochs_force_drop_end is not None, 'failed to find "Force drop end" epochs'
        assert epochs_force_shoulder_end is not None, 'failed to find "Force shoulder end" epochs'
        
        try:
            prev_force_plateau_start = df.loc[i, 'Previous force plateau start (s)'] = epochs_force_plateau_start.time_slice(None, force_rise_start)[-1]
            assert force_rise_start-prev_force_plateau_start < 16*pq.s, f'for swallow {i}, previous force plateau start is too far away'
        except IndexError:
            prev_force_plateau_start = df.loc[i, 'Previous force plateau start (s)'] = np.nan
            
        try:
            prev_force_plateau_end = df.loc[i, 'Previous force plateau end (s)'] = epochs_force_plateau_end.time_slice(None, force_rise_start)[-1]
            assert force_rise_start-prev_force_plateau_end < 12*pq.s, f'for swallow {i}, previous force plateau end is too far away'
        except IndexError:
            prev_force_plateau_end = df.loc[i, 'Previous force plateau end (s)'] = np.nan
        
        try:
            prev_force_drop_end = df.loc[i, 'Previous force drop end (s)'] = epochs_force_drop_end.time_slice(None, force_rise_start)[-1]
            assert force_rise_start-prev_force_drop_end < 12*pq.s, f'for swallow {i}, previous force drop end is too far away'
        except IndexError:
            prev_force_drop_end = df.loc[i, 'Previous force drop end (s)'] = np.nan
        
        try:
            prev_force_shoulder_end = df.loc[i, 'Previous force shoulder end (s)'] = epochs_force_shoulder_end.time_slice(None, force_rise_start)[-1]
            if prev_force_shoulder_end < prev_force_drop_end:
                # previous swallow did not have a shoulder and we instead grabbed an earlier shoulder
                prev_force_shoulder_end = df.loc[i, 'Previous force shoulder end (s)'] = np.nan
        except IndexError:
            prev_force_shoulder_end = df.loc[i, 'Previous force shoulder end (s)'] = np.nan
        
        try:
            next_force_rise_start = df.loc[i, 'Next force rise start (s)'] = epochs_force_rise_start.time_slice(force_drop_end, None)[0]
            assert next_force_rise_start-force_drop_end < 12*pq.s, f'for swallow {i}, next force rise start is too far away'
        except IndexError:
            next_force_rise_start = df.loc[i, 'Next force rise start (s)'] = np.nan
        
        # get the list of fixed times for normalization
        normalization_fixed_times = df.at[i, 'Normalization fixed times (s)'] = np.array([
            prev_force_plateau_start,
            prev_force_plateau_end,
            prev_force_drop_end,
            prev_force_shoulder_end,
            force_rise_start,
            force_plateau_start,
            force_plateau_end,
            force_drop_end,
            force_shoulder_end,
            next_force_rise_start,
        ])*pq.s # 'at', not 'loc', is important for inserting list into cell

        # get smoothed force for whole behavior for remaining force calculations
        sig = force_smoothed_sig
        if np.isfinite(force_shoulder_end):
            sig = sig.time_slice(force_rise_start - 1*pq.s, force_shoulder_end + 0.01*pq.s)
        else:
            sig = sig.time_slice(force_rise_start - 1*pq.s, force_drop_end + 0.01*pq.s)
        sig = sig.rescale('mN')

        # find force peak, baseline, and the increase
        force_min_time = df.loc[i, 'Force minimum time (s)'] = elephant.spike_train_generation.peak_detection(sig, 999*pq.mN, sign='below')[0]
        force_min = df.loc[i, 'Force minimum (mN)'] = sig[sig.time_index(force_min_time)][0]
        force_peak_time = df.loc[i, 'Force peak time (s)'] = elephant.spike_train_generation.peak_detection(sig, 0*pq.mN)[0]
        force_peak = df.loc[i, 'Force peak (mN)'] = sig[sig.time_index(force_peak_time)][0]
        force_baseline = df.loc[i, 'Force baseline (mN)'] = sig[sig.time_index(force_rise_start)][0]
        force_increase = df.loc[i, 'Force increase (mN)'] = force_peak-force_baseline
        
        # find force plateau, drop, and shoulder values
        force_plateau_start_value = df.loc[i, 'Force plateau start value (mN)'] = sig[sig.time_index(force_plateau_start)][0]
        force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)'] = sig[sig.time_index(force_plateau_end)][0]
        force_drop_end_value = df.loc[i, 'Force drop end value (mN)'] = sig[sig.time_index(force_drop_end)][0]
        if np.isfinite(force_shoulder_end):
            force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)'] = sig[sig.time_index(force_shoulder_end)][0]
        else:
            force_shoulder_end_value = np.nan

        # find force rise and plateau durations
        force_rise_duration = df.loc[i, 'Force rise duration (s)'] = force_plateau_start - force_rise_start
        force_plateau_duration = df.loc[i, 'Force plateau duration (s)'] = force_plateau_end - force_plateau_start
        force_rise_plateau_duration = df.loc[i, 'Force rise and plateau duration (s)'] = force_plateau_end - force_rise_start
        
        # find average slope during rising phase
        force_rise_increase = df.loc[i, 'Force rise increase (mN)'] = force_plateau_start_value - force_baseline
        force_slope = df.loc[i, 'Force slope (mN/s)'] = (force_rise_increase/force_rise_duration).rescale('mN/s')

        
        ### SANITY CHECK: plot force rise
        plt.sca(axes[channel_names.index('Force')])
        sig2 = sig.time_slice(force_rise_start, force_plateau_start)
        plt.plot(sig2.times, sig2.magnitude, c=force_colors['rise'], lw=2, zorder=1)
        
        ### SANITY CHECK: plot force plateau
        sig2 = sig.time_slice(force_plateau_start, force_plateau_end)
        plt.plot(sig2.times, sig2.magnitude, c=force_colors['plateau'], lw=2, zorder=1)
        
        ### SANITY CHECK: plot force shoulder
        if np.isfinite(force_shoulder_end):
            sig2 = sig.time_slice(force_drop_end, force_shoulder_end)
            plt.plot(sig2.times, sig2.magnitude, c=force_colors['shoulder'], lw=2, zorder=1)

        ### SANITY CHECK: plot force peak, baseline, and plateau values
        plt.plot([force_peak_time],     [force_peak],                marker=7, markersize=5, color='k')
#         plt.plot([force_min_time],      [force_min],                 marker=6, markersize=5, color='k')
        plt.plot([force_rise_start],    [force_baseline],            marker=6, markersize=5, color='k')
        plt.plot([force_plateau_start], [force_plateau_start_value], marker=5, markersize=5, color='k')
        plt.plot([force_plateau_end],   [force_plateau_end_value],   marker=4, markersize=5, color='k')
        
        
        
        ###
        ### FORCE NORMALIZATION
        ###
        
        channel = 'Force'
        sig = get_sig(blk, channel)
        sig = sig.time_slice(prev_force_plateau_start+0.001*pq.s, next_force_rise_start-0.001*pq.s)
        sig = sig.rescale(channel_units[channel_names.index(channel)])
        
        force_interp = df.at[i, 'Force, normalized time interpolation (mN)'] = \
            resample_sig_in_normalized_time(normalization_fixed_times, sig) # 'at', not 'loc', is important for inserting list into cell

        
        
        ###
        ### FIND SPIKE TRAINS
        ###
        
        if lazy:
            if metadata['amplitude_discriminators'] is not None:
                for discriminator in metadata['amplitude_discriminators']:
                    sig = get_sig(blk, discriminator['channel'])
                    if sig is not None:
                        sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
                        st = _detect_spikes(sig, discriminator, blk.segments[0].epochs)
                        st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                        st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                        st = st.time_slice(st_epoch_start, st_epoch_end)
                        df.at[i, st.name+' spike train'] = st # 'at', not 'loc', is important for inserting list into cell
        else:
            for spiketrain in blk.segments[0].spiketrains:
                discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == spiketrain.name), None)
                if discriminator is None:
                    raise Exception(f'For data set "{data_set_name}", discriminator "{spiketrain.name}" could not be found')
                st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                st = spiketrain.time_slice(st_epoch_start, st_epoch_end)
                df.at[i, st.name+' spike train'] = st # 'at', not 'loc', is important for inserting list into cell
    
    
            
        ###
        ### QUANTIFY SPIKE TRAINS AND BURSTS
        ###
        
        for k, unit in enumerate(units):
            st = df.loc[i, unit+' spike train']
            if st is not None:
                df.loc[i, unit+' spike count'] = st.size
                
                # get the neural channel
                channel = st.annotations['channels'][0]
                sig = get_sig(blk, channel)
                    
                # create a continuous smoothed firing rate representation
                # by convolving the spike train with a kernel
                smoothing_kernel = elephant.kernels.GaussianKernel(0.2*pq.s) # 200 ms standard deviation
#                 smoothing_kernel = elephant.kernels.RectangularKernel(0.2*pq.s / (2*np.sqrt(3))) # 200 ms width, 2*sqrt(3) undoes elephant's scaling
                firing_rate = df.at[i, unit+' firing rate (Hz)'] = elephant.statistics.instantaneous_rate(
                    spiketrain=st,
                    t_start=prev_force_plateau_start+0.001*pq.s, # choice of t_start and t_stop here ensures firing rates are recorded as zero far from the burst
                    t_stop=next_force_rise_start-0.001*pq.s,
                    sampling_period=sig.sampling_period,
                    kernel=smoothing_kernel,
                ) # 'at', not 'loc', is important for inserting list into cell
                
                # normalization
                firing_rate_interp = df.at[i, unit+' firing rate, normalized time interpolation (Hz)'] = \
                    resample_sig_in_normalized_time(normalization_fixed_times, firing_rate) # 'at', not 'loc', is important for inserting list into cell
                
                if st.size > 0:

                    # get the signal for the behavior with 10 seconds cushion before and after (for better baseline estimation)
                    sig = sig.time_slice(behavior_start - 10*pq.s, behavior_end + 10*pq.s)
                    sig = sig.rescale(channel_units[channel_names.index(channel)])
                    
                    # find every sequence of spikes that qualifies as a burst
                    bursts = df.at[i, unit+' all bursts (s)'] = find_bursts(st, burst_thresholds[unit]) # 'at', not 'loc', is important for inserting list into cell
                    
                    first_burst_start = np.nan
                    first_burst_end = np.nan
                    first_burst_spike_count = 0
                    first_burst_mean_freq = 0*pq.Hz
                    last_burst_start = np.nan
                    last_burst_end = np.nan
                    last_burst_spike_count = 0
                    last_burst_mean_freq = 0*pq.Hz
                    if len(bursts) > 0:

                        for burst in bursts:
                            if is_good_burst(burst):
                                first_burst_start, first_burst_end = burst['Start (s)'], burst['End (s)']
                                first_burst_duration = first_burst_end-first_burst_start
                                df.loc[i, unit+' first burst start (s)'] = first_burst_start.rescale('s')
                                df.loc[i, unit+' first burst end (s)'] = first_burst_end.rescale('s')
                                first_burst_duration = df.loc[i, unit+' first burst duration (s)'] = first_burst_duration.rescale('s')
                                first_burst_spike_count = df.loc[i, unit+' first burst spike count'] = st.time_slice(first_burst_start, first_burst_end).size
                                first_burst_mean_freq = df.loc[i, unit+' first burst mean frequency (Hz)'] = ((first_burst_spike_count-1)/first_burst_duration).rescale('Hz')
                                
                                # find burst RAUC and mean voltage
                                first_burst_rauc = df.loc[i, unit+' first burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=first_burst_start, t_stop=first_burst_end).rescale('uV*s')
                                first_burst_mean_rect_voltage = df.loc[i, unit+' first burst mean rectified voltage (μV)'] = first_burst_rauc/first_burst_duration

                                break # quit after finding first good burst

                        for burst in reversed(bursts):
                            if is_good_burst(burst):
                                last_burst_start, last_burst_end = burst['Start (s)'], burst['End (s)']
                                last_burst_duration = last_burst_end-last_burst_start
                                df.loc[i, unit+' last burst start (s)'] = last_burst_start.rescale('s')
                                df.loc[i, unit+' last burst end (s)'] = last_burst_end.rescale('s')
                                last_burst_duration = df.loc[i, unit+' last burst duration (s)'] = last_burst_duration.rescale('s')
                                last_burst_spike_count = df.loc[i, unit+' last burst spike count'] = st.time_slice(last_burst_start, last_burst_end).size
                                last_burst_mean_freq = df.loc[i, unit+' last burst mean frequency (Hz)'] = ((last_burst_spike_count-1)/last_burst_duration).rescale('Hz')

                                # find burst RAUC and mean voltage
                                last_burst_rauc = df.loc[i, unit+' last burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=last_burst_start, t_stop=last_burst_end).rescale('uV*s')
                                last_burst_mean_rect_voltage = df.loc[i, unit+' last burst mean rectified voltage (μV)'] = last_burst_rauc/last_burst_duration
    
                                break # quit after finding first (actually, last) good burst

                                    
                    ### SANITY CHECK: plot spikes
                    plt.sca(axes[channel_names.index(channel)])
                    marker = ['.', 'x'][j%2] # alternate markers between behaviors
                    
                    spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
                    plt.scatter(st.times.rescale('s'), spike_amplitudes, marker=marker, c=unit_colors[unit])

                    ### SANITY CHECK: plot burst windows
                    discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == st.name), None)
                    if discriminator is None:
                        raise Exception(f'For data set "{data_set_name}", discriminator "{st.name}" could not be found')
                    bottom = pq.Quantity(discriminator['amplitude'][0], discriminator['units']).rescale(sig.units)
                    top = pq.Quantity(discriminator['amplitude'][1], discriminator['units']).rescale(sig.units)
                    height = top-bottom
                    for burst in bursts:
                        left = burst['Start (s)']
                        right = burst['End (s)']
                        width = right-left
                        linestyle = '-' if is_good_burst(burst) else '--'
                        rect = patches.Rectangle((left, bottom), width, height, ls=linestyle, edgecolor=unit_colors[unit], fill=False)
                        plt.gca().add_patch(rect)

                    ### SANITY CHECK: plot markers for edges of bursts
                    if top > 0:
                        plt.plot([first_burst_start], [top], marker=7, markersize=5, color='k')
                        plt.plot([last_burst_end], [top], marker=7, markersize=5, color='k')
                    else:
                        plt.plot([first_burst_start], [bottom], marker=6, markersize=5, color='k')
                        plt.plot([last_burst_end], [bottom], marker=6, markersize=5, color='k')
        
        
        
        ###
        ### TIMING DELAYS
        ###
        
        i2_burst_start      = df.loc[i, 'I2 spikes first burst start (s)']*pq.s
        i2_burst_end        = df.loc[i, 'I2 spikes last burst end (s)']*pq.s
        i2_burst_duration   = df.loc[i, 'I2 spikes all bursts duration (s)'] = i2_burst_end - i2_burst_start
        b8_burst_start      = df.loc[i, 'B8a/b first burst start (s)']*pq.s
        b8_burst_end        = df.loc[i, 'B8a/b last burst end (s)']*pq.s
        b8_burst_duration   = df.loc[i, 'B8a/b all bursts duration (s)'] = b8_burst_end - b8_burst_start
        b6b9_burst_start    = df.loc[i, 'B6/B9 first burst start (s)']*pq.s
        b6b9_burst_end      = df.loc[i, 'B6/B9 last burst end (s)']*pq.s
        b6b9_burst_duration = df.loc[i, 'B6/B9 all bursts duration (s)'] = b6b9_burst_end - b6b9_burst_start
        b3_burst_start      = df.loc[i, 'B3 first burst start (s)']*pq.s
        b3_burst_end        = df.loc[i, 'B3 last burst end (s)']*pq.s
        b3_burst_duration   = df.loc[i, 'B3 all bursts duration (s)'] = b3_burst_end - b3_burst_start
        b38_burst_start     = df.loc[i, 'B38 first burst start (s)']*pq.s
        b38_burst_end       = df.loc[i, 'B38 last burst end (s)']*pq.s
        b38_burst_duration  = df.loc[i, 'B38 all bursts duration (s)'] = b38_burst_end - b38_burst_start
        
        df.loc[i, 'Next I2 spikes first burst start (s)'] = np.nan # will be set on next iteration
        df.loc[i, 'Next I2 spikes last burst end (s)'] = np.nan # will be set on next iteration
        df.loc[i, 'Next I2 spikes all bursts duration (s)'] = np.nan # will be set on next iteration
        if j != 0:
            df.loc[df.index[j-1], 'Next I2 spikes first burst start (s)'] = i2_burst_start
            df.loc[df.index[j-1], 'Next I2 spikes last burst end (s)'] = i2_burst_end
            df.loc[df.index[j-1], 'Next I2 spikes all bursts duration (s)'] = i2_burst_end - i2_burst_start
        
        # consider B3/B6/B9 bursting if either B3 or B6/B9 is bursting
        b3b6b9_burst_start    = df.loc[i, 'B3/B6/B9 burst start (s)']    = finite_min(b6b9_burst_start, b3_burst_start)
        b3b6b9_burst_end      = df.loc[i, 'B3/B6/B9 burst end (s)']      = finite_max(b6b9_burst_end,   b3_burst_end)
        b3b6b9_burst_duration = df.loc[i, 'B3/B6/B9 burst duration (s)'] = b3b6b9_burst_end - b3b6b9_burst_start
        
        # consider bursting only if B8a/b and B3/B6/B9 are both bursting
        b8_or_b3b6b9_burst_end = df.loc[i, 'B8a/b and B3/B6/B9 conjunction end (s)'] = \
                                           finite_min(b8_burst_end, b3b6b9_burst_end)
        
        # delays from neural to force
        i2_force_rise_start_delay        = df.loc[i, 'Delay from I2 end to force rise start (s)'] = \
                                                     force_rise_start - i2_burst_end

        b8_force_rise_start_delay        = df.loc[i, 'Delay from B8a/b start to force rise start (s)'] = \
                                                     force_rise_start - b8_burst_start
        b8_force_plateau_start_delay     = df.loc[i, 'Delay from B8a/b start to force plateau start (s)'] = \
                                                     force_plateau_start - b8_burst_start
        b8_force_plateau_end_delay       = df.loc[i, 'Delay from B8a/b end to force plateau end (s)'] = \
                                                     force_plateau_end - b8_burst_end
        
        b6b9_force_rise_start_delay      = df.loc[i, 'Delay from B6/B9 start to force rise start (s)'] = \
                                                     force_rise_start - b6b9_burst_start
        b6b9_force_plateau_start_delay   = df.loc[i, 'Delay from B6/B9 start to force plateau start (s)'] = \
                                                     force_plateau_start - b6b9_burst_start
        b6b9_force_plateau_end_delay     = df.loc[i, 'Delay from B6/B9 end to force plateau end (s)'] = \
                                                     force_plateau_end - b6b9_burst_end

        b3b6b9_force_rise_start_delay    = df.loc[i, 'Delay from B3/B6/B9 start to force rise start (s)'] = \
                                                     force_rise_start - b3b6b9_burst_start
        b3b6b9_force_plateau_start_delay = df.loc[i, 'Delay from B3/B6/B9 start to force plateau start (s)'] = \
                                                     force_plateau_start - b3b6b9_burst_start
        b3b6b9_force_plateau_end_delay   = df.loc[i, 'Delay from B3/B6/B9 end to force plateau end (s)'] = \
                                                     force_plateau_end - b3b6b9_burst_end
        
        b3_force_plateau_start_delay     = df.loc[i, 'Delay from B3 start to force plateau start (s)'] = \
                                                     force_plateau_start - b3_burst_start
        b3_force_plateau_end_delay       = df.loc[i, 'Delay from B3 end to force plateau end (s)'] = \
                                                     force_plateau_end - b3_burst_end
        b8_or_b3b6b9_force_plateau_end_delay = \
                                           df.loc[i, 'Delay from either B8a/b or B3/B6/B9 end to force plateau end (s)'] = \
                                                     force_plateau_end - b8_or_b3b6b9_burst_end
        
        b38_force_shoulder_end_delay     = df.loc[i, 'Delay from B38 end to force shoulder end (s)'] = \
                                                     force_shoulder_end - b38_burst_end

        
        
        ###
        ### B8 ACTIVITY BEFORE B3/B6/B9
        ###
        
        st = df.loc[i, 'B8a/b spike train']
        b8_preb3b6b9_burst_duration    = df.loc[i, 'B8a/b pre-B3/B6/B9 burst duration (s)'] = \
                                                   b3b6b9_burst_start - b8_burst_start
        b8_preb3b6b9_burst_spike_count = df.loc[i, 'B8a/b pre-B3/B6/B9 burst spike count'] = \
                                                   st.time_slice(b8_burst_start, b3b6b9_burst_start).size
        b8_preb3b6b9_burst_mean_freq   = df.loc[i, 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)'] = \
                                                   ((b8_preb3b6b9_burst_spike_count-1)/b8_preb3b6b9_burst_duration).rescale('Hz')
        
        # get the neural channel
        channel = st.annotations['channels'][0]
        sig = get_sig(blk, channel)
        
        # get the signal for the behavior with 5 seconds cushion before and after (for better baseline estimation)
        sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
        sig = sig.rescale('uV')

        # find RAUC and mean voltage for B8a/b before B3/B6/B9 start in each behavior
        if np.isfinite(b8_burst_start) and np.isfinite(b3b6b9_burst_start):
            b8_preb3b6b9_rauc = df.loc[i, 'B8a/b pre-B3/B6/B9 burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=b8_burst_start, t_stop=b3b6b9_burst_start).rescale('uV*s')
            b8_preb3b6b9_mean_rect_voltage = df.loc[i, 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)'] = b8_preb3b6b9_rauc/b8_preb3b6b9_burst_duration
        else:
            print(f'Missing either B8a/b burst and/or B3/B6/B9 burst in data set "{data_set_name}" for swallow spanning times ({behavior_start}, {behavior_end})')

        # get the peak smoothed frequency
        firing_rate = elephant.statistics.instantaneous_rate(
            spiketrain=st,
            t_start=st.t_start,
            sampling_period=sig.sampling_period,
            kernel=smoothing_kernel,
        )
        b8_preb3b6b9_burst_peak_smoothed_freq = df.loc[i, 'B8a/b pre-B3/B6/B9 burst peak smoothed frequency (Hz)'] = \
                                                          firing_rate.time_slice(b8_burst_start, b3b6b9_burst_start).max().rescale('Hz')

            
        # get force during rise and plateau
        sig = get_sig(blk, 'Force')
        sig = sig.time_slice(force_rise_start, force_plateau_end)
        sig = sig.rescale('mN')
        
        # get force at end of B8-only burst, offset by delay
        force_b8_only_rise_end = df.loc[i, 'Force delayed B8-only rise end (s)'] = force_rise_start + b8_preb3b6b9_burst_duration
        force_b8_only_rise_height = df.loc[i, 'Force at delayed B8-only rise end (mN)'] = sig[sig.time_index(force_b8_only_rise_end)][0]

        # find average slope during initial rising phase (before B3/B6/B9 begin, offset by delay)
        force_initial_increase = df.loc[i, 'Force initial increase (mN)'] = (force_b8_only_rise_height-force_baseline).rescale('mN')
        force_initial_slope = df.loc[i, 'Force initial slope (mN/s)'] = (force_initial_increase/b8_preb3b6b9_burst_duration).rescale('mN/s')
        
        
        
        ### SANITY CHECK: plot important times across all subplots, with delays set by I2 end and force min
#         if j == 0:
#             # use first swallow's delay from peak protraction to force start for all units in all swallows
#             muscle_delay = i2_force_rise_start_delay
# #             muscle_delay = b8_force_rise_start_delay
#             axes[-1].text(
#                 force_rise_start, 1.05, f"{muscle_delay.rescale('ms'):.0f} ms delay",
#                 horizontalalignment='right', verticalalignment='center', transform=axes[-1].get_xaxis_transform(),
#                 fontsize=8)
        
        for (t, y, c) in [
                (force_rise_start,    force_baseline,            force_colors['rise']),
                (force_plateau_start, force_plateau_start_value, force_colors['plateau']),
                (force_plateau_end,   force_plateau_end_value,   force_colors['plateau']),
                (force_drop_end,      force_drop_end_value,      force_colors['drop']),
                (force_shoulder_end,  force_shoulder_end_value,  force_colors['shoulder'])]:
            if np.isfinite(y):
#                 plot_vertical_lines_with_delay(axes, t, muscle_delay, y, c)
                axes[-1].add_artist(patches.ConnectionPatch(
                    xyA=(t, y), xyB=(t, 1),
                    coordsA='data', coordsB=axes[0].get_xaxis_transform(),
                    axesA=axes[-1], axesB=axes[0],
                    color=c, lw=1, ls=':', zorder=-2))

            
    
    # perform the following after having gone through all behaviors once
    for j, i in enumerate(df.index):
        
        ###
        ### NORMALIZED TIMES
        ###
        
        normalization_fixed_times = df.loc[i, 'Normalization fixed times (s)']
        normalization_fixed_times = normalization_fixed_times.magnitude
        
        i2_burst_start   = df.loc[i, 'I2 spikes first burst start (s)']
        i2_burst_end     = df.loc[i, 'I2 spikes last burst end (s)']
        b8_burst_start   = df.loc[i, 'B8a/b first burst start (s)']
        b8_burst_end     = df.loc[i, 'B8a/b last burst end (s)']
        b6b9_burst_start = df.loc[i, 'B6/B9 first burst start (s)']
        b6b9_burst_end   = df.loc[i, 'B6/B9 last burst end (s)']
        b3_burst_start   = df.loc[i, 'B3 first burst start (s)']
        b3_burst_end     = df.loc[i, 'B3 last burst end (s)']
        b38_burst_start  = df.loc[i, 'B38 first burst start (s)']
        b38_burst_end    = df.loc[i, 'B38 last burst end (s)']
        b4b5_burst_start = df.loc[i, 'B4/B5 first burst start (s)']
        b4b5_burst_end   = df.loc[i, 'B4/B5 last burst end (s)']
        next_i2_burst_start = df.loc[i, 'Next I2 spikes first burst start (s)']
        next_i2_burst_end   = df.loc[i, 'Next I2 spikes last burst end (s)']

        i2_burst_start_normalized      = df.loc[i, 'I2 first burst start (normalized)']      = normalize_time(normalization_fixed_times,
                                                   i2_burst_start)
        i2_burst_end_normalized        = df.loc[i, 'I2 last burst end (normalized)']         = normalize_time(normalization_fixed_times,
                                                   i2_burst_end)
        
        b8_burst_start_normalized      = df.loc[i, 'B8a/b first burst start (normalized)']   = normalize_time(normalization_fixed_times,
                                                   b8_burst_start)
        b8_burst_end_normalized        = df.loc[i, 'B8a/b last burst end (normalized)']      = normalize_time(normalization_fixed_times,
                                                   b8_burst_end)
        
        b6b9_burst_start_normalized    = df.loc[i, 'B6/B9 first burst start (normalized)']   = normalize_time(normalization_fixed_times,
                                                   b6b9_burst_start)
        b6b9_burst_end_normalized      = df.loc[i, 'B6/B9 last burst end (normalized)']      = normalize_time(normalization_fixed_times,
                                                   b6b9_burst_end)
        
        b3_burst_start_normalized      = df.loc[i, 'B3 first burst start (normalized)']      = normalize_time(normalization_fixed_times,
                                                   b3_burst_start)
        b3_burst_end_normalized        = df.loc[i, 'B3 last burst end (normalized)']         = normalize_time(normalization_fixed_times,
                                                   b3_burst_end)

        b38_burst_start_normalized     = df.loc[i, 'B38 first burst start (normalized)']     = normalize_time(normalization_fixed_times,
                                                   b38_burst_start)
        b38_burst_end_normalized       = df.loc[i, 'B38 last burst end (normalized)']        = normalize_time(normalization_fixed_times,
                                                   b38_burst_end)
        
        b4b5_burst_start_normalized    = df.loc[i, 'B4/B5 first burst start (normalized)']   = normalize_time(normalization_fixed_times,
                                                   b4b5_burst_start)
        b4b5_burst_end_normalized      = df.loc[i, 'B4/B5 last burst end (normalized)']      = normalize_time(normalization_fixed_times,
                                                   b4b5_burst_end)
        
        next_i2_burst_start_normalized = df.loc[i, 'Next I2 first burst start (normalized)'] = normalize_time(normalization_fixed_times,
                                                   next_i2_burst_start)
        next_i2_burst_end_normalized   = df.loc[i, 'Next I2 last burst end (normalized)']    = normalize_time(normalization_fixed_times,
                                                   next_i2_burst_end)

        
        
    ###
    ### FINISH
    ###
    
    # optimize plot margins
    plt.subplots_adjust(
        left   = 0.1,
        right  = 0.99,
        top    = 0.96,
        bottom = 0.06,
        hspace = 0.15,
    )
    
    # export figure
    export_dir2 = os.path.join(export_dir, 'sanity-checks')
    if not os.path.exists(export_dir2):
        os.mkdir(export_dir2)
    plt.gcf().savefig(os.path.join(export_dir2, f'{animal} {food} {bout_index}.png'), dpi=300)
    
    # index the table on 4 variables so that this dataframe can later be merged with others
    df['Animal'] = animal
    df['Food'] = food
    df['Bout_index'] = bout_index
    df = df.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
    
    df_list += [df]
    
df_all = pd.concat(df_list, sort=False).sort_index()

# move exemplary behaviors to separate dataframes
df_exemplary_swallow = df_all.loc[exemplary_swallow].copy()
df_exemplary_bout = df_all.loc[exemplary_bout].copy()
df_all = df_all.drop(exemplary_swallow)
df_all = df_all.drop(exemplary_bout)

# rename output files for exemplars
old_path = os.path.join(export_dir2, f'{exemplary_swallow[0]} {exemplary_swallow[1]} {exemplary_swallow[2]}.png')
new_path = os.path.join(export_dir2, 'exemplary_swallow.png')
if os.path.exists(old_path):
    if os.path.exists(new_path):
        os.remove(new_path)
    os.rename(old_path, new_path)

old_path = os.path.join(export_dir2, f'{exemplary_bout[0]} {exemplary_bout[1]} {exemplary_bout[2]}.png')
new_path = os.path.join(export_dir2, 'exemplary_bout.png')
if os.path.exists(old_path):
    if os.path.exists(new_path):
        os.remove(new_path)
    os.rename(old_path, new_path)

end = datetime.datetime.now()
print('elapsed time:', end-start)

In [None]:
start = datetime.datetime.now()

for (animal, food, bout_index), (data_set_name, channel_names, time_window, epoch_types_to_keep, burst_thresholds) in feeding_bouts.items():

    if (animal, food, bout_index) == exemplary_swallow:
        # skip the lone swallow
        continue
    elif (animal, food, bout_index) == exemplary_bout:
        df = df_exemplary_bout
    else:
        df = df_all.loc[animal, food, bout_index]

    # load the data
    metadata = neurotic.MetadataSelector('../../data/metadata.yml')
    metadata.select(data_set_name)
    blk = neurotic.load_dataset(metadata, lazy=True)

    units = [
        'I2 spikes',
        'B8a/b',
        'B3',
        'B6/B9',
        'B38',
        'B4/B5',
    ]

    # figsize = (9.5, 10) # dimensions for notebook
    # figsize = (11, 8.5) # dimensions for printing
    figsize = (16, 9) # dimensions for wide screens
    fig, axes = plt.subplots(len(units)+1, 2, sharex='col', figsize=figsize)

    for k, unit in enumerate(units):
        # get the subplot axes handles
        ax_left, ax_right = axes[k]

        # set y-axis label
        ax_left.set_ylabel(unit+' (Hz)')
        ax_left.yaxis.set_label_coords(-0.06, 0.5)

        # remove right, top, and bottom plot borders, and remove x-axis
        sns.despine(ax=ax_left, bottom=True)
        sns.despine(ax=ax_right, bottom=True)
        ax_left.xaxis.set_visible(False)
        ax_right.xaxis.set_visible(False)

    # remove right and top plot borders from bottom panel, and set x-label
    ax_left, ax_right = axes[-1]
    sns.despine(ax=ax_left)
    sns.despine(ax=ax_right)
    ax_left.set_xlabel('Time (s)')
    ax_right.set_xlabel('Time (normalized)')

    # plot force in real time
    ax_left, ax_right = axes[-1]
    channel = 'Force'
    sig = get_sig(blk, channel)
    sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
    sig = sig.rescale(channel_units[channel_names.index(channel)])
    ax_left.plot(sig.times, sig.magnitude, c='0.8', lw=1)
    ax_left.set_ylabel(sig.name + ' (' + sig.units.dimensionality.string + ')')
    ax_left.yaxis.set_label_coords(-0.06, 0.5)
    
    all_normalized_times_series = {}
    for unit in units:
        all_normalized_times_series[unit] = np.zeros((0, times_interp.size))
    all_normalized_times_series['Force'] = np.zeros((0, times_interp.size))

    for j, i in enumerate(df.index):
        for k, unit in enumerate(units):
            ax_left, ax_right = axes[k]

            
            # raster plot
            st = df.loc[i, unit+' spike train']
            ax_left.eventplot(positions=st, lineoffsets=-1, colors=unit_colors[unit])

            
            # plot the firing rates in real time
            firing_rate = df.loc[i, unit+' firing rate (Hz)']
            ax_left.plot(firing_rate.times.rescale('s'), firing_rate, c=unit_colors[unit])

            
            # plot firing rates in normalized time
            firing_rate_interp = df.loc[i, unit+' firing rate, normalized time interpolation (Hz)']
            all_normalized_times_series[unit] = np.concatenate([all_normalized_times_series[unit], firing_rate_interp[np.newaxis, :]])
            ax_right.plot(times_interp, firing_rate_interp, c=lighten_color(unit_colors[unit], amount=0.7))

            
        # plot force in normalized time
        ax_left, ax_right = axes[-1]
        force_interp = df.at[i, 'Force, normalized time interpolation (mN)']
        all_normalized_times_series['Force'] = np.concatenate([all_normalized_times_series['Force'], force_interp[np.newaxis, :]])
        ax_right.plot(times_interp, force_interp, c='0.8', lw=1)

        
        # plot force phase boundaries in real time
        normalization_fixed_times = df.at[i, 'Normalization fixed times (s)'].rescale('s').magnitude
        for m, t in enumerate(normalization_fixed_times[3:8]):
            if m == 0:
                color = force_colors['rise']
            else:
                color = '0.75'
#             axes[-1][0].axvline(x=t, lw=1, ls=':', c=color, zorder=-1)
            axes[-1][0].add_artist(patches.ConnectionPatch(
                xyA=(t, 0), xyB=(t, 1),
                coordsA=axes[-1][0].get_xaxis_transform(), coordsB=axes[0][0].get_xaxis_transform(),
                axesA=axes[-1][0], axesB=axes[0][0],
                color=color, lw=1, ls=':'))

            
    # plot force phase boundaries in normalized time
    for m in range(len(normalization_fixed_times)):
        if m == 3:
            color = force_colors['rise']
        else:
            color = '0.75'
#         axes[-1][1].axvline(x=m, lw=1, ls=':', c=color, zorder=-1)
        axes[-1][1].add_artist(patches.ConnectionPatch(
            xyA=(m, 0), xyB=(m, 1),
            coordsA=axes[-1][1].get_xaxis_transform(), coordsB=axes[0][1].get_xaxis_transform(),
            axesA=axes[-1][1], axesB=axes[0][1],
            color=color, lw=1, ls=':'))
  
        
    # plot firing rate distributions
    for k, unit in enumerate(units):
        ax_left, ax_right = axes[k]

        firing_rate_median = np.nanmedian(all_normalized_times_series[unit], axis=0)
        firing_rate_q1 = np.nanquantile(all_normalized_times_series[unit], q=0.25, axis=0)
        firing_rate_q3 = np.nanquantile(all_normalized_times_series[unit], q=0.75, axis=0)
        ax_right.plot(times_interp, firing_rate_median, c='k', lw=2, zorder=3)
        ax_right.plot(times_interp, firing_rate_q1, c='gray', lw=1)
        ax_right.plot(times_interp, firing_rate_q3, c='gray', lw=1)

    
    # plot force distribution
    ax_left, ax_right = axes[-1]

    force_median = np.nanmedian(all_normalized_times_series['Force'], axis=0)
    force_q1 = np.nanquantile(all_normalized_times_series['Force'], q=0.25, axis=0)
    force_q3 = np.nanquantile(all_normalized_times_series['Force'], q=0.75, axis=0)
    ax_right.plot(times_interp, force_median, c='k', lw=2, zorder=3)
    ax_right.plot(times_interp, force_q1, c='gray', lw=1)
    ax_right.plot(times_interp, force_q3, c='gray', lw=1)


    plt.suptitle(f'({animal}, {food}, {bout_index}): {data_set_name}')
    plt.tight_layout(rect=(0, 0, 1, 0.97))

    # export figure
    export_dir3 = os.path.join(export_dir, 'sanity-checks-firing-rates')
    if not os.path.exists(export_dir3):
        os.mkdir(export_dir3)
    plt.gcf().savefig(os.path.join(export_dir3, f'{animal} {food} {bout_index}.png'), dpi=300)

# rename output file for exemplar
old_path = os.path.join(export_dir3, f'{exemplary_bout[0]} {exemplary_bout[1]} {exemplary_bout[2]}.png')
new_path = os.path.join(export_dir3, 'exemplary_bout.png')
if os.path.exists(old_path):
    if os.path.exists(new_path):
        os.remove(new_path)
    os.rename(old_path, new_path)

end = datetime.datetime.now()
print('elapsed time:', end-start)

### Special dataframe used for Fig 3C only

In [None]:
feeding_bouts_multiple_foods = {
    # (animal, food, bout_index): (data_set_name, time_window, epoch_types_to_keep)
        
    ('JG12', 'Regular nori', 0): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 147,  165], ['Swallow (regular 5-cm nori strip)']),
    ('JG12', 'Regular nori', 1): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 229,  245], ['Swallow (regular 5-cm nori strip)']),
    ('JG12', 'Regular nori', 2): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 277,  291], ['Swallow (regular 5-cm nori strip)']),

    ('JG12', 'Tape nori',  103): ('IN VIVO / JG12 / 2019-05-10 / 002', [2890, 2941], ['Swallow (tape nori)']), # early swallows only
}

# filter epochs for each feeding condition and perform calculations
df_list = []
last_data_set_name = None
for (animal, food, bout_index), (data_set_name, time_window, epoch_types_to_keep) in feeding_bouts_multiple_foods.items():
    
    metadata.select(data_set_name)
    if data_set_name is last_data_set_name:
        # skip reloading the data if it's already in memory
        pass
    else:
        blk = neurotic.load_dataset(metadata, lazy=True)
    last_data_set_name = data_set_name
    
    # construct a query for locating behaviors
    behavior_query = f'(Type in {epoch_types_to_keep}) & ({time_window[0]} <= Start) & (End <= {time_window[1]})'
    
    # construct queries for locating epochs associated with each behavior
    # - each query should match at most one epoch
    # - dictionary keys are used as prefixes for the names of new columns
    subepoch_queries = {}
    
    # construct a table in which each row is a behavior and subepoch data
    # is added as columns, e.g. df['Force start (s)']
    df = BehaviorsDataFrame(blk.segments[0].epochs, behavior_query, subepoch_queries)

    # calculate interbehavior interval assuming all behaviors are from a single contiguous sequence
    df['Interval after (s)'] = np.nan
    previous_i = None
    for i in df.index:
        if previous_i is not None:
            df.loc[previous_i, 'Interval after (s)']  = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'End (s)']
        previous_i = i

    # renumber behaviors assuming all behaviors are from a single contiguous sequence
    df = df.sort_values('Start (s)').reset_index(drop=True).rename_axis('Behavior_index')
    
    # index the table on 4 variables
    df['Animal'] = animal
    df['Food'] = food
    df['Bout_index'] = bout_index
    df = df.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
    
    df_list += [df]
    
df_durations_intervals = pd.concat(df_list, sort=False).sort_index()

## Plotting Functions

In [None]:
def scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, trend_separately=False, tooltips=False, padding=0.05):
    
    for j, (label, query) in enumerate(data_subsets.items()):
        if query is not None:
            df = df_all.query(query)
            ax.scatter(df[xlabel], df[ylabel],
                       label=label, marker=markers[j], c=colors[j], clip_on=False)
            
    all_points = df_all.query(query_union(data_subsets.values()))[[xlabel, ylabel]].dropna()
    
    if len(all_points) > 0:

        xrange = np.ptp(all_points.iloc[:, 0])
        xmin = min(all_points.iloc[:, 0]) - xrange * padding
        xmax = max(all_points.iloc[:, 0]) + xrange * padding
        if xlim is None:
            xlim = [xmin, xmax]
        if xlim[0] is None:
            xlim[0] = xmin
        if xlim[1] is None:
            xlim[1] = xmax

        yrange = np.ptp(all_points.iloc[:, 1])
        ymin = min(all_points.iloc[:, 1]) - yrange * padding
        ymax = max(all_points.iloc[:, 1]) + yrange * padding
        if ylim is None:
            ylim = [ymin, ymax]
        if ylim[0] is None:
            ylim[0] = ymin
        if ylim[1] is None:
            ylim[1] = ymax
    
    if trend_separately:
        for j, (label, query) in enumerate(data_subsets.items()):
            if query is not None:
                df = df_all.query(query)[[xlabel, ylabel]].dropna()
                model = sm.OLS(df.iloc[:,1], sm.add_constant(df.iloc[:,0])).fit()
                model_stats = 'R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(df))
                print(label+':', model_stats)
                model_x = np.linspace(xlim[0], xlim[1], 100)
                model_y = model.params[0] + model.params[1] * model_x
                ax.plot(model_x, model_y, color=colors[j])#, label=model_stats)
                
    if trend:
        model = sm.OLS(all_points.iloc[:,1], sm.add_constant(all_points.iloc[:,0])).fit()
        model_stats = 'R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(all_points))
        print('All points:', model_stats)
        model_x = np.linspace(xlim[0], xlim[1], 100)
        model_y = model.params[0] + model.params[1] * model_x
        ax.plot(model_x, model_y, color='gray')#, label=model_stats)
    
    if tooltips:
        # create a transparent layer containing all points for detecting mouse events
        sc = ax.scatter(all_points.iloc[:, 0], all_points.iloc[:, 1], label=None,
                        alpha=0, # transparent points
                        s=0.001, # small radius of tooltip activation
                       )
        
        # initialize a hidden empty tooltip
        annot = ax.annotate(
            '', xy=(0, 0),                              # initialize empty at origin
            xytext=(0, 15), textcoords='offset points', # position text above point
            color='k', size=10, ha='center',            # small black centered text
            bbox=dict(fc='w', lw=0, alpha=0.6),         # transparent white background
            arrowprops=dict(shrink=0, headlength=7, headwidth=7, width=0, lw=0, color='k'), # small black arrow
        )
        annot.set_visible(False)
        
        # prepare tooltip contents
        tooltip_text = [' '.join([str(x) for x in index]) for index in list(all_points.index)]
        
        # bind the animation of tooltips to a hover event
        def hover(event):
            if event.inaxes == ax:
                cont, ind = sc.contains(event)
                if cont:
                    point_index = ind['ind'][0]
                    pos = sc.get_offsets()[point_index]
                    annot.xy = pos
                    annot.set_text(tooltip_text[point_index])
                    annot.set_visible(True)
                    ax.figure.canvas.draw_idle()
                else:
                    if annot.get_visible():
                        annot.set_visible(False)
                        ax.figure.canvas.draw_idle()
        ax.figure.canvas.mpl_connect('motion_notify_event', hover)
        
    ax.set_xlabel(xlabel)
    ax.set_xlim(xlim)
    ax.set_ylabel(ylabel)#.replace('rectified ',''))
    ax.set_ylim(ylim)

In [None]:
from matplotlib.offsetbox import AnchoredOffsetbox
class AnchoredScaleBar(AnchoredOffsetbox):
    def __init__(self, transform, sizex=0, sizey=0, labelx=None, labely=None, loc=4,
                 pad=0.1, borderpad=0.1, sep=2, prop=None, barcolor="black", barwidth=None, 
                 **kwargs):
        """
        Draw a horizontal and/or vertical  bar with the size in data coordinate
        of the give axes. A label will be drawn underneath (center-aligned).
        - transform : the coordinate frame (typically axes.transData)
        - sizex,sizey : width of x,y bar, in data units. 0 to omit
        - labelx,labely : labels for x,y bars; None to omit
        - loc : position in containing axes
        - pad, borderpad : padding, in fraction of the legend font size (or prop)
        - sep : separation between labels and bars in points.
        - **kwargs : additional arguments passed to base class constructor
        """
        from matplotlib.patches import Rectangle
        from matplotlib.offsetbox import AuxTransformBox, VPacker, HPacker, TextArea, DrawingArea
        bars = AuxTransformBox(transform)
        if sizex:
#             bars.add_artist(Rectangle((0,0), sizex, 0, ec=barcolor, lw=barwidth, fc="none"))
            bars.add_artist(Rectangle((0,0), -sizex, 0, ec=barcolor, lw=barwidth, fc="none"))
        if sizey:
            bars.add_artist(Rectangle((0,0), 0, sizey, ec=barcolor, lw=barwidth, fc="none"))

        if sizex and labelx:
            self.xlabel = TextArea(labelx, minimumdescent=False)
            bars = VPacker(children=[bars, self.xlabel], align="center", pad=0, sep=sep)
        if sizey and labely:
            self.ylabel = TextArea(labely)
#             bars = HPacker(children=[self.ylabel, bars], align="center", pad=0, sep=sep)
            bars = HPacker(children=[bars, self.ylabel], align="center", pad=0, sep=sep)

        AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad,
                                   child=bars, prop=prop, frameon=False, **kwargs)

In [None]:
def prettyplot_with_scalebars(
    blk,
    t_start,
    t_stop,
    plots,
    
    outfile_basename=None, # base name of output files
    export_only=False,     # if True, will not render in notebook
    formats=['pdf', 'svg', 'png'], # extensions of output files
    dpi=300,               # resolution (applicable only for PNG)
    
    figsize=(14, 7),       # figure size in inches
    linewidth=1,           # thickness of lines in points
    layout_settings=None,  # positioning of plot edges and the space between plots
    
    x_scalebar=1*pq.s,     # size of the time scale bar in seconds
    ylabel_padding=10,     # space between trace labels and plots
    scalebar_padding=1,    # space between scale bars and plots
    scalebar_sep=5,        # space between scale bars and scale labels
    barwidth=2,            # thickness of scale bars
):
    
    if export_only:
        plt.ioff()
        
    fig, axes = plt.subplots(len(plots), 1, sharex=True, figsize=figsize, squeeze=False)
    axes = axes.flatten()

    for i, p in enumerate(plots):

        # get the subplot axes handle
        ax = axes[i]

        # select and rescale a channel for the subplot
        sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == p['channel']), None)
        assert sig is not None, f"Signal with name {p['channel']} not found"
        sig = sig.time_slice(t_start, t_stop)
        sig = sig.rescale(p['units'])

        # downsample the data
        sig_downsampled = DownsampleNeoSignal(sig, p.get('decimation_factor', 1))

        # specify the x- and y-data for the subplot
        ax.plot(
            sig_downsampled.times,
            sig_downsampled.as_quantity(),
            linewidth=linewidth,
            color=p.get('color', 'k'),
        )
        
        # hide the box around the subplot
        ax.set_frame_on(False)

        # specify the y-axis label
        ylabel = p.get('ylabel', sig.name)
        if ylabel is not None:
            ax.set_ylabel(ylabel, rotation='horizontal', ha='right', va='center', labelpad=ylabel_padding)

        # specify the plot range
        ax.set_xlim([t_start, t_stop])
        ax.set_ylim(p['ylim'])

        # disable tick marks
        ax.tick_params(
            bottom=False,
            left=False,
            labelbottom=False,
            labelleft=False)

        # add y-axis scale bar
        if p['scalebar'] is not None:
            ax.add_artist(AnchoredScaleBar(
                ax.transData,
                sizey=p['scalebar'],
                labely=f'{p["scalebar"]} {sig.units.dimensionality.string}',

                loc='center left',
                bbox_to_anchor=(1, 0.5),
                bbox_transform=ax.transAxes,

                pad=0,
                borderpad=scalebar_padding,
                sep=scalebar_sep,
                barwidth=barwidth,
            ))
        
    # add time scale bar below final plot
    if x_scalebar is not None:
        axes[-1].add_artist(AnchoredScaleBar(
            axes[-1].transData,
            sizex=x_scalebar.rescale(sig.times.units).magnitude,
            labelx=f'{x_scalebar.magnitude:g} {x_scalebar.units.dimensionality.string}',

            loc='upper right',
            bbox_to_anchor=(1, 0),
            bbox_transform=axes[-1].transAxes,

            pad=0,
            borderpad=scalebar_padding,
            sep=scalebar_sep,
            barwidth=barwidth,
        ))

    # adjust the white space around and between the subplots
    if layout_settings is None:
        fig.tight_layout(h_pad=0, w_pad=0, pad=0)
    else:
        plt.subplots_adjust(**layout_settings)

    if outfile_basename is not None:
        # specify file metadata (applicable only for PDF)
        metadata = dict(
            Subject = 'Data file: '  + blk.file_origin + '\n' +
                      'Start time: ' + str(t_start)    + '\n' +
                      'End time: '   + str(t_stop),
        )

        # write the figure to files
        for ext in formats:
            fig.savefig(outfile_basename+'.'+ext, metadata=metadata, dpi=dpi)

    if export_only:
        plt.ion()
    
    return fig, axes

---

# Figures

## [FIGURE 1]

### 🐌 Figure 1A

In [None]:
start = datetime.datetime.now()

(data_set_name, channel_names, time_window, epoch_types_to_keep, burst_thresholds) = feeding_bouts[exemplary_bout]

t_start, t_stop = exemplary_bout_plot_range*pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300, 'decimation_factor': 100},
]
channel_names = [p['channel'] for p in plots]
channel_units = [p['units'] for p in plots]

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (7, 5),
    linewidth = 0.5,
    x_scalebar = 5*pq.s,
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata) # not lazy so that I2 filter is applied

# plot the data
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

# plot a zero baseline for force
# ax = axes[channel_names.index('Force')]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.5', lw=1, ls='--', zorder=-1)

# plot a vertical line marking time of video frame
video_time = 2955 # sec
axes[-1].set_zorder(-1)
axes[-1].add_artist(patches.ConnectionPatch(
    xyA=(video_time, 0), xyB=(video_time, 1),
    coordsA=axes[-1].get_xaxis_transform(), coordsB=axes[0].get_xaxis_transform(),
    axesA=axes[-1], axesB=axes[0],
    color='gray', lw=1, ls=':'))

fig.savefig(os.path.join(export_dir, 'figure-1A.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

## [FIGURE 2]

Biomechanics of swallowing

## [FIGURE 3]

### 🐌 Figure 3A

NOTE: The plot output by this code is much longer than it will be in the final figure and must be cropped manually. It is rendered with the same amount of time showing as Fig 3B so that time scales are identical.

In [None]:
start = datetime.datetime.now()

data_set_name = 'IN VIVO / JG12 / 2019-05-10 / 002'
t_start, t_stop = [228.8, 392.8] * pq.s # t=278, twidth=164, first few are regular nori strip swallows
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 400},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 400},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 400},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 400},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -50, 450], 'scalebar': 300, 'decimation_factor': 100},
]
channel_names = [p['channel'] for p in plots]
channel_units = [p['units'] for p in plots]

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (14, 5),
    linewidth = 0.5,
    x_scalebar = 10*pq.s,
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata) # not lazy so that I2 filter is applied

# plot the data
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

fig.savefig(os.path.join(export_dir, 'figure-3A-needs-cropped.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 3B

In [None]:
start = datetime.datetime.now()

data_set_name = 'IN VIVO / JG12 / 2019-05-10 / 002'
t_start, t_stop = [2879, 3043] * pq.s # t=2928.2, twidth=164, 19 tape nori swallows
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 400},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 400},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 400},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 400},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -50, 450], 'scalebar': 300, 'decimation_factor': 100},
]
channel_names = [p['channel'] for p in plots]
channel_units = [p['units'] for p in plots]

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (14, 5),
    linewidth = 0.5,
    x_scalebar = 10*pq.s,
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata) # not lazy so that I2 filter is applied

# plot the data
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

# plot a zero baseline for force
ax = axes[channel_names.index('Force')]
force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
ax.axhline(force_zero, color='0.5', lw=1, ls='--', zorder=-1)

# plot a gray rectangle highlighting sequence expanded in other parts of the figure
axes[-1].set_zorder(-1) # needed for trick below
axes[-1].axvspan(
    exemplary_bout_plot_range[0], exemplary_bout_plot_range[1],
    0, 999, clip_on=False, # trick for getting to span entire figure
    facecolor='0.8', edgecolor=None, lw=0, zorder=-2)

fig.savefig(os.path.join(export_dir, 'figure-3B.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 3C

In [None]:
# plt.figure(figsize=(5, 5))
# ax = plt.gca()

# df = df_durations_intervals.reset_index()
# df = df.rename(columns={
#     'Duration (s)':       'Swallow duration',
#     'Interval after (s)': 'Inter-swallow interval'})
# df = pd.melt(df,
#              id_vars=['Animal', 'Food', 'Bout_index', 'Behavior_index'],
#              value_vars=['Swallow duration', 'Inter-swallow interval'],
#              var_name='',
#              value_name='Duration/interval (s)')
# sns.boxplot(hue='Food', y='Duration/interval (s)', x='', data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# # sns.swarmplot(hue='Food', y='Duration/interval (s)', x='', data=df)

# # plot zero line
# plt.axhline(y=0, ls=':', c='gray', zorder=-1)

# plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-3C.png'), dpi=300)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

df = df_durations_intervals.reset_index()
df = df.rename(columns={
    'Duration (s)':       'Swallow duration (s)',
    'Interval after (s)': 'Inter-swallow interval (s)'})

sns.boxplot(x='Food', y='Swallow duration (s)', data=df, ax=axes[0], color='0.75', fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x='Food', y='Swallow duration (s)', data=df, ax=axes[0], color='0.25')
axes[0].set_ylim([0, None])
axes[0].set_xlabel(None)

sns.boxplot(x='Food', y='Inter-swallow interval (s)', data=df, ax=axes[1], color='0.75', fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x='Food', y='Inter-swallow interval (s)', data=df, ax=axes[1], color='0.25')
axes[1].set_xlabel(None)
axes[1].axhline(y=0, ls=':', c='gray', zorder=-1)

plt.tight_layout(w_pad=2)

plt.gcf().savefig(os.path.join(export_dir, 'figure-3C.png'), dpi=300)

### 🐌 Figure 3D

In [None]:
start = datetime.datetime.now()

df = df_exemplary_bout
(data_set_name, channel_names, time_window, epoch_types_to_keep, burst_thresholds) = feeding_bouts[exemplary_bout]

t_start, t_stop = exemplary_bout_plot_range*pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
#     {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300, 'decimation_factor': 100},
]
channel_names = [p['channel'] for p in plots]
channel_units = [p['units'] for p in plots]

kwargs = dict(
#     figsize = (9, 3),
    figsize = (6, 3),
    linewidth = 0.5,
    x_scalebar = 5*pq.s,
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata) # not lazy so that I2 filter is applied

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
    'I2 spikes',
#     'B8a/b',
#     'B6/B9',
#     'B3',
#     'B38',
]

# # plot a zero baseline for force
# ax = axes[channel_names.index('Force')]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    
    # plot force markers
    force_shoulder_end = df.loc[i, 'Force shoulder end start (s)']*pq.s
    force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)']*pq.mN
    force_min_time = df.loc[i, 'Force minimum time (s)']*pq.s
    force_min = df.loc[i, 'Force minimum (mN)']*pq.mN
    ax = axes[channel_names.index('Force')]
    ax.plot([force_shoulder_end], [force_shoulder_end_value], marker=7, markersize=8, color='k')
    ax.plot([force_min_time], [force_min],                    marker=6, markersize=8, color='k')
    
    for k, unit in enumerate(units):
                
        # plot burst windows
        burst_start = df.loc[i, unit+' first burst start (s)']
        if np.isfinite(burst_start) and burst_start < t_stop:
            burst_end = df.loc[i, unit+' last burst end (s)']
            axes[-1].set_zorder(-1) # needed for trick below
            axes[-1].axvspan(
                burst_start, burst_end,
                0, 999, clip_on=False, # trick for getting to span entire figure
                facecolor=lighten_color(unit_colors[unit], amount=0.7), edgecolor=None, lw=0)

fig.savefig(os.path.join(export_dir, 'figure-3D.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 3E

In [None]:
start = datetime.datetime.now()

df = df_exemplary_bout
(data_set_name, channel_names, time_window, epoch_types_to_keep, burst_thresholds) = feeding_bouts[exemplary_bout]

t_start, t_stop = exemplary_bout_plot_range*pq.s
plots = [
#     {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
#     {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300, 'decimation_factor': 100},
]
channel_names = [p['channel'] for p in plots]
channel_units = [p['units'] for p in plots]

kwargs = dict(
#     figsize = (9, 3),
    figsize = (6, 3),
    linewidth = 0.5,
    x_scalebar = 5*pq.s,
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata, lazy=True)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
#     'I2 spikes',
    'B8a/b',
#     'B6/B9',
#     'B3',
#     'B38',
]

# # plot a zero baseline for force
# ax = axes[channel_names.index('Force')]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    
    # plot force markers
    force_rise_start = df.loc[i, 'Force rise start start (s)']*pq.s
    force_baseline = df.loc[i, 'Force baseline (mN)']*pq.mN
    force_plateau_end = df.loc[i, 'Force plateau end start (s)']*pq.s
    force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)']*pq.mN
    ax = axes[channel_names.index('Force')]
    ax.plot([force_rise_start],  [force_baseline],          marker=6, markersize=8, color='k')
    ax.plot([force_plateau_end], [force_plateau_end_value], marker=4, markersize=8, color='k')
    
    for k, unit in enumerate(units):
                
        # plot burst windows
        burst_start = df.loc[i, unit+' first burst start (s)']
        if np.isfinite(burst_start) and burst_start < t_stop:
            burst_end = df.loc[i, unit+' last burst end (s)']
            burst_end = min(burst_end, t_stop) # hack to prevent unclipped rect entering margin
            axes[-1].set_zorder(-1) # needed for trick below
            axes[-1].axvspan(
                burst_start, burst_end,
                0, 999, clip_on=False, # trick for getting to span entire figure
                facecolor=lighten_color(unit_colors[unit], amount=0.7), edgecolor=None, lw=0)

fig.savefig(os.path.join(export_dir, 'figure-3E.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 3F

In [None]:
start = datetime.datetime.now()

df = df_exemplary_bout
(data_set_name, channel_names, time_window, epoch_types_to_keep, burst_thresholds) = feeding_bouts[exemplary_bout]

t_start, t_stop = exemplary_bout_plot_range*pq.s
plots = [
#     {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300, 'decimation_factor': 100},
]
channel_names = [p['channel'] for p in plots]
channel_units = [p['units'] for p in plots]

kwargs = dict(
#     figsize = (9, 3),
    figsize = (6, 3),
    linewidth = 0.5,
    x_scalebar = 5*pq.s,
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata, lazy=True)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
#     'I2 spikes',
#     'B8a/b',
    'B6/B9',
#     'B3',
#     'B38',
]

# # plot a zero baseline for force
# ax = axes[channel_names.index('Force')]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    
    # plot force markers
    force_rise_start = df.loc[i, 'Force rise start start (s)']*pq.s
    force_baseline = df.loc[i, 'Force baseline (mN)']*pq.mN
#     force_plateau_start = df.loc[i, 'Force plateau start start (s)']*pq.s
#     force_plateau_start_value = df.loc[i, 'Force plateau start value (mN)']*pq.mN
    force_plateau_end = df.loc[i, 'Force plateau end start (s)']*pq.s
    force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)']*pq.mN
    ax = axes[channel_names.index('Force')]
    ax.plot([force_rise_start],    [force_baseline],            marker=6, markersize=8, color='k')
#     ax.plot([force_plateau_start], [force_plateau_start_value], marker=5, markersize=8, color='k')
    ax.plot([force_plateau_end],   [force_plateau_end_value],   marker=4, markersize=8, color='k')
    
    for k, unit in enumerate(units):
                
        # plot burst windows
        burst_start = df.loc[i, unit+' first burst start (s)']
        if np.isfinite(burst_start) and burst_start < t_stop:
            burst_end = df.loc[i, unit+' last burst end (s)']
            burst_end = min(burst_end, t_stop) # hack to prevent unclipped rect entering margin
            axes[-1].set_zorder(-1) # needed for trick below
            axes[-1].axvspan(
                burst_start, burst_end,
                0, 999, clip_on=False, # trick for getting to span entire figure
                facecolor=lighten_color(unit_colors[unit], amount=0.7), edgecolor=None, lw=0)

fig.savefig(os.path.join(export_dir, 'figure-3F.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 3G

In [None]:
start = datetime.datetime.now()

df = df_exemplary_bout
(data_set_name, channel_names, time_window, epoch_types_to_keep, burst_thresholds) = feeding_bouts[exemplary_bout]

t_start, t_stop = exemplary_bout_plot_range*pq.s
plots = [
#     {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300, 'decimation_factor': 100},
]
channel_names = [p['channel'] for p in plots]
channel_units = [p['units'] for p in plots]

kwargs = dict(
#     figsize = (9, 3),
    figsize = (6, 3),
    linewidth = 0.5,
    x_scalebar = 5*pq.s,
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata, lazy=True)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
#     'I2 spikes',
#     'B8a/b',
#     'B6/B9',
    'B3',
#     'B38',
]

# # plot a zero baseline for force
# ax = axes[channel_names.index('Force')]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    
    # plot force markers
#     force_rise_start = df.loc[i, 'Force rise start start (s)']*pq.s
#     force_baseline = df.loc[i, 'Force baseline (mN)']*pq.mN
    force_plateau_start = df.loc[i, 'Force plateau start start (s)']*pq.s
    force_plateau_start_value = df.loc[i, 'Force plateau start value (mN)']*pq.mN
    force_plateau_end = df.loc[i, 'Force plateau end start (s)']*pq.s
    force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)']*pq.mN
    ax = axes[channel_names.index('Force')]
#     ax.plot([force_rise_start],    [force_baseline],            marker=6, markersize=8, color='k')
    ax.plot([force_plateau_start], [force_plateau_start_value], marker=5, markersize=8, color='k')
    ax.plot([force_plateau_end],   [force_plateau_end_value],   marker=4, markersize=8, color='k')
    
    for k, unit in enumerate(units):
                
        # plot burst windows
        burst_start = df.loc[i, unit+' first burst start (s)']
        if np.isfinite(burst_start) and burst_start < t_stop:
            burst_end = df.loc[i, unit+' last burst end (s)']
            burst_end = min(burst_end, t_stop) # hack to prevent unclipped rect entering margin
            axes[-1].set_zorder(-1) # needed for trick below
            axes[-1].axvspan(
                burst_start, burst_end,
                0, 999, clip_on=False, # trick for getting to span entire figure
                facecolor=lighten_color(unit_colors[unit], amount=0.7), edgecolor=None, lw=0)

fig.savefig(os.path.join(export_dir, 'figure-3G.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 3H

In [None]:
start = datetime.datetime.now()

df = df_exemplary_bout
(data_set_name, channel_names, time_window, epoch_types_to_keep, burst_thresholds) = feeding_bouts[exemplary_bout]

t_start, t_stop = exemplary_bout_plot_range*pq.s
plots = [
#     {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300, 'decimation_factor': 100},
]
channel_names = [p['channel'] for p in plots]
channel_units = [p['units'] for p in plots]

kwargs = dict(
#     figsize = (9, 3),
    figsize = (6, 3),
    linewidth = 0.5,
    x_scalebar = 5*pq.s,
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata, lazy=True)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
#     'I2 spikes',
#     'B8a/b',
#     'B6/B9',
#     'B3',
    'B38',
]

# # plot a zero baseline for force
# ax = axes[channel_names.index('Force')]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    
    # plot force markers
    force_shoulder_end = df.loc[i, 'Force shoulder end start (s)']*pq.s
    force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)']*pq.mN
    ax = axes[channel_names.index('Force')]
    ax.plot([force_shoulder_end], [force_shoulder_end_value], marker=7, markersize=8, color='k')
    
    for k, unit in enumerate(units):
                
        # plot burst windows
        burst_start = df.loc[i, unit+' first burst start (s)']
        if np.isfinite(burst_start) and burst_start < t_stop:
            burst_end = df.loc[i, unit+' last burst end (s)']
            burst_end = min(burst_end, t_stop) # hack to prevent unclipped rect entering margin
            axes[-1].set_zorder(-1) # needed for trick below
            axes[-1].axvspan(
                burst_start, burst_end,
                0, 999, clip_on=False, # trick for getting to span entire figure
                facecolor=lighten_color(unit_colors[unit], amount=0.7), edgecolor=None, lw=0)

fig.savefig(os.path.join(export_dir, 'figure-3H.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 3I ?

In [None]:
start = datetime.datetime.now()

df = df_exemplary_bout
(data_set_name, channel_names, time_window, epoch_types_to_keep, burst_thresholds) = feeding_bouts[exemplary_bout]

t_start, t_stop = exemplary_bout_plot_range*pq.s
plots = [
#     {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 100},
#     {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
#     {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -20, 375], 'scalebar': 300, 'decimation_factor': 100},
]
channel_names = [p['channel'] for p in plots]
channel_units = [p['units'] for p in plots]

kwargs = dict(
#     figsize = (9, 3),
    figsize = (6, 3),
    linewidth = 0.5,
    x_scalebar = 5*pq.s,
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata, lazy=True)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
#     'I2 spikes',
#     'B8a/b',
#     'B6/B9',
#     'B3',
#     'B38',
    'B4/B5',
]

# # plot a zero baseline for force
# ax = axes[channel_names.index('Force')]
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    
    # plot force markers
    force_rise_start = df.loc[i, 'Force rise start start (s)']*pq.s
    force_baseline = df.loc[i, 'Force baseline (mN)']*pq.mN
    ax = axes[channel_names.index('Force')]
    ax.plot([force_rise_start],  [force_baseline],          marker=6, markersize=8, color='k')
    
    for k, unit in enumerate(units):
                
        # plot burst windows
        burst_start = df.loc[i, unit+' first burst start (s)']
        if np.isfinite(burst_start) and burst_start < t_stop:
            burst_end = df.loc[i, unit+' last burst end (s)']
            burst_end = min(burst_end, t_stop) # hack to prevent unclipped rect entering margin
            axes[-1].set_zorder(-1) # needed for trick below
            axes[-1].axvspan(
                burst_start, burst_end,
                0, 999, clip_on=False, # trick for getting to span entire figure
                facecolor=lighten_color(unit_colors[unit], amount=0.7), edgecolor=None, lw=0)

fig.savefig(os.path.join(export_dir, 'figure-3I.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

## [FIGURE 4]

In [None]:
start = datetime.datetime.now()

df = df_exemplary_swallow
(data_set_name, channel_names, time_window, epoch_types_to_keep, burst_thresholds) = feeding_bouts[exemplary_swallow]

t_start = (df.loc[0, 'I2 spikes first burst start (s)'] - 0.4)*pq.s
t_stop = (df.loc[0, 'Next force rise start (s)'] + 0.4)*pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -50, 260], 'scalebar': 100, 'decimation_factor': 100},
]
channel_names = [p['channel'] for p in plots]
channel_units = [p['units'] for p in plots]

kwargs = dict(
    figsize = (7, 8),
    linewidth = 0.5,
    x_scalebar = 1*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# add/update the force filter
new_force_filter = {'channel': 'Force', 'lowpass': 10}
update_filter(metadata, new_force_filter)

# load the data
blk = neurotic.load_dataset(metadata) # not lazy so that I2 and force filters are applied

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
    'I2 spikes',
    'B8a/b',
    'B6/B9',
    'B3',
    'B38',
    'B4/B5',
]
unit_burst_boxes = {
    'I2 spikes': [-35, 45],
    'B8a/b':     [-20, 12],
    'B6/B9':     [-15, 12],
    'B3':        [-45, 35],
    'B38':       [-12, 12],
    'B4/B5':     [-55, 45],
}

ax = axes[channel_names.index('Force')]

# plot a zero baseline for force
# force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
# ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

# plot force phase boundaries
times = df.loc[0, 'Normalization fixed times (s)'][3:]
for t in times[:-1]:
    ax.axvline(x=t, lw=1, ls=':', c='gray', zorder=-1, clip_on=False)
#     axes[-1].add_artist(patches.ConnectionPatch(
#         xyA=(t, 0), xyB=(t, 1),
#         coordsA=axes[-1].get_xaxis_transform(), coordsB=axes[0].get_xaxis_transform(),
#         axesA=axes[-1], axesB=axes[0],
#         color='gray', lw=1, ls=':', zorder=-1))

# add roman numerals for force phases
ax.annotate('I',   xy=(times[0],          5), xycoords=('data', 'axes points'), ha='right',  # final drop
                                              xytext=(-10, 0), textcoords='offset points')
ax.annotate('II',  xy=(times[0:2].mean(), 5), xycoords=('data', 'axes points'), ha='center') # rise
ax.annotate('III', xy=(times[1:3].mean(), 5), xycoords=('data', 'axes points'), ha='center') # maintenance
ax.annotate('IV',  xy=(times[2:4].mean(), 5), xycoords=('data', 'axes points'), ha='center') # major drop
ax.annotate('V',   xy=(times[3:5].mean(), 5), xycoords=('data', 'axes points'), ha='center') # partial maintenance
ax.annotate('I',   xy=(times[4:6].mean(), 5), xycoords=('data', 'axes points'), ha='center') # final drop

for j, i in enumerate(df.index):
    for k, unit in enumerate(units):
        st = df.loc[i, unit+' spike train']
        if st is not None and st.size > 0:

            # get the neural channel
            channel = st.annotations['channels'][0]
            sig = get_sig(blk, channel)

            # get the signal for the entire bout
            sig = sig.time_slice(t_start, t_stop)
            sig = sig.rescale(channel_units[channel_names.index(channel)])

            # plot spikes
            ax = axes[channel_names.index(channel)]
            spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
            ax.scatter(st.times.rescale('s'), spike_amplitudes, marker='.', s=20, c=unit_colors[unit], zorder=3)

            # plot burst windows
            bursts = df.at[i, unit+' all bursts (s)']
            bottom, top = unit_burst_boxes[unit]
            height = top-bottom
            for burst in bursts:
                if is_good_burst(burst):
                    left = burst['Start (s)']
                    right = burst['End (s)']
                    width = right-left
                    rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor=unit_colors[unit], fill=False, zorder=3, clip_on=False)
                    ax.add_patch(rect)

fig.savefig(os.path.join(export_dir, 'figure-4.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

## [FIGURE 5]

In [None]:
# get all normalization times starting with force rise start
t = np.array([times.magnitude[4:] for times in df_all['Normalization fixed times (s)']])

# get all phase durations
all_phase_durations = np.diff(t).T

# find median phase durations
median_phase_durations = np.nanmedian(all_phase_durations, axis=1)

# copy last 4 phases to beginning to represent "previous" swallow
median_phase_durations = np.concatenate([median_phase_durations[-4:], median_phase_durations])

# convert durations into boundary timings
median_phase_boundaries = np.concatenate([[0], median_phase_durations]).cumsum()

phase_labels = [
    'Previous force\nmaintenance',
    'Previous major\nforce drop',
    'Previous\npartial force\nmaintenance',
    'Previous final\nforce drop',
    'Force rise',
    'Force\nmaintenance',
    'Major\nforce drop',
    'Partial force\nmaintenance',
    'Final\nforce drop',
]

In [None]:
plt.figure()

plt.boxplot(
    [a[np.isfinite(a)] for a in list(all_phase_durations)],
    labels=phase_labels[4:],
    showmeans=True,
)

plt.ylabel('Duration (s)');

In [None]:
for t, l in zip(all_phase_durations, phase_labels[4:]):
    l = l.replace('\n', ' ')
    print(f'{l}:\tmedian {np.nanmedian(t):g}, mean {np.nanmean(t):g} (n={t[np.isfinite(t)].size})')

### 🐌 Figure 5A

In [None]:
plt.figure(figsize=(9, 5))
ax = plt.gca()

units = ['I2', 'B8a/b', 'B6/B9', 'B3', 'B38', 'B4/B5']

for i, unit in enumerate(units):
    t0_data   = df_all[f'{unit} first burst start (normalized)'].dropna()
    t0_data[:] = unnormalize_time(median_phase_boundaries, t0_data.values)
    t0_median = t0_data.median()
    t0_q1     = t0_data.quantile(0.25)
    t0_q3     = t0_data.quantile(0.75)
    
    t1_data   = df_all[f'{unit} last burst end (normalized)'].dropna()
    t1_data[:] = unnormalize_time(median_phase_boundaries, t1_data.values)
    t1_median = t1_data.median()
    t1_q1     = t1_data.quantile(0.25)
    t1_q3     = t1_data.quantile(0.75)
    
    # plot boxes using medians
    height = 0.8
    lw = 1
    rect = patches.Rectangle((t0_median, i-height/2), t1_median-t0_median, height, facecolor=unit_colors[unit], edgecolor='k', lw=lw, clip_on=False)
    ax.add_patch(rect)
    
    # plot quartiles (25% and 75% quantiles)
    ax.add_line(mlines.Line2D([t0_q1, t0_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t0_q3, t0_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q1, t1_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q3, t1_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t0_q1, t0_q3], [i, i], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q1, t1_q3], [i, i], color='k', lw=lw, clip_on=False))
    
    if unit == 'I2':
        # plot I2 again shifted one cycle (5 phases)
        t0_data   = df_all[f'{unit} first burst start (normalized)'].dropna()
        t0_data[:] = unnormalize_time(median_phase_boundaries, t0_data.values+5)
        t0_median = t0_data.median()
        t0_q1     = t0_data.quantile(0.25)
        t0_q3     = t0_data.quantile(0.75)
        t1_data   = df_all[f'{unit} last burst end (normalized)'].dropna()
        t1_data[:] = unnormalize_time(median_phase_boundaries, t1_data.values+5)
        t1_median = t1_data.median()
        t1_q1     = t1_data.quantile(0.25)
        t1_q3     = t1_data.quantile(0.75)
        rect = patches.Rectangle((t0_median, i-height/2), t1_median-t0_median, height, facecolor=unit_colors[unit], edgecolor='k', lw=lw, clip_on=False)
        ax.add_patch(rect)
        ax.add_line(mlines.Line2D([t0_q1, t0_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
        ax.add_line(mlines.Line2D([t0_q3, t0_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
        ax.add_line(mlines.Line2D([t1_q1, t1_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
        ax.add_line(mlines.Line2D([t1_q3, t1_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
        ax.add_line(mlines.Line2D([t0_q1, t0_q3], [i, i], color='k', lw=lw, clip_on=False))
        ax.add_line(mlines.Line2D([t1_q1, t1_q3], [i, i], color='k', lw=lw, clip_on=False))

###
### LABELS AND ANNOTATIONS
###

# drop "Previous force maintenance" because I2 distribution starts after it
trimmed_median_phase_boundaries = median_phase_boundaries[1:]
trimmed_phase_labels = phase_labels[1:]

for x in trimmed_median_phase_boundaries:
    plt.axvline(x=x, ls=':', c='gray', zorder=-1, clip_on=False)

plt.xlabel(None)
plt.ylabel(None)

sns.despine(bottom=True, left=True)
plt.tick_params(bottom=False, left=False) # disable tick marks
plt.xticks(trimmed_median_phase_boundaries[:-1]+np.diff(trimmed_median_phase_boundaries)/2, trimmed_phase_labels)
plt.yticks(range(len(units)), units)

plt.xlim(trimmed_median_phase_boundaries[[0, -1]])
plt.ylim(5.9, -0.5)

ax.tick_params(
    axis='x',
    labelsize='small',
)

ax.add_artist(AnchoredScaleBar(
    ax.transData,
    sizex=1,
    labelx='1 s',

    loc='lower left',
    bbox_to_anchor=(0, 0),
    bbox_transform=ax.transAxes,

    pad=0,
    borderpad=1,
    sep=5,
    barwidth=2,
))

plt.tight_layout()

plt.gcf().savefig(os.path.join(export_dir, 'figure-5A.png'), dpi=300)

In [None]:
for i, unit in enumerate(['I2', 'B8a/b', 'B6/B9', 'B3', 'B38', 'B4/B5']):
    t0_data   = df_all[f'{unit} first burst start (normalized)'].dropna()
    t0_data[:] = unnormalize_time(median_phase_boundaries, t0_data.values)
    t0_median = t0_data.median()
    t0_N      = t0_data.size
    
    t1_data   = df_all[f'{unit} last burst end (normalized)'].dropna()
    t1_data[:] = unnormalize_time(median_phase_boundaries, t1_data.values)
    t1_median = t1_data.median()
    t1_N      = t1_data.size

    print(f'{unit}:\t[{t0_median:.2f} (n={t0_N}), {t1_median:.2f} (n={t1_N})], duration: {t1_median-t0_median:.2f}')

### 🐌 Figure 5B

In [None]:
times_interp_unnormalized = unnormalize_time(median_phase_boundaries, times_interp)

# find the number of data points in one cycle
n = np.where((4 <= times_interp) & (times_interp <= 9))[0].size # 4 = start of rise, 9 = end of final drop

units = [
    'I2 spikes',
    'B8a/b',
    'B6/B9',
    'B3',
    'B38',
    'B4/B5',
]

figsize = (9, 8)
# figsize = (6, 8)
# figsize = (9.5, 10) # dimensions for notebook
# figsize = (11, 8.5) # dimensions for printing
# figsize = (16, 9) # dimensions for wide screens
fig, axes = plt.subplots(len(units)+1, 1, sharex='col', figsize=figsize)

###
### UNITS
###

for i, unit in enumerate(units):
    ax = axes[i]
    
    # elevate the Axes for units and remove background colors so that
    # each vertical ConnectionPatch drawn later is visible behind it
    ax.set_zorder(1)
    ax.set_facecolor('none')
    
    # find the firing rate median and quartiles
    firing_rate_data = np.array(list(df_all[unit+' firing rate, normalized time interpolation (Hz)']))
    firing_rate_median = np.nanmedian(firing_rate_data, axis=0)
    firing_rate_q1 = np.nanquantile(firing_rate_data, q=0.25, axis=0)
    firing_rate_q3 = np.nanquantile(firing_rate_data, q=0.75, axis=0)

    # represent repetative swallowing by duplicating the bursts one cycle forward and backward
    firing_rate_median = np.nanmax([
        firing_rate_median,
        np.concatenate((np.zeros(n), firing_rate_median[:-n])), # shifted forward one cycle
        np.concatenate((firing_rate_median[n:], np.zeros(n))),  # shifted backward one cycle
    ], axis=0)
    firing_rate_q1 = np.nanmax([
        firing_rate_q1,
        np.concatenate((np.zeros(n), firing_rate_q1[:-n])), # shifted forward one cycle
        np.concatenate((firing_rate_q1[n:], np.zeros(n))),  # shifted backward one cycle
    ], axis=0)
    firing_rate_q3 = np.nanmax([
        firing_rate_q3,
        np.concatenate((np.zeros(n), firing_rate_q3[:-n])), # shifted forward one cycle
        np.concatenate((firing_rate_q3[n:], np.zeros(n))),  # shifted backward one cycle
    ], axis=0)

    # plot the firing rate median and quartiles (median last so it's on top)
    ax.plot(times_interp_unnormalized, firing_rate_q1,     c=lighten_color(unit_colors[unit], amount=0.7), lw=1, zorder=2)
    ax.plot(times_interp_unnormalized, firing_rate_q3,     c=lighten_color(unit_colors[unit], amount=0.7), lw=1, zorder=2)
    ax.plot(times_interp_unnormalized, firing_rate_median, c=unit_colors[unit], lw=2, zorder=2)

    
    ax.set_ylim([0, None])
    if unit == 'I2 spikes':
        ax.set_ylabel('I2', rotation='horizontal', ha='right', va='center', labelpad=10)
    else:
        ax.set_ylabel(unit, rotation='horizontal', ha='right', va='center', labelpad=10)
#     ax.yaxis.set_label_coords(-0.06, 0.5)

# remove right and top plot borders, and remove x-axis
#     sns.despine(ax=ax)#, bottom=True)
    sns.despine(ax=ax, left=True, right=False)#, trim=True)
#     ax.xaxis.set_visible(False)

#     ax.tick_params(bottom=False) # disable tick marks
    # disable tick marks
    ax.tick_params(
        bottom=False,
        left=False,
        right=True,
        labelbottom=False,
        labelleft=False)

#     ax.set_yticks([0, 10, 20])
#     ax.set_yticklabels([0, 10, '20 Hz'])
#     ax.set_yticks(ax.get_yticks()) # trick to prevent tight_layout from changing ticks
#     yticklabels = [f'{y:g}' for y in ax.get_yticks()]
#     yticklabels[-1] += ' Hz'
#     ax.set_yticklabels(yticklabels)
    


#     # add freq scale bar
#     ax.add_artist(AnchoredScaleBar(
#         ax.transData,
#         sizey=10,
#         labely='10 Hz',

#         loc='center left',
#         bbox_to_anchor=(1, 0.5),
#         bbox_transform=ax.transAxes,

#         pad=0,
#         borderpad=1,
#         sep=5,
#         barwidth=2,
#     ))

###
### FORCE
###

ax = axes[-1]

force_data = np.array(list(df_all['Force, normalized time interpolation (mN)']))

# def tile_array(arr, i_start, i_stop, axis=1):
#     arr = arr.copy()
#     n = i_stop - i_start
    
#     ind_before, ind_after = [slice(None)]*arr.ndim, [slice(None)]*arr.ndim
#     ind_before[axis] = slice(None, i_start)
#     ind_after[axis] = slice(i_stop+1, None)
    
#     # replace values outside of range with NaN
#     arr[tuple(ind_before)] = np.nan
#     arr[tuple(ind_after)] = np.nan
    
#     # construct versions of arr shifted to the left and right
#     arr_nan = np.full_like(arr, np.nan)
#     arr_nan = arr_nan
#     arr_nan = np.full((arr.shape[0], n), np.nan)
#     left = np.concatenate((arr[:, n:], arr_nan), axis=1)
#     right = np.concatenate((arr_nan, arr[:, :-n]), axis=1)
    
#     # merge arrays
#     arr = np.nanmax([
#         arr,
#         left,
#         right
#     ], axis=0) # can axis be generalized?
    
#     return arr

# # replace the "previous" swallow force data with a repeat of the "current" swallow
# force_data = tile_array(force_data, i_start, i_stop)

# find the force median and quartiles
force_median = np.nanmedian(force_data, axis=0)
force_q1 = np.nanquantile(force_data, q=0.25, axis=0)
force_q3 = np.nanquantile(force_data, q=0.75, axis=0)

# plot the force median and quartiles (median last so it's on top)
ax.plot(times_interp_unnormalized, force_q1,     c=lighten_color('k', amount=0.7), lw=1, zorder=2)
ax.plot(times_interp_unnormalized, force_q3,     c=lighten_color('k', amount=0.7), lw=1, zorder=2)
ax.plot(times_interp_unnormalized, force_median, c='k', lw=2, zorder=2)

ax.set_ylim([0, None])
ax.set_ylabel('Force', rotation='horizontal', ha='right', va='center', labelpad=10)
# ax.yaxis.set_label_coords(-0.06, 0.5)

# drop "Previous force maintenance" because I2 distribution starts after it
trimmed_median_phase_boundaries = median_phase_boundaries[1:]
trimmed_phase_labels = phase_labels[1:]

# remove right and top plot borders from bottom panel, and set x-label
sns.despine(ax=ax, left=True, right=False)#, trim=True)
ax.set_xlim(trimmed_median_phase_boundaries[[0, -1]])
# ax.tick_params(bottom=False) # disable tick marks
# disable tick marks
ax.tick_params(
    bottom=False,
    left=False,
    right=True,
#     labelbottom=False,
    labelleft=False)
ax.tick_params(
    axis='x',
    labelsize='small',
)

ax.set_xticks(trimmed_median_phase_boundaries[:-1]+np.diff(trimmed_median_phase_boundaries)/2)
ax.set_xticklabels(trimmed_phase_labels)

# ax.set_yticks([0, 100, 200])
# ax.set_yticklabels([0, 100, '200 mN'])
# ax.set_yticks(ax.get_yticks()) # trick to prevent tight_layout from changing ticks
# yticklabels = [f'{y:g}' for y in ax.get_yticks()]
# yticklabels[-1] += ' mN'
# ax.set_yticklabels(yticklabels)
       
# plot force phase boundaries in normalized time
for t in trimmed_median_phase_boundaries:
    ax.add_artist(patches.ConnectionPatch(
        xyA=(t, 0), xyB=(t, 1),
        coordsA=ax.get_xaxis_transform(), coordsB=axes[0].get_xaxis_transform(),
        axesA=ax, axesB=axes[0],
        color='gray', lw=1, ls=':', zorder=0))

# add time scale bar
ax.add_artist(AnchoredScaleBar(
    ax.transData,
    sizex=1,
    labelx='1 s',

    loc='upper right',
    bbox_to_anchor=(0.98, 1),
    bbox_transform=ax.transAxes,

    pad=0,
    borderpad=1,
    sep=5,
    barwidth=2,
))

fig.tight_layout(h_pad=0, pad=0) # first tight_layout removes excess margins and sets reasonable ylims
# fig.tight_layout(pad=0)

# ylims = [
#     [0, 20], # I2
#     [0, 40], # B8a/b
#     [0, 50], # B6/B9
#     [0, 10], # B3
#     [0, 15], # B38
#     [0, 20], # B4/B5
#     [0, 300], # Force
# ]
# for i, ax in enumerate(axes):
#     ax.set_ylim(ylims[i])

for i, unit in enumerate(units):
    ax = axes[i]
    ax.grid(axis='y', clip_on=False)
    ax.set_yticks(ax.get_yticks()) # trick to prevent tight_layout from changing ticks
    yticklabels = [f'{y:g}' for y in ax.get_yticks()]
    yticklabels[-1] += ' Hz'
    ax.set_yticklabels(yticklabels)
ax = axes[-1]
ax.grid(axis='y', clip_on=False)
ax.set_yticks(ax.get_yticks()) # trick to prevent tight_layout from changing ticks
yticklabels = [f'{y:g}' for y in ax.get_yticks()]
yticklabels[-1] += ' mN'
ax.set_yticklabels(yticklabels)

fig.tight_layout(h_pad=0, pad=0) # second tight_layout makes room for units added to y tick labels
# fig.tight_layout(pad=0)

fig.savefig(os.path.join(export_dir, 'figure-5B.png'), dpi=600)

## [FIGURE 6]

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3 all bursts duration (s)', [0, 5]
xlabel, xlim = 'B6/B9 all bursts duration (s)', [0, 6]
# xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 5]
ylabel, ylim = 'Force plateau duration (s)', [0, 6]
ylabel_alt = 'Force maintenance duration (s)'

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

plt.gcf().savefig(os.path.join(export_dir, 'figure-6.png'), dpi=300)

---

## Random old figures

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from I2 end to force min (s)'
column = 'Delay from I2 end to force rise start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from B3/B6/B9 start to force start (s)'
column = 'Delay from B3/B6/B9 start to force rise start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from B8a/b start to force start (s)'
column = 'Delay from B8a/b start to force rise start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b first burst mean rectified voltage (μV)', [0, None]
xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]
ylabel, ylim = 'Force initial slope (mN/s)', [0, None]
# ylabel, ylim = 'Force initial increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b first burst mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)', [0, None]
xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]
ylabel, ylim = 'Force initial slope (mN/s)', [0, None]
# ylabel, ylim = 'Force initial increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b first burst mean rectified voltage (μV)', [0, None]
xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]
# ylabel, ylim = 'Force initial slope (mN/s)', [0, None]
ylabel, ylim = 'Force initial increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b first burst mean rectified voltage (μV)', [0, None]
# xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean rectified voltage (μV)', [0, None]
xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst mean frequency (Hz)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]
ylabel, ylim = 'Force increase (mN)', [0, None]
# ylabel, ylim = 'Force initial slope (mN/s)', [0, None]
# ylabel, ylim = 'Force initial increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8 activity duration (s)', [0, 8]
xlabel, xlim = 'B8a/b first burst duration (s)', [0, 8]
ylabel, ylim = 'Force rise and plateau duration (s)', [0, 8]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from B3/B6/B9 start to force 50%-height (s)'
column = 'Delay from B3/B6/B9 start to force plateau start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from either B8a/b or B3/B6/B9 end to force 80%-height end (s)'
column = 'Delay from either B8a/b or B3/B6/B9 end to force plateau end (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3/6/9/10 activity duration (s)', [0, 8]
# xlabel, xlim = 'B6/B9 first burst duration (s)', [0, 8]
xlabel, xlim = 'B6/B9 all bursts duration (s)', [0, 8]
ylabel, ylim = 'Force rise and plateau duration (s)', [0, 8]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
df = df.rename(columns={
#     'Delay from B3 start to force 80%-height start (s)': 'Start of burst',
#     'Delay from B3 end to force 80%-height end (s)':     'End of burst'})
    'Delay from B3 start to force plateau start (s)': 'Start of burst',
    'Delay from B3 end to force plateau end (s)':     'End of burst'})
df = pd.melt(df,
             id_vars=['Animal', 'Food', 'Bout_index', 'Behavior_index'],
             value_vars=['Start of burst', 'End of burst'],
             var_name='When?',
             value_name='Delay from B3 to plateau (s)')
sns.boxplot(hue='Animal', x='Delay from B3 to plateau (s)', y='When?', data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(hue='Animal', x='Delay from B3 to plateau (s)', y='When?', data=df)

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3/6/9/10 activity mean rectified voltage (μV)', [0, None]
xlabel, xlim = 'B6/B9 first burst mean rectified voltage (μV)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]
# ylabel, ylim = 'Force 80%-height (mN)', [0, None]
ylabel, ylim = 'Force increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3/6/9/10 activity mean rectified voltage (μV)', [0, None]
xlabel, xlim = 'B6/B9 first burst mean frequency (Hz)', [0, None]
# ylabel, ylim = 'Force slope (mN/s)', [0, None]
ylabel, ylim = 'Force plateau start value (mN)', [0, None]
# ylabel, ylim = 'Force increase (mN)', [0, None]
# ylabel, ylim = 'Force rise increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from B38 end to shoulder end (s)'
column = 'Delay from B38 end to force shoulder end (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B8a/b pre-B3/B6/B9 burst peak smoothed frequency (Hz)', [0, None]
ylabel, ylim = 'Force slope (mN/s)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'Force rise duration (s)', [0, None]
ylabel, ylim = 'Force rise increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
start = datetime.datetime.now()

df = df_exemplary_bout
(data_set_name, channel_names, time_window, epoch_types_to_keep, burst_thresholds) = feeding_bouts[exemplary_bout]

t_start, t_stop = time_window*pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -50, 450], 'scalebar': 300, 'decimation_factor': 100},
]
channel_names = [p['channel'] for p in plots]
channel_units = [p['units'] for p in plots]

kwargs = dict(
    figsize = (12, 6),
    linewidth = 0.5,
    x_scalebar = 10*pq.s,
)

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata) # not lazy so that I2 filter is applied

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

units = [
    'I2 spikes',
    'B8a/b',
    'B6/B9',
    'B3',
    'B38',
]
unit_burst_boxes = {
    'I2 spikes': [-35, 30],
    'B8a/b':     [-20, 12],
    'B6/B9':     [-15, 12],
    'B3':        [-45, 35],
    'B38':       [-12, 12],
}

# plot a zero baseline for force
ax = axes[channel_names.index('Force')]
force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
ax.axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

for j, i in enumerate(df.index):
    for k, unit in enumerate(units):
        st = df.loc[i, unit+' spike train']
        if st is not None and st.size > 0:
                
            # get the neural channel
            channel = st.annotations['channels'][0]
            sig = get_sig(blk, channel)

            # get the signal for the entire bout
            sig = sig.time_slice(t_start, t_stop)
            sig = sig.rescale(channel_units[channel_names.index(channel)])

            # plot spikes
            ax = axes[channel_names.index(channel)]
            spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
            ax.scatter(st.times.rescale('s'), spike_amplitudes, marker='.', s=4, c=unit_colors[unit], zorder=3)

            # plot burst windows
            bursts = df.at[i, unit+' all bursts (s)']
            bottom, top = unit_burst_boxes[unit]
            height = top-bottom
            for burst in bursts:
                if is_good_burst(burst):
                    left = burst['Start (s)']
                    right = burst['End (s)']
                    width = right-left
                    rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor=unit_colors[unit], fill=False, zorder=3, clip_on=False)
                    ax.add_patch(rect)

# fig.savefig(os.path.join(export_dir, 'figure-exemplary-bout-export.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

## Model stuff

In [None]:
swallow_id = ('JG07', 'Tape nori', 0, 0)

(data_set_name, channel_names, time_window, _, _) = feeding_bouts[swallow_id[:3]]

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata, lazy=True)
sig = get_sig(blk, 'Force')
sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
sig = sig.rescale('mN')
sig = elephant.signal_processing.butter(sig, lowpass_freq = 10*pq.Hz)

st_b6b9 = df_all.loc[swallow_id]['B6/B9 spike train']
st_b3 = df_all.loc[swallow_id]['B3 spike train']

weight_b6b9, tau_b6b9 = 1.75, 1
weight_b3,   tau_b3   = 1.75, 0.2
model_scale           = 100
model_baseline        = 85
u_to_y_constant       = 0.005

rate_b6b9 = elephant.statistics.instantaneous_rate(
    spiketrain=st_b6b9,
    sampling_period=0.0002*pq.s,
    kernel=CausalAlphaKernel(tau_b6b9*pq.s),
)

rate_b3 = elephant.statistics.instantaneous_rate(
    spiketrain=st_b3,
    sampling_period=0.0002*pq.s,
    kernel=CausalAlphaKernel(tau_b3*pq.s),
)

rate_total = rate_b6b9 * weight_b6b9 + rate_b3 * weight_b3
# y = np.exp(-u_to_y_constant*rate_total.magnitude)
y = 1 - u_to_y_constant * rate_total.magnitude
y = np.max([y.flatten(), np.zeros(y.size)], axis=0)
y = np.min([y.flatten(), np.ones(y.size)], axis=0)
x = np.sqrt(1-y**2)
x = x * model_scale + model_baseline


plt.figure(figsize=(8,4))
plt.plot(rate_total.times.rescale('s'), rate_total)
plt.plot(rate_total.times.rescale('s'), x)
plt.plot(sig.times.rescale('s'), sig.magnitude, color='0.75', zorder=-1)
plt.xlim([df_all.loc[swallow_id]['Start (s)'], df_all.loc[swallow_id]['End (s)']])

plt.title(f'Scale: {model_scale} | Baseline: {model_baseline} | B6/B9: ({weight_b6b9}, {tau_b6b9}) | B3: ({weight_b3}, {tau_b3})')

export_dir4 = os.path.join(export_dir, 'firing-rate-models')
plt.gcf().savefig(os.path.join(export_dir4, f'S {model_scale} BL {model_baseline} B6B9 {weight_b6b9} {tau_b6b9} B3 {weight_b3} {tau_b3}.png'), dpi=300)