# Import Packages

In [None]:
import numpy as np
import quantities as pq
import elephant
import pandas as pd
import statsmodels.api as sm
import neurotic
from neurotic.datasets.data import _detect_spikes
from utils import BehaviorsDataFrame

pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns

# IPython Magics

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

# Plot Settings

In [None]:
# color map
cm = plt.cm.cool
# cm = plt.cm.brg
# cm = plt.cm.RdBu

sns.set(
#     context = 'poster',
    style = 'ticks',
    font_scale = 1,
    font = 'Palatino Linotype',
)

In [None]:
markers = ['.', '+', 'x', '1', '4', 'o', 'D']
colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6']

# Data Parameters

In [None]:
feeding_bouts = {
    # (animal, food, bout_index): (
    #     data_set_name,
    #     channels_to_keep,
    #     time_window,
    #     epoch_types_to_keep)
    
    ('JG08', 'Tape nori', 0): (
        'IN VIVO / JG08 / 2018-06-21 / 002',
        ['I2', 'RN', 'BN2', 'BN3', 'Force'],
        [148, 208], # 7 swallows, some bucket and head movement
        ['Swallow (tape nori)']),

    ('JG12', 'Tape nori', 0): (
        'IN VIVO / JG12 / 2019-05-10 / 002',
        ['I2', 'RN', 'BN2', 'BN3-DIST', 'Force'],
        [437, 465], # 4 swallows
        ['Swallow (tape nori)']),
    
    ('JG12', 'Tape nori', 1): (
        'IN VIVO / JG12 / 2019-05-10 / 002',
        ['I2', 'RN', 'BN2', 'BN3-DIST', 'Force'],
        [2901, 2937], # 5 swallows
        ['Swallow (tape nori)']),
}

# Import and Process the Data

In [None]:
def query_union(queries):
    return '|'.join([f'({q})' for q in queries if q is not None])

def label2query(label):
    animal, food = label[:4], label[5:]
    query = f'(Animal == "{animal}") & (Food == "{food}")'
    return query

def contains(series, string):
#     return np.array([string in x for x in series])
    return series.map(lambda x: string in x)

In [None]:
# use Neo RawIO lazy loading to load much faster and using less memory
# - with lazy=True, filtering parameters specified in metadata are ignored
# - with lazy=True, loading via time_slice requires neo>=0.8.0.dev
lazy = True # IMPORTANT: force and I2 filters affect smoothness and possibly threshold crossings and spike detection

# load the metadata containing file paths
metadata = neurotic.MetadataSelector(file='../../data/metadata.yml')

# filter epochs for each bout and perform calculations
df_list = []
last_data_set_name = None
for (animal, food, bout_index), (data_set_name, channels_to_keep, time_window, epoch_types_to_keep) in feeding_bouts.items():
    
    ###
    ### LOCATE BEHAVIOR EPOCHS AND SUBEPOCHS
    ###
    
    metadata.select(data_set_name)
    if data_set_name is last_data_set_name:
        # skip reloading the data if it's already in memory
        pass
    else:
        blk = neurotic.load_dataset(metadata, lazy=lazy)
    last_data_set_name = data_set_name
    
    # construct a query for locating behaviors
    behavior_query = f'(Type in {epoch_types_to_keep}) & ({time_window[0]} <= Start) & (End <= {time_window[1]})'
    
    # construct queries for locating epochs associated with each behavior
    # - each query should match at most one epoch
    # - dictionary keys are used as prefixes for the names of new columns
    subepoch_queries = {}
    
    subepoch_queries['I2 protraction activity'] = f'(Type == "I2 protraction activity") & ' \
                                                  f'(@behavior_start-0.5 <= Start) & (End <= @behavior_end)' # must start no earlier than 0.5 seconds before behavior and end within it
    
    subepoch_queries['B8 activity']             = f'(Type == "B8 activity") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)' # must be fully contained within behavior
    
    subepoch_queries['B3/6/9/10 activity']      = f'(Type == "B3/6/9/10 activity") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)' # must be fully contained within behavior

    subepoch_queries['B38 activity']            = f'(Type == "B38 activity") & ' \
                                                  f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)' # must start within 2 seconds of behavior end

    subepoch_queries['Force dip']               = f'(Type == "Force dip") & ' \
                                                  f'(@behavior_start <= Start) & (End <= @behavior_end)' # must be fully contained within behavior
    
    subepoch_queries['Force shoulder']          = f'(Type == "Force shoulder") & ' \
                                                  f'(@behavior_end-2 <= Start) & (Start <= @behavior_end+2)' # must start within 2 seconds of behavior end
    
    # construct a table in which each row is a behavior and subepoch data
    # is added as columns, e.g. df['Force start (s)']
    df = BehaviorsDataFrame(blk.segments[0].epochs, behavior_query, subepoch_queries)

    ###
    ### PERFORM CALCULATIONS
    ###

    # sanity check: plot all channels for entire time window
    plt.figure(figsize=(8,10))
    axes = []
    channel_units = ['uV', 'uV', 'uV', 'uV', 'mN']
    for i, channel in enumerate(channels_to_keep):
        if i == 0:
            ax = plt.subplot(len(channels_to_keep), 1, i+1)
        else:
            ax = plt.subplot(len(channels_to_keep), 1, i+1, sharex=axes[0])
        axes += [ax]
        sig = next((sig for sig in blk.segments[0].analogsignals if sig.name==channel), None)
        if sig is None:
            raise Exception(f'For data set "{data_set_name}", channel "{channel}" could not be found')
        sig = sig.time_slice(time_window[0]*pq.s, time_window[1]*pq.s)
        sig = sig.rescale(channel_units[i])
        plt.plot(sig.times, sig.magnitude, c='0.8', lw=1, zorder=-1)
        if i == 0:
            plt.title(f'({animal}, {food}, {bout_index}): {data_set_name}')
        plt.xlabel('Time (s)')
        plt.ylabel(sig.name + ' (' + sig.units.dimensionality.string + ')')
    
    
    
    # quantify force in each behavior
    channel = 'Force'
    for i in df.index:
        
        behavior_start = df.loc[i, 'Start (s)']*pq.s
        behavior_end = df.loc[i, 'End (s)']*pq.s
        force_dip_start = df.loc[i, 'Force dip start (s)']*pq.s
        force_rise_start = df.loc[i, 'Force rise start (s)'] = df.loc[i, 'Force dip end (s)']*pq.s

        # get the force channel
        sig = next((sig for sig in blk.segments[0].analogsignals if sig.name==channel), None)
        if sig is None:
            raise Exception(f'For data set "{data_set_name}", channel "{channel}" could not be found')

        # get force from a little before start of dip to 3 seconds after behavior
        sig = sig.time_slice(force_dip_start - 0.01*pq.s, behavior_end + 3*pq.s)
        sig = sig.rescale('mN')

        # find force peak, force baseline, and the force at 80% of peak height relative to baseline
        force_peak_time = elephant.spike_train_generation.peak_detection(sig, 0*pq.mN)[0]
        force_peak = df.loc[i, 'Force peak (mN)'] = sig[sig.time_index(force_peak_time)][0]
        force_baseline = df.loc[i, 'Force baseline (mN)'] = sig[sig.time_index(force_rise_start)][0]
        force_80percent = df.loc[i, 'Force 80%-height (mN)'] = force_baseline + 0.8*(force_peak-force_baseline)

        # find time when force first rises above the 80%-height threshold
        crossings = elephant.spike_train_generation.threshold_detection(sig, force_80percent, 'above')
        force_high_start = df.loc[i, 'Force high start (s)'] = crossings[np.where(crossings > force_rise_start)[0][0]].rescale('s')
        
        # find time when force drops below the 80%-height threshold after peaking
        crossings = elephant.spike_train_generation.threshold_detection(sig, force_80percent, 'below')
        force_rise_end = df.loc[i, 'Force rise end (s)'] = crossings[np.where(crossings > force_peak_time)[0][0]].rescale('s')
        
        # find force rise duration and force plateau duration
        force_rise_duration = df.loc[i, 'Force rise duration (s)'] = force_rise_end - force_rise_start
        force_high_duration = df.loc[i, 'Force high duration (s)'] = force_rise_end - force_high_start
        
        # find average slope during initial rising phase
        force_slope = df.loc[i, 'Force slope (mN/s)'] = ((force_80percent-force_baseline)/(force_high_start-force_rise_start)).rescale('mN/s')

        # sanity check: plot 80%-height force threshold
        plt.sca(axes[channels_to_keep.index(channel)])
        plt.plot([force_rise_start, force_rise_end], [force_80percent, force_80percent], c='gray', lw=1, ls=':')
        
        # sanity check: plot force dip
        sig3 = sig.time_slice(force_dip_start, force_rise_start)
        plt.plot(sig3.times, sig3.magnitude, c='#6666ff', lw=2)
        
        # sanity check: plot force rise
        sig2 = sig.time_slice(force_rise_start, force_rise_end)
        plt.plot(sig2.times, sig2.magnitude, c='#dd5500', lw=2)
        
        # sanity check: plot force shoulder
        force_shoulder_start = df.loc[i, 'Force shoulder start (s)']*pq.s
        force_shoulder_end = df.loc[i, 'Force shoulder end (s)']*pq.s
        if np.isfinite(force_shoulder_start):
            sig4 = sig.time_slice(force_shoulder_start, force_shoulder_end)
            plt.plot(sig4.times, sig4.magnitude, c='#00dd00', lw=2)

        # sanity check: plot force baseline, peak, and threshold crossings
        plt.plot([force_rise_start], [force_baseline],  marker=6, markersize=5, color='k')
        plt.plot([force_peak_time],  [force_peak],      marker=7, markersize=5, color='k')
        plt.plot([force_high_start], [force_80percent], marker=5, markersize=5, color='k')
        plt.plot([force_rise_end],   [force_80percent], marker=4, markersize=5, color='k')

        

    # spike train columns must have type object, which
    # can be accomplished by initializing with None
    units = [
        'I2 spikes',
        'B8a/b',
        'B3',
        'B6/B9',
        'B38',
    ]
    for unit in units:
        df[unit+' spike train'] = None
        df[unit+' inter-spike intervals'] = None
        
        
    
    # find spike trains in each behavior
    for i in df.index:
        behavior_start = df.loc[i, 'Start (s)']*pq.s
        behavior_end = df.loc[i, 'End (s)']*pq.s
        if lazy:
            if metadata['amplitude_discriminators'] is not None:
                for discriminator in metadata['amplitude_discriminators']:
                    sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == discriminator['channel']), None)
                    if sig is not None:
                        sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
                        st = _detect_spikes(sig, discriminator, blk.segments[0].epochs)
                        st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                        st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                        st = st.time_slice(st_epoch_start, st_epoch_end)
                        df.at[i, st.name+' spike train'] = st # 'at', not 'loc', is important for inserting list into cell
        else:
            for spiketrain in blk.segments[0].spiketrains:
                discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == spiketrain.name), None)
                if discriminator is None:
                    raise Exception(f'For data set "{data_set_name}", discriminator "{spiketrain.name}" could not be found')
                st_epoch_start = df.loc[i, discriminator['epoch']+' start (s)']*pq.s
                st_epoch_end = df.loc[i, discriminator['epoch']+' end (s)']*pq.s
                st = spiketrain.time_slice(st_epoch_start, st_epoch_end)
                df.at[i, st.name+' spike train'] = st # 'at', not 'loc', is important for inserting list into cell
    
    
            
    # quantify spike trains in each behavior
    for i in df.index:
        behavior_start = df.loc[i, 'Start (s)']*pq.s
        behavior_end = df.loc[i, 'End (s)']*pq.s
        for unit in units:
            st = df.loc[i, unit+' spike train']
            if st is not None:
                df.at[i, unit+' inter-spike intervals'] = elephant.statistics.isi(st) # 'at', not 'loc', is important for inserting list into cell
                df.loc[i, unit+' spike count'] = len(st)
                if len(st) >= 2:
                    burst_duration = np.ptp(st)
                    df.loc[i, unit+' burst start (s)'] = st.times[0]
                    df.loc[i, unit+' burst end (s)'] = st.times[-1]
                    df.loc[i, unit+' burst duration (s)'] = burst_duration
                    df.loc[i, unit+' burst mean frequency (Hz)'] = ((len(st)-1)/burst_duration).rescale('Hz')
                else:
                    df.loc[i, unit+' burst duration (s)'] = 0
                    df.loc[i, unit+' burst mean frequency (Hz)'] = 0

                    
                    
    # sanity check: plot spikes
    for j, i in enumerate(df.index):
        behavior_start = df.loc[i, 'Start (s)']*pq.s
        behavior_end = df.loc[i, 'End (s)']*pq.s
        for k, unit in enumerate(units):
            marker = ['.', 'x'][j%2] # alternate markers between behaviors
            color = f'C{k%10}' # cycle through colors with units
            
            st = df.loc[i, unit+' spike train']
            if st.size > 0:
                channel = st.annotations['channels'][0]
                plt.sca(axes[channels_to_keep.index(channel)])
                sig = next((sig for sig in blk.segments[0].analogsignals if sig.name == channel), None)
                if sig is None:
                    raise Exception(f'For data set "{data_set_name}", channel "{channel}" could not be found')
                sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
                sig = sig.rescale(channel_units[channels_to_keep.index(channel)])
                spike_amplitudes = np.concatenate([sig[sig.time_index(t)] for t in st] + [[]]) * pq.Quantity(sig.units)
                plt.scatter(st.times.rescale('s'), spike_amplitudes, marker=marker, c=color)

                discriminator = next((d for d in metadata['amplitude_discriminators'] if d['name'] == st.name), None)
                if discriminator is None:
                    raise Exception(f'For data set "{data_set_name}", discriminator "{st.name}" could not be found')
                left = df.loc[i, unit+' burst start (s)']*pq.s
                right = df.loc[i, unit+' burst end (s)']*pq.s
                width = right-left
                bottom = pq.Quantity(discriminator['amplitude'][0], discriminator['units']).rescale(sig.units)
                top = pq.Quantity(discriminator['amplitude'][1], discriminator['units']).rescale(sig.units)
                height = top-bottom
                rect = patches.Rectangle((left, bottom), width, height, edgecolor=color, fill=False)
                plt.gca().add_patch(rect)

            

    # find RAUC and mean voltage for each burst in each behavior
    unit_to_channel_mapping = {
        'I2 protraction activity': channels_to_keep[0],
        'B8 activity': channels_to_keep[1],
        'B3/6/9/10 activity': channels_to_keep[2],
        'B38 activity': channels_to_keep[2],
    }
    for i in df.index:
        
        behavior_start = df.loc[i, 'Start (s)']*pq.s
        behavior_end = df.loc[i, 'End (s)']*pq.s
        
        for unit, channel in unit_to_channel_mapping.items():
            burst_start = df.loc[i, unit+' start (s)']*pq.s
            burst_end = df.loc[i, unit+' end (s)']*pq.s
            burst_duration = df.loc[i, unit+' duration (s)']*pq.s
            
            # get the neural channel
            sig = next((sig for sig in blk.segments[0].analogsignals if sig.name==channel), None)
            if sig is None:
                raise Exception(f'For data set "{data_set_name}", channel "{channel}" could not be found')
        
            # get the signal for the behavior with 5 seconds cushion before and after (for better baseline estimation)
            sig = sig.time_slice(behavior_start - 5*pq.s, behavior_end + 5*pq.s)
            sig = sig.rescale('uV')
            
            # find burst RAUC and mean voltage
            rauc = df.loc[i, unit+' RAUC (μV·s)'] = elephant.signal_processing.rauc(sig, baseline='mean', t_start=burst_start, t_stop=burst_end).rescale('uV*s')
            mean_rect_voltage = df.loc[i, unit+' mean rectified voltage (μV)'] = rauc/burst_duration



    # find timing delays between neural and force events
    for i in df.index:
        i2_burst_start = df.loc[i, 'I2 protraction activity start (s)']*pq.s
        i2_burst_end = df.loc[i, 'I2 protraction activity end (s)']*pq.s
        b8_burst_start = df.loc[i, 'B8 activity start (s)']*pq.s
        b8_burst_end = df.loc[i, 'B8 activity end (s)']*pq.s
        bn2_burst_start = df.loc[i, 'B3/6/9/10 activity start (s)']*pq.s
        bn2_burst_end = df.loc[i, 'B3/6/9/10 activity end (s)']*pq.s
        b38_burst_start = df.loc[i, 'B38 activity start (s)']*pq.s
        b38_burst_end = df.loc[i, 'B38 activity end (s)']*pq.s
        force_dip_start = df.loc[i, 'Force dip start (s)']*pq.s
        force_rise_start = df.loc[i, 'Force rise start (s)']*pq.s
        force_rise_end = df.loc[i, 'Force rise end (s)']*pq.s
        force_shoulder_start = df.loc[i, 'Force shoulder start (s)']*pq.s
        force_shoulder_end = df.loc[i, 'Force shoulder end (s)']*pq.s
        
        i2_force_dip_delay = df.loc[i, 'Delay from I2 end to force dip (s)'] = force_dip_start - i2_burst_end
        
        b8_force_start_delay = df.loc[i, 'Delay from B8 start to force start (s)'] = force_rise_start - b8_burst_start
        b8_force_end_delay = df.loc[i, 'Delay from B8 end to force end (s)'] = force_rise_end - b8_burst_end
        
        bn2_force_start_delay = df.loc[i, 'Delay from B3/6/9/10 start to force start (s)'] = force_rise_start - bn2_burst_start
        bn2_force_end_delay = df.loc[i, 'Delay from B3/6/9/10 end to force end (s)'] = force_rise_end - bn2_burst_end
        
        b38_shoulder_start_delay = df.loc[i, 'Delay from B38 start to shoulder start (s)'] = force_shoulder_start - b38_burst_start
        b38_shoulder_end_delay = df.loc[i, 'Delay from B38 end to shoulder end (s)'] = force_shoulder_end - b38_burst_end


        
    plt.tight_layout()
    
    

    # renumber behaviors assuming all behaviors are from a single contiguous sequence
    df = df.sort_values('Start (s)').reset_index(drop=True).rename_axis('Behavior_index')
    
    # index the table on 4 variables
    df['Animal'] = animal
    df['Food'] = food
    df['Bout_index'] = bout_index
    df = df.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
    
    df_list += [df]
    
df_all = pd.concat(df_list, sort=False).sort_index()

# Plots

In [None]:
def scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, tooltips=False, padding=0.05):
    
    for j, (label, query) in enumerate(data_subsets.items()):
        if query is not None:
            df = df_all.query(query)
            ax.scatter(df[xlabel], df[ylabel],
                       label=label, marker=markers[j], c=colors[j], clip_on=False)
            
    all_points = df_all.query(query_union(data_subsets.values()))[[xlabel, ylabel]].dropna()
    
    if len(all_points) > 0:

        xrange = np.ptp(all_points.iloc[:, 0])
        xmin = min(all_points.iloc[:, 0]) - xrange * padding
        xmax = max(all_points.iloc[:, 0]) + xrange * padding
        if xlim is None:
            xlim = [xmin, xmax]
        if xlim[0] is None:
            xlim[0] = xmin
        if xlim[1] is None:
            xlim[1] = xmax

        yrange = np.ptp(all_points.iloc[:, 1])
        ymin = min(all_points.iloc[:, 1]) - yrange * padding
        ymax = max(all_points.iloc[:, 1]) + yrange * padding
        if ylim is None:
            ylim = [ymin, ymax]
        if ylim[0] is None:
            ylim[0] = ymin
        if ylim[1] is None:
            ylim[1] = ymax
        
    if trend:
        model = sm.OLS(all_points.iloc[:,1], sm.add_constant(all_points.iloc[:,0])).fit()
        model_stats = 'R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(all_points))
        print(model_stats)
        model_x = np.linspace(xlim[0], xlim[1], 100)
        model_y = model.params[0] + model.params[1] * model_x
        ax.plot(model_x, model_y, label=model_stats, color='gray')#colors[len(data_subsets)])
    
    if tooltips:
        # create a transparent layer containing all points for detecting mouse events
        sc = ax.scatter(all_points.iloc[:, 0], all_points.iloc[:, 1], label=None,
                        alpha=0, # transparent points
                        s=0.001, # small radius of tooltip activation
                       )
        
        # initialize a hidden empty tooltip
        annot = ax.annotate(
            '', xy=(0, 0),                              # initialize empty at origin
            xytext=(0, 15), textcoords='offset points', # position text above point
            color='k', size=10, ha='center',            # small black centered text
            bbox=dict(fc='w', lw=0, alpha=0.6),         # transparent white background
            arrowprops=dict(shrink=0, headlength=7, headwidth=7, width=0, lw=0, color='k'), # small black arrow
        )
        annot.set_visible(False)
        
        # prepare tooltip contents
        tooltip_text = [' '.join([str(x) for x in index]) for index in list(all_points.index)]
        
        # bind the animation of tooltips to a hover event
        def hover(event):
            if event.inaxes == ax:
                cont, ind = sc.contains(event)
                if cont:
                    point_index = ind['ind'][0]
                    pos = sc.get_offsets()[point_index]
                    annot.xy = pos
                    annot.set_text(tooltip_text[point_index])
                    annot.set_visible(True)
                    ax.figure.canvas.draw_idle()
                else:
                    if annot.get_visible():
                        annot.set_visible(False)
                        ax.figure.canvas.draw_idle()
        ax.figure.canvas.mpl_connect('motion_notify_event', hover)
        
    ax.set_xlabel(xlabel)
    ax.set_xlim(xlim)
    ax.set_ylabel(ylabel)#.replace('rectified ',''))
    ax.set_ylim(ylim)

## Figure 3: B3/6/9/10 and the rise, or B8 and the rise

__TODO:__

* Plot mean voltage on BN2 during B3/6/9/10 acitivty vs slope or relative rise in force ?
* Use firing frequency thresholds to crop neural activity window ?
* Example panel illustating measurements, with example data point marked in scatter plot

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG08 Tape nori',
    'JG12 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B3/6/9/10 activity duration (s)', [0, 8]
ylabel, ylim = 'Force rise duration (s)', [0, 8]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG08 Tape nori',
    'JG12 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B8 activity duration (s)', [0, 8]
ylabel, ylim = 'Force rise duration (s)', [0, 8]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
df = df.rename(columns={
    'Delay from B3/6/9/10 start to force start (s)': 'Start of burst',
    'Delay from B3/6/9/10 end to force end (s)':     'End of burst'})
df = pd.melt(df,
             id_vars=['Animal', 'Food', 'Bout_index', 'Behavior_index'],
             value_vars=['Start of burst', 'End of burst'],
             var_name='When?',
             value_name='Delay from B3/6/9/10 to force (s)')

sns.violinplot(y='Animal', x='Delay from B3/6/9/10 to force (s)', hue='When?', data=df, inner='points')#, split=True)

plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
df = df.rename(columns={
    'Delay from B8 start to force start (s)': 'Start of burst',
    'Delay from B8 end to force end (s)':     'End of burst'})
df = pd.melt(df,
             id_vars=['Animal', 'Food', 'Bout_index', 'Behavior_index'],
             value_vars=['Start of burst', 'End of burst'],
             var_name='When?',
             value_name='Delay from B8 to force (s)')

sns.violinplot(y='Animal', x='Delay from B8 to force (s)', hue='When?', data=df, inner='points')#, split=True)

plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

__TODO__:

* The section of BN2 that should be compared to force slope should be calculated backwards from threshold crossing, esp. to exclude neural activity responsible for plateau, after slope had leveled out

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG08 Tape nori',
    'JG12 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B3/6/9/10 activity mean rectified voltage (μV)', [0, None]
ylabel, ylim = 'Force slope (mN/s)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG08 Tape nori',
    'JG12 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B8 activity mean rectified voltage (μV)', [0, None]
ylabel, ylim = 'Force slope (mN/s)', [0, None]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, tooltips=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

## Figure 4: Peak protraction and the dip

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()

sns.violinplot(y='Animal', x='Delay from I2 end to force dip (s)', data=df, inner='points')

plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()

## Figure 6: B3 and the plateau

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG08 Tape nori',
    'JG12 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B3 burst duration (s)', [0, 5]
ylabel, ylim = 'Force high duration (s)', [0, 5]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

## Figure 7: B38 and the shoulder

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = [
    'JG08 Tape nori',
    'JG12 Tape nori',
]
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'B38 activity duration (s)', [0, 6]
ylabel, ylim = 'Force shoulder duration (s)', [0, 6]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, tooltips=False)
plt.plot([0, 999], [0, 999], ls=':', c='gray', zorder=0)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

df = df_all.reset_index()
df = df.rename(columns={
    'Delay from B38 start to shoulder start (s)': 'Start of burst',
    'Delay from B38 end to shoulder end (s)':     'End of burst'})
df = pd.melt(df,
             id_vars=['Animal', 'Food', 'Bout_index', 'Behavior_index'],
             value_vars=['Start of burst', 'End of burst'],
             var_name='When?',
             value_name='Delay from B38 to shoulder (s)')

sns.violinplot(y='Animal', x='Delay from B38 to shoulder (s)', hue='When?', data=df, inner='points')#, split=True)

plt.axvline(x=0, ls=':', c='gray', zorder=-1)

plt.tight_layout()