# Jump to a Figure

- [Figure 1](#[FIGURE-1])
  - [Figure 1A](#🐌-Figure-1A)
  - [Figures 1B & 1C](#🐌-Figures-1B-&-1C)
- [Figure 2](#[FIGURE-2])
  - [Figure 2A](#🐌-Figure-2A)
  - [Figure 2B](#🐌-Figure-2B)
  - [Figure 2C](#🐌-Figure-2C)
  - [Figure 2D](#🐌-Figure-2D)
  - [Figure 2E](#🐌-Figure-2E)
- [Figure 3](#[FIGURE-3])
  - [Figure 3A](#🐌-Figure-3A)
  - [Figure 3B](#🐌-Figure-3B)
  - [Figure 3C](#🐌-Figure-3C)
  - [Figure 3D](#🐌-Figure-3D)
  - [Figure 3E](#🐌-Figure-3E)
  - [Figure 3C (alt)](#🐌-Figure-3C-(alt))
  - [Figure 3E (alt)](#🐌-Figure-3E-(alt))
- [Figure 4](#[FIGURE-4])
  - [Figure 4 Statistics](#🐌-Figure-4-Statistics)
  - [Figure 4A](#🐌-Figure-4A)
  - [Figure 4B](#🐌-Figure-4B)
  - [Figure 4C](#🐌-Figure-4C)
  - [Figure 4D](#🐌-Figure-4D)
  - [Figure 4E](#🐌-Figure-4E)
  - [Figure 4F](#🐌-Figure-4F)
  - [Other frequencies (not plotted in manuscript)](#🐌-Other-frequencies-(not-plotted-in-manuscript))
- [Figure 5](#[FIGURE-5])
  - [Figure 5A](#🐌-Figure-5A)
  - [Figure 5B](#🐌-Figure-5B)
  - [Figure 5C](#🐌-Figure-5C)
  - [Figure 5D](#🐌-Figure-5D)
  - [Figure 5E](#🐌-Figure-5E)
  - [Figure 5F](#🐌-Figure-5F)
- [Statistical Tables](#[STATISTICAL-TABLES])
  - [Table 2](#🐌-Table-2)
  - [Table 3](#🐌-Table-3)
  - [Table 4](#🐌-Table-4)

# Preamble

## Import Packages

In [None]:
import os
import datetime
from tqdm.notebook import tqdm
import numpy as np
import scipy as sp
import quantities as pq
import elephant
import pandas as pd
import statsmodels.api as sm
import neurotic
from neurotic.datasets.data import _detect_spikes, _find_bursts
from utils import BehaviorsDataFrame, DownsampleNeoSignal

pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import matplotlib.lines as mlines
from matplotlib.markers import CARETLEFT, CARETRIGHT, CARETUP, CARETDOWN, CARETUPBASE
from matplotlib.ticker import MultipleLocator
import seaborn as sns

from rpy2.robjects.packages import importr
from rpy2.robjects import numpy2ri
numpy2ri.activate()

In [None]:
import warnings

# don't warn about invalid comparisons to NaN
np.seterr(invalid='ignore')

# np.nanmax raises a warning if all values are NaN and returns NaN, which is the behavior we want
warnings.filterwarnings('ignore', message='All-NaN slice encountered')

# elephant.statistics.instantaneous_rate always complains about negative values
warnings.filterwarnings('ignore', message='Instantaneous firing rate approximation contains '
                                          'negative values, possibly caused due to machine '
                                          'precision errors')

# with matplotlib>=3.1 and seaborn<=0.9.0, deprecation warnings are raised
# whenever tick marks are placed on the right axis but not the left
from matplotlib.cbook import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)

# don't complain about opening too many figures
plt.rcParams.update({'figure.max_open_warning': 0})

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

## Plot Settings

In [None]:
export_dir = 'manuscript-figures'
if not os.path.exists(export_dir):
    os.mkdir(export_dir)

In [None]:
# general plot settings
sns.set(
#     context = 'poster',
    style = 'ticks',
    font_scale = 1,
    font = 'Palatino Linotype',
)

In [None]:
# display current color palette
with sns.axes_style('darkgrid'):
    sns.palplot(sns.color_palette(None), size=0.5)

In [None]:
unit_colors = {
    'B38':       '#EFBF46', # yellow
    'I2':        '#DC5151', # red
    'B8a/b':     'C6',      # pink
    'B6/B9':     'C9',      # light blue
    'B3/B6/B9':  '#5A9BC5', # medium blue
    'B3':        '#4F80BD', # dark blue
    'B4/B5':     '#00A86B', # jade green
    'Force':     'k',       # black
}
force_colors = {
    'shoulder':     unit_colors['B38'],
    'dip':          unit_colors['I2'],
    'initial rise': unit_colors['B8a/b'],
    'rise':         unit_colors['B8a/b'],
    'plateau':      unit_colors['B3/B6/B9'],
    'drop':         'gray',
}

In [None]:
# display the selected unit colors
with sns.axes_style('darkgrid'):
    sns.palplot(unit_colors.values(), size=0.5)

In [None]:
# print hex codes for selected unit colors
for unit, color in unit_colors.items():
    print(f'{unit.ljust(10)} {mcolors.to_hex(color).upper()}')

In [None]:
# see simulated colorblindness for selected unit colors
print(
    'https://davidmathlogic.com/colorblind/#' +
    '-'.join([mcolors.to_hex(c).upper().replace('#', '%23') for c in unit_colors.values()]))

## Data Parameters

In [None]:
units = [
    'B38',
    'I2',
    'B8a/b',
    'B6/B9',
    'B3',
    'B4/B5',
]

In [None]:
burst_thresholds_default = {
    'B38':   ( 8,  5)*pq.Hz, # based on McManus et al. 2014
    'I2':    (10,  5)*pq.Hz, # same as Cullins et al. 2015a (based on Hurwitz et al. 1996)
    'B8a/b': ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a (based on Morton and Chiel 1993a)
    'B6/B9': (10,  5)*pq.Hz, # based on Lu et al. 2015
    'B3':    ( 8,  2)*pq.Hz, # based on Lu et al. 2015
    'B4/B5': ( 3,  3)*pq.Hz, # same as Cullins et al. 2015a ? (Table 1 says 3 Hz, text says first/last spike; based on Warman and Chiel 1995 ?)
}

burst_thresholds_by_animal = {
    'JG07': burst_thresholds_default.copy(),
    'JG08': burst_thresholds_default.copy(),
    'JG11': burst_thresholds_default.copy(),
    'JG12': burst_thresholds_default.copy(),
    'JG14': burst_thresholds_default.copy(),
}

# exceptions
burst_thresholds_by_animal['JG07']['B8a/b'] = ( 2,   2  )*pq.Hz # both thresholds reduced for this animal because B8a/b signal was weak with few spikes
burst_thresholds_by_animal['JG08']['B6/B9'] = (10,   3  )*pq.Hz # end threshold reduced for this animal
burst_thresholds_by_animal['JG11']['B4/B5'] = ( 1.5, 1.5)*pq.Hz # both thresholds reduced for this animal because only one neuron appeared to project
burst_thresholds_by_animal['JG14']['B6/B9'] = ( 4,   2  )*pq.Hz # both thresholds reduced for this animal because B6/B9 always fired slowly

In [None]:
channel_units = [
    'uV', # I2
    'uV', # RN
    'uV', # BN2
    'uV', # BN3
    'mN', # Force
]

In [None]:
channel_names_by_animal = {
    'JG07': ['I2-L', 'RN-L', 'BN2-L', 'BN3-L',    'Force'],
    'JG08': ['I2',   'RN',   'BN2',   'BN3',      'Force'],
    'JG11': ['I2',   'RN',   'BN2',   'BN3-PROX', 'Force'],
    'JG12': ['I2',   'RN',   'BN2',   'BN3-DIST', 'Force'],
    'JG14': ['I2',   'RN',   'BN2',   'BN3-PROX', 'Force'],
}

In [None]:
sig_filters_by_animal = {
    'JG07': [{'channel': 'I2-L', 'lowpass': 100}],
    'JG08': [{'channel': 'I2',   'lowpass': 100}],
    'JG11': [{'channel': 'I2',   'lowpass': 100}],
    'JG12': [{'channel': 'I2',   'lowpass': 100}],
    'JG14': [{'channel': 'I2',   'lowpass': 100}],
}

In [None]:
epoch_types_by_food = {
    'Regular nori': ['Swallow (regular 5-cm nori strip)'],
    'Tape nori':    ['Swallow (tape nori)'],
}

In [None]:
feeding_bouts = {
    # (animal, food, bout_index): (data_set_name, time_window)

    ('JG07', 'Regular nori', 0): ('IN VIVO / JG07 / 2018-05-20 / 002', [1496, 1518]), # 4 swallows
    ('JG07', 'Regular nori', 1): ('IN VIVO / JG07 / 2018-05-20 / 002', [1169, 1191]), # 4 swallows
    ('JG07', 'Regular nori', 2): ('IN VIVO / JG07 / 2018-05-20 / 002', [1582, 1615]), # 5 swallows
    ('JG08', 'Regular nori', 0): ('IN VIVO / JG08 / 2018-06-21 / 002', [ 256,  287]), # 4 swallows
    ('JG08', 'Regular nori', 1): ('IN VIVO / JG08 / 2018-06-21 / 002', [ 454,  481]), # 4 swallows
    ('JG11', 'Regular nori', 0): ('IN VIVO / JG11 / 2019-04-03 / 001', [1791, 1819]), # 5 swallows
    ('JG11', 'Regular nori', 1): ('IN VIVO / JG11 / 2019-04-03 / 004', [ 551,  568]), # 3 swallows
    ('JG12', 'Regular nori', 0): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 147,  165]), # 3 swallows
    ('JG12', 'Regular nori', 1): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 229,  245]), # 3 swallows
    ('JG12', 'Regular nori', 2): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 277,  291]), # 3 swallows
    ('JG14', 'Regular nori', 0): ('IN VIVO / JG14 / 2019-07-30 / 001', [1834, 1865]), # 4 swallows
    ('JG14', 'Regular nori', 1): ('IN VIVO / JG14 / 2019-07-30 / 001', [1910, 1943]), # 5 swallows
    ('JG14', 'Regular nori', 2): ('IN VIVO / JG14 / 2019-07-30 / 001', [2052, 2084]), # 5 swallows
    
    ('JG07', 'Tape nori',    0): ('IN VIVO / JG07 / 2018-05-20 / 002', [2718, 2755]), # 5 swallows
    ('JG08', 'Tape nori',    0): ('IN VIVO / JG08 / 2018-06-21 / 002', [ 147,  208]), # 7 swallows, some bucket and head movement
    ('JG08', 'Tape nori',    1): ('IN VIVO / JG08 / 2018-06-21 / 002', [ 664,  701]), # 5 swallows, large bucket movement
    ('JG08', 'Tape nori',    2): ('IN VIVO / JG08 / 2018-06-21 / 002', [1451, 1477]), # 3 swallows, some bucket movement
    ('JG11', 'Tape nori',    0): ('IN VIVO / JG11 / 2019-04-03 / 004', [1227, 1280]), # 5 swallows, inward food movements had very low amplitude
    ('JG12', 'Tape nori',    0): ('IN VIVO / JG12 / 2019-05-10 / 002', [ 436,  465]), # 4 swallows
    ('JG12', 'Tape nori',    1): ('IN VIVO / JG12 / 2019-05-10 / 002', [2901, 2937]), # 5 swallows
    ('JG14', 'Tape nori',    0): ('IN VIVO / JG14 / 2019-07-29 / 004', [ 829,  870]), # 5 swallows
}

In [None]:
# for example figures only -- not used in majority of analysis because of long inter-swallow intervals

exemplary_bout    = ('JG12', 'Tape nori', 101)
exemplary_swallow = ('JG12', 'Tape nori', 102)

feeding_bouts[exemplary_bout]    = ('IN VIVO / JG12 / 2019-05-10 / 002', [2970.7, 2992.0]) # 3 swallows
feeding_bouts[exemplary_swallow] = ('IN VIVO / JG12 / 2019-05-10 / 002', [2977.3, 2984.5]) # 1 swallow

In [None]:
# these unloaded swallows have unreliable inward movement measurement
# and are excluded from some parts of the analysis
unreliable_inward_movement = [
    ('JG07', 'Regular nori', 0, 2), # inward movement unusually short (maybe strip tore?)
    ('JG07', 'Regular nori', 0, 3), # finished strip mid-retraction
    ('JG07', 'Regular nori', 1, 1), # inward movement unusually short (maybe strip tore?)
    ('JG07', 'Regular nori', 1, 3), # finished strip mid-retraction
    ('JG07', 'Regular nori', 2, 2), # inward movement unusually short (maybe strip tore?)
    ('JG07', 'Regular nori', 2, 4), # finished strip mid-retraction
    ('JG08', 'Regular nori', 0, 3), # finished strip mid-retraction
    ('JG08', 'Regular nori', 1, 3), # finished strip mid-retraction
    ('JG11', 'Regular nori', 0, 4), # finished strip mid-retraction
    ('JG11', 'Regular nori', 1, 2), # finished strip mid-retraction
    ('JG12', 'Regular nori', 0, 2), # finished strip mid-retraction
    # JG12 Regular nori 0 1 is OK
    ('JG12', 'Regular nori', 2, 2), # finished strip mid-retraction
    ('JG14', 'Regular nori', 0, 3), # finished strip mid-retraction
    ('JG14', 'Regular nori', 1, 0), # bite-swallow and inward movement unusually short (maybe strip tore?)
    ('JG14', 'Regular nori', 1, 3), # inward movement unusually short (maybe strip tore?)
    ('JG14', 'Regular nori', 1, 4), # finished strip mid-retraction
    ('JG14', 'Regular nori', 2, 1), # not visible
]

# these behaviors at the start of each unloaded bout are also excluded
# because they are actually bite-swallows instead of pure swallows
bite_swallow_behaviors = [
    ('JG07', 'Regular nori', 0, 0),
    ('JG07', 'Regular nori', 1, 0),
    ('JG07', 'Regular nori', 2, 0),
    ('JG08', 'Regular nori', 0, 0),
    ('JG08', 'Regular nori', 1, 0),
    ('JG11', 'Regular nori', 0, 0),
    ('JG11', 'Regular nori', 1, 0),
    ('JG12', 'Regular nori', 0, 0),
    ('JG12', 'Regular nori', 1, 0),
    ('JG12', 'Regular nori', 2, 0),
    ('JG14', 'Regular nori', 0, 0),
    ('JG14', 'Regular nori', 1, 0),
    ('JG14', 'Regular nori', 2, 0),
]
unreliable_inward_movement += bite_swallow_behaviors

## Helper Functions

In [None]:
def query_union(queries):
    return '|'.join([f'({q})' for q in queries if q is not None])

def label2query(label):
    animal, food = label[:4], label[5:]
    query = f'(Animal == "{animal}") & (Food == "{food}")'
    return query

def contains(series, string):
#     return np.array([string in x for x in series])
    return series.map(lambda x: string in x)

In [None]:
def get_sig(blk, channel):
    sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == channel), None)
    if sig is None:
        raise Exception(f'Channel "{channel}" could not be found')
    else:
        return sig

In [None]:
def get_sig_index(blk, channel):
    index = next((i for i, sig in enumerate(blk.segments[0].analogsignals) if sig.name == channel), None)
    if index is None:
        raise Exception(f'Channel "{channel}" could not be found')
    else:
        return index

In [None]:
def apply_filters(blk, metadata):
    # nearly identical to neurotic's implementation except
    # time_slice ensures proxies are loaded
    
    for sig_filter in metadata['filters']:
        index = get_sig_index(blk, sig_filter['channel'])
        high = sig_filter.get('highpass', None)
        low  = sig_filter.get('lowpass',  None)
        if high:
            high *= pq.Hz
        if low:
            low  *= pq.Hz
        blk.segments[0].analogsignals[index] = elephant.signal_processing.butter(  # may raise a FutureWarning
            signal = blk.segments[0].analogsignals[index].time_slice(None, None),
            highpass_freq = high,
            lowpass_freq  = low,
        )
    
    return blk

In [None]:
def is_good_burst(burst):
    time, duration, n_spikes = burst
    return duration >= 0.5*pq.s and n_spikes > 2

In [None]:
# must first convert args from quantities to simple ndarrays
# (use .rescale('s').magnitude)
def normalize_time(fixed_times, t, extrapolate=True):
    
    if not isinstance(t, np.ndarray):
        if type(t) is list:
            t = np.array(t)
        else:
            t = np.array([t])
    
    assert not isinstance(fixed_times, pq.Quantity), f'fixed_times should not be a quantity (use .magnitude)'
    assert not isinstance(t, pq.Quantity), f't should not be a quantity (use .magnitude)'
    assert np.all(np.diff(fixed_times[~np.isnan(fixed_times)])>=0), f'fixed_times must be sorted: {fixed_times}'
    
    # create a copy in which NaNs are removed
    # - this is done because searchsorted does not work well with NaNs
    # - infinities are retained here
    fixed_times_without_nans = fixed_times[np.where(~np.isnan(fixed_times))[0]]
    
    # find the indexes of the fixed values that are just after each value in t
    indexes = np.searchsorted(fixed_times_without_nans, t)
    
    # adjust indexes to account for the NaNs that were removed
    nan_indexes = np.where(np.isnan(fixed_times))[0]
    for nan_index in nan_indexes:
        indexes[indexes >= nan_index] += 1
    
    # increment/decrement any index equal to 0/N, where N=len(fixed_times)
    # - this is needed for values in t that are less/greater than the min/max fixed
    #   time, and for NaNs in t which get assigned an index of N by searchsorted
    # - for values in t less/greater than the min/max fixed time, this
    #   increment/decrement in index will prepare that value to be normalized
    #   using extrapolation based on the first/last interval in fixed_times
    indexes = np.clip(indexes, 1, len(fixed_times)-1)
    
    # compute the normalized values of t using linear interpolation between the
    # bordering fixed times
    # - normalization of values in t that are less than the min fixed time is
    #   accomplished by extrapolation, as the fraction (after-t)/(after-before)
    #   will be greater than 1 since the t value is in fact earlier than the
    #   "before" fixed time
    # - normalization of values in t that are greater than the max fixed time is
    #   accomplished by extrapolation, as the fraction (after-t)/(after-before)
    #   will be less than 0 since the t value is in fact later than the "after"
    #   fixed time
    before = fixed_times[indexes-1]
    after  = fixed_times[indexes]
    t_normalized = indexes - (after-t)/(after-before)
    
    if not extrapolate:
        # extrapolation is already done, so here we undo it by setting to NaN
        # the normalized value for any t that is less/greater than the min/max
        # fixed time
        with np.errstate(invalid='ignore'): # don't warn about invalid comparisons to NaN
            t_normalized[t < np.nanmin(fixed_times)] = np.nan
            t_normalized[t > np.nanmax(fixed_times)] = np.nan
    
    return t_normalized

In [None]:
# must first convert args from quantities to simple ndarrays
# (use .rescale('s').magnitude)
def unnormalize_time(fixed_times, t_normalized, extrapolate=True):
    
    if not isinstance(t_normalized, np.ndarray):
        if type(t_normalized) is list:
            t_normalized = np.array(t_normalized)
        else:
            t_normalized = np.array([t_normalized])
    
    assert not isinstance(fixed_times, pq.Quantity), f'fixed_times should not be a quantity (use .magnitude)'
    assert not isinstance(t_normalized, pq.Quantity), f't_normalized should not be a quantity (use .magnitude)'
    assert np.all(np.diff(fixed_times[~np.isnan(fixed_times)])>=0), f'fixed_times must be sorted: {fixed_times}'
    
    # get the index of the fixed time that comes after each t
    indexes = np.ceil(t_normalized)
    
    # clip the "after" indexes so that the first or last interval
    # in fixed_times will be used for extrapolation
    indexes = np.clip(indexes, 1, len(fixed_times)-1)
    
    # get the fixed times before and after each t
    before = np.array([fixed_times[int(i)-1] if not np.isnan(i) else np.nan for i in indexes])
    after  = np.array([fixed_times[int(i)]   if not np.isnan(i) else np.nan for i in indexes])
    
    # compute the real values of t
    t = after + (t_normalized-indexes)*(after-before)
    
    if not extrapolate:
        # extrapolation is already done, so here we undo it by setting to NaN
        # the value for any t that is less/greater than the min/max fixed time
        with np.errstate(invalid='ignore'): # don't warn about invalid comparisons to NaN
            t[t < np.nanmin(fixed_times)] = np.nan
            t[t > np.nanmax(fixed_times)] = np.nan
    
    return t

In [None]:
# these arrays are used as interp_times for resample_sig_in_normalized_time
# depending on the segmentation scheme
# - linspace ensures samples will be taken at regular intervals in normalized time
interp_resolution = 1000
video_seg_interp_times = np.linspace(0, 9, interp_resolution)
force_seg_interp_times = np.linspace(0, 9, interp_resolution)

def resample_sig_in_normalized_time(fixed_times, sig, interp_times):
    
    # get normalized times and signal values
    # - normalize_time will put into times_normalized a np.nan wherever a time was not normalizable,
    #   i.e., wherever the time occurred adjacent to a missing fixed time (np.nan in fixed_times)
    times = sig.times.rescale('s').magnitude
    times_normalized = normalize_time(fixed_times.magnitude, times)
    y = sig.magnitude.flatten()

    # drop times that could not be normalized due to missing fixed times
    # - interp1d will erroneously interpolate across the gaps created by this deletion,
    #   but we will then replace those interpolated values with np.nan
    where_normalizable = np.where(~np.isnan(times_normalized))[0]
    times_normalized = times_normalized[where_normalizable]
    y = y[where_normalizable]

    # resample evenly in normalized time
    # - with bounds_error=False and fill_value=np.nan, interp1d will put np.nan in places outside
    #   the min and max of times_normalized
    # - interp1d will interpolate across the regions we deleted, but we want np.nan there,
    #   so we will insert them manually in the next step
    interp_func = sp.interpolate.interp1d(times_normalized, y, kind='linear', bounds_error=False, fill_value=np.nan)
    y_interp = interp_func(interp_times)

    # replace erroneously interpolated values with np.nan
    # - now the points in y_interp which would correspond to times that could not be normalized
    #   have been set to np.nan
    where_not_normalizable = np.where(np.isnan(unnormalize_time(fixed_times.magnitude, interp_times)))[0]
    y_interp[where_not_normalizable] = np.nan

    return y_interp

In [None]:
def print_column_analysis(column):
    unit = column.split('(')[-1].split(')')[0]
    mean = df_all[column].mean()
    std = df_all[column].std()
    cv = std/abs(mean)
    print(f'Overall mean +/- std: {mean:.2f} +/- {std:.2f} {unit}')
    print(f'Overall coefficient of variation CV: {cv:.2f}')

In [None]:
def differences_test(x, y, x_label='GROUP 1', y_label='GROUP 2', measure_label='MEASURE', units='UNITS', alpha=0.05):

    # descriptive statistics
    x_mean = np.mean(x)
    y_mean = np.mean(y)
    x_std = np.std(x, ddof=1)
    y_std = np.std(y, ddof=1)
    x_n = len(x)
    y_n = len(y)
    assert x_n == y_n, 'expected x and y to be paired but they have unequal length'
    print(f'{x_label}: (M = {x_mean:g}, SD = {x_std:g}, N = {x_n:g})')
    print(f'{y_label}: (M = {y_mean:g}, SD = {y_std:g}, N = {y_n:g})')
    print()

    # Shapiro-Wilk test for normality of differences
    # - equivalent R test: shapiro.test(x-y)
    shapiro_result = r_shapiro_test(x.values-y.values)
    shapiro_W, shapiro_p = shapiro_result['W'], shapiro_result['p']
    shapiro_signif = '*' if shapiro_p < alpha else '(n.s.)'
    print(f'H0: Differences have normal distribution, W = {shapiro_W:g},\tp = {shapiro_p:g} {shapiro_signif}')

    if shapiro_p >= 0.05:
        print('- Because the differences can be assumed to be normal, a paired t-test will be used')

        # paired one-tailed T-test for an increase in means
        # - equivalent R test : t.test(x, y, paired=TRUE, alternative="greater")
        ttest_result = r_t_test(x.values, y.values, paired=True, alternative='greater')
        ttest_t, ttest_p = ttest_result['t'], ttest_result['p']
        ttest_signif = '*' if ttest_p < alpha and x_mean > y_mean else '(n.s.)'
        print(f'H0: Difference in means is not positive,  t = {ttest_t:g},\tp = {ttest_p:g} {ttest_signif}')

        # Cohen's d for effect size
        # - equivalent R function: library(effsize); cohen.d(x, y)
        # - although the R function offers a paired=TRUE case, I'm not using it here
        #   because I don't understand the paper cited in the documentation. It appears
        #   to be a variation on a different effect size measure, so it's unclear that
        #   it should be referred to as "Cohen's d". When paired=TRUE, the result is
        #   generally a little smaller.
        cohen_d = r_effect_size(x.values, y.values)['estimate']
        print()
        print(f'Effect size: Cohen\'s d = {cohen_d:g}')
        
        if ttest_p < alpha and x_mean > y_mean:
            print()
            print(f'"A paired-samples one-tailed t-test indicated that {measure_label} for the ' \
                  f'[{x_n}] animals were significantly [greater/longer] for ' \
                  f'{x_label} (M = {x_mean:.2f} {units}, SD = {x_std:.2f} {units}) than for ' \
                  f'{y_label} (M = {y_mean:.2f} {units}, SD = {y_std:.2f} {units}), ' \
                  f't = {ttest_t:.3f}, df = {x_n-1}, p = {ttest_p:.3f}, Cohen\'s d = {cohen_d:.2f}."')
        else:
            print()
            print(f'"A paired-samples one-tailed t-test indicated no significant increase in {measure_label} in the [{x_n}] animals between ' \
                  f'{x_label} (M = {x_mean:.2f} {units}, SD = {x_std:.2f} {units}) and ' \
                  f'{y_label} (M = {y_mean:.2f} {units}, SD = {y_std:.2f} {units}), ' \
                  f't = {ttest_t:.3f}, df = {x_n-1}, p = {ttest_p:.3f}, Cohen\'s d = {cohen_d:.2f}."')
        
        return ttest_signif

    else:
        print('- Because the differences cannot be assumed to be normal, a Wilcoxon signed rank test will be used')

        # paired Wilcoxon signed rank test for increase in locations (medians?)
        # - equivalent R test: wilcox.test(x, y, paired=TRUE, alternative="greater")
        wilcoxon_result = r_wilcoxon_test(x.values, y.values, paired=True, alternative='greater')
        wilcoxon_W, wilcoxon_p = wilcoxon_result['W'], wilcoxon_result['p']
        wilcoxon_signif = '*' if wilcoxon_p < alpha else '(n.s.)'
        print(f'H0: Difference in medians is not positive, W = {wilcoxon_W:g},\tp = {wilcoxon_p:g} {wilcoxon_signif}')
        
        return wilcoxon_signif

In [None]:
def r_hotelling_T2_test(*args, **kwargs):
    '''Utilizes the R package rrcov's implementation of Hotelling's T-squared test'''
    
    r_rrcov = importr('rrcov')
    result = r_rrcov.T2_test(*args, **kwargs)

    return {
        'T2': result.rx2('statistic')[0],
        'F': result.rx2('statistic')[1],
        'df_num': result.rx2('parameter')[0],
        'df_den': result.rx2('parameter')[1],
        'p': result.rx2('p.value')[0],
    }

In [None]:
def r_wilcoxon_test(*args, **kwargs):
    '''Utilizes R's implementation of the Wilcoxon signed-rank test'''
    
    r_stats = importr('stats')
    result = r_stats.wilcox_test(*args, **kwargs)

    return {
        'W': result.rx2('statistic')[0],
        'p': result.rx2('p.value')[0],
    }

In [None]:
def r_shapiro_test(*args, **kwargs):
    '''Utilizes R's implementation of the Shapiro-Wilk normality test'''
    
    r_stats = importr('stats')
    result = r_stats.shapiro_test(*args, **kwargs)

    return {
        'W': result.rx2('statistic')[0],
        'p': result.rx2('p.value')[0],
    }

In [None]:
def r_t_test(*args, **kwargs):
    '''Utilizes R's implementation of Student's t-test'''
    
    r_stats = importr('stats')
    result = r_stats.t_test(*args, **kwargs)

    return {
        't': result.rx2('statistic')[0],
        'df': result.rx2('parameter')[0],
        'p': result.rx2('p.value')[0],
    }

In [None]:
def r_effect_size(*args, **kwargs):
    '''Utilizes the R package effsize's implementation of Cohen's d and Hedges g effect size'''
    
    r_effsize = importr('effsize')
    result = r_effsize.cohen_d(*args, **kwargs)

    return {
        'method': result.rx2('method')[0],
        'estimate': result.rx2('estimate')[0],
    }
    return result

## Import and Process the Data

In [None]:
# skip expensive calculations by loading the results from a file?
# - with load_from_files=False, perform data processing from scratch,
#   which takes several minutes
# - with load_from_files=True, load the final results (dataframes)
#   pickled last time the calculations were performed
load_from_files = False

In [None]:
pickled_vars = ['df_all', 'df_exemplary_bout', 'df_exemplary_swallow']
if load_from_files:
    
    for var in pickled_vars:
        exec(f'{var} = pd.read_pickle("{var}.zip")')

    print('calculation results loaded from files')
    
else:
    
    start = datetime.datetime.now()

    # use Neo RawIO lazy loading to load much faster and using less memory
    # - with lazy=True, filtering parameters specified in metadata are ignored
    #     - note: filters are replaced below and applied manually anyway
    # - with lazy=True, loading via time_slice requires neo>=0.8.0
    lazy = True

    # load the metadata containing file paths
    metadata = neurotic.MetadataSelector(file='../../data/metadata.yml')

    # filter epochs for each bout and perform calculations
    df_list = []
    last_data_set_name = None
    pbar = tqdm(total=len(feeding_bouts), unit='feeding bout')
    for (animal, food, bout_index), (data_set_name, time_window) in feeding_bouts.items():

        channel_names = channel_names_by_animal[animal]
        epoch_types = epoch_types_by_food[food]
        burst_thresholds = burst_thresholds_by_animal[animal]

        ###
        ### LOAD DATASET
        ###

        metadata.select(data_set_name)

        if data_set_name is last_data_set_name:
            # skip reloading the data if it's already in memory
            pass
        else:

            # ensure that the right filters are used
            metadata['filters'] = sig_filters_by_animal[animal]

            blk = neurotic.load_dataset(metadata, lazy=lazy)

            if lazy:
                # manually perform filters
                blk = apply_filters(blk, metadata)

        last_data_set_name = data_set_name



        ###
        ### LOCATE BEHAVIOR EPOCHS AND SUBEPOCHS
        ###

        # construct a query for locating behaviors
        behavior_query = f'(Type in {epoch_types}) & ({time_window[0]} <= Start) & (End <= {time_window[1]})'

        # construct queries for locating epochs associated with each behavior
        # - each query should match at most one epoch
        # - dictionary keys are used as prefixes for the names of new columns
        subepoch_queries = {}

        subepoch_queries['B38 activity']            = (f'(Type == "B38 activity") & ' \
                                                       f'(@behavior_start-3 <= End) & (End <= @behavior_start+4)',
                                                       'last') # use last if there are multiple matches
                                                      # must end within a few seconds of behavior start (3 before or 4 after)

        subepoch_queries['I2 protraction activity'] = f'(Type == "I2 protraction activity") & ' \
                                                      f'(@behavior_start-1 <= Start) & (End <= @behavior_end)'
                                                      # must start no earlier than 1 second before behavior and end within it

        subepoch_queries['B8 activity']             = f'(Type == "B8 activity") & ' \
                                                      f'(@behavior_start <= Start) & (End <= @behavior_end)'
                                                      # must be fully contained within behavior

        subepoch_queries['B3/6/9/10 activity']      = f'(Type == "B3/6/9/10 activity") & ' \
                                                      f'(@behavior_start <= Start) & (End <= @behavior_end)'
                                                      # must be fully contained within behavior

        subepoch_queries['B4/B5 activity']          = (f'(Type == "B4/B5 activity") & ' \
                                                       f'(@behavior_start <= Start) & (Start <= @behavior_end)',
                                                       'first') # use first if there are multiple matches
                                                      # must start within behavior
        
        subepoch_queries['Force shoulder end']      = f'(Type == "Force shoulder end") & ' \
                                                      f'(@behavior_start-3 <= End) & (End <= @behavior_start+3)'
                                                      # must end within 3 seconds of behavior start
        
        subepoch_queries['Force rise start']        = f'(Type == "Force rise start") & ' \
                                                      f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                      # must start within behavior

        subepoch_queries['Force plateau start']     = f'(Type == "Force plateau start") & ' \
                                                      f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                      # must start within behavior

        subepoch_queries['Force plateau end']       = f'(Type == "Force plateau end") & ' \
                                                      f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)'
                                                      # must start within 2 seconds of behavior end

        subepoch_queries['Force drop end']          = f'(Type == "Force drop end") & ' \
                                                      f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)'
                                                      # must start within 2 seconds of behavior end
        
        subepoch_queries['Inward movement']         = f'(Type == "Inward movement") & ' \
                                                      f'(@behavior_start <= Start) & (Start <= @behavior_end)'
                                                      # must start within behavior

        # construct a table in which each row is a behavior and subepoch data
        # is added as columns, e.g. df['B38 activity start (s)']
        df = BehaviorsDataFrame(blk.segments[0].epochs, behavior_query, subepoch_queries)

        # renumber behaviors assuming all behaviors are from a single contiguous sequence
        df = df.sort_values('Start (s)').reset_index(drop=True).rename_axis('Behavior_index')



        ###
        ### START CALCULATIONS
        ###

        # some columns must have type 'object', which
        # can be accomplished by initializing with None or np.nan
        df['Video segmentation times (s)'] = None
        df['Force segmentation times (s)'] = None
        df['Force, video segmented interpolation (mN)'] = None
        df['Force, force segmented interpolation (mN)'] = None
        for unit in units:
            df[f'{unit} spike train'] = None
            df[f'{unit} firing rate (Hz)'] = None
            df[f'{unit} firing rate, video segmented interpolation (Hz)'] = None
            df[f'{unit} firing rate, force segmented interpolation (Hz)'] = None
            df[f'{unit} all bursts'] = None

            # while we're at it, initialize some other things that might otherwise never be given values
#             df[f'{unit} first burst start (s)'] = np.nan
#             df[f'{unit} first burst end (s)'] = np.nan
#             df[f'{unit} first burst duration (s)'] = 0
#             df[f'{unit} first burst spike count'] = 0
#             df[f'{unit} first burst mean frequency (Hz)'] = np.nan
#             df[f'{unit} last burst start (s)'] = np.nan
#             df[f'{unit} last burst end (s)'] = np.nan
#             df[f'{unit} last burst duration (s)'] = 0
#             df[f'{unit} last burst spike count'] = 0
#             df[f'{unit} last burst mean frequency (Hz)'] = np.nan
            df[f'{unit} burst start (s)'] = np.nan
            df[f'{unit} burst end (s)'] = np.nan
            df[f'{unit} burst duration (s)'] = 0
            df[f'{unit} burst spike count'] = 0
            df[f'{unit} burst mean frequency (Hz)'] = 0
            df[f'{unit} burst start (video seg normalized)'] = np.nan
            df[f'{unit} burst end (video seg normalized)'] = np.nan
            df[f'{unit} burst start (force seg normalized)'] = np.nan
            df[f'{unit} burst end (force seg normalized)'] = np.nan
        df['Inward movement start (video seg normalized)'] = np.nan
        df['Inward movement end (video seg normalized)'] = np.nan
        df['Inward movement start (force seg normalized)'] = np.nan
        df['Inward movement end (force seg normalized)'] = np.nan



        # get smoothed force for entire time window
        channel = 'Force'
        sig = get_sig(blk, channel)
        if lazy:
            sig = sig.time_slice(None, None)
        sig = elephant.signal_processing.butter(sig, lowpass_freq = 5*pq.Hz)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale('mN')
        force_smoothed_sig = sig



        df['End to next start (s)'] = np.nan
        df['Start to next start (s)'] = np.nan
        previous_i = None

        # iterate over all swallows
        for j, i in enumerate(df.index):

            behavior_start = df.loc[i, 'Start (s)']*pq.s
            behavior_end = df.loc[i, 'End (s)']*pq.s
            inward_movement_start = df.loc[i, 'Inward movement start (s)']*pq.s
            inward_movement_end = df.loc[i, 'Inward movement end (s)']*pq.s

            # calculate interbehavior intervals assuming all behaviors are from a single contiguous sequence
            if previous_i is not None:
                df.loc[previous_i, 'End to next start (s)']   = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'End (s)']
                df.loc[previous_i, 'Start to next start (s)'] = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'Start (s)']
                df.loc[previous_i, 'Inward movement start to next inward movement start (s)'] = \
                    df.loc[i, 'Inward movement start (s)'] - df.loc[previous_i, 'Inward movement start (s)']
                df.loc[previous_i, 'Inward movement end to next inward movement start (s)'] = \
                    df.loc[i, 'Inward movement start (s)'] - df.loc[previous_i, 'Inward movement end (s)']
            previous_i = i



            ###
            ### VIDEO SEGMENTATION
            ###
            
            # inward movement start and end are required
            video_is_segmented = np.all(np.isfinite(np.array([
                inward_movement_start, inward_movement_end])))
            
            if video_is_segmented:
                inward_movement_duration = df.loc[i, 'Inward movement duration (s)']*pq.s
                
                # get the list of fixed times for normalization
                video_segmentation_times = df.at[i, 'Video segmentation times (s)'] = np.array([
                    inward_movement_start-inward_movement_duration*4,
                    inward_movement_start-inward_movement_duration*3,
                    inward_movement_start-inward_movement_duration*2,
                    inward_movement_start-inward_movement_duration*1,
                    inward_movement_start,
                    inward_movement_end,
                    inward_movement_end+inward_movement_duration*1,
                    inward_movement_end+inward_movement_duration*2,
                    inward_movement_end+inward_movement_duration*3,
                    inward_movement_end+inward_movement_duration*4,
                ])*pq.s # 'at', not 'loc', is important for inserting list into cell

            else: # video is not segmented
                video_segmentation_times = df.at[i, 'Video segmentation times (s)'] = np.array([
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                ])*pq.s # 'at', not 'loc', is important for inserting list into cell
            
            
            
            ###
            ### FORCE SEGMENTATION
            ###

            force_shoulder_end  = df.loc[i, 'Force shoulder end start (s)']*pq.s  # start of "Force shoulder end" epoch
            force_rise_start    = df.loc[i, 'Force rise start start (s)']*pq.s    # start of "Force rise start" epoch
            force_plateau_start = df.loc[i, 'Force plateau start start (s)']*pq.s # start of "Force plateau start" epoch
            force_plateau_end   = df.loc[i, 'Force plateau end start (s)']*pq.s   # start of "Force plateau end" epoch
            force_drop_end      = df.loc[i, 'Force drop end start (s)']*pq.s      # start of "Force drop end" epoch

            # force rise start, plateau start and end, and drop end are required
            force_is_segmented = np.all(np.isfinite(np.array([
                force_rise_start, force_plateau_start, force_plateau_end, force_drop_end])))

            if force_is_segmented:
                # get some times for the previous and next swallow
                epochs_force_shoulder_end  = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force shoulder end'), None)
                epochs_force_rise_start    = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force rise start'), None)
                epochs_force_plateau_start = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force plateau start'), None)
                epochs_force_plateau_end   = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force plateau end'), None)
                epochs_force_drop_end      = next((ep for ep in blk.segments[0].epochs if ep.name == 'Force drop end'), None)
                assert epochs_force_shoulder_end  is not None, 'failed to find "Force shoulder end" epochs'
                assert epochs_force_rise_start    is not None, 'failed to find "Force rise start" epochs'
                assert epochs_force_plateau_start is not None, 'failed to find "Force plateau start" epochs'
                assert epochs_force_plateau_end   is not None, 'failed to find "Force plateau end" epochs'
                assert epochs_force_drop_end      is not None, 'failed to find "Force drop end" epochs'

                try:
                    prev_force_plateau_start = df.loc[i, 'Previous force plateau start (s)'] = epochs_force_plateau_start.time_slice(None, force_rise_start)[-1]
                    assert force_rise_start-prev_force_plateau_start < 16*pq.s, f'for swallow {i}, previous force plateau start is too far away'
                except IndexError:
                    prev_force_plateau_start = df.loc[i, 'Previous force plateau start (s)'] = np.nan

                try:
                    prev_force_plateau_end = df.loc[i, 'Previous force plateau end (s)'] = epochs_force_plateau_end.time_slice(None, force_rise_start)[-1]
                    assert force_rise_start-prev_force_plateau_end < 12*pq.s, f'for swallow {i}, previous force plateau end is too far away'
                except IndexError:
                    prev_force_plateau_end = df.loc[i, 'Previous force plateau end (s)'] = np.nan

                try:
                    prev_force_drop_end = df.loc[i, 'Previous force drop end (s)'] = epochs_force_drop_end.time_slice(None, force_rise_start)[-1]
                    assert force_rise_start-prev_force_drop_end < 12*pq.s, f'for swallow {i}, previous force drop end is too far away'
                except IndexError:
                    prev_force_drop_end = df.loc[i, 'Previous force drop end (s)'] = np.nan

                try:
                    next_force_rise_start = df.loc[i, 'Next force rise start (s)'] = epochs_force_rise_start.time_slice(force_drop_end, None)[0]
                    assert next_force_rise_start-force_drop_end < 12*pq.s, f'for swallow {i}, next force rise start is too far away'
                except IndexError:
                    next_force_rise_start = df.loc[i, 'Next force rise start (s)'] = np.nan

                try:
                    next_force_shoulder_end = df.loc[i, 'Next force shoulder end (s)'] = epochs_force_shoulder_end.time_slice(force_drop_end, None)[0]
                    if next_force_shoulder_end > next_force_rise_start:
                        # next swallow did not have a shoulder and we instead grabbed a later shoulder
                        next_force_shoulder_end = df.loc[i, 'Next force shoulder end (s)'] = np.nan
                except IndexError:
                    next_force_shoulder_end = df.loc[i, 'Next force shoulder end (s)'] = np.nan

                # get the list of fixed times for normalization
                force_segmentation_times = df.at[i, 'Force segmentation times (s)'] = np.array([
                    prev_force_plateau_start,
                    prev_force_plateau_end,
                    prev_force_drop_end,
                    force_shoulder_end,
                    force_rise_start,
                    force_plateau_start,
                    force_plateau_end,
                    force_drop_end,
                    next_force_shoulder_end,
                    next_force_rise_start,
                ])*pq.s # 'at', not 'loc', is important for inserting list into cell

            else: # force is not segmented
                force_segmentation_times = df.at[i, 'Force segmentation times (s)'] = np.array([
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                    np.nan,
                ])*pq.s # 'at', not 'loc', is important for inserting list into cell



            ###
            ### BEHAVIORAL MARKERS
            ###
            
            if video_is_segmented:
                df.loc[i, 'Inward movement start (video seg normalized)'] = \
                    normalize_time(video_segmentation_times.magnitude, float(inward_movement_start))
                df.loc[i, 'Inward movement end (video seg normalized)'] = \
                    normalize_time(video_segmentation_times.magnitude, float(inward_movement_end))
            
            if force_is_segmented:
                df.loc[i, 'Inward movement start (force seg normalized)'] = \
                    normalize_time(force_segmentation_times.magnitude, float(inward_movement_start))
                df.loc[i, 'Inward movement end (force seg normalized)'] = \
                    normalize_time(force_segmentation_times.magnitude, float(inward_movement_end))



            ###
            ### FORCE QUANTIFICATION
            ###

            if force_is_segmented:

                # get smoothed force for whole behavior for remaining force calculations
                sig = force_smoothed_sig
                if np.isfinite(force_shoulder_end):
                    sig = sig.time_slice(force_shoulder_end - 0.01*pq.s, force_drop_end + 0.01*pq.s)
                else:
                    sig = sig.time_slice(force_rise_start - 1*pq.s, force_drop_end + 0.01*pq.s)
                sig = sig.rescale('mN')

                # find force peak, baseline, and the increase
                force_min_time = df.loc[i, 'Force minimum time (s)'] = elephant.spike_train_generation.peak_detection(sig, 999*pq.mN, sign='below')[0]
                force_min = df.loc[i, 'Force minimum (mN)'] = sig[sig.time_index(force_min_time)][0]
                force_peak_time = df.loc[i, 'Force peak time (s)'] = elephant.spike_train_generation.peak_detection(sig, 0*pq.mN)[0]
                force_peak = df.loc[i, 'Force peak (mN)'] = sig[sig.time_index(force_peak_time)][0]
                force_baseline = df.loc[i, 'Force baseline (mN)'] = sig[sig.time_index(force_rise_start)][0]
                force_increase = df.loc[i, 'Force increase (mN)'] = force_peak-force_baseline

                # find force plateau, drop, and shoulder values
                force_plateau_start_value = df.loc[i, 'Force plateau start value (mN)'] = sig[sig.time_index(force_plateau_start)][0]
                force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)'] = sig[sig.time_index(force_plateau_end)][0]
                force_drop_end_value = df.loc[i, 'Force drop end value (mN)'] = sig[sig.time_index(force_drop_end)][0]
                if np.isfinite(force_shoulder_end):
                    force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)'] = sig[sig.time_index(force_shoulder_end)][0]
                else:
                    force_shoulder_end_value = np.nan

                # find force rise and plateau durations
                force_rise_duration = df.loc[i, 'Force rise duration (s)'] = force_plateau_start - force_rise_start
                force_plateau_duration = df.loc[i, 'Force plateau duration (s)'] = force_plateau_end - force_plateau_start
                force_rise_plateau_duration = df.loc[i, 'Force rise and plateau duration (s)'] = force_plateau_end - force_rise_start

                # find average slope during rising phase
                force_rise_increase = df.loc[i, 'Force rise increase (mN)'] = force_plateau_start_value - force_baseline
                force_slope = df.loc[i, 'Force slope (mN/s)'] = (force_rise_increase/force_rise_duration).rescale('mN/s')



            ###
            ### FORCE NORMALIZATION
            ###

            if video_is_segmented:
                t_start = video_segmentation_times[0]-0.001*pq.s
                t_stop = video_segmentation_times[-1]+0.001*pq.s
                
                channel = 'Force'
                sig = get_sig(blk, channel)
                sig = sig.time_slice(t_start, t_stop)
                sig = sig.rescale(channel_units[channel_names.index(channel)])

                force_video_seg_interp = df.at[i, 'Force, video segmented interpolation (mN)'] = \
                    resample_sig_in_normalized_time(video_segmentation_times, sig, video_seg_interp_times) # 'at', not 'loc', is important for inserting list into cell
            
            if force_is_segmented:
                t_start = force_segmentation_times[0]-0.001*pq.s
                t_stop = force_segmentation_times[-1]+0.001*pq.s
                
                channel = 'Force'
                sig = get_sig(blk, channel)
                sig = sig.time_slice(t_start, t_stop)
                sig = sig.rescale(channel_units[channel_names.index(channel)])

                force_force_seg_interp = df.at[i, 'Force, force segmented interpolation (mN)'] = \
                    resample_sig_in_normalized_time(force_segmentation_times, sig, force_seg_interp_times) # 'at', not 'loc', is important for inserting list into cell



            ###
            ### FIND SPIKE TRAINS
            ###

            if lazy:
                if metadata['amplitude_discriminators'] is not None:
                    for discriminator in metadata['amplitude_discriminators']:
                        sig = get_sig(blk, discriminator['channel'])
                        if sig is not None:
                            sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
                            st = _detect_spikes(sig, discriminator, blk.segments[0].epochs)
                            st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                            st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                            st = st.time_slice(st_epoch_start, st_epoch_end)
                            df.at[i, f'{st.name} spike train'] = st # 'at', not 'loc', is important for inserting list into cell
            else:
                for spiketrain in blk.segments[0].spiketrains:
                    discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == spiketrain.name), None)
                    if discriminator is None:
                        raise Exception(f'For data set "{data_set_name}", discriminator "{spiketrain.name}" could not be found')
                    st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                    st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                    if np.isfinite(st_epoch_start) and np.isfinite(st_epoch_end):
                        st = spiketrain.time_slice(st_epoch_start, st_epoch_end)
                    else:
                        # this unit's discriminator epoch was not located for this swallow
                        st = None
                    df.at[i, f'{spiketrain.name} spike train'] = st # 'at', not 'loc', is important for inserting list into cell



            ###
            ### QUANTIFY SPIKE TRAINS AND BURSTS
            ###

            for k, unit in enumerate(units):
                st = df.loc[i, f'{unit} spike train']
                if st is not None:

                    # get the neural channel
                    channel = st.annotations['channels'][0]
                    sig = get_sig(blk, channel)

                    # create a continuous smoothed firing rate representation
                    # by convolving the spike train with a kernel
                    # - choice of t_start and t_stop here ensures firing rates are
                    #   recorded as zero far from the burst and can be resampled later
                    if video_is_segmented and force_is_segmented:
                        t_start = min(video_segmentation_times[0], force_segmentation_times[0])-0.001*pq.s
                        t_stop = max(video_segmentation_times[-1], force_segmentation_times[-1])+0.001*pq.s
                    elif video_is_segmented:
                        t_start = video_segmentation_times[0]-0.001*pq.s
                        t_stop = video_segmentation_times[-1]+0.001*pq.s
                    elif force_is_segmented:
                        t_start = force_segmentation_times[0]-0.001*pq.s
                        t_stop = force_segmentation_times[-1]+0.001*pq.s
                    else:
                        # no segmentation available
                        t_start = behavior_start-5*pq.s
                        t_stop = behavior_end+5*pq.s
                    smoothing_kernel = elephant.kernels.GaussianKernel(0.2*pq.s) # 200 ms standard deviation
                    firing_rate = df.at[i, f'{unit} firing rate (Hz)'] = elephant.statistics.instantaneous_rate(
                        spiketrain=st,
                        t_start=t_start,
                        t_stop=t_stop,
                        sampling_period=sig.sampling_period,
                        kernel=smoothing_kernel,
                    ) # 'at', not 'loc', is important for inserting list into cell

                    # normalization
                    if video_is_segmented:
                        firing_rate_video_seg_interp = df.at[i, f'{unit} firing rate, video segmented interpolation (Hz)'] = \
                            resample_sig_in_normalized_time(video_segmentation_times, firing_rate, video_seg_interp_times) # 'at', not 'loc', is important for inserting list into cell
                    if force_is_segmented:
                        firing_rate_force_seg_interp = df.at[i, f'{unit} firing rate, force segmented interpolation (Hz)'] = \
                            resample_sig_in_normalized_time(force_segmentation_times, firing_rate, force_seg_interp_times) # 'at', not 'loc', is important for inserting list into cell

                    if st.size > 0:

#                         # get the signal for the behavior with 10 seconds cushion before and after (for better baseline estimation)
#                         sig = sig.time_slice(behavior_start - 10*pq.s, behavior_end + 10*pq.s)
#                         sig = sig.rescale(channel_units[channel_names.index(channel)])

                        # find every sequence of spikes that qualifies as a burst
                        bursts = df.at[i, f'{unit} all bursts'] = _find_bursts(st, burst_thresholds[unit][0], burst_thresholds[unit][1]) # 'at', not 'loc', is important for inserting list into cell

                        first_burst_start = np.nan
#                         first_burst_end = np.nan
#                         first_burst_spike_count = 0
#                         first_burst_mean_freq = 0*pq.Hz
#                         last_burst_start = np.nan
                        last_burst_end = np.nan
#                         last_burst_spike_count = 0
#                         last_burst_mean_freq = 0*pq.Hz
                        if len(bursts) > 0:

                            for burst in zip(bursts.times, bursts.durations, bursts.array_annotations['spikes']):
                                if is_good_burst(burst):
                                    time, duration, n_spikes = burst
                                    first_burst_start = time
#                                     first_burst_start, first_burst_end = burst['Start (s)'], burst['End (s)']
#                                     first_burst_duration = first_burst_end-first_burst_start
#                                     df.loc[i, f'{unit} first burst start (s)'] = first_burst_start.rescale('s')
#                                     df.loc[i, f'{unit} first burst end (s)'] = first_burst_end.rescale('s')
#                                     first_burst_duration = df.loc[i, f'{unit} first burst duration (s)'] = first_burst_duration.rescale('s')
#                                     first_burst_spike_count = df.loc[i, f'{unit} first burst spike count'] = st.time_slice(first_burst_start, first_burst_end).size
#                                     first_burst_mean_freq = df.loc[i, f'{unit} first burst mean frequency (Hz)'] = ((first_burst_spike_count-1)/first_burst_duration).rescale('Hz')

#                                     # find burst RAUC and mean voltage
#                                     first_burst_rauc = df.loc[i, f'{unit} first burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=first_burst_start, t_stop=first_burst_end).rescale('uV*s')
#                                     first_burst_mean_rect_voltage = df.loc[i, f'{unit} first burst mean rectified voltage (μV)'] = first_burst_rauc/first_burst_duration

                                    break # quit after finding first good burst

                            for burst in zip(reversed(bursts.times), reversed(bursts.durations), reversed(bursts.array_annotations['spikes'])):
                                if is_good_burst(burst):
                                    time, duration, n_spikes = burst
                                    last_burst_end = time + duration
#                                     last_burst_start, last_burst_end = burst['Start (s)'], burst['End (s)']
#                                     last_burst_duration = last_burst_end-last_burst_start
#                                     df.loc[i, f'{unit} last burst start (s)'] = last_burst_start.rescale('s')
#                                     df.loc[i, f'{unit} last burst end (s)'] = last_burst_end.rescale('s')
#                                     last_burst_duration = df.loc[i, f'{unit} last burst duration (s)'] = last_burst_duration.rescale('s')
#                                     last_burst_spike_count = df.loc[i, f'{unit} last burst spike count'] = st.time_slice(last_burst_start, last_burst_end).size
#                                     last_burst_mean_freq = df.loc[i, f'{unit} last burst mean frequency (Hz)'] = ((last_burst_spike_count-1)/last_burst_duration).rescale('Hz')

#                                     # find burst RAUC and mean voltage
#                                     last_burst_rauc = df.loc[i, f'{unit} last burst RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=last_burst_start, t_stop=last_burst_end).rescale('uV*s')
#                                     last_burst_mean_rect_voltage = df.loc[i, f'{unit} last burst mean rectified voltage (μV)'] = last_burst_rauc/last_burst_duration
                                    
                                    break # quit after finding first (actually, last) good burst
                        
                        # merge the first and last good bursts and anything in between
                        if np.isfinite(first_burst_start) and np.isfinite(last_burst_end):
                            st_burst = st.time_slice(first_burst_start, last_burst_end)
                            burst_start = df.loc[i, f'{unit} burst start (s)'] = first_burst_start
                            burst_end = df.loc[i, f'{unit} burst end (s)'] = last_burst_end
                            burst_duration = df.loc[i, f'{unit} burst duration (s)'] = (burst_end-burst_start)
                            burst_spike_count = df.loc[i, f'{unit} burst spike count'] = st_burst.size
                            if burst_spike_count > 1:
                                burst_mean_freq = df.loc[i, f'{unit} burst mean frequency (Hz)'] = ((burst_spike_count-1)/burst_duration).rescale('Hz')
                            if video_is_segmented:
                                df.loc[i, f'{unit} burst start (video seg normalized)'] = normalize_time(video_segmentation_times.magnitude, float(first_burst_start))
                                df.loc[i, f'{unit} burst end (video seg normalized)']   = normalize_time(video_segmentation_times.magnitude, float(last_burst_end))
                            if force_is_segmented:
                                df.loc[i, f'{unit} burst start (force seg normalized)'] = normalize_time(force_segmentation_times.magnitude, float(first_burst_start))
                                df.loc[i, f'{unit} burst end (force seg normalized)']   = normalize_time(force_segmentation_times.magnitude, float(last_burst_end))

            # B3/B6/B9
            df['B3/B6/B9 burst start (s)'] = np.nan
            df['B3/B6/B9 burst end (s)'] = np.nan
            df['B3/B6/B9 burst duration (s)'] = 0
            df['B3/B6/B9 burst spike count'] = 0
            df['B3/B6/B9 burst mean frequency (Hz)'] = 0
            b3b6b9_burst_start = df['B3/B6/B9 burst start (s)'] = df[['B6/B9 burst start (s)', 'B3 burst start (s)']].min(axis=1)
            b3b6b9_burst_end = df['B3/B6/B9 burst end (s)']   = df[['B6/B9 burst end (s)',   'B3 burst end (s)']]  .max(axis=1)
            b3b6b9_burst_duration = df['B3/B6/B9 burst duration (s)'] = b3b6b9_burst_end - b3b6b9_burst_start
            b3b6b9_burst_spike_count = df['B3/B6/B9 burst spike count'] = df['B6/B9 burst spike count'] + df['B3 burst spike count']
            b3b6b9_burst_mean_freq = df['B3/B6/B9 burst mean frequency (Hz)'] = (b3b6b9_burst_spike_count-1)/b3b6b9_burst_duration


        ###
        ### FINISH
        ###

        # index the table on 4 variables so that this dataframe can later be merged with others
        df['Animal'] = animal
        df['Food'] = food
        df['Bout_index'] = bout_index
        df = df.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])

        df_list += [df]
        
        pbar.update()
        
    pbar.close()

    df_all = pd.concat(df_list, sort=False).sort_index()

    if exemplary_swallow in feeding_bouts:
        # move exemplar to separate dataframe
        df_exemplary_swallow = df_all.loc[exemplary_swallow].copy()
        df_all = df_all.drop(exemplary_swallow)

    if exemplary_bout in feeding_bouts:
        # move exemplar to separate dataframe
        df_exemplary_bout = df_all.loc[exemplary_bout].copy()
        df_all = df_all.drop(exemplary_bout)

    # save dataframes to files so that calculations can be skipped in the future
    for var in pickled_vars:
        exec(f'{var}.to_pickle("{var}.zip")')

    end = datetime.datetime.now()
    print('elapsed time:', end-start)

## 🤪 Sanity Checks

In [None]:
skip_sanity_checks = True

In [None]:
if not skip_sanity_checks:
    start = datetime.datetime.now()

    # reconstruct the original df_all, before exemplary_swallow and
    # exemplary_bout were removed
    df_list2 = [df_all]
    if exemplary_swallow in feeding_bouts:
        df_exemplary_swallow2 = df_exemplary_swallow.copy()
        df_exemplary_swallow2['Animal'] = exemplary_swallow[0]
        df_exemplary_swallow2['Food'] = exemplary_swallow[1]
        df_exemplary_swallow2['Bout_index'] = exemplary_swallow[2]
        df_exemplary_swallow2 = df_exemplary_swallow2.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
        df_list2 += [df_exemplary_swallow2]
    if exemplary_bout in feeding_bouts:
        df_exemplary_bout2 = df_exemplary_bout.copy()
        df_exemplary_bout2['Animal'] = exemplary_bout[0]
        df_exemplary_bout2['Food'] = exemplary_bout[1]
        df_exemplary_bout2['Bout_index'] = exemplary_bout[2]
        df_exemplary_bout2 = df_exemplary_bout2.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
        df_list2 += [df_exemplary_bout2]
    df_all2 = pd.concat(df_list2, sort=False).sort_index()

    # use Neo RawIO lazy loading to load much faster and using less memory
    # - with lazy=True, filtering parameters specified in metadata are ignored
    #     - note: filters are replaced below and applied manually anyway
    # - with lazy=True, loading via time_slice requires neo>=0.8.0
    lazy = True

    # load the metadata containing file paths
    metadata = neurotic.MetadataSelector(file='../../data/metadata.yml')

    last_data_set_name = None
    pbar = tqdm(total=len(feeding_bouts), unit='figure')
    for (animal, food, bout_index), (data_set_name, time_window) in feeding_bouts.items():

        channel_names = channel_names_by_animal[animal]
        epoch_types = epoch_types_by_food[food]
        burst_thresholds = burst_thresholds_by_animal[animal]

        df = df_all2.loc[animal, food, bout_index]



        ###
        ### LOAD DATASET
        ###

        metadata.select(data_set_name)

        if data_set_name is last_data_set_name:
            # skip reloading the data if it's already in memory
            pass
        else:

            # ensure that the right filters are used
            metadata['filters'] = sig_filters_by_animal[animal]

            blk = neurotic.load_dataset(metadata, lazy=lazy)

            if lazy:
                # manually perform filters
                blk = apply_filters(blk, metadata)

        last_data_set_name = data_set_name



        ###
        ### START FIGURE
        ###

#             figsize = (9.5, 10) # dimensions for notebook
#             figsize = (11, 8.5) # dimensions for printing
        figsize = (16, 9) # dimensions for filling wide screens
        fig, axes = plt.subplots(len(channel_names), 1, sharex=True, figsize=figsize)



        ###
        ### PLOT SIGNALS
        ###

        # plot all channels for entire time window
        for i, channel in enumerate(channel_names):
            plt.sca(axes[i])
            sig = get_sig(blk, channel)
            sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
            sig = sig.rescale(channel_units[i])
            plt.plot(sig.times, sig.magnitude, c='0.8', lw=1, zorder=-1)

            if i == 0:
                plt.title(f'({animal}, {food}, {bout_index}): {data_set_name}')

            plt.ylabel(sig.name + ' (' + sig.units.dimensionality.string + ')')
            axes[i].yaxis.set_label_coords(-0.06, 0.5)

            if i < len(channel_names)-1:
                # remove right, top, and bottom plot borders, and remove x-axis
                sns.despine(ax=plt.gca(), bottom=True)
                plt.gca().xaxis.set_visible(False)
            else:
                # remove right and top plot borders, and set x-label
                sns.despine(ax=plt.gca())
                plt.xlabel('Time (s)')

        # plot smoothed force for entire time window
        channel = 'Force'
        sig = get_sig(blk, channel)
        if lazy:
            sig = sig.time_slice(None, None)
        sig = elephant.signal_processing.butter(sig, lowpass_freq = 5*pq.Hz)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale('mN')
        plt.plot(sig.times, sig.magnitude, c='k', lw=1, zorder=0)
        force_smoothed_sig = sig



        # iterate over all swallows
        for j, i in enumerate(df.index):

            behavior_start = df.loc[i, 'Start (s)']*pq.s
            behavior_end   = df.loc[i, 'End (s)']*pq.s
            
            ###
            ### MOVEMENTS
            ###
            
            # plot inward food movement
            inward_movement_start = df.loc[i, 'Inward movement start (s)']*pq.s
            inward_movement_end   = df.loc[i, 'Inward movement end (s)']*pq.s
            if np.isfinite(inward_movement_start):
                channel = 'Force'
                ax = axes[channel_names.index(channel)]
                ax.axvspan(
                    inward_movement_start, inward_movement_end,
                    0.99, 1,
                    facecolor='k', edgecolor=None, lw=0)
            
            

            ###
            ### FORCE SEGMENTATION
            ###

            force_shoulder_end  = df.loc[i, 'Force shoulder end start (s)']*pq.s  # start of "Force shoulder end" epoch
            force_rise_start    = df.loc[i, 'Force rise start start (s)']*pq.s    # start of "Force rise start" epoch
            force_plateau_start = df.loc[i, 'Force plateau start start (s)']*pq.s # start of "Force plateau start" epoch
            force_plateau_end   = df.loc[i, 'Force plateau end start (s)']*pq.s   # start of "Force plateau end" epoch
            force_drop_end      = df.loc[i, 'Force drop end start (s)']*pq.s      # start of "Force drop end" epoch

            # force rise start, plateau start and end, and drop end are required
            force_is_segmented = np.all(np.isfinite(np.array([
                force_rise_start, force_plateau_start, force_plateau_end, force_drop_end])))

            force_segmentation_times = df.at[i, 'Force segmentation times (s)']

            if force_is_segmented:

                force_min_time = df.loc[i, 'Force minimum time (s)']
                force_min = df.loc[i, 'Force minimum (mN)']
                force_peak_time = df.loc[i, 'Force peak time (s)']
                force_peak = df.loc[i, 'Force peak (mN)']
                force_baseline = df.loc[i, 'Force baseline (mN)']

                force_plateau_start_value = df.loc[i, 'Force plateau start value (mN)']
                force_plateau_end_value = df.loc[i, 'Force plateau end value (mN)']
                force_drop_end_value = df.loc[i, 'Force drop end value (mN)']
                force_shoulder_end_value = df.loc[i, 'Force shoulder end value (mN)']

                # get smoothed force for whole behavior for remaining force calculations
                sig = force_smoothed_sig
                if np.isfinite(force_shoulder_end):
                    sig = sig.time_slice(force_shoulder_end - 0.01*pq.s, force_drop_end + 0.01*pq.s)
                else:
                    sig = sig.time_slice(force_rise_start - 1*pq.s, force_drop_end + 0.01*pq.s)
                sig = sig.rescale('mN')

                # plot force rise in color
                plt.sca(axes[channel_names.index('Force')])
                sig2 = sig.time_slice(force_rise_start, force_plateau_start)
                plt.plot(sig2.times, sig2.magnitude, c=force_colors['rise'], lw=2, zorder=1)

                # plot force plateau in color
                sig2 = sig.time_slice(force_plateau_start, force_plateau_end)
                plt.plot(sig2.times, sig2.magnitude, c=force_colors['plateau'], lw=2, zorder=1)

                # plot force shoulder in color
                if np.isfinite(force_shoulder_end):
                    sig2 = sig.time_slice(force_drop_end, force_shoulder_end)
                    plt.plot(sig2.times, sig2.magnitude, c=force_colors['shoulder'], lw=2, zorder=1)

                # plot force peak, baseline, and plateau values
                plt.plot([force_peak_time],     [force_peak],                marker=CARETDOWN,  markersize=5, color='k')
#                 plt.plot([force_min_time],      [force_min],                 marker=CARETUP,    markersize=5, color='k')
                plt.plot([force_rise_start],    [force_baseline],            marker=CARETUP,    markersize=5, color='k')
                plt.plot([force_plateau_start], [force_plateau_start_value], marker=CARETRIGHT, markersize=5, color='k')
                plt.plot([force_plateau_end],   [force_plateau_end_value],   marker=CARETLEFT,  markersize=5, color='k')

                # plot segmentation boundaries across all subplots
                for (t, y, c) in [
                        (force_shoulder_end,  force_shoulder_end_value,  force_colors['shoulder']),
                        (force_rise_start,    force_baseline,            force_colors['rise']),
                        (force_plateau_start, force_plateau_start_value, force_colors['plateau']),
                        (force_plateau_end,   force_plateau_end_value,   force_colors['plateau']),
                        (force_drop_end,      force_drop_end_value,      force_colors['drop'])]:
                    if np.isfinite(y):
                        axes[-1].add_artist(patches.ConnectionPatch(
                            xyA=(t, y), xyB=(t, 1),
                            coordsA='data', coordsB=axes[0].get_xaxis_transform(),
                            axesA=axes[-1], axesB=axes[0],
                            color=c, lw=1, ls=':', zorder=-2))



            ###
            ### SPIKES AND BURSTS
            ###

            for k, unit in enumerate(units):
                st = df.loc[i, f'{unit} spike train']
                if st is not None and st.size > 0:

                    # get the signal for the behavior with 10 seconds cushion for spikes outside the behavior duration
                    channel = st.annotations['channels'][0]
                    sig = get_sig(blk, channel)
                    sig = sig.time_slice(behavior_start - 10*pq.s, behavior_end + 10*pq.s)
                    sig = sig.rescale(channel_units[channel_names.index(channel)])

                    # plot spikes
                    plt.sca(axes[channel_names.index(channel)])
                    marker = ['.', 'x'][j%2] # alternate markers between behaviors
                    spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
                    plt.scatter(st.times.rescale('s'), spike_amplitudes, marker=marker, c=unit_colors[unit])

                    # plot burst windows
                    discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == st.name), None)
                    if discriminator is None:
                        raise Exception(f'For data set "{data_set_name}", discriminator "{st.name}" could not be found')
                    bottom = pq.Quantity(discriminator['amplitude'][0], discriminator['units']).rescale(sig.units)
                    top = pq.Quantity(discriminator['amplitude'][1], discriminator['units']).rescale(sig.units)
                    height = top-bottom
                    bursts = df.at[i, f'{unit} all bursts']
                    for burst in zip(bursts.times, bursts.durations, bursts.array_annotations['spikes']):
                        time, duration, n_spikes = burst
                        left = time
                        right = time + duration
                        width = right-left
                        linestyle = '-' if is_good_burst(burst) else '--'
                        rect = patches.Rectangle((left, bottom), width, height, ls=linestyle, edgecolor=unit_colors[unit], fill=False)
                        plt.gca().add_patch(rect)

                    # plot markers for edges of bursts
                    burst_start = df.loc[i, f'{unit} burst start (s)']
                    burst_end = df.loc[i, f'{unit} burst end (s)']
                    if top > 0:
                        plt.plot([burst_start], [top],    marker=CARETDOWN, markersize=5, color='k')
                        plt.plot([burst_end],   [top],    marker=CARETDOWN, markersize=5, color='k')
                    else:
                        plt.plot([burst_start], [bottom], marker=CARETUP,   markersize=5, color='k')
                        plt.plot([burst_end],   [bottom], marker=CARETUP,   markersize=5, color='k')



        ###
        ### FINISH FIGURE
        ###

        # optimize plot margins
        plt.subplots_adjust(
            left   = 0.1,
            right  = 0.99,
            top    = 0.96,
            bottom = 0.06,
            hspace = 0.15,
        )

        # export figure
        export_dir2 = os.path.join(export_dir, 'sanity-checks')
        if not os.path.exists(export_dir2):
            os.mkdir(export_dir2)
        plt.gcf().savefig(os.path.join(export_dir2, f'{animal} {food} {bout_index}.png'), dpi=300)
        
        pbar.update()
        
    pbar.close()

    if exemplary_swallow in feeding_bouts:
        # rename output file for exemplar
        animal, food, bout_index = exemplary_swallow
        old_path = os.path.join(export_dir2, f'{animal} {food} {bout_index}.png')
        new_path = os.path.join(export_dir2, 'exemplary_swallow.png')
        if os.path.exists(old_path):
            if os.path.exists(new_path):
                os.remove(new_path)
            os.rename(old_path, new_path)

    if exemplary_bout in feeding_bouts:
        # rename output file for exemplar
        animal, food, bout_index = exemplary_bout
        old_path = os.path.join(export_dir2, f'{animal} {food} {bout_index}.png')
        new_path = os.path.join(export_dir2, 'exemplary_bout.png')
        if os.path.exists(old_path):
            if os.path.exists(new_path):
                os.remove(new_path)
            os.rename(old_path, new_path)


    del df_all2, df_list2

    end = datetime.datetime.now()
    print('elapsed time:', end-start)

In [None]:
if not skip_sanity_checks:
    start = datetime.datetime.now()

    if exemplary_swallow in feeding_bouts:
        pbar = tqdm(total=len(feeding_bouts)-1, unit='figure')
    else:
        pbar = tqdm(total=len(feeding_bouts), unit='figure')
    for (animal, food, bout_index), (data_set_name, time_window) in feeding_bouts.items():

        channel_names = channel_names_by_animal[animal]

        if (animal, food, bout_index) == exemplary_swallow:
            # skip the lone swallow
            continue
        elif (animal, food, bout_index) == exemplary_bout:
            df = df_exemplary_bout
        else:
            df = df_all.loc[animal, food, bout_index]

        # load the data
        metadata = neurotic.MetadataSelector('../../data/metadata.yml')
        metadata.select(data_set_name)
        blk = neurotic.load_dataset(metadata, lazy=True)

        # figsize = (9.5, 10) # dimensions for notebook
        # figsize = (11, 8.5) # dimensions for printing
        # figsize = (16, 9) # dimensions for filling wide screens
        figsize = (20, 9)
        fig, axes = plt.subplots(len(units)+1, 3, sharex='col', figsize=figsize)

        for k, unit in enumerate(units):
            # get the subplot axes handles
            ax_real_time, ax_force_seg, ax_video_seg = axes[k]

            # set y-axis label
            ax_real_time.set_ylabel(f'{unit} (Hz)')
            ax_real_time.yaxis.set_label_coords(-0.06, 0.5)

            # remove right, top, and bottom plot borders, and remove x-axis
            sns.despine(ax=ax_real_time, bottom=True)
            sns.despine(ax=ax_force_seg, bottom=True)
            sns.despine(ax=ax_video_seg, bottom=True)
            ax_real_time.xaxis.set_visible(False)
            ax_force_seg.xaxis.set_visible(False)
            ax_video_seg.xaxis.set_visible(False)

            # set time plot ranges
            ax_real_time.set_xlim([time_window[0]-5, time_window[1]+5])
            ax_force_seg.set_xlim([force_seg_interp_times.min(), force_seg_interp_times.max()])
#             ax_video_seg.set_xlim([video_seg_interp_times.min(), video_seg_interp_times.max()])
            ax_video_seg.set_xlim([2, 6.5])

            # elevate the Axes for units and remove background colors so that
            # each vertical ConnectionPatch drawn later is visible behind it
            ax_real_time.set_zorder(1)
            ax_force_seg.set_zorder(1)
            ax_video_seg.set_zorder(1)
            ax_real_time.set_facecolor('none')
            ax_force_seg.set_facecolor('none')
            ax_video_seg.set_facecolor('none')

        # remove right and top plot borders from bottom panels, and set x-label
        ax_real_time, ax_force_seg, ax_video_seg = axes[-1]
        sns.despine(ax=ax_real_time)
        sns.despine(ax=ax_force_seg)
        sns.despine(ax=ax_video_seg)
        ax_real_time.set_xlabel('Time (s)')
        ax_force_seg.set_xlabel('Time (normalized using force segmentation)')
        ax_video_seg.set_xlabel('Time (normalized using video segmentation)')

        # set time plot ranges
        ax_real_time.set_xlim([time_window[0]-5, time_window[1]+5])
        ax_force_seg.set_xlim([force_seg_interp_times.min(), force_seg_interp_times.max()])
#         ax_video_seg.set_xlim([video_seg_interp_times.min(), video_seg_interp_times.max()])
        ax_video_seg.set_xlim([2, 6.5])

        # plot force in real time
        ax_real_time, ax_force_seg, ax_video_seg = axes[-1]
        channel = 'Force'
        sig = get_sig(blk, channel)
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale(channel_units[channel_names.index(channel)])
        ax_real_time.plot(sig.times, sig.magnitude, c='0.8', lw=1)
        ax_real_time.set_ylabel(sig.name + ' (' + sig.units.dimensionality.string + ')')
        ax_real_time.yaxis.set_label_coords(-0.06, 0.5)

        all_force_seg_normalized_times_series = {}
        all_video_seg_normalized_times_series = {}
        for unit in units:
            all_force_seg_normalized_times_series[unit] = np.zeros((0, force_seg_interp_times.size))
            all_video_seg_normalized_times_series[unit] = np.zeros((0, video_seg_interp_times.size))
        all_force_seg_normalized_times_series['Force'] = np.zeros((0, force_seg_interp_times.size))
        all_video_seg_normalized_times_series['Force'] = np.zeros((0, video_seg_interp_times.size))
        
        for j, i in enumerate(df.index):
            
            
            # plot inward food movement
            inward_movement_start = df.loc[i, 'Inward movement start (s)']
            inward_movement_end   = df.loc[i, 'Inward movement end (s)']
            if np.isfinite(inward_movement_start):
                ax_real_time, ax_force_seg, ax_video_seg = axes[-1]
                ax_real_time.axvspan(
                    inward_movement_start, inward_movement_end,
                    0.98, 1,
                    facecolor='k', edgecolor=None, lw=0)
            
            
            for k, unit in enumerate(units):
                ax_real_time, ax_force_seg, ax_video_seg = axes[k]


                # raster plot
                st = df.loc[i, f'{unit} spike train']
                if st is not None:
                    ax_real_time.eventplot(positions=st, lineoffsets=-1, colors=unit_colors[unit])


                # plot the firing rates in real time
                firing_rate = df.loc[i, f'{unit} firing rate (Hz)']
                if firing_rate is not None:
                    ax_real_time.plot(firing_rate.times.rescale('s'), firing_rate, c=unit_colors[unit])


                # plot firing rates in normalized time
                firing_rate_force_seg_interp = df.loc[i, f'{unit} firing rate, force segmented interpolation (Hz)']
                if firing_rate_force_seg_interp is not None:
                    all_force_seg_normalized_times_series[unit] = np.concatenate([all_force_seg_normalized_times_series[unit], firing_rate_force_seg_interp[np.newaxis, :]])
                    ax_force_seg.plot(force_seg_interp_times, firing_rate_force_seg_interp, c=unit_colors[unit])
                firing_rate_video_seg_interp = df.loc[i, f'{unit} firing rate, video segmented interpolation (Hz)']
                if firing_rate_video_seg_interp is not None:
                    all_video_seg_normalized_times_series[unit] = np.concatenate([all_video_seg_normalized_times_series[unit], firing_rate_video_seg_interp[np.newaxis, :]])
                    ax_video_seg.plot(video_seg_interp_times, firing_rate_video_seg_interp, c=unit_colors[unit])


            # plot force in normalized time
            ax_real_time, ax_force_seg, ax_video_seg = axes[-1]
            force_force_seg_interp = df.at[i, 'Force, force segmented interpolation (mN)']
            if force_force_seg_interp is not None:
                all_force_seg_normalized_times_series['Force'] = np.concatenate([all_force_seg_normalized_times_series['Force'], force_force_seg_interp[np.newaxis, :]])
                ax_force_seg.plot(force_seg_interp_times, force_force_seg_interp, c='0.8', lw=1)
            force_video_seg_interp = df.at[i, 'Force, video segmented interpolation (mN)']
            if force_video_seg_interp is not None:
                all_video_seg_normalized_times_series['Force'] = np.concatenate([all_video_seg_normalized_times_series['Force'], force_video_seg_interp[np.newaxis, :]])
                ax_video_seg.plot(video_seg_interp_times, force_video_seg_interp, c='0.8', lw=1)


            # plot force phase boundaries in real time
            force_segmentation_times = df.at[i, 'Force segmentation times (s)'].rescale('s').magnitude
            for m, t in enumerate(force_segmentation_times[3:8]): # 3 = end of shoulder, 8 = end of next shoulder
                if np.isfinite(t):
                    if m == 1: # 1 = start of rise
                        color = force_colors['rise']
                    else:
                        color = '0.75'
                    axes[-1][0].add_artist(patches.ConnectionPatch(
                        xyA=(t, 0), xyB=(t, 1),
                        coordsA=axes[-1][0].get_xaxis_transform(), coordsB=axes[0][0].get_xaxis_transform(),
                        axesA=axes[-1][0], axesB=axes[0][0],
                        color=color, lw=1, ls=':'))


        # plot force phase boundaries in normalized time
        for m in range(len(force_segmentation_times)):
            if m == 4: # 4 = start of rise
                color = force_colors['rise']
            else:
                color = '0.75'
            axes[-1][1].add_artist(patches.ConnectionPatch(
                xyA=(m, 0), xyB=(m, 1),
                coordsA=axes[-1][1].get_xaxis_transform(), coordsB=axes[0][1].get_xaxis_transform(),
                axesA=axes[-1][1], axesB=axes[0][1],
                color=color, lw=1, ls=':'))
        
        # plot video phase boundaries in normalized time
        video_segmentation_times = df.at[i, 'Video segmentation times (s)'].rescale('s').magnitude # grab last swallow's as an example
#         for m in range(len(video_segmentation_times)):
        for m in [2, 3, 4, 5, 6]:
            if m == 4: # 4 = start of inward movement
                color = force_colors['rise']
            else:
                color = '0.75'
            axes[-1][2].add_artist(patches.ConnectionPatch(
                xyA=(m, 0), xyB=(m, 1),
                coordsA=axes[-1][2].get_xaxis_transform(), coordsB=axes[0][2].get_xaxis_transform(),
                axesA=axes[-1][2], axesB=axes[0][2],
                color=color, lw=1, ls=':'))


        # plot firing rate distributions
        for k, unit in enumerate(units):
            ax_real_time, ax_force_seg, ax_video_seg = axes[k]

            firing_rate_median = np.nanmedian(all_force_seg_normalized_times_series[unit], axis=0)
            firing_rate_q1 = np.nanquantile(all_force_seg_normalized_times_series[unit], q=0.25, axis=0)
            firing_rate_q3 = np.nanquantile(all_force_seg_normalized_times_series[unit], q=0.75, axis=0)
            ax_force_seg.plot(force_seg_interp_times, firing_rate_median, c='k', lw=2, zorder=3)
            ax_force_seg.plot(force_seg_interp_times, firing_rate_q1, c='k', lw=2, ls='--')
            ax_force_seg.plot(force_seg_interp_times, firing_rate_q3, c='k', lw=2, ls='--')
            
            firing_rate_median = np.nanmedian(all_video_seg_normalized_times_series[unit], axis=0)
            firing_rate_q1 = np.nanquantile(all_video_seg_normalized_times_series[unit], q=0.25, axis=0)
            firing_rate_q3 = np.nanquantile(all_video_seg_normalized_times_series[unit], q=0.75, axis=0)
            ax_video_seg.plot(video_seg_interp_times, firing_rate_median, c='k', lw=2, zorder=3)
            ax_video_seg.plot(video_seg_interp_times, firing_rate_q1, c='k', lw=2, ls='--')
            ax_video_seg.plot(video_seg_interp_times, firing_rate_q3, c='k', lw=2, ls='--')


        # plot force distribution
        ax_real_time, ax_force_seg, ax_video_seg = axes[-1]

        force_median = np.nanmedian(all_force_seg_normalized_times_series['Force'], axis=0)
        force_q1 = np.nanquantile(all_force_seg_normalized_times_series['Force'], q=0.25, axis=0)
        force_q3 = np.nanquantile(all_force_seg_normalized_times_series['Force'], q=0.75, axis=0)
        ax_force_seg.plot(force_seg_interp_times, force_median, c='k', lw=2, zorder=3)
        ax_force_seg.plot(force_seg_interp_times, force_q1, c='k', lw=2, ls='--')
        ax_force_seg.plot(force_seg_interp_times, force_q3, c='k', lw=2, ls='--')
        
        force_median = np.nanmedian(all_video_seg_normalized_times_series['Force'], axis=0)
        force_q1 = np.nanquantile(all_video_seg_normalized_times_series['Force'], q=0.25, axis=0)
        force_q3 = np.nanquantile(all_video_seg_normalized_times_series['Force'], q=0.75, axis=0)
        ax_video_seg.plot(video_seg_interp_times, force_median, c='k', lw=2, zorder=3)
        ax_video_seg.plot(video_seg_interp_times, force_q1, c='k', lw=2, ls='--')
        ax_video_seg.plot(video_seg_interp_times, force_q3, c='k', lw=2, ls='--')


        plt.suptitle(f'({animal}, {food}, {bout_index}): {data_set_name}')
        plt.tight_layout(rect=(0, 0, 1, 0.97))

        # export figure
        export_dir3 = os.path.join(export_dir, 'sanity-checks-firing-rates')
        if not os.path.exists(export_dir3):
            os.mkdir(export_dir3)
        plt.gcf().savefig(os.path.join(export_dir3, f'{animal} {food} {bout_index}.png'), dpi=300)
        
        pbar.update()
    
    pbar.close()

    if exemplary_bout in feeding_bouts:
        # rename output file for exemplar
        animal, food, bout_index = exemplary_bout
        old_path = os.path.join(export_dir3, f'{animal} {food} {bout_index}.png')
        new_path = os.path.join(export_dir3, 'exemplary_bout.png')
        if os.path.exists(old_path):
            if os.path.exists(new_path):
                os.remove(new_path)
            os.rename(old_path, new_path)

    end = datetime.datetime.now()
    print('elapsed time:', end-start)

## Plotting Functions

In [None]:
def boxplot_with_points(x, y, hue, data, ax=None, show_points=False, describe=True):
    if ax is None:
        ax = plt.gca()
    boxcolor = '0.75' if hue is None else None
    pointcolor = 'k' if hue is None else None
    edgecolor = '0.25'
    linewidth = 0 if hue is None else 1
    size = 4
    
    data = data.dropna(subset=[y])
    
    sns.boxplot(x=x, y=y, hue=hue, data=data, ax=ax, color=boxcolor, whis=999) # whiskers span extrema
    
    if show_points:
        sns.swarmplot(x=x, y=y, hue=hue, data=data, ax=ax, color=pointcolor, linewidth=linewidth, edgecolor=edgecolor, size=size, dodge=True)
        
        if hue is not None:
            # avoid duplicate legend entries
            handles, labels = ax.get_legend_handles_labels()
            n = int(len(labels)/2)
            ax.legend(handles[:n], labels[:n], title=hue)
    
    ax.set_xlabel(None)

    if describe:
        by = [x] if hue is None else [x, hue]
        print(y)
#         print(data.groupby(by)[y].describe())
        print(data.groupby(by)[y].apply(lambda y: {
            'N': f'{y.count()}',
            'Median': y.median(),
#             'Q1': y.quantile(0.25),
#             'Q3': y.quantile(0.75),
        }).unstack())
        print()

In [None]:
default_markers = ['.', '+', 'x', '1', '4', 'o', 'D']
default_colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6']

def scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, trend_separately=False, tooltips=False, padding=0.05, colors=default_colors, markers=default_markers):
    
    for j, (label, query) in enumerate(data_subsets.items()):
        if query is not None:
            df = df_all.query(query)
            ax.scatter(df[xlabel], df[ylabel],
                       label=label, marker=markers[j], c=colors[j], clip_on=False)
            
    all_points = df_all.query(query_union(data_subsets.values()))[[xlabel, ylabel]].dropna()
    
    if len(all_points) > 0:

        xrange = np.ptp(all_points.iloc[:, 0])
        xmin = min(all_points.iloc[:, 0]) - xrange * padding
        xmax = max(all_points.iloc[:, 0]) + xrange * padding
        if xlim is None:
            xlim = [xmin, xmax]
        if xlim[0] is None:
            xlim[0] = xmin
        if xlim[1] is None:
            xlim[1] = xmax

        yrange = np.ptp(all_points.iloc[:, 1])
        ymin = min(all_points.iloc[:, 1]) - yrange * padding
        ymax = max(all_points.iloc[:, 1]) + yrange * padding
        if ylim is None:
            ylim = [ymin, ymax]
        if ylim[0] is None:
            ylim[0] = ymin
        if ylim[1] is None:
            ylim[1] = ymax
    
    if trend_separately:
        for j, (label, query) in enumerate(data_subsets.items()):
            if query is not None:
                df = df_all.query(query)[[xlabel, ylabel]].dropna()
                model = sm.OLS(df.iloc[:,1], sm.add_constant(df.iloc[:,0])).fit()
                model_stats = 'R$^2$ = {:.2f}, p = {:.5f}, n = {}'.format(model.rsquared, model.pvalues[1], len(df))
                print(f'{label}:', model_stats)
                model_x = np.linspace(df.iloc[:,0].min()-np.ptp(df.iloc[:,0])*0.1, df.iloc[:,0].max()+np.ptp(df.iloc[:,0])*0.1, 100)
                model_y = model.params[0] + model.params[1] * model_x
                ax.plot(model_x, model_y, color=colors[j])#, label=model_stats)
                
    if trend:
        model = sm.OLS(all_points.iloc[:,1], sm.add_constant(all_points.iloc[:,0])).fit()
        model_stats = 'R$^2$ = {:.2f}, p = {:.5f}, n = {}'.format(model.rsquared, model.pvalues[1], len(all_points))
        print('All points:', model_stats)
        model_x = np.linspace(all_points.iloc[:,0].min()-np.ptp(all_points.iloc[:,0])*0.1, all_points.iloc[:,0].max()+np.ptp(all_points.iloc[:,0])*0.1, 100)
        model_y = model.params[0] + model.params[1] * model_x
        ax.plot(model_x, model_y, color='k')#, label=model_stats)
    
    if tooltips:
        # create a transparent layer containing all points for detecting mouse events
        sc = ax.scatter(all_points.iloc[:, 0], all_points.iloc[:, 1], label=None,
                        alpha=0, # transparent points
                        s=0.001, # small radius of tooltip activation
                       )
        
        # initialize a hidden empty tooltip
        annot = ax.annotate(
            '', xy=(0, 0),                              # initialize empty at origin
            xytext=(0, 15), textcoords='offset points', # position text above point
            color='k', size=10, ha='center',            # small black centered text
            bbox=dict(fc='w', lw=0, alpha=0.6),         # transparent white background
            arrowprops=dict(shrink=0, headlength=7, headwidth=7, width=0, lw=0, color='k'), # small black arrow
        )
        annot.set_visible(False)
        
        # prepare tooltip contents
        tooltip_text = [' '.join([str(x) for x in index]) for index in list(all_points.index)]
        
        # bind the animation of tooltips to a hover event
        def hover(event):
            if event.inaxes == ax:
                cont, ind = sc.contains(event)
                if cont:
                    point_index = ind['ind'][0]
                    pos = sc.get_offsets()[point_index]
                    annot.xy = pos
                    annot.set_text(tooltip_text[point_index])
                    annot.set_visible(True)
                    ax.figure.canvas.draw_idle()
                else:
                    if annot.get_visible():
                        annot.set_visible(False)
                        ax.figure.canvas.draw_idle()
        ax.figure.canvas.mpl_connect('motion_notify_event', hover)
        
    ax.set_xlabel(xlabel)
    ax.set_xlim(xlim)
    ax.set_ylabel(ylabel)#.replace('rectified ',''))
    ax.set_ylim(ylim)

In [None]:
from matplotlib.offsetbox import AnchoredOffsetbox
class AnchoredScaleBar(AnchoredOffsetbox):
    def __init__(self, transform, sizex=0, sizey=0, labelx=None, labely=None, loc=4,
                 pad=0.1, borderpad=0.1, sep=2, prop=None, barcolor="black", barwidth=None, 
                 **kwargs):
        """
        Draw a horizontal and/or vertical  bar with the size in data coordinate
        of the give axes. A label will be drawn underneath (center-aligned).
        - transform : the coordinate frame (typically axes.transData)
        - sizex,sizey : width of x,y bar, in data units. 0 to omit
        - labelx,labely : labels for x,y bars; None to omit
        - loc : position in containing axes
        - pad, borderpad : padding, in fraction of the legend font size (or prop)
        - sep : separation between labels and bars in points.
        - **kwargs : additional arguments passed to base class constructor
        """
        from matplotlib.patches import Rectangle
        from matplotlib.offsetbox import AuxTransformBox, VPacker, HPacker, TextArea, DrawingArea
        bars = AuxTransformBox(transform)
        if sizex:
#             bars.add_artist(Rectangle((0,0), sizex, 0, ec=barcolor, lw=barwidth, fc="none"))
            bars.add_artist(Rectangle((0,0), -sizex, 0, ec=barcolor, lw=barwidth, fc="none"))
        if sizey:
            bars.add_artist(Rectangle((0,0), 0, sizey, ec=barcolor, lw=barwidth, fc="none"))

        if sizex and labelx:
            self.xlabel = TextArea(labelx, minimumdescent=False)
            bars = VPacker(children=[bars, self.xlabel], align="center", pad=0, sep=sep)
        if sizey and labely:
            self.ylabel = TextArea(labely)
#             bars = HPacker(children=[self.ylabel, bars], align="center", pad=0, sep=sep)
            bars = HPacker(children=[bars, self.ylabel], align="center", pad=0, sep=sep)

        AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad,
                                   child=bars, prop=prop, frameon=False, **kwargs)

In [None]:
def prettyplot_with_scalebars(
    blk,
    t_start,
    t_stop,
    plots,
    
    outfile_basename=None, # base name of output files
    export_only=False,     # if True, will not render in notebook
    formats=['pdf', 'svg', 'png'], # extensions of output files
    dpi=300,               # resolution (applicable only for PNG)
    
    figsize=(14, 7),       # figure size in inches
    linewidth=1,           # thickness of lines in points
    layout_settings=None,  # positioning of plot edges and the space between plots
    
    x_scalebar=1*pq.s,     # size of the time scale bar in seconds
    ylabel_padding=10,     # space between trace labels and plots
    scalebar_padding=1,    # space between scale bars and plots
    scalebar_sep=5,        # space between scale bars and scale labels
    barwidth=2,            # thickness of scale bars
):
    
    if export_only:
        plt.ioff()
        
    fig, axes = plt.subplots(len(plots), 1, sharex=True, figsize=figsize, squeeze=False)
    axes = axes.flatten()

    for i, p in enumerate(plots):

        # get the subplot axes handle
        ax = axes[i]

        # select and rescale a channel for the subplot
        sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == p['channel']), None)
        assert sig is not None, f"Signal with name {p['channel']} not found"
        sig = sig.time_slice(t_start, t_stop)
        sig = sig.rescale(p['units'])

        # downsample the data
        sig_downsampled = DownsampleNeoSignal(sig, p.get('decimation_factor', 1))

        # specify the x- and y-data for the subplot
        ax.plot(
            sig_downsampled.times,
            sig_downsampled.as_quantity(),
            linewidth=linewidth,
            color=p.get('color', 'k'),
        )
        
        # hide the box around the subplot
        ax.set_frame_on(False)

        # specify the y-axis label
        ylabel = p.get('ylabel', sig.name)
        if ylabel is not None:
            ax.set_ylabel(ylabel, rotation='horizontal', ha='right', va='center', labelpad=ylabel_padding)

        # specify the plot range
        ax.set_xlim([t_start, t_stop])
        ax.set_ylim(p['ylim'])

        # disable tick marks
        ax.tick_params(
            bottom=False,
            left=False,
            labelbottom=False,
            labelleft=False)

        # add y-axis scale bar
        if p['scalebar'] is not None:
            ax.add_artist(AnchoredScaleBar(
                ax.transData,
                sizey=p['scalebar'],
                labely=f'{p["scalebar"]} {sig.units.dimensionality.string}',

                loc='center left',
                bbox_to_anchor=(1, 0.5),
                bbox_transform=ax.transAxes,

                pad=0,
                borderpad=scalebar_padding,
                sep=scalebar_sep,
                barwidth=barwidth,
            ))
        
    # add time scale bar below final plot
    if x_scalebar is not None:
        axes[-1].add_artist(AnchoredScaleBar(
            axes[-1].transData,
            sizex=x_scalebar.rescale(sig.times.units).magnitude,
            labelx=f'{x_scalebar.magnitude:g} {x_scalebar.units.dimensionality.string}',

            loc='upper right',
            bbox_to_anchor=(1, 0),
            bbox_transform=axes[-1].transAxes,

            pad=0,
            borderpad=scalebar_padding,
            sep=scalebar_sep,
            barwidth=barwidth,
        ))

    # adjust the white space around and between the subplots
    if layout_settings is None:
        fig.tight_layout(h_pad=0, w_pad=0, pad=0)
    else:
        plt.subplots_adjust(**layout_settings)

    if outfile_basename is not None:
        # specify file metadata (applicable only for PDF)
        metadata = dict(
            Subject = 'Data file: '  + blk.file_origin + '\n' +
                      'Start time: ' + str(t_start)    + '\n' +
                      'End time: '   + str(t_stop),
        )

        # write the figure to files
        for ext in formats:
            fig.savefig(f'{outfile_basename}.{ext}', metadata=metadata, dpi=dpi)

    if export_only:
        plt.ion()
    
    return fig, axes

In [None]:
# def plot_unloaded_vs_loaded(df, y, ax, color='0.25', show_statistics=True, show_signif=True, alpha=0.05):
    
#     df = df.reset_index()
#     df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
#     df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'
    
#     # plot points
#     sns.stripplot(
#         x='Food', y=y, hue='Animal',
#         data=df.groupby(['Animal', 'Food'])[y].mean().reset_index(),
#         order=['Unloaded', 'Loaded'],
#         jitter=False,
#         palette=[color], # do not desaturate by animal
#         ax=ax,
#         clip_on=False,
#     )
#     ax.legend_.remove()

#     # plot lines connecting points
#     ax.plot([
#         df.query('Food == "Unloaded"').groupby('Animal')[y].mean(),
#         df.query('Food == "Loaded"').groupby('Animal')[y].mean()
#     ], color='0.75', clip_on=False)

#     ax.set_xlim([-0.25, 1.25])
#     ax.set_xlabel(None)
#     sns.despine(ax=ax)

#     if show_statistics:
#         print(df.groupby(['Animal', 'Food'])[y].apply(lambda x: {'Mean': x.mean(), 'Count': x.count()}).unstack([1, 2])[['Unloaded', 'Loaded']])
#         print()
#         signif = differences_test(
#             x=df.query('Food == "Loaded"').groupby('Animal')[y].mean(),
#             y=df.query('Food == "Unloaded"').groupby('Animal')[y].mean(),
#             x_label='loaded swallows',
#             y_label='unloaded swallows',
#             measure_label='mean [' + ' '.join(y.split()[:-1]) + ']',
#             units=y.split()[-1].strip('()'),
#             alpha=alpha,
#         )
        
#         if show_signif and signif == '*':
#             ax.annotate(
#                 '*',
#                 xy=(0.5, 1), xycoords='axes fraction',
#                 xytext=(0, -20), textcoords='offset points',
#                 ha='center', fontsize='xx-large', color='0.5',
#             )

In [None]:
def plot_unloaded_vs_loaded(df, y, ax, color='0.25', show_all=False, show_statistics=True, show_signif=True, bracket_width=1.0, alpha=0.05):
    
    df = df.reset_index()
    df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
    df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'
    
#     # plot means
#     sns.pointplot(
#         x='Animal', y=y, hue='Food',
#         data=df,
#         hue_order=['Unloaded', 'Loaded'],
#         dodge=0.4,
# #         palette=[color], # do not desaturate by food
#         palette=['0.75'], # do not desaturate by food
#         join=False,
#         estimator=np.mean,
#         ci='sd',
#         scale=0.6,
#         capsize=0.2,
#         errwidth=1,
#         ax=ax,
#     )
    
#     # plot individual swallows points
#     if show_all:
#         sns.stripplot(
#             x='Animal', y=y, hue='Food',
#             data=df,
#             hue_order=['Unloaded', 'Loaded'],
#             jitter=False,
#             dodge=True,
#             linewidth=1,
#             marker='_',
#             color='k',
#             ax=ax,
#             clip_on=False,
#         )

    for food, offset in {'Unloaded': -0.2, 'Loaded': 0.2}.items():
        for i, (animal, values) in enumerate(df[df['Food'] == food].groupby('Animal')[y]):
            
            # plot the mean
            ax.scatter(
                [i+offset], values.mean(),
                color=color,
                s=15,
                zorder=2, clip_on=False,
            )
            
            # plot the standard error of the mean
            lower_sem = values.mean()-values.sem()
            upper_sem = values.mean()+values.sem()
            ax.plot(
                [i+offset]*2, [lower_sem, upper_sem],
                color=color, lw=1,
                zorder=2, clip_on=False,
            )
            ax.plot(
                [i+offset-0.1, i+offset+0.1], [lower_sem, lower_sem],
                color=color, lw=1,
                zorder=2, clip_on=False,
            )
            ax.plot(
                [i+offset-0.1, i+offset+0.1], [upper_sem, upper_sem],
                color=color, lw=1,
                zorder=2, clip_on=False,
            )
            
            if show_all:
                # plot individual swallows points
                ax.scatter(
                    [i+offset]*values.size, values,

                    facecolors='none',
                    edgecolors='k',
                    linewidths=1,
                    s=12,

#                     facecolors='k',
#                     marker='_',
#                     s=18,

                    zorder=3,
                )
            ax.axvline(x=i+offset, c='0.9', lw=0.5, zorder=-2)
            ax.annotate(
                food[0],
                xy=(i+offset, -0.02), xycoords=ax.get_xaxis_transform(),
                ha='center', va='top', fontsize='x-small', color='0.5',
            )
            if food == 'Unloaded': # do it just once
                ax.annotate(
                    i+1,
                    xy=(i, -0.02), xycoords=ax.get_xaxis_transform(),
                    xytext=(0, -10), textcoords='offset points',
                    ha='center', va='top', fontsize='x-small', color='0.5',
                )

    # plot lines connecting points
    for i, (y_unloaded, y_loaded) in enumerate(zip(
                df.query('Food == "Unloaded"').groupby('Animal')[y].mean(),
                df.query('Food == "Loaded"').groupby('Animal')[y].mean())):
        ax.plot(
            [i-0.2, i+0.2],
            [y_unloaded, y_loaded],
            color='0.75',
            zorder=-1,
        )

    ax.tick_params(bottom=False, labelbottom=False)
#     ax.legend_.remove()
    ax.set_xlabel(None)
    sns.despine(ax=ax, bottom=True)

    if show_statistics:
        print(df.groupby(['Animal', 'Food'])[y].apply(lambda x: {'Mean': x.mean(), 'SEM': x.sem(), 'STD': x.std(), 'Count': x.count()}).unstack([1, 2])[['Unloaded', 'Loaded']])
        print()
        signif = differences_test(
            x=df.query('Food == "Loaded"').groupby('Animal')[y].mean(),
            y=df.query('Food == "Unloaded"').groupby('Animal')[y].mean(),
            x_label='loaded swallows',
            y_label='unloaded swallows',
            measure_label='mean [' + ' '.join(y.split()[:-1]) + ']',
            units=y.split()[-1].strip('()'),
            alpha=alpha,
        )
        
        if show_signif and signif == '*':
            from matplotlib.patches import ArrowStyle
            ax.annotate(
                '*',
                xy=(0.5, 1.03), xycoords='axes fraction',
                xytext=(0, 0), textcoords='offset points',
                arrowprops=dict(arrowstyle=ArrowStyle.BracketB(widthB=bracket_width, lengthB=0.2, angleB=None), color='0.5'),
                ha='center', fontsize='x-large', color='0.5',
            )

In [None]:
def solve_figure_vertical_dimensions(nrows, subplot_height_in_inches, top_margin_in_inches, bottom_margin_in_inches, hspace):
    
    fig_height_in_inches = (nrows)*(subplot_height_in_inches) + (nrows-1)*(hspace*subplot_height_in_inches) + top_margin_in_inches + bottom_margin_in_inches
    top_fraction = 1 - top_margin_in_inches/fig_height_in_inches
    bottom_fraction = bottom_margin_in_inches/fig_height_in_inches
    
    return fig_height_in_inches, top_fraction, bottom_fraction

In [None]:
def solve_figure_horizontal_dimensions(ncols, subplot_width_in_inches, left_margin_in_inches, right_margin_in_inches, wspace):
    
    fig_width_in_inches = (ncols)*(subplot_width_in_inches) + (ncols-1)*(wspace*subplot_width_in_inches) + left_margin_in_inches + right_margin_in_inches
    right_fraction = 1 - right_margin_in_inches/fig_width_in_inches
    left_fraction = left_margin_in_inches/fig_width_in_inches
    
    return fig_width_in_inches, left_fraction, right_fraction

---

# Figures

## [FIGURE 1]

### 🐌 Figure 1A

In [None]:
start = datetime.datetime.now()

df = df_exemplary_swallow
(data_set_name, time_window) = feeding_bouts[exemplary_swallow]
animal, _, _ = exemplary_swallow

t_start, t_stop = time_window*pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -35,  35], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 100},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 100},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -70,  70], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 100},
    {'channel': 'Force',    'units': 'mN', 'ylim': [-100, 400], 'scalebar': 200}, #, 'decimation_factor': 100},
]
plot_names = [p['channel'] for p in plots]
plot_units = [p['units'] for p in plots]

kwargs = dict(
    figsize = (5, 6.5),
    linewidth = 0.5,
    x_scalebar = 1*pq.s,
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the signal
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

unit_burst_boxes = {
    'B38':   [-12, 12],
    'I2':    [-30, 25],
    'B8a/b': [-20, 12],
    'B6/B9': [-20, 15],
    'B3':    [-45, 35],
    'B4/B5': [-60, 55],
}

for j, i in enumerate(df.index):
    for k, unit in enumerate(units):
        st = df.loc[i, f'{unit} spike train']
        if st is not None and st.size > 0:

            # get the neural channel
            channel = st.annotations['channels'][0]
            sig = get_sig(blk, channel)

            # get the signal for the entire bout
            sig = sig.time_slice(t_start, t_stop)
            sig = sig.rescale(plot_units[plot_names.index(channel)])

            # plot spikes
            ax = axes[plot_names.index(channel)]
            spike_amplitudes = np.array([sig[sig.time_index(t)] for t in st]) * pq.Quantity(sig.units)
            ax.scatter(st.times.rescale('s'), spike_amplitudes, marker='.', s=20, c=unit_colors[unit], zorder=3)

            # plot burst windows
            left = df.loc[i, f'{unit} burst start (s)']
            right = df.loc[i, f'{unit} burst end (s)']
            if np.isfinite(left) and np.isfinite(right):
                width = right-left
                bottom, top = unit_burst_boxes[unit]
                height = top-bottom
                rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor=unit_colors[unit], fill=False, zorder=3, clip_on=False)
                ax.add_patch(rect)

# plot a zero baseline for force
force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
axes[-1].axhline(force_zero, color='0.75', lw=1, ls='--', zorder=-1)

# plot force phase boundaries
times = df.loc[0, 'Force segmentation times (s)'][2:2+6]
for t in times:
    axes[-1].axvline(x=t, ymin=0.15, lw=1, ls=':', c='gray', zorder=-1)

# add arabic numerals for force phase boundaries
axes[-1].annotate('1', xy=(times[0], -0.05), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
axes[-1].annotate('2', xy=(times[1], -0.05), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
axes[-1].annotate('3', xy=(times[2], -0.05), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
axes[-1].annotate('4', xy=(times[3], -0.05), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
axes[-1].annotate('5', xy=(times[4], -0.05), xycoords=('data', 'axes fraction'), ha='center', va='bottom')
axes[-1].annotate('1', xy=(times[5], -0.05), xycoords=('data', 'axes fraction'), ha='center', va='bottom')

# add roman numerals for force phases
axes[-1].annotate('I',   xy=(times[0:2].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')
axes[-1].annotate('II',  xy=(times[1:3].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')
axes[-1].annotate('III', xy=(times[2:4].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')
axes[-1].annotate('IV',  xy=(times[3:5].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')
axes[-1].annotate('V',   xy=(times[4:6].mean(), 1), xycoords=('data', 'axes fraction'), ha='center', va='top')

# add unit name labels
axes[2].annotate('B38',   xy=(2978.15, 0.70), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B38'])
axes[0].annotate('I2',    xy=(2979.40, 0.75), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['I2'])
axes[1].annotate('B8a/b', xy=(2981.10, 0.80), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B8a/b'])
axes[2].annotate('B6/B9', xy=(2980.10, 0.73), xycoords=('data', 'axes fraction'), ha='center', c=unit_colors['B6/B9'])
axes[2].annotate('B3',    xy=(2981.85, 0.79), xycoords=('data', 'axes fraction'), ha='left',   c=unit_colors['B3'])
axes[3].annotate('B4/B5', xy=(2979.10, 0.80), xycoords=('data', 'axes fraction'), ha='right',  c=unit_colors['B4/B5'])

# add protraction box
left, right = df.loc[0, ['I2 burst start (s)', 'I2 burst end (s)']]
bottom, top = (1, 1.15)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', fill=False, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)
axes[0].annotate('Prot.', xy=((left+right)/2, 1.065), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')

# add retraction box
left, right = df.loc[0, ['I2 burst end (s)', 'End (s)']] # behavior ends with end of B43 burst
bottom, top = (1, 1.15)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', facecolor='k', fill=True, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)
axes[0].annotate('Retraction', xy=((left+right)/2, 1.065), xycoords=('data', 'axes fraction'), ha='center', va='center', c='w', fontsize='small')

# add inward movement box
left, right = df.loc[0, ['Inward movement start (s)', 'Inward movement end (s)']]
bottom, top = (1.20, 1.35)
width = right-left
height = top-bottom
rect = patches.Rectangle((left, bottom), width, height, linewidth=2, ls='-', edgecolor='k', facecolor='0.75', fill=True, zorder=3, clip_on=False, transform=axes[0].get_xaxis_transform())
axes[0].add_patch(rect)
axes[0].annotate('Inward', xy=((left+right)/2, 1.260), xycoords=('data', 'axes fraction'), ha='center', va='center', c='k', fontsize='small')

# add markers for video frame times
video_times = {
    'B': 2980.0,
    'C': 2982.5,
}
for label, video_time in video_times.items():
    axes[0].plot([video_time], [1.39], marker=CARETDOWN, markersize=8, color='k', transform=axes[0].get_xaxis_transform(), clip_on=False)
    axes[0].annotate(label, xy=(video_time, 1.46), xycoords=('data', 'axes fraction'), ha='center', va='bottom')

fig.tight_layout(h_pad=0, w_pad=0, pad=0)

fig.savefig(os.path.join(export_dir, 'figure-1A.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figures 1B & 1C

Video frames (see code for Figure 1A for times of video frames)

## [FIGURE 2]

In [None]:
# shared plot settings
inches_per_second = 0.05
subplot_height_in_inches = 0.6
top_margin_in_inches = 0.12
bottom_margin_in_inches = 0.39
left_margin_in_inches = 0.72
right_margin_in_inches = 0.8

### 🐌 Figure 2A

In [None]:
start = datetime.datetime.now()

data_set_name = 'IN VIVO / JG12 / 2019-05-10 / 002'
animal = 'JG12'
t_start, t_stop = [223.4, 261.4] * pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 400},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 400},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 400},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 400},
]

fig_height_in_inches, top_fraction, bottom_fraction = solve_figure_vertical_dimensions(
    len(plots), subplot_height_in_inches, top_margin_in_inches, bottom_margin_in_inches, 0)
fig_width_in_inches, left_fraction, right_fraction = solve_figure_horizontal_dimensions(
    1, inches_per_second*(t_stop-t_start).magnitude, left_margin_in_inches, right_margin_in_inches, 0)

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (fig_width_in_inches, fig_height_in_inches),
    linewidth = 0.5,
    x_scalebar = 10*pq.s,
    layout_settings = dict(
        left   = left_fraction,
        right  = right_fraction,
        bottom = bottom_fraction,
        top    = top_fraction,
        hspace = 0,
    ),
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the data
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

# plot the inward movements
y_inward = 1.03
ep = next((ep for ep in blk.segments[0].epochs if ep.name=='Inward movement'), None)
assert ep is not None
ep = ep.time_slice(t_start, t_stop)
for t, dur in zip(ep.times, ep.durations):
    left, right = t, t + dur
    bottom, top = y_inward-0.025, y_inward+0.025
    width = right-left
    height = top-bottom
    rect = patches.Rectangle((left, bottom), width, height, linewidth=0, facecolor='k', fill=True, clip_on=False, transform=axes[0].get_xaxis_transform())
    axes[0].add_patch(rect)
# axes[0].annotate(
#     'Inward',
#     xy=(0, y_inward), xycoords='axes fraction',
#     xytext=(-10, 0), textcoords='offset points',
#     ha='right', va='center',
# )

# add bite/swallow markers
behavior_markers = [
    (226.4, 'B'),
    (246.3, 'S'),
    (250.6, 'S'),
    (256.2, 'S'),
#     (263.1, 'B'),
]
for t, label in behavior_markers:
    axes[0].annotate(
        label,
        xy=(t, y_inward), xycoords=axes[0].get_xaxis_transform(),
        ha='center', va='center', fontsize='small',
    )

fig.savefig(os.path.join(export_dir, 'figure-2A.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 2B

In [None]:
start = datetime.datetime.now()

data_set_name = 'IN VIVO / JG12 / 2019-05-10 / 002'
animal = 'JG12'
t_start, t_stop = [2875.3, 3039.3] * pq.s
plots = [
    {'channel': 'I2',       'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50}, #, 'decimation_factor': 400},
    {'channel': 'RN',       'units': 'uV', 'ylim': [ -25,  25], 'scalebar': 25}, #, 'decimation_factor': 400},
    {'channel': 'BN2',      'units': 'uV', 'ylim': [ -45,  45], 'scalebar': 50}, #, 'decimation_factor': 400},
    {'channel': 'BN3-DIST', 'units': 'uV', 'ylim': [ -60,  60], 'scalebar': 50, 'ylabel': 'BN3'}, #, 'decimation_factor': 400},
    {'channel': 'Force',    'units': 'mN', 'ylim': [ -50, 450], 'scalebar': 300}, #, 'decimation_factor': 100},
]

fig_height_in_inches, top_fraction, bottom_fraction = solve_figure_vertical_dimensions(
    len(plots), subplot_height_in_inches, top_margin_in_inches, bottom_margin_in_inches, 0)
fig_width_in_inches, left_fraction, right_fraction = solve_figure_horizontal_dimensions(
    1, inches_per_second*(t_stop-t_start).magnitude, left_margin_in_inches, right_margin_in_inches, 0)

kwargs = dict(
    formats = ['png'],
    dpi = 600,
    figsize = (fig_width_in_inches, fig_height_in_inches),
    linewidth = 0.5,
    x_scalebar = 10*pq.s,
    layout_settings = dict(
        left   = left_fraction,
        right  = right_fraction,
        bottom = bottom_fraction,
        top    = top_fraction,
        hspace = 0,
    ),
)

# load the metadata
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)

# ensure that the right filters are used
metadata['filters'] = sig_filters_by_animal[animal]

# load the data
blk = neurotic.load_dataset(metadata, lazy=True)

# manually perform filters
blk = apply_filters(blk, metadata)

# plot the data
fig, axes = prettyplot_with_scalebars(blk, t_start, t_stop, plots, **kwargs)

# plot the inward movements
y_inward = 1.03
ep = next((ep for ep in blk.segments[0].epochs if ep.name=='Inward movement'), None)
assert ep is not None
ep = ep.time_slice(t_start, t_stop)
for t, dur in zip(ep.times, ep.durations):
    left, right = t, t + dur
    bottom, top = y_inward-0.025, y_inward+0.025
    width = right-left
    height = top-bottom
    rect = patches.Rectangle((left, bottom), width, height, linewidth=0, facecolor='k', fill=True, clip_on=False, transform=axes[0].get_xaxis_transform())
    axes[0].add_patch(rect)
# axes[0].annotate(
#     'Inward',
#     xy=(0, y_inward), xycoords='axes fraction',
#     xytext=(-10, 0), textcoords='offset points',
#     ha='right', va='center',
# )

# add bite/swallow markers
behavior_markers = [
    (2877.5, 'B'),
]
for t, label in behavior_markers:
    axes[0].annotate(
        label,
        xy=(t, y_inward), xycoords=axes[0].get_xaxis_transform(),
        ha='center', va='center', fontsize='small',
    )

# plot a zero baseline for force
ax = axes[-1]
force_zero = -11.5 # mN, average before animal began swallowing, not zero because of DC offset
ax.axhline(force_zero, color='0.5', lw=1, ls='--', zorder=-1)

fig.savefig(os.path.join(export_dir, 'figure-2B.png'), dpi=600)

end = datetime.datetime.now()
print('render time:', end-start)

### 🐌 Figure 2C

In [None]:
y = 'Inward movement start to next inward movement start (s)'

fig, ax = plt.subplots(1, 1, figsize=(2.25, 3))
plot_unloaded_vs_loaded(df_all.drop(unreliable_inward_movement), y, ax, bracket_width=2.9)
ax.set_ylim([0, 10])
ax.set_ylabel('Start of one inward\nmovement to the next (s)')
plt.subplots_adjust(left=0.32, right=0.99, top=0.91, bottom=0.10)
fig.savefig(os.path.join(export_dir, 'figure-2C.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 5))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(unreliable_inward_movement).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

### 🐌 Figure 2D

In [None]:
y = 'Inward movement duration (s)'

fig, ax = plt.subplots(1, 1, figsize=(2.25, 3))
plot_unloaded_vs_loaded(df_all.drop(unreliable_inward_movement), y, ax, bracket_width=2.9)
ax.set_ylim([0, 10])
ax.set_ylabel('Duration of\ninward movement (s)')
plt.subplots_adjust(left=0.32, right=0.99, top=0.91, bottom=0.10)
fig.savefig(os.path.join(export_dir, 'figure-2D.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 5))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(unreliable_inward_movement).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

### 🐌 Figure 2E

In [None]:
y = 'Inward movement end to next inward movement start (s)'

fig, ax = plt.subplots(1, 1, figsize=(2.25, 3))
plot_unloaded_vs_loaded(df_all.drop(unreliable_inward_movement), y, ax, bracket_width=2.9)
ax.set_ylim([0, 10])
ax.set_ylabel('Time between\ninward movements (s)')
plt.subplots_adjust(left=0.32, right=0.99, top=0.91, bottom=0.10)
fig.savefig(os.path.join(export_dir, 'figure-2E.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 5))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(unreliable_inward_movement).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

---

## [FIGURE 3]

In [None]:
video_seg_phase_labels = [
    '',
    '',
    '',
    '',
    'Inward movement',
    '',
    '',
    '',
    '',
]

def get_median_video_seg_phase_boundaries(df, show_summary=False):
    
    # the dataframe column 'Video segmentation times (s)' is an array of these times:
    # - 0: inward_movement_start-inward_movement_duration*4,
    # - 1: inward_movement_start-inward_movement_duration*3,
    # - 2: inward_movement_start-inward_movement_duration*2,
    # - 3: inward_movement_start-inward_movement_duration*1,
    # - 4: inward_movement_start,
    # - 5: inward_movement_end,
    # - 6: inward_movement_end+inward_movement_duration*1,
    # - 7: inward_movement_end+inward_movement_duration*2,
    # - 8: inward_movement_end+inward_movement_duration*3,
    # - 9: inward_movement_end+inward_movement_duration*4,

    # get all normalization times from inward_movement_start to inward_movement_end
    t = np.stack(df['Video segmentation times (s)']).magnitude[:, 4:5+1]

    # get all phase durations
    all_video_seg_phase_durations = np.diff(t).T

    # find median video segmentation phase durations
    # - 0: inward movement
    median_video_seg_phase_durations = np.nanmedian(all_video_seg_phase_durations, axis=1)

    # copy phase durations
    # - 0: inward movement (duplicate)
    # - 1: inward movement (duplicate)
    # - 2: inward movement (duplicate)
    # - 3: inward movement (duplicate)
    # - 4: inward movement
    # - 5: inward movement (duplicate)
    # - 6: inward movement (duplicate)
    # - 7: inward movement (duplicate)
    # - 8: inward movement (duplicate)
    median_video_seg_phase_durations = np.concatenate([
        median_video_seg_phase_durations,
        median_video_seg_phase_durations,
        median_video_seg_phase_durations,
        median_video_seg_phase_durations,
        median_video_seg_phase_durations,
        median_video_seg_phase_durations,
        median_video_seg_phase_durations,
        median_video_seg_phase_durations,
        median_video_seg_phase_durations,
    ])

    # convert median durations into median boundary timings
    # - 0: inward_movement_start-inward_movement_duration*4 <-- reference point, zero by definition
    # - 1: inward_movement_start-inward_movement_duration*3
    # - 2: inward_movement_start-inward_movement_duration*2
    # - 3: inward_movement_start-inward_movement_duration*1
    # - 4: inward_movement_start
    # - 5: inward_movement_end
    # - 6: inward_movement_end+inward_movement_duration*1
    # - 7: inward_movement_end+inward_movement_duration*2
    # - 8: inward_movement_end+inward_movement_duration*3
    # - 9: inward_movement_end+inward_movement_duration*4
    median_video_seg_phase_boundaries = np.concatenate([[0], median_video_seg_phase_durations]).cumsum()
    
    if show_summary:
        # plot the distributions of phase durations
        plt.figure(figsize=(4,3))
        plt.boxplot(
            [a[np.isfinite(a)] for a in list(all_video_seg_phase_durations)],
            labels=video_seg_phase_labels[4:5],
            showmeans=True,
        )
        plt.ylim([0, None])
        plt.ylabel('Phase duration (s)')
        plt.title('Video segmentation')
        plt.tight_layout()

        # print summaries of the phase durations
        for t, l in zip(all_video_seg_phase_durations, video_seg_phase_labels[4:5]):
            l = l.replace('\n', ' ')
            print(f'{l}:\tmedian {np.nanmedian(t):g}, mean {np.nanmean(t):g} (n={t[np.isfinite(t)].size})')

    return median_video_seg_phase_boundaries

In [None]:
# force_seg_phase_labels = [
#     'IV\nPrevious force\nmaintenance',
#     'V\nPrevious major\nforce drop',
#     'I\nPartial force\nmaintenance',
#     'II\nForce dip',
#     'III\nForce rise',
#     'IV\nForce\nmaintenance',
#     'V\nMajor\nforce drop',
#     'I\nNext partial\nforce maintenance',
#     'II\nNext\nforce dip',
# ]
force_seg_phase_labels = [
    '',
    'V',
    'I',
    'II',
    'III',
    'IV',
    'V',
    '',
    '',
]

def get_median_force_seg_phase_boundaries(df, show_summary=False):
    
    # the dataframe column 'Force segmentation times (s)' is an array of these times:
    # - 0: prev_force_plateau_start
    # - 1: prev_force_plateau_end
    # - 2: prev_force_drop_end
    # - 3: force_shoulder_end
    # - 4: force_rise_start
    # - 5: force_plateau_start
    # - 6: force_plateau_end
    # - 7: force_drop_end
    # - 8: next_force_shoulder_end
    # - 9: next_force_rise_start

    # get all normalization times from previous_force_drop_end to current force_drop_end
    t = np.stack(df['Force segmentation times (s)']).magnitude[:, 2:7+1]

    # get all phase durations
    all_force_seg_phase_durations = np.diff(t).T

    # find median force segmentation phase durations
    # - 0: partial force maintenance
    # - 1: force dip
    # - 2: force rise
    # - 3: force maintenance
    # - 4: major force drop
    median_force_seg_phase_durations = np.nanmedian(all_force_seg_phase_durations, axis=1)

    # copy phase durations to represent "previous" and "next" swallows
    # - 0: previous force maintenance
    # - 1: previous major force drop
    # - 2: partial force maintenance
    # - 3: force dip
    # - 4: force rise
    # - 5: force maintenance
    # - 6: major force drop
    # - 7: next partial force maintenance
    # - 8: next force dip
    median_force_seg_phase_durations = np.concatenate([
        median_force_seg_phase_durations[-2:],
        median_force_seg_phase_durations,
        median_force_seg_phase_durations[:2],
    ])

    # convert median durations into median boundary timings
    # - 0: prev_force_plateau_start <-- reference point, zero by definition
    # - 1: prev_force_plateau_end
    # - 2: prev_force_drop_end
    # - 3: force_shoulder_end
    # - 4: force_rise_start
    # - 5: force_plateau_start
    # - 6: force_plateau_end
    # - 7: force_drop_end
    # - 8: next_force_shoulder_end
    # - 9: next_force_rise_start
    median_force_seg_phase_boundaries = np.concatenate([[0], median_force_seg_phase_durations]).cumsum()
    
    if show_summary:
        # plot the distributions of phase durations
        plt.figure(figsize=(4,3))
        plt.boxplot(
            [a[np.isfinite(a)] for a in list(all_force_seg_phase_durations)],
            labels=force_seg_phase_labels[2:7],
            showmeans=True,
        )
        plt.ylim([0, None])
        plt.ylabel('Phase duration (s)')
        plt.title('Force segmentation')
        plt.tight_layout()

        # print summaries of the phase durations
        for t, l in zip(all_force_seg_phase_durations, force_seg_phase_labels[2:7]):
            l = l.replace('\n', ' ')
            print(f'{l}:\tmedian {np.nanmedian(t):g}, mean {np.nanmean(t):g} (n={t[np.isfinite(t)].size})')

    return median_force_seg_phase_boundaries

In [None]:
###
### LOADED, FORCE SEGEMENTATION
###

median_loaded_force_seg_phase_boundaries = get_median_force_seg_phase_boundaries(df_all.query('Food == "Tape nori"'), show_summary=True)
plt.title('Force segmentation of loaded swallows')

###
### UNLOADED, VIDEO SEGEMENTATION
###

median_unloaded_video_seg_phase_boundaries = get_median_video_seg_phase_boundaries(df_all.query('Food == "Regular nori"').drop(unreliable_inward_movement), show_summary=True)
plt.title('Video segmentation of unloaded swallows')

###
### LOADED, VIDEO SEGEMENTATION
###

median_loaded_video_seg_phase_boundaries = get_median_video_seg_phase_boundaries(df_all.query('Food == "Tape nori"'), show_summary=True)
plt.title('Video segmentation of loaded swallows')

###
### SHARED FIGURE SETTINGS
###

fig_width = 7 # inches
fig_left_margin = 0.10 # axes fraction
fig_right_margin = 0.90 # axes fraction
t_width = 9.0 # sec
time_before_force_rise = 4.1 # sec
xlim_loaded_force_seg = [
    (median_loaded_force_seg_phase_boundaries[4]-time_before_force_rise),
    (median_loaded_force_seg_phase_boundaries[4]-time_before_force_rise) + t_width,
]
xlim_unloaded_video_seg = [
    (median_unloaded_video_seg_phase_boundaries[4]-time_before_force_rise),
    (median_unloaded_video_seg_phase_boundaries[4]-time_before_force_rise) + t_width,
]
xlim_loaded_video_seg = [
    (median_loaded_video_seg_phase_boundaries[4]-time_before_force_rise),
    (median_loaded_video_seg_phase_boundaries[4]-time_before_force_rise) + t_width,
]

In [None]:
shoulder_end_times = np.stack(df_all.query('Food == "Tape nori"')['Force segmentation times (s)'].values).magnitude[:,3]
print(f'{np.sum(np.isnan(shoulder_end_times))} out of {shoulder_end_times.size} tape nori swallows have missing shoulders')

### 🐌 Figure 3A

Biomechanics schematic

### 🐌 Figure 3B

In [None]:
fig = plt.figure(figsize=(fig_width, 2.5))
ax = plt.gca()

# specify the data subset to use
df = df_all.query('Food == "Regular nori"').drop(unreliable_inward_movement)

# print number of swallows per animal
print(df.groupby('Animal')['B6/B9 burst end (video seg normalized)'].count())
print()

# use a time plot range appropriate for the unnormalized burst timing
xlim = xlim_unloaded_video_seg

# specify boundaries where vertical lines will be drawn, and the labels between them
phase_boundaries = median_unloaded_video_seg_phase_boundaries[4:6]
phase_labels = video_seg_phase_labels[4:6]

# specify times used for unnormalization
unnormalization_fixed_times = median_unloaded_video_seg_phase_boundaries

###
### UNIT BOXES
###

for i, unit in enumerate(units):
    t0_data    = df[f'{unit} burst start (video seg normalized)'].dropna()
    t0_data[:] = unnormalize_time(unnormalization_fixed_times, t0_data.values)
    t0_median  = t0_data.median()
    t0_q1      = t0_data.quantile(0.25)
    t0_q3      = t0_data.quantile(0.75)
    
    t1_data    = df[f'{unit} burst end (video seg normalized)'].dropna()
    t1_data[:] = unnormalize_time(unnormalization_fixed_times, t1_data.values)
    t1_median  = t1_data.median()
    t1_q1      = t1_data.quantile(0.25)
    t1_q3      = t1_data.quantile(0.75)
    
    # print summaries of the unit timing
    print(f'{unit}:\t[{t0_median:.2f} (n={t0_data.size}), {t1_median:.2f} (n={t1_data.size})], duration: {t1_median-t0_median:.2f}')
    
    if unit == 'B3':
        # skip plotting B3 since it rarely burst for unloaded swallows
        continue
    
    # plot boxes using medians
    height = 0.8
    lw = 1
    rect = patches.Rectangle((t0_median, i-height/2), t1_median-t0_median, height, facecolor=unit_colors[unit], edgecolor='k', lw=lw, clip_on=False)
    ax.add_patch(rect)
    
    # plot quartiles (25% and 75% quantiles)
    ax.add_line(mlines.Line2D([t0_q1, t0_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t0_q3, t0_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q1, t1_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q3, t1_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t0_q1, t0_q3], [i, i], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q1, t1_q3], [i, i], color='k', lw=lw, clip_on=False))

###
### INWARD SEAWEED MOVEMENT
###

t0_data    = df['Inward movement start (video seg normalized)'].dropna()
t0_data[:] = unnormalize_time(unnormalization_fixed_times, t0_data.values)
t0_median  = t0_data.median()
t0_q1      = t0_data.quantile(0.25)
t0_q3      = t0_data.quantile(0.75)

t1_data    = df['Inward movement end (video seg normalized)'].dropna()
t1_data[:] = unnormalize_time(unnormalization_fixed_times, t1_data.values)
t1_median  = t1_data.median()
t1_q1      = t1_data.quantile(0.25)
t1_q3      = t1_data.quantile(0.75)

# print summaries of the unit timing
print(f'Inward:\t[{t0_median:.2f} (n={t0_data.size}), {t1_median:.2f} (n={t1_data.size})], duration: {t1_median-t0_median:.2f}')

# plot box using medians
i = len(units)
height = 0.8
lw = 1
rect = patches.Rectangle((t0_median, i-height/2), t1_median-t0_median, height, facecolor='0.75', edgecolor='k', lw=lw, clip_on=False)
ax.add_patch(rect)

# # plot quartiles (25% and 75% quantiles)
# ax.add_line(mlines.Line2D([t0_q1, t0_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t0_q3, t0_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t1_q1, t1_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t1_q3, t1_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t0_q1, t0_q3], [i, i], color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t1_q1, t1_q3], [i, i], color='k', lw=lw, clip_on=False))

###
### LABELS AND ANNOTATIONS
###

# plot vertical lines
for t in phase_boundaries:
    ax.axvline(x=t, ls=':', lw=1, c='gray', zorder=-1)

# add phase labels on bottom edge
ax.tick_params(bottom=False) # disable tick marks
ax.set_xticks(phase_boundaries[:-1]+np.diff(phase_boundaries)/2)
ax.set_xticklabels(phase_labels, fontsize='small')

# add box labels on left edge
ax.tick_params(left=False) # disable tick marks
ax.set_yticks(range(len(units)+1))
ax.set_yticklabels(units + ['Inward'], fontsize='medium')

# add time scale bar
ax.add_artist(AnchoredScaleBar(
    ax.transData, sizex=1, labelx='1 s',
    loc='lower right', bbox_to_anchor=(1, 0), bbox_transform=ax.transAxes,
    pad=0, borderpad=0.5, sep=5, barwidth=2,
))

###
### FINISH
###

# set plot ranges
ax.set_xlim(xlim)
ax.set_ylim(6.75, -0.75)

# remove box around figure
sns.despine(bottom=True, left=True)

plt.subplots_adjust(
    left   = fig_left_margin,
    right  = fig_right_margin,
    bottom = 0.12,
    top    = 1.00,
)

fig.savefig(os.path.join(export_dir, 'figure-3B.png'), dpi=600)

### 🐌 Figure 3C

In [None]:
fig = plt.figure(figsize=(fig_width, 2.5))
ax = plt.gca()

# specify the data subset to use
df = df_all.query('Food == "Tape nori"')

# print number of swallows per animal
print(df.groupby('Animal')['B6/B9 burst end (force seg normalized)'].count())
print()

# use a time plot range appropriate for the unnormalized burst timing
xlim = xlim_loaded_force_seg

# specify boundaries where vertical lines will be drawn, and the labels between them
phase_boundaries = median_loaded_force_seg_phase_boundaries[1:8]
phase_labels = force_seg_phase_labels[1:8]

# specify times used for unnormalization
unnormalization_fixed_times = median_loaded_force_seg_phase_boundaries

###
### UNIT BOXES
###

for i, unit in enumerate(units):
    t0_data    = df[f'{unit} burst start (force seg normalized)'].dropna()
    t0_data[:] = unnormalize_time(unnormalization_fixed_times, t0_data.values)
    t0_median  = t0_data.median()
    t0_q1      = t0_data.quantile(0.25)
    t0_q3      = t0_data.quantile(0.75)
    
    t1_data    = df[f'{unit} burst end (force seg normalized)'].dropna()
    t1_data[:] = unnormalize_time(unnormalization_fixed_times, t1_data.values)
    t1_median  = t1_data.median()
    t1_q1      = t1_data.quantile(0.25)
    t1_q3      = t1_data.quantile(0.75)
    
    # print summaries of the unit timing
    print(f'{unit}:\t[{t0_median:.2f} (n={t0_data.size}), {t1_median:.2f} (n={t1_data.size})], duration: {t1_median-t0_median:.2f}')
    
    # plot boxes using medians
    height = 0.8
    lw = 1
    rect = patches.Rectangle((t0_median, i-height/2), t1_median-t0_median, height, facecolor=unit_colors[unit], edgecolor='k', lw=lw, clip_on=False)
    ax.add_patch(rect)
    
    # plot quartiles (25% and 75% quantiles)
    ax.add_line(mlines.Line2D([t0_q1, t0_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t0_q3, t0_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q1, t1_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q3, t1_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t0_q1, t0_q3], [i, i], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q1, t1_q3], [i, i], color='k', lw=lw, clip_on=False))

###
### INWARD SEAWEED MOVEMENT
###

t0_data    = df['Inward movement start (force seg normalized)'].dropna()
t0_data[:] = unnormalize_time(unnormalization_fixed_times, t0_data.values)
t0_median  = t0_data.median()
t0_q1      = t0_data.quantile(0.25)
t0_q3      = t0_data.quantile(0.75)

t1_data    = df['Inward movement end (force seg normalized)'].dropna()
t1_data[:] = unnormalize_time(unnormalization_fixed_times, t1_data.values)
t1_median  = t1_data.median()
t1_q1      = t1_data.quantile(0.25)
t1_q3      = t1_data.quantile(0.75)

# print summaries of the unit timing
print(f'Inward:\t[{t0_median:.2f} (n={t0_data.size}), {t1_median:.2f} (n={t1_data.size})], duration: {t1_median-t0_median:.2f}')

# plot box using medians
i = len(units)
height = 0.8
lw = 1
rect = patches.Rectangle((t0_median, i-height/2), t1_median-t0_median, height, facecolor='0.75', edgecolor='k', lw=lw, clip_on=False)
ax.add_patch(rect)

# plot quartiles (25% and 75% quantiles)
ax.add_line(mlines.Line2D([t0_q1, t0_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
ax.add_line(mlines.Line2D([t0_q3, t0_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
ax.add_line(mlines.Line2D([t1_q1, t1_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
ax.add_line(mlines.Line2D([t1_q3, t1_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
ax.add_line(mlines.Line2D([t0_q1, t0_q3], [i, i], color='k', lw=lw, clip_on=False))
ax.add_line(mlines.Line2D([t1_q1, t1_q3], [i, i], color='k', lw=lw, clip_on=False))

###
### LABELS AND ANNOTATIONS
###

# plot vertical lines
for t in phase_boundaries:
    ax.axvline(x=t, ls=':', lw=1, c='gray', zorder=-1)

# add phase labels on bottom edge
ax.tick_params(bottom=False) # disable tick marks
ax.set_xticks(phase_boundaries[:-1]+np.diff(phase_boundaries)/2)
ax.set_xticklabels(phase_labels, fontsize='small')

# add box labels on left edge
ax.tick_params(left=False) # disable tick marks
ax.set_yticks(range(len(units)+1))
ax.set_yticklabels(units + ['Inward'], fontsize='medium')

# add time scale bar
ax.add_artist(AnchoredScaleBar(
    ax.transData, sizex=1, labelx='1 s',
    loc='lower right', bbox_to_anchor=(1, 0), bbox_transform=ax.transAxes,
    pad=0, borderpad=0.5, sep=5, barwidth=2,
))

###
### FINISH
###

# set plot ranges
ax.set_xlim(xlim)
ax.set_ylim(6.75, -0.75)

# remove box around figure
sns.despine(bottom=True, left=True)

plt.subplots_adjust(
    left   = fig_left_margin,
    right  = fig_right_margin,
    bottom = 0.12,
    top    = 1.00,
)

fig.savefig(os.path.join(export_dir, 'figure-3C.png'), dpi=600)

### 🐌 Figure 3D

In [None]:
nrows = len(units)
subplot_height_in_inches = 412/600
top_margin_in_inches = 300/600
bottom_margin_in_inches = 180/600
hspace = 0.2

fig_height_in_inches, top_fraction, bottom_fraction = solve_figure_vertical_dimensions(
    nrows, subplot_height_in_inches, top_margin_in_inches, bottom_margin_in_inches, hspace)

fig, axes = plt.subplots(nrows, 1, sharex='col', figsize=(fig_width, fig_height_in_inches))

# specify the data subset to use
df = df_all.query('Food == "Regular nori"').drop(unreliable_inward_movement)

# print number of swallows per animal
print(df.groupby('Animal')['B6/B9 firing rate, video segmented interpolation (Hz)'].count())
print()

# use a time plot range appropriate for the unnormalized burst timing
xlim = xlim_unloaded_video_seg

# specify boundaries where vertical lines will be drawn, and the labels between them
phase_boundaries = median_unloaded_video_seg_phase_boundaries[4:6]
phase_labels = video_seg_phase_labels[4:6]

# specify times used for unnormalization
unnormalization_fixed_times = median_unloaded_video_seg_phase_boundaries

# use these time values for plotting interpolated functions in unnormalized time
interp_times_unnormalized = unnormalize_time(unnormalization_fixed_times, video_seg_interp_times)

# specify dataframe columns to plot
df_columns = {unit: f'{unit} firing rate, video segmented interpolation (Hz)' for unit in units}
# df_columns['Force'] = 'Force, video segmented interpolation (mN)'

###
### UNIT FREQUENCIES AND FORCE
###

for i, (label, column) in enumerate(df_columns.items()):
    ax = axes[i]

    # find the median and quartiles
    data   = np.stack(df[column])
    median = np.nanmedian(data, axis=0)
    q1     = np.nanquantile(data, q=0.25, axis=0)
    q3     = np.nanquantile(data, q=0.75, axis=0)

    # plot the median and quartiles (median last so it's on top)
    ax.plot(interp_times_unnormalized, q1,     c=unit_colors[label], lw=1, ls='--', zorder=2)
    ax.plot(interp_times_unnormalized, q3,     c=unit_colors[label], lw=1, ls='--', zorder=2)
    ax.plot(interp_times_unnormalized, median, c=unit_colors[label], lw=2, zorder=2)

###
### INWARD SEAWEED MOVEMENT
###

t0_data    = df['Inward movement start (video seg normalized)'].dropna()
t0_data[:] = unnormalize_time(unnormalization_fixed_times, t0_data.values)
t0_median  = t0_data.median()
t0_q1      = t0_data.quantile(0.25)
t0_q3      = t0_data.quantile(0.75)

t1_data    = df['Inward movement end (video seg normalized)'].dropna()
t1_data[:] = unnormalize_time(unnormalization_fixed_times, t1_data.values)
t1_median  = t1_data.median()
t1_q1      = t1_data.quantile(0.25)
t1_q3      = t1_data.quantile(0.75)

ax = axes[0]
y_inward = 1.3 # axes fractions above first panel

# plot label
ax.annotate(
    'Inward',
    xy=(0, y_inward), xycoords='axes fraction',
    xytext=(-10, 0), textcoords='offset points',
    ha='right', va='center',
)

# plot box using medians
height = 0.2
lw = 1
rect = patches.Rectangle((t0_median, y_inward-height/2), t1_median-t0_median, height, transform=ax.get_xaxis_transform(), facecolor='0.75', edgecolor='k', lw=lw, clip_on=False)
ax.add_patch(rect)

# # plot quartiles (25% and 75% quantiles)
# whisker_height = 0.1
# ax.add_line(mlines.Line2D([t0_q1, t0_q1], [y_inward-whisker_height/2, y_inward+whisker_height/2], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t0_q3, t0_q3], [y_inward-whisker_height/2, y_inward+whisker_height/2], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t1_q1, t1_q1], [y_inward-whisker_height/2, y_inward+whisker_height/2], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t1_q3, t1_q3], [y_inward-whisker_height/2, y_inward+whisker_height/2], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t0_q1, t0_q3], [y_inward, y_inward], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t1_q1, t1_q3], [y_inward, y_inward], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))

###
### LABELS AND ANNOTATIONS
###

ax = axes[-1]

# plot vertical lines
for t in phase_boundaries:
    ymax = nrows + (nrows-1)*hspace + (y_inward-1)*2
    ax.axvline(x=t, ymax=ymax, ls=':', lw=1, c='gray', zorder=-1, clip_on=False) # axvline's dotted line style is better than ConnectionPath's
for i in range(len(df_columns)-1):
    axes[i].set_zorder(1) # elevate axes so vertical lines are behind it
    axes[i].set_facecolor('none') # remove background so vertical line is visible

# plot horizontal grid lines
for i in range(len(df_columns)):
    axes[i].grid(axis='y', clip_on=False)

# add phase labels on bottom edge
for i in range(len(df_columns)):
    axes[i].tick_params(bottom=False) # disable tick marks
ax.set_xticks(phase_boundaries[:-1]+np.diff(phase_boundaries)/2)
ax.set_xticklabels(phase_labels, fontsize='small')

# add trace labels on left edge
for i, label in enumerate(df_columns):
    axes[i].set_ylabel(label, rotation='horizontal', ha='right', va='center', labelpad=10, fontsize='medium')

# add yticks on right edge
yticks = {
    'B38':   [0,  10,  20],
    'I2':    [0,  10,  20],
    'B8a/b': [0,  20,  40],
    'B6/B9': [0,  25,  50],
    'B3':    [0,   5,  10],
    'B4/B5': [0,  10,  20],
    'Force': [0, 150, 300],
}
for i, (label, column) in enumerate(df_columns.items()):
    axes[i].tick_params(right=True, labelright=True) # enable tick marks and labels
    axes[i].set_ylim([min(yticks[label]), max(yticks[label])])
    axes[i].set_yticks(yticks[label])
    yticklabels = [f'{y:g}' for y in axes[i].get_yticks()]
    yunits = column.split()[-1].strip('()')
    yticklabels[-1] += f' {yunits}' # append units to max label
    axes[i].set_yticklabels(yticklabels)

# add time scale bar
ax.add_artist(AnchoredScaleBar(
    ax.transData, sizex=1, labelx='1 s',
    loc='upper right', bbox_to_anchor=(1, 1), bbox_transform=ax.transAxes,
    pad=0, borderpad=0.5, sep=5, barwidth=2,
))

###
### FINISH
###

# set time range
ax.set_xlim(xlim)

# remove left and top plot borders
sns.despine(fig=fig, left=True, right=False)

plt.subplots_adjust(
    left   = fig_left_margin,
    right  = fig_right_margin,
    bottom = bottom_fraction,
    top    = top_fraction,
    hspace = hspace,
)

fig.savefig(os.path.join(export_dir, 'figure-3D.png'), dpi=600)

### 🐌 Figure 3E

In [None]:
nrows = len(units)+1
subplot_height_in_inches = 412/600
top_margin_in_inches = 300/600
bottom_margin_in_inches = 180/600
hspace = 0.2

fig_height_in_inches, top_fraction, bottom_fraction = solve_figure_vertical_dimensions(
    nrows, subplot_height_in_inches, top_margin_in_inches, bottom_margin_in_inches, hspace)

fig, axes = plt.subplots(nrows, 1, sharex='col', figsize=(fig_width, fig_height_in_inches))

# specify the data subset to use
df = df_all.query('Food == "Tape nori"')

# print number of swallows per animal
print(df.groupby('Animal')['B6/B9 firing rate, force segmented interpolation (Hz)'].count())
print()

# use a time plot range appropriate for the unnormalized burst timing
xlim = xlim_loaded_force_seg

# specify boundaries where vertical lines will be drawn, and the labels between them
phase_boundaries = median_loaded_force_seg_phase_boundaries[1:8]
phase_labels = force_seg_phase_labels[1:8]

# specify times used for unnormalization
unnormalization_fixed_times = median_loaded_force_seg_phase_boundaries

# use these time values for plotting interpolated functions in unnormalized time
interp_times_unnormalized = unnormalize_time(unnormalization_fixed_times, force_seg_interp_times)

# specify dataframe columns to plot
df_columns = {unit: f'{unit} firing rate, force segmented interpolation (Hz)' for unit in units}
df_columns['Force'] = 'Force, force segmented interpolation (mN)'

###
### UNIT FREQUENCIES AND FORCE
###

for i, (label, column) in enumerate(df_columns.items()):
    ax = axes[i]

    # find the median and quartiles
    data   = np.stack(df[column])
    median = np.nanmedian(data, axis=0)
    q1     = np.nanquantile(data, q=0.25, axis=0)
    q3     = np.nanquantile(data, q=0.75, axis=0)

    # plot the median and quartiles (median last so it's on top)
    ax.plot(interp_times_unnormalized, q1,     c=unit_colors[label], lw=1, ls='--', zorder=2)
    ax.plot(interp_times_unnormalized, q3,     c=unit_colors[label], lw=1, ls='--', zorder=2)
    ax.plot(interp_times_unnormalized, median, c=unit_colors[label], lw=2, zorder=2)

###
### INWARD SEAWEED MOVEMENT
###

t0_data    = df['Inward movement start (force seg normalized)'].dropna()
t0_data[:] = unnormalize_time(unnormalization_fixed_times, t0_data.values)
t0_median  = t0_data.median()
t0_q1      = t0_data.quantile(0.25)
t0_q3      = t0_data.quantile(0.75)

t1_data    = df['Inward movement end (force seg normalized)'].dropna()
t1_data[:] = unnormalize_time(unnormalization_fixed_times, t1_data.values)
t1_median  = t1_data.median()
t1_q1      = t1_data.quantile(0.25)
t1_q3      = t1_data.quantile(0.75)

ax = axes[0]
y_inward = 1.3 # axes fractions above first panel

# plot label
ax.annotate(
    'Inward',
    xy=(0, y_inward), xycoords='axes fraction',
    xytext=(-10, 0), textcoords='offset points',
    ha='right', va='center',
)

# plot box using medians
height = 0.2
lw = 1
rect = patches.Rectangle((t0_median, y_inward-height/2), t1_median-t0_median, height, transform=ax.get_xaxis_transform(), facecolor='0.75', edgecolor='k', lw=lw, clip_on=False)
ax.add_patch(rect)

# plot quartiles (25% and 75% quantiles)
whisker_height = 0.1
ax.add_line(mlines.Line2D([t0_q1, t0_q1], [y_inward-whisker_height/2, y_inward+whisker_height/2], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
ax.add_line(mlines.Line2D([t0_q3, t0_q3], [y_inward-whisker_height/2, y_inward+whisker_height/2], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
ax.add_line(mlines.Line2D([t1_q1, t1_q1], [y_inward-whisker_height/2, y_inward+whisker_height/2], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
ax.add_line(mlines.Line2D([t1_q3, t1_q3], [y_inward-whisker_height/2, y_inward+whisker_height/2], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
ax.add_line(mlines.Line2D([t0_q1, t0_q3], [y_inward, y_inward], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
ax.add_line(mlines.Line2D([t1_q1, t1_q3], [y_inward, y_inward], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))

###
### LABELS AND ANNOTATIONS
###

ax = axes[-1]

# plot vertical lines
for t in phase_boundaries:
    ymax = nrows + (nrows-1)*hspace + (y_inward-1)*2
    ax.axvline(x=t, ymax=ymax, ls=':', lw=1, c='gray', zorder=-1, clip_on=False) # axvline's dotted line style is better than ConnectionPath's
for i in range(len(df_columns)-1):
    axes[i].set_zorder(1) # elevate axes so vertical lines are behind it
    axes[i].set_facecolor('none') # remove background so vertical line is visible

# plot horizontal grid lines
for i in range(len(df_columns)):
    axes[i].grid(axis='y', clip_on=False)

# add phase labels on bottom edge
for i in range(len(df_columns)):
    axes[i].tick_params(bottom=False) # disable tick marks
ax.set_xticks(phase_boundaries[:-1]+np.diff(phase_boundaries)/2)
ax.set_xticklabels(phase_labels, fontsize='small')

# add trace labels on left edge
for i, label in enumerate(df_columns):
    axes[i].set_ylabel(label, rotation='horizontal', ha='right', va='center', labelpad=10, fontsize='medium')

# add yticks on right edge
yticks = {
    'B38':   [0,  10,  20],
    'I2':    [0,  10,  20],
    'B8a/b': [0,  20,  40],
    'B6/B9': [0,  25,  50],
    'B3':    [0,   5,  10],
    'B4/B5': [0,  10,  20],
    'Force': [0, 150, 300],
}
for i, (label, column) in enumerate(df_columns.items()):
    axes[i].tick_params(right=True, labelright=True) # enable tick marks and labels
    axes[i].set_ylim([min(yticks[label]), max(yticks[label])])
    axes[i].set_yticks(yticks[label])
    yticklabels = [f'{y:g}' for y in axes[i].get_yticks()]
    yunits = column.split()[-1].strip('()')
    yticklabels[-1] += f' {yunits}' # append units to max label
    axes[i].set_yticklabels(yticklabels)

# add time scale bar
ax.add_artist(AnchoredScaleBar(
    ax.transData, sizex=1, labelx='1 s',
    loc='upper right', bbox_to_anchor=(1, 1), bbox_transform=ax.transAxes,
    pad=0, borderpad=0.5, sep=5, barwidth=2,
))

###
### FINISH
###

# set time range
ax.set_xlim(xlim)

# remove left and top plot borders
sns.despine(fig=fig, left=True, right=False)

plt.subplots_adjust(
    left   = fig_left_margin,
    right  = fig_right_margin,
    bottom = bottom_fraction,
    top    = top_fraction,
    hspace = hspace,
)

fig.savefig(os.path.join(export_dir, 'figure-3E.png'), dpi=600)

### 🐌 Figure 3C (alt)

In [None]:
fig = plt.figure(figsize=(fig_width, 2.5))
ax = plt.gca()

# specify the data subset to use
df = df_all.query('Food == "Tape nori"')

# print number of swallows per animal
print(df.groupby('Animal')['B6/B9 burst end (video seg normalized)'].count())
print()

# use a time plot range appropriate for the unnormalized burst timing
xlim = xlim_loaded_video_seg

# specify boundaries where vertical lines will be drawn, and the labels between them
phase_boundaries = median_loaded_video_seg_phase_boundaries[4:6]
phase_labels = video_seg_phase_labels[4:6]

# specify times used for unnormalization
unnormalization_fixed_times = median_loaded_video_seg_phase_boundaries

###
### UNIT BOXES
###

for i, unit in enumerate(units):
    t0_data    = df[f'{unit} burst start (video seg normalized)'].dropna()
    t0_data[:] = unnormalize_time(unnormalization_fixed_times, t0_data.values)
    t0_median  = t0_data.median()
    t0_q1      = t0_data.quantile(0.25)
    t0_q3      = t0_data.quantile(0.75)
    
    t1_data    = df[f'{unit} burst end (video seg normalized)'].dropna()
    t1_data[:] = unnormalize_time(unnormalization_fixed_times, t1_data.values)
    t1_median  = t1_data.median()
    t1_q1      = t1_data.quantile(0.25)
    t1_q3      = t1_data.quantile(0.75)
    
    # print summaries of the unit timing
    print(f'{unit}:\t[{t0_median:.2f} (n={t0_data.size}), {t1_median:.2f} (n={t1_data.size})], duration: {t1_median-t0_median:.2f}')
    
    # plot boxes using medians
    height = 0.8
    lw = 1
    rect = patches.Rectangle((t0_median, i-height/2), t1_median-t0_median, height, facecolor=unit_colors[unit], edgecolor='k', lw=lw, clip_on=False)
    ax.add_patch(rect)
    
    # plot quartiles (25% and 75% quantiles)
    ax.add_line(mlines.Line2D([t0_q1, t0_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t0_q3, t0_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q1, t1_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q3, t1_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t0_q1, t0_q3], [i, i], color='k', lw=lw, clip_on=False))
    ax.add_line(mlines.Line2D([t1_q1, t1_q3], [i, i], color='k', lw=lw, clip_on=False))

###
### INWARD SEAWEED MOVEMENT
###

t0_data    = df['Inward movement start (video seg normalized)'].dropna()
t0_data[:] = unnormalize_time(unnormalization_fixed_times, t0_data.values)
t0_median  = t0_data.median()
t0_q1      = t0_data.quantile(0.25)
t0_q3      = t0_data.quantile(0.75)

t1_data    = df['Inward movement end (video seg normalized)'].dropna()
t1_data[:] = unnormalize_time(unnormalization_fixed_times, t1_data.values)
t1_median  = t1_data.median()
t1_q1      = t1_data.quantile(0.25)
t1_q3      = t1_data.quantile(0.75)

# print summaries of the unit timing
print(f'Inward:\t[{t0_median:.2f} (n={t0_data.size}), {t1_median:.2f} (n={t1_data.size})], duration: {t1_median-t0_median:.2f}')

# plot box using medians
i = len(units)
height = 0.8
lw = 1
rect = patches.Rectangle((t0_median, i-height/2), t1_median-t0_median, height, facecolor='0.75', edgecolor='k', lw=lw, clip_on=False)
ax.add_patch(rect)

# # plot quartiles (25% and 75% quantiles)
# ax.add_line(mlines.Line2D([t0_q1, t0_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t0_q3, t0_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t1_q1, t1_q1], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t1_q3, t1_q3], [i-0.2, i+0.2], color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t0_q1, t0_q3], [i, i], color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t1_q1, t1_q3], [i, i], color='k', lw=lw, clip_on=False))

###
### LABELS AND ANNOTATIONS
###

# plot vertical lines
for t in phase_boundaries:
    ax.axvline(x=t, ls=':', lw=1, c='gray', zorder=-1)

# add phase labels on bottom edge
ax.tick_params(bottom=False) # disable tick marks
ax.set_xticks(phase_boundaries[:-1]+np.diff(phase_boundaries)/2)
ax.set_xticklabels(phase_labels, fontsize='small')

# add box labels on left edge
ax.tick_params(left=False) # disable tick marks
ax.set_yticks(range(len(units)+1))
ax.set_yticklabels(units + ['Inward'], fontsize='medium')

# add time scale bar
ax.add_artist(AnchoredScaleBar(
    ax.transData, sizex=1, labelx='1 s',
    loc='lower right', bbox_to_anchor=(1, 0), bbox_transform=ax.transAxes,
    pad=0, borderpad=0.5, sep=5, barwidth=2,
))

###
### FINISH
###

# set plot ranges
ax.set_xlim(xlim)
ax.set_ylim(6.75, -0.75)

# remove box around figure
sns.despine(bottom=True, left=True)

plt.subplots_adjust(
    left   = fig_left_margin,
    right  = fig_right_margin,
    bottom = 0.12,
    top    = 1.00,
)

fig.savefig(os.path.join(export_dir, 'figure-3C-video-seg.png'), dpi=600)

### 🐌 Figure 3E (alt)

In [None]:
nrows = len(units)+1
subplot_height_in_inches = 412/600
top_margin_in_inches = 300/600
bottom_margin_in_inches = 180/600
hspace = 0.2

fig_height_in_inches, top_fraction, bottom_fraction = solve_figure_vertical_dimensions(
    nrows, subplot_height_in_inches, top_margin_in_inches, bottom_margin_in_inches, hspace)

fig, axes = plt.subplots(nrows, 1, sharex='col', figsize=(fig_width, fig_height_in_inches))

# specify the data subset to use
df = df_all.query('Food == "Tape nori"')

# print number of swallows per animal
print(df.groupby('Animal')['B6/B9 firing rate, video segmented interpolation (Hz)'].count())
print()

# use a time plot range appropriate for the unnormalized burst timing
xlim = xlim_loaded_video_seg

# specify boundaries where vertical lines will be drawn, and the labels between them
phase_boundaries = median_loaded_video_seg_phase_boundaries[4:6]
phase_labels = video_seg_phase_labels[4:6]

# specify times used for unnormalization
unnormalization_fixed_times = median_loaded_video_seg_phase_boundaries

# use these time values for plotting interpolated functions in unnormalized time
interp_times_unnormalized = unnormalize_time(unnormalization_fixed_times, video_seg_interp_times)

# specify dataframe columns to plot
df_columns = {unit: f'{unit} firing rate, video segmented interpolation (Hz)' for unit in units}
df_columns['Force'] = 'Force, video segmented interpolation (mN)'

###
### UNIT FREQUENCIES AND FORCE
###

for i, (label, column) in enumerate(df_columns.items()):
    ax = axes[i]

    # find the median and quartiles
    data   = np.stack(df[column])
    median = np.nanmedian(data, axis=0)
    q1     = np.nanquantile(data, q=0.25, axis=0)
    q3     = np.nanquantile(data, q=0.75, axis=0)

    # plot the median and quartiles (median last so it's on top)
    ax.plot(interp_times_unnormalized, q1,     c=unit_colors[label], lw=1, ls='--', zorder=2)
    ax.plot(interp_times_unnormalized, q3,     c=unit_colors[label], lw=1, ls='--', zorder=2)
    ax.plot(interp_times_unnormalized, median, c=unit_colors[label], lw=2, zorder=2)

###
### INWARD SEAWEED MOVEMENT
###

t0_data    = df['Inward movement start (video seg normalized)'].dropna()
t0_data[:] = unnormalize_time(unnormalization_fixed_times, t0_data.values)
t0_median  = t0_data.median()
t0_q1      = t0_data.quantile(0.25)
t0_q3      = t0_data.quantile(0.75)

t1_data    = df['Inward movement end (video seg normalized)'].dropna()
t1_data[:] = unnormalize_time(unnormalization_fixed_times, t1_data.values)
t1_median  = t1_data.median()
t1_q1      = t1_data.quantile(0.25)
t1_q3      = t1_data.quantile(0.75)

ax = axes[0]
y_inward = 1.3 # axes fractions above first panel

# plot label
ax.annotate(
    'Inward',
    xy=(0, y_inward), xycoords='axes fraction',
    xytext=(-10, 0), textcoords='offset points',
    ha='right', va='center',
)

# plot box using medians
height = 0.2
lw = 1
rect = patches.Rectangle((t0_median, y_inward-height/2), t1_median-t0_median, height, transform=ax.get_xaxis_transform(), facecolor='0.75', edgecolor='k', lw=lw, clip_on=False)
ax.add_patch(rect)

# # plot quartiles (25% and 75% quantiles)
# whisker_height = 0.1
# ax.add_line(mlines.Line2D([t0_q1, t0_q1], [y_inward-whisker_height/2, y_inward+whisker_height/2], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t0_q3, t0_q3], [y_inward-whisker_height/2, y_inward+whisker_height/2], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t1_q1, t1_q1], [y_inward-whisker_height/2, y_inward+whisker_height/2], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t1_q3, t1_q3], [y_inward-whisker_height/2, y_inward+whisker_height/2], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t0_q1, t0_q3], [y_inward, y_inward], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))
# ax.add_line(mlines.Line2D([t1_q1, t1_q3], [y_inward, y_inward], transform=ax.get_xaxis_transform(), color='k', lw=lw, clip_on=False))

###
### LABELS AND ANNOTATIONS
###

ax = axes[-1]

# plot vertical lines
for t in phase_boundaries:
    ymax = nrows + (nrows-1)*hspace + (y_inward-1)*2
    ax.axvline(x=t, ymax=ymax, ls=':', lw=1, c='gray', zorder=-1, clip_on=False) # axvline's dotted line style is better than ConnectionPath's
for i in range(len(df_columns)-1):
    axes[i].set_zorder(1) # elevate axes so vertical lines are behind it
    axes[i].set_facecolor('none') # remove background so vertical line is visible

# plot horizontal grid lines
for i in range(len(df_columns)):
    axes[i].grid(axis='y', clip_on=False)

# add phase labels on bottom edge
for i in range(len(df_columns)):
    axes[i].tick_params(bottom=False) # disable tick marks
ax.set_xticks(phase_boundaries[:-1]+np.diff(phase_boundaries)/2)
ax.set_xticklabels(phase_labels, fontsize='small')

# add trace labels on left edge
for i, label in enumerate(df_columns):
    axes[i].set_ylabel(label, rotation='horizontal', ha='right', va='center', labelpad=10, fontsize='medium')

# add yticks on right edge
yticks = {
    'B38':   [0,  10,  20],
    'I2':    [0,  10,  20],
    'B8a/b': [0,  20,  40],
    'B6/B9': [0,  25,  50],
    'B3':    [0,   5,  10],
    'B4/B5': [0,  10,  20],
    'Force': [0, 150, 300],
}
for i, (label, column) in enumerate(df_columns.items()):
    axes[i].tick_params(right=True, labelright=True) # enable tick marks and labels
    axes[i].set_ylim([min(yticks[label]), max(yticks[label])])
    axes[i].set_yticks(yticks[label])
    yticklabels = [f'{y:g}' for y in axes[i].get_yticks()]
    yunits = column.split()[-1].strip('()')
    yticklabels[-1] += f' {yunits}' # append units to max label
    axes[i].set_yticklabels(yticklabels)

# add time scale bar
ax.add_artist(AnchoredScaleBar(
    ax.transData, sizex=1, labelx='1 s',
    loc='upper right', bbox_to_anchor=(1, 1), bbox_transform=ax.transAxes,
    pad=0, borderpad=0.5, sep=5, barwidth=2,
))

###
### FINISH
###

# set time range
ax.set_xlim(xlim)

# remove left and top plot borders
sns.despine(fig=fig, left=True, right=False)

plt.subplots_adjust(
    left   = fig_left_margin,
    right  = fig_right_margin,
    bottom = bottom_fraction,
    top    = top_fraction,
    hspace = hspace,
)

fig.savefig(os.path.join(export_dir, 'figure-3E-video-seg.png'), dpi=600)

## [FIGURE 4]

### 🐌 Figure 4 Statistics

#### Burst duration

##### Protraction-phase Hotelling's T-squared

In [None]:
columns = ['B38 burst duration (s)', 'I2 burst duration (s)']

df = df_all.drop(bite_swallow_behaviors)
x = df.query('Food == "Regular nori"')[columns].groupby('Animal').mean().values
y = df.query('Food == "Tape nori"')[columns].groupby('Animal').mean().values
result = r_hotelling_T2_test(x-y)
print(f"T^2 = {result['T2']:.3f}, F = {result['F']:.3f}, df_num = {result['df_num']}, df_den = {result['df_den']}, p = {result['p']:.4f}")

"A paired-samples Hotelling’s T-squared test indicated no significant difference in durations of mean protraction-phase motor activity (mean B38 and mean I2 burst durations) in the five animals between loaded and unloaded swallows (T^2 = XXX, F = XXX, df_num = 2, df_den = 3, p = XXX)."

##### Retraction-phase Hotelling's T-squared

In [None]:
columns = ['B8a/b burst duration (s)', 'B3/B6/B9 burst duration (s)']

df = df_all.drop(bite_swallow_behaviors)
x = df.query('Food == "Regular nori"')[columns].groupby('Animal').mean().values
y = df.query('Food == "Tape nori"')[columns].groupby('Animal').mean().values
result = r_hotelling_T2_test(x-y)
print(f"T^2 = {result['T2']:.3f}, F = {result['F']:.3f}, df_num = {result['df_num']}, df_den = {result['df_den']}, p = {result['p']:.4f}")

"In contrast, a paired-samples Hotelling’s T-squared test indicated that durations of mean retraction-phase motor activity (mean B8a/b and mean B3/B6/B9 burst durations) in the five animals were significantly different between loaded and unloaded swallows (T^2 = XXX, F = XXX, df_num = 2, df_den = 3, p = XXX)."

##### All t-tests

In [None]:
df = df_all.drop(bite_swallow_behaviors)

columns = [
#     'B38 burst duration (s)',
#     'I2 burst duration (s)',
    'B8a/b burst duration (s)',
    'B3/B6/B9 burst duration (s)',
    'B4/B5 burst duration (s)',
]

print('NOTE: This table assumes by using t-tests that all Shapiro-Wilk tests conducted below were not significant!')
print()
print('\t\tUnloaded\t\t\tLoaded\t\t\t\tPaired t-test\tEffect size')
print('Unit\t\tMean\tSEM\tSTD\tCount\tMean\tSEM\tSTD\tCount\tt\tp\td')
for column in columns:
    x = df.query('Food == "Tape nori"')[column].groupby('Animal').mean()
    y = df.query('Food == "Regular nori"')[column].groupby('Animal').mean()

    ttest_result = r_t_test(x.values, y.values, paired=True, alternative='greater')
    ttest_t, ttest_p = ttest_result['t'], ttest_result['p']
    
    cohen_d = r_effect_size(x.values, y.values)['estimate']

    print(f'{column.split()[0].ljust(8)}\t{y.mean():.2f}\t{y.sem():.2f}\t{y.std():.2f}\t{y.count()}\t{x.mean():.2f}\t{x.sem():.2f}\t{x.std():.2f}\t{x.count()}\t{ttest_t:.3f}\t{ttest_p:.3f}\t{cohen_d:.2f}')

"Post hoc paired-samples one-tailed t-tests indicated that both mean B8a/b burst duration (t = XXX, df = 4, p = XXX, Cohen's d = XXX) and mean B3/B6/B9 burst duration (t = XXX, df = 4, p = XXX, Cohen's d = XXX) were significantly longer for loaded swallows."

"Separately, a paired-samples one-tailed t-test indicated that the mean burst duration of multi-action neurons B4/B5 was also significantly longer for loaded swallows (t = XXX, df = 4, p = XXX, Cohen's d = XXX)."

#### Burst mean frequency

##### Protraction-phase Hotelling's T-squared

In [None]:
columns = ['B38 burst mean frequency (Hz)', 'I2 burst mean frequency (Hz)']

df = df_all.drop(bite_swallow_behaviors)
x = df.query('Food == "Regular nori"')[columns].groupby('Animal').mean().values
y = df.query('Food == "Tape nori"')[columns].groupby('Animal').mean().values
result = r_hotelling_T2_test(x-y)
print(f"T^2 = {result['T2']:.3f}, F = {result['F']:.3f}, df_num = {result['df_num']}, df_den = {result['df_den']}, p = {result['p']:.4f}")

"When a similar analysis was applied to the mean firing rates of the protraction-phase motor activity (mean B38 and mean I2 burst firing frequencies), no significant difference was found between loaded and unloaded swallows (T^2 = XXX, F = XXX, df_num = 2, df_den = 3, p = XXX)."

##### Retraction-phase motor units

In [None]:
columns = ['B8a/b burst mean frequency (Hz)', 'B3/B6/B9 burst mean frequency (Hz)']

df = df_all.drop(bite_swallow_behaviors)
x = df.query('Food == "Regular nori"')[columns].groupby('Animal').mean().values
y = df.query('Food == "Tape nori"')[columns].groupby('Animal').mean().values
result = r_hotelling_T2_test(x-y)
print(f"T^2 = {result['T2']:.3f}, F = {result['F']:.3f}, df_num = {result['df_num']}, df_den = {result['df_den']}, p = {result['p']:.4f}")

"In contrast, the mean firing rates of the retraction-phase motor activity (mean B8a/b and mean B3/B6/B9 burst firing frequencies) were significantly different between loaded and unloaded swallows (T^2 = XXX, F = XXX, df_num = 2, df_den = 3, p = XXX)."

##### Post hoc t-tests

In [None]:
df = df_all.drop(bite_swallow_behaviors)

columns = [
#     'B38 burst mean frequency (Hz)',
#     'I2 burst mean frequency (Hz)',
    'B8a/b burst mean frequency (Hz)',
    'B3/B6/B9 burst mean frequency (Hz)',
#     'B4/B5 burst mean frequency (Hz)',
]

print('NOTE: This table assumes by using t-tests that all Shapiro-Wilk tests conducted below were not significant!')
print()
print('\t\tUnloaded\t\t\tLoaded\t\t\t\tPaired t-test\tEffect size')
print('Unit\t\tMean\tSEM\tSTD\tCount\tMean\tSEM\tSTD\tCount\tt\tp\td')
for column in columns:
    x = df.query('Food == "Tape nori"')[column].groupby('Animal').mean()
    y = df.query('Food == "Regular nori"')[column].groupby('Animal').mean()

    ttest_result = r_t_test(x.values, y.values, paired=True, alternative='greater')
    ttest_t, ttest_p = ttest_result['t'], ttest_result['p']
    
    cohen_d = r_effect_size(x.values, y.values)['estimate']

    print(f'{column.split()[0].ljust(8)}\t{y.mean():.2f}\t{y.sem():.2f}\t{y.std():.2f}\t{y.count()}\t{x.mean():.2f}\t{x.sem():.2f}\t{x.std():.2f}\t{x.count()}\t{ttest_t:.3f}\t{ttest_p:.3f}\t{cohen_d:.2f}')

"Post hoc paired-samples one-tailed t-tests indicated that mean B3/B6/B9 firing rate (t = XXX, df = 4, p = XXX, Cohen's d = XXX) but not B8a/b firing rate (t = XXX, df = 4, p = XXX, Cohen's d = XXX) was significantly greater for loaded swallows."

#### B4/B5 Wilcoxon signed-rank test

In [None]:
column = 'B4/B5 burst mean frequency (Hz)'

df = df_all.drop(bite_swallow_behaviors)
df = df.reset_index()
x = df.query('Food == "Tape nori"').groupby('Animal')[column].mean().values
y = df.query('Food == "Regular nori"').groupby('Animal')[column].mean().values

result = r_shapiro_test(x-y)
print(f"Shapiro-Wilk: W = {result['W']:g}, p = {result['p']:.4f}")

result = r_wilcoxon_test(x, y, paired=True, alternative='greater')
print(f"Wilcoxon: W = {result['W']:g}, p = {result['p']:.4f}")

"Finally, a separate paired-samples one-tailed Wilcoxon signed-rank test (conducted because the normality assumption of the t-test was not satisfied) indicated that the mean firing rate of multi-action neurons B4/B5 was not significantly greater for loaded swallows (W = XXX, p = XXX)."

### 🐌 Figure 4A

In [None]:
y = 'B38 burst duration (s)'
color = unit_colors['B38']

fig, ax = plt.subplots(1, 1, figsize=(2.25, 2.5))
plot_unloaded_vs_loaded(df_all.drop(bite_swallow_behaviors), y, ax, color, bracket_width=3.2)
ax.set_ylim([0, 5])
ax.set_ylabel(y)
ax.yaxis.set_label_coords(-0.2, 0.5)
plt.subplots_adjust(left=0.25, right=0.99, top=0.89, bottom=0.11)
fig.savefig(os.path.join(export_dir, 'figure-4A.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(bite_swallow_behaviors).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

### 🐌 Figure 4B

In [None]:
y = 'I2 burst duration (s)'
color = unit_colors['I2']

fig, ax = plt.subplots(1, 1, figsize=(2.25, 2.5))
plot_unloaded_vs_loaded(df_all.drop(bite_swallow_behaviors), y, ax, color, bracket_width=3.2)
ax.set_ylim([0, 5])
ax.set_ylabel(y)
ax.yaxis.set_label_coords(-0.2, 0.5)
plt.subplots_adjust(left=0.25, right=0.99, top=0.89, bottom=0.11)
fig.savefig(os.path.join(export_dir, 'figure-4B.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(bite_swallow_behaviors).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

### 🐌 Figure 4C

In [None]:
y = 'B8a/b burst duration (s)'
color = unit_colors['B8a/b']

fig, ax = plt.subplots(1, 1, figsize=(2.25, 2.5))
plot_unloaded_vs_loaded(df_all.drop(bite_swallow_behaviors), y, ax, color, bracket_width=3.2)
ax.set_ylim([0, 5])
ax.set_ylabel(y)
ax.yaxis.set_label_coords(-0.2, 0.5)
plt.subplots_adjust(left=0.25, right=0.99, top=0.89, bottom=0.11)
fig.savefig(os.path.join(export_dir, 'figure-4C.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(bite_swallow_behaviors).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

### 🐌 Figure 4D

In [None]:
y = 'B3/B6/B9 burst duration (s)'
color = unit_colors['B3/B6/B9']

fig, ax = plt.subplots(1, 1, figsize=(2.25, 2.5))
plot_unloaded_vs_loaded(df_all.drop(bite_swallow_behaviors), y, ax, color, bracket_width=3.2)
ax.set_ylim([0, 5])
ax.set_ylabel(y)
ax.yaxis.set_label_coords(-0.2, 0.5)
plt.subplots_adjust(left=0.25, right=0.99, top=0.89, bottom=0.11)
fig.savefig(os.path.join(export_dir, 'figure-4D.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(bite_swallow_behaviors).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

### 🐌 Figure 4E

In [None]:
y = 'B4/B5 burst duration (s)'
color = unit_colors['B4/B5']

fig, ax = plt.subplots(1, 1, figsize=(2.25, 2.5))
plot_unloaded_vs_loaded(df_all.drop(bite_swallow_behaviors), y, ax, color, bracket_width=3.2)
ax.set_ylim([0, 5])
ax.set_ylabel(y)
ax.yaxis.set_label_coords(-0.2, 0.5)
plt.subplots_adjust(left=0.25, right=0.99, top=0.89, bottom=0.11)
fig.savefig(os.path.join(export_dir, 'figure-4E.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(bite_swallow_behaviors).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

### 🐌 Figure 4F

In [None]:
y = 'B3/B6/B9 burst mean frequency (Hz)'
color = unit_colors['B3/B6/B9']

fig, ax = plt.subplots(1, 1, figsize=(2.25, 2.5))
plot_unloaded_vs_loaded(df_all.drop(bite_swallow_behaviors), y, ax, color, bracket_width=3.2)
ax.set_ylim([0, 50])
ax.set_ylabel('B3/B6/B9 firing rate (Hz)')
ax.yaxis.set_label_coords(-0.2, 0.5)
plt.subplots_adjust(left=0.25, right=0.99, top=0.89, bottom=0.11)
fig.savefig(os.path.join(export_dir, 'figure-4F.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(bite_swallow_behaviors).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

### 🐌 Other frequencies (not plotted in manuscript)

In [None]:
y = 'B38 burst mean frequency (Hz)'
color = unit_colors['B38']

fig, ax = plt.subplots(1, 1, figsize=(2.25, 2.5))
plot_unloaded_vs_loaded(df_all.drop(bite_swallow_behaviors), y, ax, color, bracket_width=3.2)
ax.set_ylim([0, 50])
ax.set_ylabel('B38 firing rate (Hz)')
ax.yaxis.set_label_coords(-0.2, 0.5)
plt.subplots_adjust(left=0.25, right=0.99, top=0.89, bottom=0.11)
# fig.savefig(os.path.join(export_dir, 'figure-4F.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(bite_swallow_behaviors).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

In [None]:
y = 'I2 burst mean frequency (Hz)'
color = unit_colors['I2']

fig, ax = plt.subplots(1, 1, figsize=(2.25, 2.5))
plot_unloaded_vs_loaded(df_all.drop(bite_swallow_behaviors), y, ax, color, bracket_width=3.2)
ax.set_ylim([0, 50])
ax.set_ylabel('I2 firing rate (Hz)')
ax.yaxis.set_label_coords(-0.2, 0.5)
plt.subplots_adjust(left=0.25, right=0.99, top=0.89, bottom=0.11)
# fig.savefig(os.path.join(export_dir, 'figure-4F.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(bite_swallow_behaviors).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

In [None]:
y = 'B8a/b burst mean frequency (Hz)'
color = unit_colors['B8a/b']

fig, ax = plt.subplots(1, 1, figsize=(2.25, 2.5))
plot_unloaded_vs_loaded(df_all.drop(bite_swallow_behaviors), y, ax, color, bracket_width=3.2)
ax.set_ylim([0, 50])
ax.set_ylabel('B8a/b firing rate (Hz)')
ax.yaxis.set_label_coords(-0.2, 0.5)
plt.subplots_adjust(left=0.25, right=0.99, top=0.89, bottom=0.11)
# fig.savefig(os.path.join(export_dir, 'figure-4F.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(bite_swallow_behaviors).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

In [None]:
y = 'B4/B5 burst mean frequency (Hz)'
color = unit_colors['B4/B5']

fig, ax = plt.subplots(1, 1, figsize=(2.25, 2.5))
plot_unloaded_vs_loaded(df_all.drop(bite_swallow_behaviors), y, ax, color, bracket_width=3.2)
ax.set_ylim([0, 50])
ax.set_ylabel('B4/B5 firing rate (Hz)')
ax.yaxis.set_label_coords(-0.2, 0.5)
plt.subplots_adjust(left=0.25, right=0.99, top=0.89, bottom=0.11)
# fig.savefig(os.path.join(export_dir, 'figure-4F.png'), dpi=300)

# # plot by animal
# fig, ax = plt.subplots(1, 1, figsize=(4, 4))
# boxplot_with_points('Animal', y, 'Food', df_all.drop(bite_swallow_behaviors).reset_index(), ax, show_points=True)
# ax.set_ylim([0, None])
# plt.tight_layout()

## [FIGURE 5]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 5]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 5]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
ax.legend()
sns.despine(ax=ax, trim=True)
plt.tight_layout()

fig.savefig(os.path.join(export_dir, 'figure-5-all-animals.png'), dpi=300)

### 🐌 Figure 5A

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 3))

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 5]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 5]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, tooltips=False, colors=['k']*5, markers=['.']*5)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
sns.despine(ax=ax, trim=True)
plt.tight_layout()

fig.savefig(os.path.join(export_dir, 'figure-5A.png'), dpi=300)

### 🐌 Figure 5B

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 3))

data_subsets = [
    'JG07 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 5]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 5]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend_separately=True, tooltips=False, colors=['k'])
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
sns.despine(ax=ax, trim=True)
plt.tight_layout()

fig.savefig(os.path.join(export_dir, 'figure-5B.png'), dpi=300)

### 🐌 Figure 5C

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 3))

data_subsets = [
    'JG08 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 5]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 5]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend_separately=True, tooltips=False, colors=['k'])
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
sns.despine(ax=ax, trim=True)
plt.tight_layout()

fig.savefig(os.path.join(export_dir, 'figure-5C.png'), dpi=300)

### 🐌 Figure 5D

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 3))

data_subsets = [
    'JG11 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 5]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 5]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend_separately=True, tooltips=False, colors=['k'])
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
sns.despine(ax=ax, trim=True)
plt.tight_layout()

fig.savefig(os.path.join(export_dir, 'figure-5D.png'), dpi=300)

### 🐌 Figure 5E

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 3))

data_subsets = [
    'JG12 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 5]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 5]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend_separately=True, tooltips=False, colors=['k'])
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
sns.despine(ax=ax, trim=True)
plt.tight_layout()

fig.savefig(os.path.join(export_dir, 'figure-5E.png'), dpi=300)

### 🐌 Figure 5F

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 3))

data_subsets = [
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B3/B6/B9 burst duration (s)', [0, 5]
ylabel, ylabel_alt, ylim = 'Force plateau duration (s)', 'Force maintenance duration (s)', [0, 5]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend_separately=True, tooltips=False, colors=['k'])
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
plt.ylabel(ylabel_alt)
sns.despine(ax=ax, trim=True)
plt.tight_layout()

fig.savefig(os.path.join(export_dir, 'figure-5F.png'), dpi=300)

## [STATISTICAL TABLES]

In [None]:
def stats_table(figure_map, df):

    df = df.reset_index()
    df.loc[df['Food'] == 'Regular nori', 'Food'] = 'Unloaded'
    df.loc[df['Food'] == 'Tape nori',    'Food'] = 'Loaded'
    df.loc[df['Animal'] == 'JG07', 'Animal'] = 1
    df.loc[df['Animal'] == 'JG08', 'Animal'] = 2
    df.loc[df['Animal'] == 'JG11', 'Animal'] = 3
    df.loc[df['Animal'] == 'JG12', 'Animal'] = 4
    df.loc[df['Animal'] == 'JG14', 'Animal'] = 5
        
    for fig, columns in figure_map.items():

        x = df.query('Food == "Loaded"').groupby('Animal')[columns].mean()
        y = df.query('Food == "Unloaded"').groupby('Animal')[columns].mean()

        print('-----------------------------------------')
        print(f'FIGURE {fig}: {columns}')
        print()

        if isinstance(columns, list):
            # OMNIBUS TEST

            hotelling_result = r_hotelling_T2_test(x.values-y.values)
            print(f"Hotelling's T-squared, T^2 = {hotelling_result['T2']:.3f}, F({hotelling_result['df_num']},{hotelling_result['df_den']}) = {hotelling_result['F']:.3f}, p = {hotelling_result['p']:.3f} {'(n.s.)' if hotelling_result['p'] > 0.05 else '(sig.)'}")

        else:
            # POST-HOC TEST

#             differences_test(x, y)

            print(df.groupby(['Animal', 'Food'])[columns].apply(lambda x: {
                'Mean ± SEM (N)': f'{x.mean():.2f} ± {x.sem():.2f} ({x.count()})'.ljust(18),
            }).unstack([1, 2])[['Unloaded', 'Loaded']])
            print()

            print(f"Difference of means, Mean ± SEM (N): {(x-y).mean():.2f} ± {(x-y).sem():.2f} ({(x-y).count()})")
            percent_change = ((x-y)/y).replace([-np.inf, np.inf], np.nan).dropna()
            print(f"Proportional change, Mean ± SEM (N): {100*percent_change.mean():.0f}% ± {100*percent_change.sem():.0f}% ({percent_change.count()})")
            print()

            shapiro_result = r_shapiro_test(x.values-y.values)
            if shapiro_result['p'] > 0.05:
                print(f"Shapiro-Wilk, W = {shapiro_result['W']:.2f}, p = {shapiro_result['p']:.2f} (n.s.)")
            else:
                print(f"Shapiro-Wilk, W = {shapiro_result['W']:.2f}, p = {shapiro_result['p']:.3f} (sig.)")

            if shapiro_result['p'] > 0.05:
                ttest_result = r_t_test(x.values, y.values, paired=True, alternative='greater')
                print(f"Paired t-test, t({ttest_result['df']:g}) = {ttest_result['t']:.3f}, p = {ttest_result['p']:.3f} {'(n.s.)' if ttest_result['p'] > 0.05 else '(sig.)'}")
            else:
                wilcoxon_result = r_wilcoxon_test(x.values, y.values, paired=True, alternative='greater')
                print(f"Paired Wilcoxon signed rank, W = {wilcoxon_result['W']:g}, p = {wilcoxon_result['p']:.2f} {'(n.s.)' if wilcoxon_result['p'] > 0.05 else '(sig.)'}")

            cohen_d = r_effect_size(x.values, y.values)['estimate']
            print(f"Cohen's d = {cohen_d:.2f}")

        print()

### 🐌 Table 2

In [None]:
figure_map = {
    '2C': 'Inward movement start to next inward movement start (s)',
    '2D': 'Inward movement duration (s)',
    '2E': 'Inward movement end to next inward movement start (s)',
}

stats_table(figure_map, df_all.drop(unreliable_inward_movement))

### 🐌 Table 3

In [None]:
figure_map = {
    '4A+4B': ['B38 burst duration (s)', 'I2 burst duration (s)'],
    '4A':    'B38 burst duration (s)',
    '4B':    'I2 burst duration (s)',
    '4C+4D': ['B8a/b burst duration (s)', 'B3/B6/B9 burst duration (s)'],
    '4C':    'B8a/b burst duration (s)',
    '4D':    'B3/B6/B9 burst duration (s)',
    '4E':    'B4/B5 burst duration (s)',
}

stats_table(figure_map, df_all.drop(bite_swallow_behaviors))

### 🐌 Table 4

In [None]:
figure_map = {
    'X1+X2': ['B38 burst mean frequency (Hz)', 'I2 burst mean frequency (Hz)'],
    'X1':    'B38 burst mean frequency (Hz)',
    'X2':    'I2 burst mean frequency (Hz)',
    'X3+4F': ['B8a/b burst mean frequency (Hz)', 'B3/B6/B9 burst mean frequency (Hz)'],
    'X3':    'B8a/b burst mean frequency (Hz)',
    '4F':    'B3/B6/B9 burst mean frequency (Hz)',
    'X4':    'B4/B5 burst mean frequency (Hz)',
}

stats_table(figure_map, df_all.drop(bite_swallow_behaviors))

---

## Random old figures

In [None]:
df_all['Delay from I2 end to force rise start (s)'] = \
    df_all['Force rise start start (s)'] - \
    df_all['I2 burst end (s)']

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from I2 end to force min (s)'
column = 'Delay from I2 end to force rise start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
df_all['Delay from B3/B6/B9 start to force rise start (s)'] = \
    df_all['Force rise start start (s)'] - \
    df_all[['B6/B9 burst start (s)', 'B3 burst start (s)']].min(axis=1)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from B3/B6/B9 start to force start (s)'
column = 'Delay from B3/B6/B9 start to force rise start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
df_all['Delay from B8a/b start to force rise start (s)'] = \
    df_all['Force rise start start (s)'] - \
    df_all['B8a/b burst start (s)']

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from B8a/b start to force start (s)'
column = 'Delay from B8a/b start to force rise start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B8 activity duration (s)', [0, 8]
xlabel, xlim = 'B8a/b burst duration (s)', [0, 8]
ylabel, ylim = 'Force rise and plateau duration (s)', [0, 8]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
df_all['Delay from B3/B6/B9 start to force plateau start (s)'] = \
    df_all['Force plateau start start (s)'] - \
    df_all[['B6/B9 burst start (s)', 'B3 burst start (s)']].min(axis=1)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from B3/B6/B9 start to force 50%-height (s)'
column = 'Delay from B3/B6/B9 start to force plateau start (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
df_all['Delay from either B8a/b or B3/B6/B9 end to force plateau end (s)'] = \
    df_all['Force plateau end start (s)'] - \
    pd.concat([
        df_all['B8a/b burst end (s)'],
        df_all[['B6/B9 burst end (s)', 'B3 burst end (s)']].max(axis=1)
    ], axis=1).min(axis=1)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from either B8a/b or B3/B6/B9 end to force 80%-height end (s)'
column = 'Delay from either B8a/b or B3/B6/B9 end to force plateau end (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'B3/6/9/10 activity duration (s)', [0, 8]
xlabel, xlim = 'B6/B9 burst duration (s)', [0, 8]
ylabel, ylim = 'Force rise and plateau duration (s)', [0, 8]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
df_all['Delay from B3 start to force plateau start (s)'] = \
    df_all['Force plateau start start (s)'] - \
    df_all['B3 burst start (s)']

df_all['Delay from B3 end to force plateau end (s)'] = \
    df_all['Force plateau end start (s)'] - \
    df_all['B3 burst end (s)']

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
df = df.rename(columns={
#     'Delay from B3 start to force 80%-height start (s)': 'Start of burst',
#     'Delay from B3 end to force 80%-height end (s)':     'End of burst'})
    'Delay from B3 start to force plateau start (s)': 'Start of burst',
    'Delay from B3 end to force plateau end (s)':     'End of burst'})
df = pd.melt(df,
             id_vars=['Animal', 'Food', 'Bout_index', 'Behavior_index'],
             value_vars=['Start of burst', 'End of burst'],
             var_name='When?',
             value_name='Delay from B3 to plateau (s)')
sns.boxplot(hue='Animal', x='Delay from B3 to plateau (s)', y='When?', data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(hue='Animal', x='Delay from B3 to plateau (s)', y='When?', data=df)

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

In [None]:
df_all['Delay from B38 end to force shoulder end (s)'] = \
    df_all['Force shoulder end start (s)'] - \
    df_all['B38 burst end (s)']

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
# column = 'Delay from B38 end to shoulder end (s)'
column = 'Delay from B38 end to force shoulder end (s)'

sns.boxplot(y='Animal', x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
sns.swarmplot(y='Animal', x=column, data=df, color='0.25')

# sns.boxplot(x=column, data=df, fliersize=0, whis=999) # whis=999 ensures whiskers go to min and max
# sns.swarmplot(x=column, data=df, color='0.25')

# plot zero line
plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

# plt.gcf().savefig(os.path.join(export_dir, 'figure-X.png'), dpi=300)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG07 Tape nori',
    'JG08 Tape nori',
    'JG11 Tape nori',
    'JG12 Tape nori',
    'JG14 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'Force rise duration (s)', [0, None]
ylabel, ylim = 'Force rise increase (mN)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, trend_separately=True, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

## Model stuff

In [None]:
from utils import CausalAlphaKernel

swallow_id = ('JG07', 'Tape nori', 0, 0)

(data_set_name, time_window) = feeding_bouts[swallow_id[:3]]

# load the data
metadata = neurotic.MetadataSelector('../../data/metadata.yml')
metadata.select(data_set_name)
blk = neurotic.load_dataset(metadata, lazy=True)
sig = get_sig(blk, 'Force')
sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
sig = sig.rescale('mN')
sig = elephant.signal_processing.butter(sig, lowpass_freq = 10*pq.Hz)

st_b6b9 = df_all.loc[swallow_id]['B6/B9 spike train']
st_b3 = df_all.loc[swallow_id]['B3 spike train']

weight_b6b9, tau_b6b9 = 1.75, 1
weight_b3,   tau_b3   = 1.75, 0.2
model_scale           = 100
model_baseline        = 85
u_to_y_constant       = 0.005

rate_b6b9 = elephant.statistics.instantaneous_rate(
    spiketrain=st_b6b9,
    sampling_period=0.0002*pq.s,
    kernel=CausalAlphaKernel(tau_b6b9*pq.s),
)

rate_b3 = elephant.statistics.instantaneous_rate(
    spiketrain=st_b3,
    sampling_period=0.0002*pq.s,
    kernel=CausalAlphaKernel(tau_b3*pq.s),
)

# derivation and assumptions of this *very* crude and simple model:
# - force (which is approximately isometric) is a linear function
#   of grasper position, x
# - the grasper is always spherical, and therefore its position is
#   related to the I1/I3 torus contact altitude, y, by the equation
#   for a circle, x^2 + y^2 = r, where r is the grasper radius
#   (here r=1 with arbitrary units)
# - the contact altitude y is approimately equivalent to the major
#   radius of the torus (i.e., the minor radius is negligible)
# - the major radius of the I1/I3 torus decreases from its max radius,
#   ymax (here ymax=1 with arbitrary units), to its min radius (here
#   set to 0) as muscle activation, u, increases
#     - this may be modeled as an asymptotical approach from ymax to 0,
#       e.g., y = ymax*exp(-c*u), or as a piecewise linear function,
#       e.g., y = ymax-c*u with floor 0 and ceiling ymax
# - the muscle activation u is a weighted sum of the synaptic potentials
#   generated by the relevant motor neurons, modeled as alpha functions
#   with fixed time constants and size (i.e., changes in size due to
#   changing driving force as the muscle depolarizes are ignored)
u = rate_total = rate_b6b9 * weight_b6b9 + rate_b3 * weight_b3
u = np.clip(u.magnitude.flatten(), 0, None) # replace with 0 any negative values (caused by numerical imprecision)
# y = np.exp(-u_to_y_constant*u)
y = np.clip(1-u_to_y_constant*u, 0, 1)
x = np.sqrt(1-y**2)
model_force = x * model_scale + model_baseline


plt.figure(figsize=(8,4))
plt.plot(rate_total.times.rescale('s'), rate_total, label='Sum of currents')
plt.plot(rate_total.times.rescale('s'), model_force, label='Force model')
plt.plot(sig.times.rescale('s'), sig.magnitude, color='0.75', zorder=-1, label='Experimental force')
plt.xlim([df_all.loc[swallow_id]['Start (s)'], df_all.loc[swallow_id]['End (s)']])

plt.title(f'Scale: {model_scale} | Baseline: {model_baseline} | B6/B9: ({weight_b6b9}, {tau_b6b9}) | B3: ({weight_b3}, {tau_b3})')
plt.xlabel('Time (s)')
plt.ylabel('Force (mN)')
plt.legend()
plt.tight_layout()

export_dir4 = os.path.join(export_dir, 'firing-rate-models')
if not os.path.exists(export_dir4):
    os.mkdir(export_dir4)
plt.gcf().savefig(os.path.join(export_dir4, f'S {model_scale} BL {model_baseline} B6B9 {weight_b6b9} {tau_b6b9} B3 {weight_b3} {tau_b3}.png'), dpi=300)

## Monty Python

In [None]:
# font obtained from: https://www.fontspace.com/mentor-type/goudy-medieval
# after installing the font, had to delete the cached fontlist JSON files in C:\Users\<name>\.matplotlib

movement_time_per_strip = df_all.query('Food == "Regular nori"')['Inward movement duration (s)'].groupby(['Animal', 'Bout_index']).sum()
velocity = (5*pq.cm)/(movement_time_per_strip.values*pq.s)

plt.figure(figsize=(4.5,5))
sns.set(style = 'ticks', font_scale=1.8, font='GoudyMedieval')
sns.boxplot(velocity, orient='v', color='#71a45a')
sns.swarmplot(velocity, orient='v', color='k', size=6)
plt.ylim(0, None)
plt.title('Seaweed Velocity of\nUnladen Swallows')
plt.ylabel('Mean Inward Velocity (cm/s)')
plt.gca().tick_params(bottom=False) # disable tick marks
plt.tight_layout()
plt.gcf().savefig(os.path.join(export_dir, 'seaweed-velocity-of-unladen-swallows.png'), dpi=600)

In [None]:
velocity.size

In [None]:
velocity.mean()

In [None]:
velocity.std()