# Import Packages

In [None]:
# add the directory containing modules to the path
import sys
sys.path.append('../modules')

In [None]:
################################################################################
# NUMPY
# conda install numpy

import numpy as np

################################################################################
# SCIPY
# conda install scipy

# import scipy as sp

################################################################################
# MATPLOTLIB
# conda install matplotlib

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# from matplotlib.ticker import MultipleLocator

################################################################################
# SEABORN
# conda install seaborn

import seaborn as sns

################################################################################
# NEO
# pip install neo>=0.7.1
# - AxoGraph support requires axographio to be installed: pip install axographio

# import neo

################################################################################
# QUANTITIES
# conda install quantities

import quantities as pq
pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

################################################################################
# ELEPHANT
# pip install elephant>=0.6.2

import elephant

################################################################################
# PANDAS
# conda install pandas

import pandas as pd

################################################################################
# STATSMODELS
# conda install statsmodels

import statsmodels.api as sm

################################################################################
# SPM1D - One-Dimensional Statistical Parametric Mapping
# pip install spm1d

# import spm1d

################################################################################
# EPHYVIEWER
# pip install git+https://github.com/jpgill86/ephyviewer.git@experimental
# - requires PyAV: conda install -c conda-forge av

# import ephyviewer

################################################################################
# ParseMetadata
# - requires ipywidgets: conda install ipywidgets
# - requires yaml:       conda install pyyaml

from ParseMetadata import LoadMetadata

################################################################################
# ImportData

from ImportData import LoadAndPrepareData

################################################################################
# NeoUtilities
# - requires pylttb: pip install pylttb

from NeoUtilities import BehaviorsDataFrame#, NeoEpochToDataFrame, CausalAlphaKernel, DownsampleNeoSignal

################################################################################
# EphyviewerConfigurator

# from EphyviewerConfigurator import EphyviewerConfigurator

################################################################################
# NeoToEphyviewerBridge

# from NeoToEphyviewerBridge import NeoSegmentToEphyviewerSources, PlotExampleWithEphyviewer

# IPython Magics

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

# Data Parameters

In [None]:
feeding_bouts = {
    # (animal, food, bout_index): (data_set_name, channels_to_keep, time_window, epoch_types_to_keep)
    ('JG05', 'Regular nori', 0): ('IN VIVO / JG05 / 2018-03-05 / 001', ['I2',       'BN2', 'BN3'], [ 550,  594], ['Swallow (regular 5-cm nori strip)']),
    ('JG05', 'Tape nori',    0): ('IN VIVO / JG05 / 2018-03-05 / 001', ['I2',       'BN2', 'BN3'], [ 745,  827], ['Swallow (tape nori)']),
    ('JG07', 'Regular nori', 0): ('IN VIVO / JG07 / 2018-05-20 / 002', ['I2', 'RN', 'BN2', 'BN3'], [1496, 1527], ['Swallow (regular 5-cm nori strip)']),
    ('JG07', 'Tape nori',    0): ('IN VIVO / JG07 / 2018-05-20 / 002', ['I2', 'RN', 'BN2', 'BN3'], [1036, 1093], ['Swallow (tape nori)']),
    ('JG08', 'Fresh food',   0): ('IN VIVO / JG08 / 2018-06-25 / 001', ['I2', 'RN', 'BN2', 'BN3'], [1170, 1590], ['Swallow (fresh food)']),
    ('JG08', 'Regular nori', 0): ('IN VIVO / JG08 / 2018-06-21 / 001', ['I2', 'RN', 'BN2', 'BN3'], [2462, 2532], ['Swallow (regular 5-cm nori strip)']),
    ('JG08', 'Tape nori',    0): ('IN VIVO / JG08 / 2018-06-21 / 002', ['I2', 'RN', 'BN2', 'BN3'], [ 134,  205], ['Swallow (tape nori)']),
    ('JG08', 'Tape nori',    1): ('IN VIVO / JG08 / 2018-06-21 / 002', ['I2', 'RN', 'BN2', 'BN3'], [ 648,  724], ['Swallow (tape nori)']),
    ('JG08', 'Tape nori',    2): ('IN VIVO / JG08 / 2018-06-21 / 002', ['I2', 'RN', 'BN2', 'BN3'], [1436, 1474], ['Swallow (tape nori)']),
    ('JG08', 'Tubing',       0): ('IN VIVO / JG08 / 2018-06-25 / 001', ['I2', 'RN', 'BN2', 'BN3'], [4026, 4146], ['Swallow (tubing)', 'No movement (tubing)', 'Reposition (tubing)', 'Rejection (tubing)']),
    ('JG08', 'Two-ply nori', 0): ('IN VIVO / JG08 / 2018-06-25 / 001', ['I2', 'RN', 'BN2', 'BN3'], [3256, 3493], ['Swallow (two-ply nori)']),
}

# Import and Process the Data

In [None]:
# load the metadata containing file paths
all_metadata = LoadMetadata(file='../metadata.yml', local_data_root='../../data')

# filter epochs for each feeding condition and perform calculations
df_list = []
last_data_set_name = None
for (animal, food, bout_index), (data_set_name, channels_to_keep, time_window, epoch_types_to_keep) in feeding_bouts.items():
    
    ###
    ### LOCATE BEHAVIOR EPOCHS AND SUBEPOCHS
    ###
    
    if data_set_name is last_data_set_name:
        # skip reloading the data if it's already in memory
        pass
    else:
        blk = LoadAndPrepareData(all_metadata[data_set_name])
    last_data_set_name = data_set_name
    
    # construct a query for locating behaviors
    behavior_query = f'(Type in {epoch_types_to_keep}) & ({time_window[0]} <= Start) & (End <= {time_window[1]})'
    
    # construct queries for locating epochs associated with each behavior
    # - each query should match at most one epoch
    # - dictionary keys are used as prefixes for the names of new columns
    subepoch_queries = {}
    
    # look for a force epoch that begins within each behavior
    subepoch_queries['Force'] = '(Type == "force") & (@behavior_start <= Start) & (Start <= @behavior_end)'
    
    # look for a large hump epoch that begins within each behavior
    subepoch_queries['Force large hump'] = '(Type == "large hump") & (@behavior_start <= Start) & (Start <= @behavior_end)'
    
    # for each channel, look for a burst epoch that is at least mostly
    # contained within the behavior, allowing for a small discrepancy
    # in start and end time if the burst overextends in either direction
    burst_timing_tolerance = 0.5 # seconds
    for channel in channels_to_keep:
        subepoch_queries[channel+' burst'] = f'(Type == "{channel} burst") & (@behavior_start-{burst_timing_tolerance} <= Start) & (End <= @behavior_end+{burst_timing_tolerance})'
    
    # construct a table in which each row is a behavior and subepoch data
    # is added as columns, e.g. df['Force start (s)']
    df = BehaviorsDataFrame(blk.segments[0].epochs, behavior_query, subepoch_queries)

    ###
    ### PERFORM CALCULATIONS
    ###

    # add defaults for new columns
    df['Interval before (s)'] = np.nan
    df['Interval after (s)'] = np.nan
    for channel in channels_to_keep:
        df[channel+' RAUC (μV·s)'] = np.nan
        df[channel+' mean rectified voltage (μV)'] = np.nan
        df[channel+' burst RAUC (μV·s)'] = np.nan
        df[channel+' burst mean rectified voltage (μV)'] = np.nan
    df['Force RAUC (mN·s)'] = np.nan
    df['Force mean (mN)'] = np.nan
    df['Force peak (mN)'] = np.nan
    df['Force RAUC following BN2 burst (mN·s)'] = np.nan
    df['Force mean following BN2 burst (mN)'] = np.nan
    df['Force peak following BN2 burst (mN)'] = np.nan
    df['Force increase following BN2 burst (mN)'] = np.nan
    df['Force slope mean following BN2 burst (mN/s)'] = np.nan
    df['Force slope peak following BN2 burst (mN/s)'] = np.nan
    df['Force slope mean during initial rise (mN/s)'] = np.nan
    df['Force large hump duration following BN2 burst (s)'] = np.nan
    
    # calculate interbehavior interval assuming all behaviors are from a single contiguous sequence
    previous_i = None
    previous_end_time = None
    for i in df.index:
        if previous_i is not None:
            df.loc[i,          'Interval before (s)'] = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'End (s)']
            df.loc[previous_i, 'Interval after (s)']  = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'End (s)']
        previous_i = i

    # find RAUC and mean voltage for each channel in each epoch
    for i in df.index:
        behavior_start = df.loc[i, 'Start (s)']
        behavior_end = df.loc[i, 'End (s)']
        behavior_duration = df.loc[i, 'Duration (s)']
        for channel in channels_to_keep:
            sig = next((sig for sig in blk.segments[0].analogsignals if sig.name.replace('-L','')==channel), None)
            if sig is None:
                raise Exception(f'For data set "{data_set_name}", channel "{channel}" could not be found')
            sig = sig.time_slice(behavior_start*pq.s, behavior_end*pq.s)
            rauc = elephant.signal_processing.rauc(sig, baseline='mean').rescale('uV*s')
            df.loc[i, channel+' RAUC (μV·s)'] = rauc
            df.loc[i, channel+' mean rectified voltage (μV)'] = rauc/(behavior_duration*pq.s)

    # find RAUC and mean voltage for each burst in each epoch
    for i in df.index:
        behavior_start = df.loc[i, 'Start (s)']
        behavior_end = df.loc[i, 'End (s)']
        for channel in channels_to_keep:
            burst_start = df.loc[i, channel+' burst start (s)']
            burst_end = df.loc[i, channel+' burst end (s)']
            burst_duration = df.loc[i, channel+' burst duration (s)']
            if np.isfinite(burst_start):
                sig = next((sig for sig in blk.segments[0].analogsignals if sig.name.replace('-L','')==channel), None)
                if sig is None:
                    raise Exception(f'For data set "{data_set_name}", channel "{channel}" could not be found')
                sig = sig.time_slice((behavior_start-burst_timing_tolerance)*pq.s, (behavior_end+burst_timing_tolerance)*pq.s)
                rauc = elephant.signal_processing.rauc(sig, baseline='mean', t_start=burst_start*pq.s, t_stop=burst_end*pq.s).rescale('uV*s')
                df.loc[i, channel+' burst RAUC (μV·s)'] = rauc
                df.loc[i, channel+' burst mean rectified voltage (μV)'] = rauc/(burst_duration * pq.s)
    
    # quantify force
    channel = 'Force'
    for i in df.index:
        
        force_start = df.loc[i, 'Force start (s)']
        force_end = df.loc[i, 'Force end (s)']
        if np.isfinite(force_start):
            
            sig = next((sig for sig in blk.segments[0].analogsignals if sig.name==channel), None)
            if sig is None:
                raise Exception(f'For data set "{data_set_name}", channel "{channel}" could not be found')
            sig = sig.time_slice(force_start*pq.s, force_end*pq.s)
            rauc = elephant.signal_processing.rauc(sig).rescale('mN*s')
            df.loc[i, 'Force RAUC (mN·s)'] = rauc
            df.loc[i, 'Force mean (mN)'] = sig.rescale('mN').mean()
            df.loc[i, 'Force peak (mN)'] = sig.rescale('mN').max()
            
            # if a BN2 burst also exists, quantify the force following the burst
            # using a fixed lag time to account for slow muscles and a fixed
            # window during which force rises at a steady rate
            slow_muscle_lag = 0.2 # seconds
            force_rise_time = 1.5 # seconds
            bn2_burst_start = df.loc[i, 'BN2 burst start (s)']
            bn2_burst_end = df.loc[i, 'BN2 burst end (s)']
            if np.isfinite(bn2_burst_start):
                
                sig = next((sig for sig in blk.segments[0].analogsignals if sig.name==channel), None)
                if sig is None:
                    raise Exception(f'For data set "{data_set_name}", channel "{channel}" could not be found')
                sig = sig.time_slice((bn2_burst_start+slow_muscle_lag)*pq.s, (bn2_burst_end+slow_muscle_lag)*pq.s)
                rauc = elephant.signal_processing.rauc(sig).rescale('mN*s')
                slope = elephant.signal_processing.derivative(sig).rescale('mN/s')
                df.loc[i, 'Force RAUC following BN2 burst (mN·s)'] = rauc
                df.loc[i, 'Force mean following BN2 burst (mN)'] = sig.rescale('mN').mean()
                df.loc[i, 'Force peak following BN2 burst (mN)'] = sig.rescale('mN').max()
                df.loc[i, 'Force increase following BN2 burst (mN)'] = sig.rescale('mN').max()-sig.rescale('mN')[0]
                df.loc[i, 'Force slope mean following BN2 burst (mN/s)'] = slope.mean()
                df.loc[i, 'Force slope peak following BN2 burst (mN/s)'] = slope.max()
                df.loc[i, 'Force slope mean during initial rise (mN/s)'] = slope.time_slice(
                    max(slope.t_start, (bn2_burst_start+slow_muscle_lag)*pq.s),
                    min(slope.t_stop,  (bn2_burst_start+slow_muscle_lag+force_rise_time)*pq.s)
                ).mean()
                
                large_hump_end = df.loc[i, 'Force large hump end (s)']
                if np.isfinite(large_hump_end):
                    df.loc[i, 'Force large hump duration following BN2 burst (s)'] = large_hump_end-(bn2_burst_start+slow_muscle_lag)
    
    # renumber behaviors assuming all behaviors are from a single contiguous sequence
    df = df.sort_values('Start (s)').reset_index(drop=True).rename_axis('Behavior_index')
    
    # index the table on 4 variables
    df['Animal'] = animal
    df['Food'] = food
    df['Bout_index'] = bout_index
    df = df.reset_index().set_index(['Animal', 'Food', 'Bout_index', 'Behavior_index'])
    
    df_list += [df]
    
df_all = pd.concat(df_list, sort=False).sort_index()

In [None]:
# check for missing data
for subepoch_type in ['Force', 'Force large hump', 'I2 burst', 'RN burst', 'BN2 burst', 'BN3 burst']:
    print(f'-- {subepoch_type} --')
    for (animal, food), df in df_all.groupby(['Animal', 'Food']):
        label = ' '.join([animal, food])
        try:
            all_have_subepoch = np.all(np.isfinite(df[f'{subepoch_type} start (s)']))
            some_have_subepoch = np.any(np.isfinite(df[f'{subepoch_type} start (s)']))
            if all_have_subepoch:
                print(f'{label}: all behaviors have {subepoch_type}')
            elif some_have_subepoch:
                print(f'{label}: some behaviors have {subepoch_type} **BUT NOT ALL**')
#             else:
#                 print(f'{label}: no behaviors have {subepoch_type}')
        except:
            pass
    print()

In [None]:
def query_union(queries):
    return '|'.join([f'({q})' for q in queries if q is not None])

def label2query(label):
    animal, food = label[:4], label[5:]
    query = f'(Animal == "{animal}") & (Food == "{food}")'
    return query

def contains(series, string):
#     return np.array([string in x for x in series])
    return series.map(lambda x: string in x)

In [None]:
# example_queries = [
# #     'Food == "Tape nori"',
#     '(Animal == "JG05") & (Behavior_index < 2)',
# #     '@contains(Food, "nori")',
# #     'not @contains(Food, "nori")',
#     '@contains(Type, "Rejection")',
# #     '(@contains(Type, "Swallow")) & (@contains(Type, "tubing"))',
# ]

# for q in example_queries:
#     display(df_all.query(q))

# display(df_all.query(query_union(example_queries)))

# Plots

In [None]:
# color map
cm = plt.cm.cool
# cm = plt.cm.brg
# cm = plt.cm.RdBu

sns.set(
#     context = 'poster',
    style = 'ticks',
    font_scale = 1,
    font = 'Palatino Linotype',
)

In [None]:
duration_range = [0, 25]
interval_range = [-2, 27]
rauc_range = [0, 90]
voltage_mean_range = [0, 12]
force_peak_range = [0, 250]
force_mean_range = [0, 100]
hist_y_range = [0, 25]

markers = ['.', '+', 'x', '1', '4', 'o', 'D']
colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6']

In [None]:
# ROWS: ANIMALS, COLS: FOODS
plot_layout = np.array([
    ['JG08 Fresh food', 'JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Tubing'],
    [None,              'JG07 Regular nori', 'JG07 Tape nori', None,                None         ],
    [None,              'JG05 Regular nori', 'JG05 Tape nori', None,                None         ],
])

assert plot_layout.ndim == 2, 'plot_layout needs to be rectangular (not ragged)'
n_rows, n_cols = plot_layout.shape

# fig size in inches
figsize = (15,8)

In [None]:
# # ROWS: FOODS, COLS: ANIMALS
# plot_layout = np.array([
#     ['JG08 Fresh food', 'JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Tubing'],
#     [None,              'JG07 Regular nori', 'JG07 Tape nori', None,                None         ],
#     [None,              'JG05 Regular nori', 'JG05 Tape nori', None,                None         ],
# ]).T

# assert plot_layout.ndim == 2, 'plot_layout needs to be rectangular (not ragged)'
# n_rows, n_cols = plot_layout.shape

# # fig size in inches
# figsize = (8,12)

In [None]:
# # ALL ONE ROW
# plot_layout = np.array([
#     ['JG08 Fresh food', 'JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Tubing', 'JG07 Regular nori', 'JG07 Tape nori', 'JG05 Regular nori', 'JG05 Tape nori'],
# ])

# assert plot_layout.ndim == 2, 'plot_layout needs to be rectangular (not ragged)'
# n_rows, n_cols = plot_layout.shape

# # fig size in inches
# figsize = (30,3)

In [None]:
def sequencePlot(ax, label, query, ylabel, ylim):
    
    df = df_all.query(query)
    for bout_index, df2 in df.groupby('Bout_index'):
        df2 = df2.reset_index()
        ax.plot(df2['Behavior_index'], df2[ylabel], label=f'Sequence {bout_index}', marker='.')
    ax.set_ylim(ylim)
    ax.set_title(label)
    ax.set_xlabel('Behavior index')
    ax.set_ylabel(ylabel)

In [None]:
def freqPlot(ax, label, query, xlabel, bins, n_drop_from_beginning=5):
    
    df = df_all.query(query)
    ax.hist(df[xlabel], label='Full sequence', bins=bins)
    if n_drop_from_beginning > 0:
        ax.hist(df.query(f'Behavior_index >= {n_drop_from_beginning}')[xlabel], label=f'First {n_drop_from_beginning} dropped', bins=bins)
#     ax.set_xticks(bins)
    ax.set_ylim(hist_y_range)
    ax.set_title(label)
    ax.set_xlabel(xlabel)
    ax.set_ylabel('Frequency')

In [None]:
def scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, tooltips=False):
    
    for j, (label, query) in enumerate(data_subsets.items()):
        if query is not None:
            df = df_all.query(query)
            ax.scatter(df[xlabel], df[ylabel],
                       label=label, marker=markers[j], c=colors[j])
            
    if trend or tooltips:
        all_points = df_all.query(query_union(data_subsets.values()))[[xlabel, ylabel]].dropna()
        
    if trend:
        model = sm.OLS(all_points.iloc[:,1], sm.add_constant(all_points.iloc[:,0])).fit()
        model_stats = 'R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(all_points))
        model_x = np.linspace(xlim[0], xlim[1], 100)
        model_y = model.params[0] + model.params[1] * model_x
        ax.plot(model_x, model_y, label=model_stats, color=colors[len(data_subsets)])
    
    if tooltips:
        # create a transparent layer containing all points for detecting mouse events
        sc = ax.scatter(all_points.iloc[:, 0], all_points.iloc[:, 1], label=None,
                        alpha=0, # transparent points
                        s=0.001, # small radius of tooltip activation
                       )
        
        # initialize a hidden empty tooltip
        annot = ax.annotate(
            '', xy=(0, 0),                              # initialize empty at origin
            xytext=(0, 15), textcoords='offset points', # position text above point
            color='k', size=10, ha='center',            # small black centered text
            bbox=dict(fc='w', lw=0, alpha=0.6),         # transparent white background
            arrowprops=dict(shrink=0, headlength=7, headwidth=7, width=0, lw=0, color='k'), # small black arrow
        )
        annot.set_visible(False)
        
        # prepare tooltip contents
        tooltip_text = [' '.join([str(x) for x in index]) for index in list(all_points.index)]
        
        # bind the animation of tooltips to a hover event
        def hover(event):
            if event.inaxes == ax:
                cont, ind = sc.contains(event)
                if cont:
                    point_index = ind['ind'][0]
                    pos = sc.get_offsets()[point_index]
                    annot.xy = pos
                    annot.set_text(tooltip_text[point_index])
                    annot.set_visible(True)
                    ax.figure.canvas.draw_idle()
                else:
                    if annot.get_visible():
                        annot.set_visible(False)
                        ax.figure.canvas.draw_idle()
        ax.figure.canvas.mpl_connect('motion_notify_event', hover)
        
    ax.set_xlabel(xlabel)
    ax.set_xlim(xlim)
    ax.set_ylabel(ylabel)#.replace('rectified ',''))
    ax.set_ylim(ylim)

In [None]:
def scatter3d(ax, data_subsets, xlabel, xlim, ylabel, ylim, zlabel, zlim, trend=False, tooltips=False):
    
    for j, (label, query) in enumerate(data_subsets.items()):
        if query is not None:
            df = df_all.query(query)
            sc = ax.scatter(df[xlabel], df[ylabel], df[zlabel],
                       label=label, marker=markers[j], c=colors[j])#, depthshade=False)
    
    if trend or tooltips:
        all_points = df_all.query(query_union(data_subsets.values()))[[xlabel, ylabel, zlabel]].dropna()
            
    if trend:
        model = sm.OLS(all_points.iloc[:,2], sm.add_constant(all_points.iloc[:,:2])).fit()
        model_stats = 'R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(all_points))
        print(model_stats)
        model_x, model_y = np.meshgrid(np.linspace(min(all_points.iloc[:,0]), max(all_points.iloc[:,0]), 20), np.linspace(min(all_points.iloc[:,1]), max(all_points.iloc[:,1]), 20))
        model_z = model.params[0] + model.params[1] * model_x + model.params[2] * model_y
        ax.plot_surface(model_x, model_y, model_z, cmap=plt.cm.RdBu_r, alpha=0.6, linewidth=0)

    if tooltips:
        raise NotImplementedError('Tooltips coming soon to a 3D plot near you...')
        
    ax.set_xlabel(xlabel)
    ax.set_xlim(xlim)
    ax.set_ylabel(ylabel)
    ax.set_ylim(ylim)
    ax.set_zlabel(zlabel)
    ax.set_zlim(zlim)

## Sequence of behavior durations

In [None]:
plt.figure(figsize=figsize)

ylabel = 'Duration (s)'
ylim = duration_range

for i in range(n_rows):
    for j in range(n_cols):
        label = plot_layout[i, j]
        if label is not None:
            ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
            sequencePlot(ax, label, label2query(label), ylabel, ylim)
            ax.legend()
plt.tight_layout()

## Distribution of behavior durations

In [None]:
# plt.figure(figsize=figsize)

# xlabel = 'Duration (s)'
# bins = np.arange(duration_range[0], duration_range[1])

# for i in range(n_rows):
#     for j in range(n_cols):
#         label = plot_layout[i, j]
#         if label is not None:
#             ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             freqPlot(ax, label, label2query(label), xlabel, bins)
# #             ax.legend()
# plt.tight_layout()

## Sequence of interbehavior intervals

In [None]:
plt.figure(figsize=figsize)

ylabel = 'Interval after (s)'
ylim = interval_range

for i in range(n_rows):
    for j in range(n_cols):
        label = plot_layout[i, j]
        if label is not None:
            ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
            sequencePlot(ax, label, label2query(label), ylabel, ylim)
            ax.legend()
plt.tight_layout()

## Distribution of interbehavior intervals

In [None]:
# plt.figure(figsize=figsize)

# xlabel = 'Interval after (s)'
# bins = np.arange(interval_range[0], interval_range[1])

# for i in range(n_rows):
#     for j in range(n_cols):
#         label = plot_layout[i, j]
#         if label is not None:
#             ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             freqPlot(ax, label, label2query(label), xlabel, bins)
# #             ax.legend()
# plt.tight_layout()

## Sequences of RAUCs / mean rectified voltages

In [None]:
# plt.figure(figsize=figsize)

# ylabel = 'I2 RAUC (μV·s)'
# ylim = rauc_range

# for i in range(n_rows):
#     for j in range(n_cols):
#         label = plot_layout[i, j]
#         if label is not None:
#             ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             sequencePlot(ax, label, label2query(label), ylabel, ylim)
#             ax.legend()
# plt.tight_layout()

In [None]:
# plt.figure(figsize=figsize)

# ylabel = 'RN RAUC (μV·s)'
# ylim = rauc_range

# for i in range(n_rows):
#     for j in range(n_cols):
#         label = plot_layout[i, j]
#         if label is not None:
#             ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             sequencePlot(ax, label, label2query(label), ylabel, ylim)
#             ax.legend()
# plt.tight_layout()

In [None]:
plt.figure(figsize=figsize)

ylabel = 'BN2 RAUC (μV·s)'
ylim = rauc_range

for i in range(n_rows):
    for j in range(n_cols):
        label = plot_layout[i, j]
        if label is not None:
            ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
            sequencePlot(ax, label, label2query(label), ylabel, ylim)
            ax.legend()
plt.tight_layout()

In [None]:
# plt.figure(figsize=figsize)

# ylabel = 'BN3 RAUC (μV·s)'
# ylim = rauc_range

# for i in range(n_rows):
#     for j in range(n_cols):
#         label = plot_layout[i, j]
#         if label is not None:
#             ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             sequencePlot(ax, label, label2query(label), ylabel, ylim)
#             ax.legend()
# plt.tight_layout()

In [None]:
# plt.figure(figsize=figsize)

# ylabels = [
#     'Duration (s)',
#     'I2 RAUC (μV·s)',
#     'RN RAUC (μV·s)',
#     'BN2 RAUC (μV·s)',
#     'BN3 RAUC (μV·s)'
# ]
# ylim = [np.min([duration_range, rauc_range]), np.max([duration_range, rauc_range])]

# for i in range(n_rows):
#     for j in range(n_cols):
#         label = plot_layout[i, j]
#         if label is not None:
#             ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             for ylabel in ylabels:
#                 sequencePlot(ax, label, label2query(label), ylabel, ylim)
#             ax.set_ylabel('Duration or RAUC')
# #             ax.legend()
# plt.tight_layout()

## Sequences of mean rectified voltages

In [None]:
# plt.figure(figsize=figsize)

# ylabel = 'I2 mean rectified voltage (μV)'
# ylim = voltage_mean_range

# for i in range(n_rows):
#     for j in range(n_cols):
#         label = plot_layout[i, j]
#         if label is not None:
#             ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             sequencePlot(ax, label, label2query(label), ylabel, ylim)
#             ax.legend()
# plt.tight_layout()

In [None]:
# plt.figure(figsize=figsize)

# ylabel = 'RN mean rectified voltage (μV)'
# ylim = voltage_mean_range

# for i in range(n_rows):
#     for j in range(n_cols):
#         label = plot_layout[i, j]
#         if label is not None:
#             ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             sequencePlot(ax, label, label2query(label), ylabel, ylim)
#             ax.legend()
# plt.tight_layout()

In [None]:
plt.figure(figsize=figsize)

ylabel = 'BN2 mean rectified voltage (μV)'
ylim = voltage_mean_range

for i in range(n_rows):
    for j in range(n_cols):
        label = plot_layout[i, j]
        if label is not None:
            ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
            sequencePlot(ax, label, label2query(label), ylabel, ylim)
            ax.legend()
plt.tight_layout()

In [None]:
# plt.figure(figsize=figsize)

# ylabel = 'BN3 mean rectified voltage (μV)'
# ylim = voltage_mean_range

# for i in range(n_rows):
#     for j in range(n_cols):
#         label = plot_layout[i, j]
#         if label is not None:
#             ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             sequencePlot(ax, label, label2query(label), ylabel, ylim)
#             ax.legend()
# plt.tight_layout()

In [None]:
# plt.figure(figsize=figsize)

# ylabels = [
#     'Duration (s)',
#     'I2 mean rectified voltage (μV)',
#     'RN mean rectified voltage (μV)',
#     'BN2 mean rectified voltage (μV)',
#     'BN3 mean rectified voltage (μV)',
# ]
# ylim = [np.min([duration_range, voltage_mean_range]), np.max([duration_range, voltage_mean_range])]

# for i in range(n_rows):
#     for j in range(n_cols):
#         label = plot_layout[i, j]
#         if label is not None:
#             ax = plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             for ylabel in ylabels:
#                 sequencePlot(ax, label, label2query(label), ylabel, ylim)
#             ax.set_ylabel('Duration or mean voltage')
# #             ax.legend()
# plt.tight_layout()

## Scatter plots of behavior durations and mean rectified voltages

In [None]:
plt.figure(figsize=(5, 12))

data_subsets = ['JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Fresh food']#, 'JG08 Tubing']
data_subsets = {label:label2query(label) for label in data_subsets}

channels = ['I2', 'RN', 'BN2', 'BN3']

# xlabels = ['Duration (s)']*4
# xlims = [duration_range]*4
# ylabels = [c+' mean rectified voltage (μV)' if c is not None else None for c in channels]
# ylims = [voltage_mean_range]*4

xlabels = [c+' burst duration (s)' if c is not None else None for c in channels]
xlims = [[0,10], [0,20], [0,15], [0,15]]
ylabels = [c+' burst mean rectified voltage (μV)' if c is not None else None for c in channels]
ylims = [[0,6], [0,5], [0,10], [0,12]]

for i in range(len(ylabels)):
    if ylabels[i] is not None:
        ax = plt.subplot(len(ylabels), 1, i+1)
        scatter2d(ax, data_subsets, xlabels[i], xlims[i], ylabels[i], ylims[i], trend=False, tooltips=False)
        ax.legend()
plt.tight_layout()

In [None]:
# plt.figure(figsize=(5, 5))
# ax = plt.gca()

# data_subsets = ['JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Fresh food']
# # data_subsets = ['JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori']
# data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'BN2 burst mean rectified voltage (μV)', [0,10]
# ylabel, ylim = 'BN2 burst duration (s)', [0,7]

# scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=False, tooltips=False)
# ax.legend()
# sns.despine(ax=ax, offset=20, trim=True)
# plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

data_subsets = ['JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Fresh food']#, 'JG08 Tubing']
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'Duration (s)', duration_range
# ylabel, ylim = 'BN2 mean rectified voltage (μV)', voltage_mean_range
# zlabel, zlim = 'BN3 mean rectified voltage (μV)', voltage_mean_range

xlabel, xlim = 'RN burst mean rectified voltage (μV)', voltage_mean_range
ylabel, ylim = 'BN2 burst mean rectified voltage (μV)', voltage_mean_range
zlabel, zlim = 'BN3 burst mean rectified voltage (μV)', voltage_mean_range

scatter3d(ax, data_subsets, xlabel, xlim, ylabel, ylim, zlabel, zlim, trend=False, tooltips=False)
ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 12))

data_subsets = ['JG07 Regular nori', 'JG07 Tape nori']
data_subsets = {label:label2query(label) for label in data_subsets}

channels = ['I2', 'RN', 'BN2', 'BN3']

xlabels = ['Duration (s)']*4
xlims = [duration_range]*4
ylabels = [c+' mean rectified voltage (μV)' if c is not None else None for c in channels]
ylims = [voltage_mean_range]*4

for i in range(len(ylabels)):
    if ylabels[i] is not None:
        ax = plt.subplot(len(ylabels), 1, i+1)
        scatter2d(ax, data_subsets, xlabels[i], xlims[i], ylabels[i], ylims[i], trend=False, tooltips=False)
        ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

data_subsets = ['JG07 Regular nori', 'JG07 Tape nori']
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'Duration (s)', duration_range
ylabel, ylim = 'BN2 mean rectified voltage (μV)', voltage_mean_range
zlabel, zlim = 'BN3 mean rectified voltage (μV)', voltage_mean_range

scatter3d(ax, data_subsets, xlabel, xlim, ylabel, ylim, zlabel, zlim, trend=False, tooltips=False)
ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 12))

data_subsets = ['JG05 Regular nori', 'JG05 Tape nori']
data_subsets = {label:label2query(label) for label in data_subsets}

channels = ['I2', None, 'BN2', 'BN3']

xlabels = ['Duration (s)']*4
xlims = [duration_range]*4
ylabels = [c+' mean rectified voltage (μV)' if c is not None else None for c in channels]
ylims = [voltage_mean_range]*4

for i in range(len(ylabels)):
    if ylabels[i] is not None:
        ax = plt.subplot(len(ylabels), 1, i+1)
        scatter2d(ax, data_subsets, xlabels[i], xlims[i], ylabels[i], ylims[i], trend=False, tooltips=False)
        ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

data_subsets = ['JG05 Regular nori', 'JG05 Tape nori']
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'Duration (s)', duration_range
ylabel, ylim = 'BN2 mean rectified voltage (μV)', voltage_mean_range
zlabel, zlim = 'BN3 mean rectified voltage (μV)', voltage_mean_range

scatter3d(ax, data_subsets, xlabel, xlim, ylabel, ylim, zlabel, zlim, trend=False, tooltips=False)
ax.legend()
plt.tight_layout()

## Force plots

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

data_subsets = ['JG07 Tape nori', 'JG08 Tape nori']
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'Duration (s)', duration_range
ylabel, ylim = 'Force peak (mN)', [0,500]#force_peak_range
zlabel, zlim = 'Force mean (mN)', [0,200]#force_mean_range

# xlabel, xlim = 'Duration (s)', duration_range
# ylabel, ylim = 'BN2 mean rectified voltage (μV)', voltage_mean_range
# zlabel, zlim = 'Force mean (mN)', force_mean_range

# xlabel, xlim = 'Duration (s)', duration_range
# ylabel, ylim = 'BN2 mean rectified voltage (μV)', voltage_mean_range
# zlabel, zlim = 'Force peak (mN)', force_peak_range

scatter3d(ax, data_subsets, xlabel, xlim, ylabel, ylim, zlabel, zlim, trend=False, tooltips=False)
ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

data_subsets = ['JG07 Tape nori', 'JG08 Tape nori']
data_subsets = {label:label2query(label) for label in data_subsets}

xlabel, xlim = 'BN2 burst duration (s)', [0,8]
ylabel, ylim = 'BN2 burst mean rectified voltage (μV)', [0,6]
# zlabel, zlim = 'Force mean (mN)', force_mean_range
# zlabel, zlim = 'Force peak (mN)', force_peak_range
# zlabel, zlim = 'Force mean following BN2 burst (mN)', [0,150]
# zlabel, zlim = 'Force peak following BN2 burst (mN)', [0,250]
zlabel, zlim = 'Force increase following BN2 burst (mN)', [0,300]
# zlabel, zlim = 'Force slope mean following BN2 burst (mN/s)', [0,50]
# zlabel, zlim = 'Force slope peak following BN2 burst (mN/s)', [0,700]

scatter3d(ax, data_subsets, xlabel, xlim, ylabel, ylim, zlabel, zlim, trend=False, tooltips=False)
ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

data_subsets = ['JG07 Tape nori', 'JG08 Tape nori']
data_subsets = {label:label2query(label) for label in data_subsets}
data_subsets['JG08 Tape nori'] += ' & not ((Bout_index == 0) & (Behavior_index == 0))' # remove outlier

xlabel, xlim = 'BN2 burst duration (s)', [0,4]#[0,8]
ylabel, ylim = 'BN2 burst mean rectified voltage (μV)', [1.5,4.5]#[0,6]
# ylabel, ylim = 'BN3 burst mean rectified voltage (μV)', [0,10]
# ylabel, ylim = 'RN burst mean rectified voltage (μV)', [1,3]
# ylabel, ylim = 'RN burst duration (s)', [0,10]
# ylabel, ylim = 'BN2 burst mean rectified voltage (μV)', [1.5,4.5]#[0,6]
# zlabel, zlim = 'Force duration (s)', [0,10]
# zlabel, zlim = 'Force large hump duration following BN2 burst (s)', [3,7]
# zlabel, zlim = 'Force mean (mN)', force_mean_range
# zlabel, zlim = 'Force peak (mN)', force_peak_range
# zlabel, zlim = 'Force mean following BN2 burst (mN)', [0,150]
# zlabel, zlim = 'Force peak following BN2 burst (mN)', [0,250]
# zlabel, zlim = 'Force increase following BN2 burst (mN)', [0,300]
# zlabel, zlim = 'Force slope mean following BN2 burst (mN/s)', [0,50]
# zlabel, zlim = 'Force slope peak following BN2 burst (mN/s)', [0,700]
zlabel, zlim = 'Force slope mean during initial rise (mN/s)', [0,120]
# zlabel, zlim = 'RN burst mean rectified voltage (μV)', [1,3]

scatter3d(ax, data_subsets, xlabel, xlim, ylabel, ylim, zlabel, zlim, trend=True, tooltips=False)
ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

# data_subsets = ['JG07 Tape nori', 'JG08 Tape nori']
data_subsets = ['JG08 Tape nori']
data_subsets = {label:label2query(label) for label in data_subsets}
data_subsets['JG08 Tape nori'] += ' & not ((Bout_index == 0) & (Behavior_index == 0))' # remove outlier

xlabel, xlim = 'BN2 burst duration (s)', [2,4]#[0,4]#[0,8]
ylabel, ylim = 'BN2 burst mean rectified voltage (μV)', [1.5,4.5]#[0,6]
# ylabel, ylim = 'BN3 burst mean rectified voltage (μV)', [0,10]
# ylabel, ylim = 'RN burst mean rectified voltage (μV)', [1,3]
# ylabel, ylim = 'RN burst duration (s)', [0,10]
# ylabel, ylim = 'BN2 burst mean rectified voltage (μV)', [1.5,4.5]#[0,6]
# zlabel, zlim = 'Force duration (s)', [0,10]
zlabel, zlim = 'Force large hump duration following BN2 burst (s)', [3,7]
# zlabel, zlim = 'Force mean (mN)', force_mean_range
# zlabel, zlim = 'Force peak (mN)', force_peak_range
# zlabel, zlim = 'Force mean following BN2 burst (mN)', [0,150]
# zlabel, zlim = 'Force peak following BN2 burst (mN)', [0,250]
# zlabel, zlim = 'Force increase following BN2 burst (mN)', [0,200]
# zlabel, zlim = 'Force slope mean following BN2 burst (mN/s)', [0,50]
# zlabel, zlim = 'Force slope peak following BN2 burst (mN/s)', [0,700]
# zlabel, zlim = 'Force slope mean during initial rise (mN/s)', [0,120]
# zlabel, zlim = 'RN burst mean rectified voltage (μV)', [1,3]

scatter3d(ax, data_subsets, xlabel, xlim, ylabel, ylim, zlabel, zlim, trend=True, tooltips=False)
ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

data_subsets = ['JG07 Tape nori', 'JG08 Tape nori']
data_subsets = {label:label2query(label) for label in data_subsets}

# xlabel, xlim = 'BN2 burst duration (s)', [0,8]
xlabel, xlim = 'BN2 burst mean rectified voltage (μV)', [1.5,5]#[0,6]
# ylabel, ylim = 'Force mean (mN)', force_mean_range
# ylabel, ylim = 'Force peak (mN)', force_peak_range
# ylabel, ylim = 'Force mean following BN2 burst (mN)', [0,150]
# ylabel, ylim = 'Force peak following BN2 burst (mN)', [0,250]
ylabel, ylim = 'Force increase following BN2 burst (mN)', [0,300]
# ylabel, ylim = 'Force slope mean following BN2 burst (mN/s)', [0,50]
# ylabel, ylim = 'Force slope peak following BN2 burst (mN/s)', [0,700]
# ylabel, ylim = 'Force slope mean during initial rise (mN/s)', [0,120]

scatter2d(ax, data_subsets, xlabel, xlim, ylabel, ylim, trend=True, tooltips=False)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()