# Import Packages

In [None]:
# add the directory containing modules to the path
import sys
sys.path.append('../modules')

In [None]:
################################################################################
# NUMPY
# conda install numpy

import numpy as np

################################################################################
# SCIPY
# conda install scipy

# import scipy as sp

################################################################################
# MATPLOTLIB
# conda install matplotlib

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# from matplotlib.ticker import MultipleLocator

################################################################################
# SEABORN
# conda install seaborn

import seaborn as sns

################################################################################
# NEO
# pip install neo>=0.7.1
# - AxoGraph support requires axographio to be installed: pip install axographio

# import neo

################################################################################
# QUANTITIES
# conda install quantities

import quantities as pq
pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

################################################################################
# ELEPHANT
# pip install elephant>=0.6.2

import elephant

################################################################################
# PANDAS
# conda install pandas

# import pandas as pd

################################################################################
# STATSMODELS
# conda install statsmodels

import statsmodels.api as sm

################################################################################
# SPM1D - One-Dimensional Statistical Parametric Mapping
# pip install spm1d

# import spm1d

################################################################################
# EPHYVIEWER
# pip install git+https://github.com/jpgill86/ephyviewer.git@experimental
# - requires PyAV: conda install -c conda-forge av

# import ephyviewer

################################################################################
# ParseMetadata
# - requires ipywidgets: conda install ipywidgets
# - requires yaml:       conda install pyyaml

from ParseMetadata import LoadMetadata

################################################################################
# ImportData

from ImportData import LoadAndPrepareData

################################################################################
# NeoUtilities
# - requires pylttb: pip install pylttb

from NeoUtilities import NeoEpochToDataFrame#, CausalAlphaKernel, DownsampleNeoSignal

################################################################################
# EphyviewerConfigurator

# from EphyviewerConfigurator import EphyviewerConfigurator

################################################################################
# NeoToEphyviewerBridge

# from NeoToEphyviewerBridge import NeoSegmentToEphyviewerSources, PlotExampleWithEphyviewer

# IPython Magics

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

# Data Parameters

In [None]:
feeding_condition_parameters = [
    ('JG05 Regular nori', 'IN VIVO / JG05 / 2018-03-05 / 001', ['I2',       'BN2', 'BN3'], [[ 550,  594]], ['Swallow (regular 5-cm nori strip)']),
    ('JG05 Tape nori',    'IN VIVO / JG05 / 2018-03-05 / 001', ['I2',       'BN2', 'BN3'], [[ 745,  827]], ['Swallow (tape nori)']),
    ('JG07 Regular nori', 'IN VIVO / JG07 / 2018-05-20 / 002', ['I2', 'RN', 'BN2', 'BN3'], [[1496, 1527]], ['Swallow (regular 5-cm nori strip)']),
    ('JG07 Tape nori',    'IN VIVO / JG07 / 2018-05-20 / 002', ['I2', 'RN', 'BN2', 'BN3'], [[1036, 1093]], ['Swallow (tape nori)']),
    ('JG08 Fresh food',   'IN VIVO / JG08 / 2018-06-25 / 001', ['I2', 'RN', 'BN2', 'BN3'], [[1170, 1590]], ['Swallow (fresh food)']),
    ('JG08 Regular nori', 'IN VIVO / JG08 / 2018-06-21 / 001', ['I2', 'RN', 'BN2', 'BN3'], [[2462, 2532]], ['Swallow (regular 5-cm nori strip)']),
    ('JG08 Tape nori',    'IN VIVO / JG08 / 2018-06-21 / 002', ['I2', 'RN', 'BN2', 'BN3'], [[ 134,  205]], ['Swallow (tape nori)']),
    ('JG08 Tubing',       'IN VIVO / JG08 / 2018-06-25 / 001', ['I2', 'RN', 'BN2', 'BN3'], [[4026, 4146]], ['Swallow (tubing)', 'No movement (tubing)', 'Reposition (tubing)', 'Rejection (tubing)']),
    ('JG08 Two-ply nori', 'IN VIVO / JG08 / 2018-06-25 / 001', ['I2', 'RN', 'BN2', 'BN3'], [[3256, 3493]], ['Swallow (two-ply nori)']),
]

# Import and Process the Data

In [None]:
# load the metadata containing file paths
all_metadata = LoadMetadata(file='../metadata.yml', local_data_root='../../data')

# filter epochs for each feeding condition and perform calculations
data = {}
for feeding_condition, data_set_name, channels_to_keep, time_windows_to_keep, epoch_types_to_keep in feeding_condition_parameters:
    data[feeding_condition] = {}
    data[feeding_condition]['data_set_name'] = data_set_name
    data[feeding_condition]['channels_to_keep'] = channels_to_keep
    data[feeding_condition]['time_windows_to_keep'] = time_windows_to_keep
    data[feeding_condition]['epoch_types_to_keep'] = epoch_types_to_keep
    
    blk = LoadAndPrepareData(all_metadata[data_set_name])
    df = NeoEpochToDataFrame(blk.segments[0].epochs)

    # filter by time window
    df = df[np.any(list(map(lambda t: (t[0] <= df['Start (s)']) & (df['End (s)'] <= t[1]), time_windows_to_keep)), axis=0)]
    
    # filter by epoch type
    df = df[np.any(list(map(lambda epoch_type: df['Type'] == epoch_type, epoch_types_to_keep)), axis=0)]
    
    # renumber behaviors assuming all behaviors are from a single contiguous sequence
    df = df.reset_index(drop=True)

    # calculate interbehavior interval assuming all behaviors are from a single contiguous sequence
    df.insert(column='Interval before (s)', loc=3, value=np.nan)
    df.insert(column='Interval after (s)',  loc=4, value=np.nan)
    previous_i = np.nan
    previous_end_time = np.nan
    for i in df.index:
        if not np.isnan(previous_i):
            df.loc[i,          'Interval before (s)'] = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'End (s)']
            df.loc[previous_i, 'Interval after (s)']  = df.loc[i, 'Start (s)'] - df.loc[previous_i, 'End (s)']
        previous_i = i

    # find rectified area under the curve (RAUC) for each channel in each epoch
    for i in df.index:
        behavior_start = df.loc[i, 'Start (s)']
        behavior_end = df.loc[i, 'End (s)']
        behavior_duration = df.loc[i, 'Duration (s)']
        for j in channels_to_keep:
            sig = next((sig for sig in blk.segments[0].analogsignals if sig.name.replace('-L','')==j), None)
            sig = sig.time_slice(behavior_start*pq.s, behavior_end*pq.s)
            rauc = elephant.signal_processing.rauc(sig, baseline='mean').rescale('uV*s')
            df.loc[i, j+' RAUC (μV$\cdot$s)'] = rauc
            df.loc[i, j+' mean rectified voltage (μV)'] = rauc/(behavior_duration * pq.s)
    
    # add defaults for new columns
    for j in channels_to_keep:
        burst_type = j+' burst'
        df.insert(column=burst_type+' start (s)', loc=len(df.columns), value=np.nan)
        df.insert(column=burst_type+' end (s)', loc=len(df.columns), value=np.nan)
        df.insert(column=burst_type+' duration (s)', loc=len(df.columns), value=np.nan)
        df.insert(column=burst_type+' RAUC (μV$\cdot$s)', loc=len(df.columns), value=np.nan)
        df.insert(column=burst_type+' mean rectified voltage (μV)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force start (s)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force end (s)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force duration (s)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force RAUC (mN*s)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force mean (mN)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force peak (mN)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force RAUC following BN2 burst (mN*s)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force mean following BN2 burst (mN)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force peak following BN2 burst (mN)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force increase following BN2 burst (mN)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force slope mean following BN2 burst (mN/s)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force slope peak following BN2 burst (mN/s)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force slope mean during initial rise (mN/s)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force large hump start (s)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force large hump end (s)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force large hump duration (s)', loc=len(df.columns), value=np.nan)
    df.insert(column='Force large hump duration following BN2 burst (s)', loc=len(df.columns), value=np.nan)
    
    # for each behavior
    for i in df.index:
        behavior_start = df.loc[i, 'Start (s)']
        behavior_end = df.loc[i, 'End (s)']
        
        # for each nerve/muscle signal
        for j in channels_to_keep:
            burst_type = j+' burst'
            
            # find a burst epoch on the nerve/muscle that is at least mostly
            # contained within this behavior, allowing for a small discrepancy
            # in start and end time if the burst overextends in either direction
            timing_tolerance = 0.5 # seconds
            df2 = NeoEpochToDataFrame(blk.segments[0].epochs)
            df2 = df2[(behavior_start-timing_tolerance <= df2['Start (s)']) & (df2['End (s)'] <= behavior_end+timing_tolerance) & (df2['Type'] == burst_type)]
            assert len(df2) < 2, 'More than one epoch found with type \"{}\" for the behavior spanning [{}, {}]'.format(burst_type, behavior_start, behavior_end)
            
            # quantify the burst if it exists
            if len(df2) == 1:
                burst_epoch = df2.iloc[0]
                burst_start = burst_epoch['Start (s)']
                burst_end = burst_epoch['End (s)']
                burst_duration = burst_epoch['Duration (s)']
                
                df.loc[i, burst_type+' start (s)'] = burst_start
                df.loc[i, burst_type+' end (s)'] = burst_end
                df.loc[i, burst_type+' duration (s)'] = burst_duration
                
                sig = next((sig for sig in blk.segments[0].analogsignals if sig.name.replace('-L','')==j), None)
                sig = sig.time_slice((behavior_start-timing_tolerance)*pq.s, (behavior_end+timing_tolerance)*pq.s)
                rauc = elephant.signal_processing.rauc(sig, baseline='mean', t_start=burst_start*pq.s, t_stop=burst_end*pq.s).rescale('uV*s')
                df.loc[i, burst_type+' RAUC (μV$\cdot$s)'] = rauc
                df.loc[i, burst_type+' mean rectified voltage (μV)'] = rauc/(burst_duration * pq.s)
            
        # find a force epoch that begins within this behavior
        df3 = NeoEpochToDataFrame(blk.segments[0].epochs)
        df3 = df3[(behavior_start <= df3['Start (s)']) & (df3['Start (s)'] <= behavior_end) & (df3['Type'] == 'force')]
        assert len(df3) < 2, 'More than one force epoch begins within the behavior spanning [{}, {}]'.format(behavior_start, behavior_end)
        
        # find a large hump epoch that begins within this behavior
        df4 = NeoEpochToDataFrame(blk.segments[0].epochs)
        df4 = df4[(behavior_start <= df4['Start (s)']) & (df4['Start (s)'] <= behavior_end) & (df4['Type'] == 'large hump')]
        assert len(df4) < 2, 'More than one large hump epoch begins within the behavior spanning [{}, {}]'.format(behavior_start, behavior_end)

        # quantify the force if it exists
        if len(df3) == 1:
            force_epoch = df3.iloc[0]
            force_start = force_epoch['Start (s)']
            force_end = force_epoch['End (s)']
            force_duration = force_epoch['Duration (s)']

            df.loc[i, 'Force start (s)'] = force_start
            df.loc[i, 'Force end (s)'] = force_end
            df.loc[i, 'Force duration (s)'] = force_duration

            sig = next((sig for sig in blk.segments[0].analogsignals if sig.name=='Force'), None)
            sig = sig.time_slice(force_start*pq.s, force_end*pq.s)
            rauc = elephant.signal_processing.rauc(sig).rescale('mN*s')
            df.loc[i, 'Force RAUC (mN*s)'] = rauc
            df.loc[i, 'Force mean (mN)'] = sig.rescale('mN').mean()
            df.loc[i, 'Force peak (mN)'] = sig.rescale('mN').max()
            
            # if a BN2 burst also exists, quantify the force following the burst
            if np.isfinite(df.loc[i, 'BN2 burst start (s)']) and np.isfinite(df.loc[i, 'BN2 burst end (s)']):
                bn2_burst_start = df.loc[i, 'BN2 burst start (s)']
                bn2_burst_end = df.loc[i, 'BN2 burst end (s)']
                
                slow_muscle_lag = 0.2 # seconds
                force_rise_time = 1.5 # seconds
                sig = next((sig for sig in blk.segments[0].analogsignals if sig.name=='Force'), None)
                sig = sig.time_slice((bn2_burst_start+slow_muscle_lag)*pq.s, (bn2_burst_end+slow_muscle_lag)*pq.s)
                rauc = elephant.signal_processing.rauc(sig).rescale('mN*s')
                slope = elephant.signal_processing.derivative(sig).rescale('mN/s')
                df.loc[i, 'Force RAUC following BN2 burst (mN*s)'] = rauc
                df.loc[i, 'Force mean following BN2 burst (mN)'] = sig.rescale('mN').mean()
                df.loc[i, 'Force peak following BN2 burst (mN)'] = sig.rescale('mN').max()
                df.loc[i, 'Force increase following BN2 burst (mN)'] = sig.rescale('mN').max()-sig.rescale('mN')[0]
                df.loc[i, 'Force slope mean following BN2 burst (mN/s)'] = slope.mean()
                df.loc[i, 'Force slope peak following BN2 burst (mN/s)'] = slope.max()
                df.loc[i, 'Force slope mean during initial rise (mN/s)'] = slope.time_slice(
                    max(slope.t_start, (bn2_burst_start+slow_muscle_lag)*pq.s),
                    min(slope.t_stop,  (bn2_burst_start+slow_muscle_lag+force_rise_time)*pq.s)
                ).mean()
                
                # quantify the large hump if it exists
                if len(df4) == 1:
                    large_hump_epoch = df4.iloc[0]
                    large_hump_start = large_hump_epoch['Start (s)']
                    large_hump_end = large_hump_epoch['End (s)']
                    large_hump_duration = large_hump_epoch['Duration (s)']

                    df.loc[i, 'Force large hump start (s)'] = large_hump_start
                    df.loc[i, 'Force large hump end (s)'] = large_hump_end
                    df.loc[i, 'Force large hump duration (s)'] = large_hump_duration
                    
                    df.loc[i, 'Force large hump duration following BN2 burst (s)'] = large_hump_end-(bn2_burst_start+slow_muscle_lag)
    
    data[feeding_condition]['annotations'] = df

In [None]:
# check for errors
for subepoch_type in ['Force', 'Force large hump', 'I2 burst', 'RN burst', 'BN2 burst', 'BN3 burst']:
    print(f'-- {subepoch_type} --')
    for feeding_condition in data:
        df = data[feeding_condition]['annotations']
        try:
            all_have_subepoch = np.all(np.isfinite(df[f'{subepoch_type} start (s)']))
            some_have_subepoch = np.any(np.isfinite(df[f'{subepoch_type} start (s)']))
            if all_have_subepoch:
                print(f'{feeding_condition}: all behaviors have {subepoch_type}')
            elif some_have_subepoch:
                print(f'{feeding_condition}: some behaviors have {subepoch_type} **BUT NOT ALL**')
#             else:
#                 print(f'{feeding_condition}: no behaviors have {subepoch_type}')
        except:
            pass
    print()

# Plots

In [None]:
# color map
cm = plt.cm.cool
# cm = plt.cm.brg
# cm = plt.cm.RdBu

sns.set(
#     context = 'poster',
    style = 'ticks',
    font_scale = 1,
    font = 'Palatino Linotype',
)

In [None]:
duration_range = [0, 25]
interval_range = [-2, 27]
rauc_range = [0, 90]
voltage_mean_range = [0, 12]
force_peak_range = [0, 250]
force_mean_range = [0, 100]
hist_y_range = [0, 25]

markers = ['.', '+', 'x', '1', '4', 'o', 'D']
colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6']

In [None]:
# ROWS: ANIMALS, COLS: FOODS
plot_layout = np.array([
    ['JG08 Fresh food', 'JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Tubing'],
    [None,              'JG07 Regular nori', 'JG07 Tape nori', None,                None         ],
    [None,              'JG05 Regular nori', 'JG05 Tape nori', None,                None         ],
])

n_rows = len(plot_layout)
n_cols = len(plot_layout[0])
for row in plot_layout:
    assert len(row) == n_cols, 'plot_layout needs to be rectangular (not ragged)'

# fig size in inches
figsize = (15,8)

In [None]:
# # ROWS: FOODS, COLS: ANIMALS
# plot_layout = np.array([
#     ['JG08 Fresh food', 'JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Tubing'],
#     [None,              'JG07 Regular nori', 'JG07 Tape nori', None,                None         ],
#     [None,              'JG05 Regular nori', 'JG05 Tape nori', None,                None         ],
# ]).T

# n_rows = len(plot_layout)
# n_cols = len(plot_layout[0])
# for row in plot_layout:
#     assert len(row) == n_cols, 'plot_layout needs to be rectangular (not ragged)'

# # fig size in inches
# figsize = (8,12)

In [None]:
# # ALL ONE ROW
# plot_layout = np.array([
#     ['JG08 Fresh food', 'JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Tubing', 'JG07 Regular nori', 'JG07 Tape nori', 'JG05 Regular nori', 'JG05 Tape nori'],
# ])

# n_rows = len(plot_layout)
# n_cols = len(plot_layout[0])
# for row in plot_layout:
#     assert len(row) == n_cols, 'plot_layout needs to be rectangular (not ragged)'

# # fig size in inches
# figsize = (30,3)

## Sequence of behavior durations

In [None]:
plt.figure(figsize=figsize)

ylabel = 'Duration (s)'
ylim = duration_range

for i in range(n_rows):
    for j in range(n_cols):
        feeding_condition = plot_layout[i, j]
        if feeding_condition is not None:
            plt.subplot(n_rows, n_cols, i*n_cols+j+1)
            df = data[feeding_condition]['annotations']
            if ylabel in df:
                df[ylabel].plot(marker='.')
                plt.ylim(ylim)
                plt.title(feeding_condition)
                plt.xlabel('Behavior index')
                plt.ylabel(ylabel)
plt.tight_layout()

## Distribution of behavior durations

In [None]:
# plt.figure(figsize=figsize)

# xlabel = 'Duration (s)'
# bins = np.arange(duration_range[0], duration_range[1])

# for i in range(n_rows):
#     for j in range(n_cols):
#         feeding_condition = plot_layout[i, j]
#         if feeding_condition is not None:
#             plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             df = data[feeding_condition]['annotations']
            
#             n_drop_from_beginning = 0
#             df[xlabel].drop(df.index[:n_drop_from_beginning]).hist(bins=bins)
            
#             n_drop_from_beginning = 5
#             df[xlabel].drop(df.index[:n_drop_from_beginning]).hist(bins=bins)
            
# #             plt.xticks(bins)
#             plt.ylim(hist_y_range)
#             plt.title(feeding_condition)
#             plt.xlabel(xlabel)
#             plt.ylabel('Frequency')
# #             plt.legend(['Full sequence', 'First 5 dropped'])
# plt.tight_layout()

## Sequence of interbehavior intervals

In [None]:
plt.figure(figsize=figsize)

ylabel = 'Interval after (s)'
ylim = interval_range

for i in range(n_rows):
    for j in range(n_cols):
        feeding_condition = plot_layout[i, j]
        if feeding_condition is not None:
            plt.subplot(n_rows, n_cols, i*n_cols+j+1)
            df = data[feeding_condition]['annotations']
            if ylabel in df:
                df[ylabel].plot(marker='.')
                plt.ylim(ylim)
                plt.title(feeding_condition)
                plt.xlabel('Behavior index')
                plt.ylabel(ylabel)
plt.tight_layout()

## Distribution of interbehavior intervals

In [None]:
# plt.figure(figsize=figsize)

# xlabel = 'Interval after (s)'
# bins = np.arange(interval_range[0], interval_range[1])

# for i in range(n_rows):
#     for j in range(n_cols):
#         feeding_condition = plot_layout[i, j]
#         if feeding_condition is not None:
#             plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             df = data[feeding_condition]['annotations']
            
#             n_drop_from_beginning = 0
#             df[xlabel].drop(df.index[:n_drop_from_beginning]).hist(bins=bins)
            
#             n_drop_from_beginning = 5
#             df[xlabel].drop(df.index[:n_drop_from_beginning]).hist(bins=bins)
            
# #             plt.xticks(bins)
#             plt.ylim(hist_y_range)
#             plt.title(feeding_condition)
#             plt.xlabel(xlabel)
#             plt.ylabel('Frequency')
# #             plt.legend(['Full sequence', 'First 5 dropped'])
# plt.tight_layout()

## Sequences of RAUCs

In [None]:
# plt.figure(figsize=figsize)

# ylabel = 'I2 RAUC (μV$\cdot$s)'
# ylim = rauc_range

# for i in range(n_rows):
#     for j in range(n_cols):
#         feeding_condition = plot_layout[i, j]
#         if feeding_condition is not None:
#             plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             df = data[feeding_condition]['annotations']
#             if ylabel in df:
#                 df[ylabel].plot(marker='.')
#                 plt.ylim(ylim)
#                 plt.title(feeding_condition)
#                 plt.xlabel('Behavior index')
#                 plt.ylabel(ylabel)
# plt.tight_layout()

In [None]:
# plt.figure(figsize=figsize)

# ylabel = 'RN RAUC (μV$\cdot$s)'
# ylim = rauc_range

# for i in range(n_rows):
#     for j in range(n_cols):
#         feeding_condition = plot_layout[i, j]
#         if feeding_condition is not None:
#             plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             df = data[feeding_condition]['annotations']
#             if ylabel in df:
#                 df[ylabel].plot(marker='.')
#                 plt.ylim(ylim)
#                 plt.title(feeding_condition)
#                 plt.xlabel('Behavior index')
#                 plt.ylabel(ylabel)
# plt.tight_layout()

In [None]:
# plt.figure(figsize=figsize)

# ylabel = 'BN2 RAUC (μV$\cdot$s)'
# ylim = rauc_range

# for i in range(n_rows):
#     for j in range(n_cols):
#         feeding_condition = plot_layout[i, j]
#         if feeding_condition is not None:
#             plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             df = data[feeding_condition]['annotations']
#             if ylabel in df:
#                 df[ylabel].plot(marker='.')
#                 plt.ylim(ylim)
#                 plt.title(feeding_condition)
#                 plt.xlabel('Behavior index')
#                 plt.ylabel(ylabel)
# plt.tight_layout()

In [None]:
# plt.figure(figsize=figsize)

# ylabel = 'BN3 RAUC (μV$\cdot$s)'
# ylim = rauc_range

# for i in range(n_rows):
#     for j in range(n_cols):
#         feeding_condition = plot_layout[i, j]
#         if feeding_condition is not None:
#             plt.subplot(n_rows, n_cols, i*n_cols+j+1)
#             df = data[feeding_condition]['annotations']
#             if ylabel in df:
#                 df[ylabel].plot(marker='.')
#                 plt.ylim(ylim)
#                 plt.title(feeding_condition)
#                 plt.xlabel('Behavior index')
#                 plt.ylabel(ylabel)
# plt.tight_layout()

In [None]:
plt.figure(figsize=figsize)

ylabels = [
    'Duration (s)',
    'I2 RAUC (μV$\cdot$s)',
    'RN RAUC (μV$\cdot$s)',
    'BN2 RAUC (μV$\cdot$s)',
    'BN3 RAUC (μV$\cdot$s)'
]
ylim = [np.min([duration_range, rauc_range]), np.max([duration_range, rauc_range])]

for i in range(n_rows):
    for j in range(n_cols):
        feeding_condition = plot_layout[i, j]
        if feeding_condition is not None:
            plt.subplot(n_rows, n_cols, i*n_cols+j+1)
            df = data[feeding_condition]['annotations']
            for ylabel in ylabels:
                if ylabel in df:
                    df[ylabel].plot(marker='.')
                    plt.ylim(ylim)
                    plt.title(feeding_condition)
                    plt.xlabel('Behavior index')
                    plt.ylabel('Duration or RAUC')
#             plt.legend()
plt.tight_layout()

## Sequences of mean rectified voltages

In [None]:
plt.figure(figsize=figsize)

ylabels = [
    'Duration (s)',
    'I2 mean rectified voltage (μV)',
    'RN mean rectified voltage (μV)',
    'BN2 mean rectified voltage (μV)',
    'BN3 mean rectified voltage (μV)',
]
ylim = [np.min([duration_range, voltage_mean_range]), np.max([duration_range, voltage_mean_range])]

for i in range(n_rows):
    for j in range(n_cols):
        feeding_condition = plot_layout[i, j]
        if feeding_condition is not None:
            plt.subplot(n_rows, n_cols, i*n_cols+j+1)
            df = data[feeding_condition]['annotations']
            for ylabel in ylabels:
                if ylabel in df:
                    df[ylabel].plot(marker='.')
                    plt.ylim(ylim)
                    plt.title(feeding_condition)
                    plt.xlabel('Behavior index')
                    plt.ylabel('Duration or mean rectified voltage')
#             plt.legend()
plt.tight_layout()

## Scatter plots of behavior durations and mean rectified voltages

In [None]:
def scatter2d(ax, feeding_conditions, xlabel, xlim, ylabel, ylim, trend=False):
    
    all_points = np.empty(shape=(0,2))
    for j, feeding_condition in enumerate(feeding_conditions):
        if feeding_condition is not None:
            df = data[feeding_condition]['annotations']
            all_points = np.concatenate([all_points, df[[xlabel, ylabel]]])
            ax.scatter(df[xlabel], df[ylabel],
                       label=feeding_condition, marker=markers[j], c=colors[j])
            
    if trend:
        all_points = all_points[np.isfinite(all_points).all(axis=1)]
        model = sm.OLS(all_points[:,1], sm.add_constant(all_points[:,0])).fit()
        model_stats = 'R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(all_points))
        model_x = np.linspace(xlim[0], xlim[1], 100)
        ax.plot(model_x, model_x*model.params[1] + model.params[0], label=model_stats, color=colors[len(feeding_conditions)])
        
    ax.set_xlabel(xlabel)
    ax.set_xlim(xlim)
    ax.set_ylabel(ylabel)#.replace('rectified ',''))
    ax.set_ylim(ylim)

In [None]:
def scatter3d(ax, feeding_conditions, xlabel, xlim, ylabel, ylim, zlabel, zlim, trend=False, remove_outliers=False):
    
    all_points = np.empty(shape=(0,3))
    for j, feeding_condition in enumerate(feeding_conditions):
        if feeding_condition is not None:
            df = data[feeding_condition]['annotations']
            if remove_outliers:
                # HORRIBLE HACK FOR REMOVING AN INITIAL SWALLOW WITH LONG DURATION
                if feeding_condition == 'JG08 Tape nori':
                    time_windows_to_keep_corrected_for_outlier = [[ 148,  205]]
                    df = df[np.any(list(map(lambda t: (t[0] <= df['Start (s)']) & (df['End (s)'] <= t[1]), time_windows_to_keep_corrected_for_outlier)), axis=0)]
            all_points = np.concatenate([all_points, df[[xlabel, ylabel, zlabel]]])
            ax.scatter(df[xlabel], df[ylabel], df[zlabel],
                       label=feeding_condition, marker=markers[j], c=colors[j])#, depthshade=False)
            
    if trend:
        all_points = all_points[np.isfinite(all_points).all(axis=1)]
        model = sm.OLS(all_points[:,2], sm.add_constant(all_points[:,:2])).fit()
        model_stats = 'R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(all_points))
        print(model_stats)
        model_x, model_y = np.meshgrid(np.linspace(min(all_points[:,0]), max(all_points[:,0]), 20), np.linspace(min(all_points[:,1]), max(all_points[:,1]), 20))
        model_z = model.params[0] + model.params[1] * model_x + model.params[2] * model_y
        ax.plot_surface(model_x, model_y, model_z, cmap=plt.cm.RdBu_r, alpha=0.6, linewidth=0)
        
    ax.set_xlabel(xlabel)
    ax.set_xlim(xlim)
    ax.set_ylabel(ylabel)
    ax.set_ylim(ylim)
    ax.set_zlabel(zlabel)
    ax.set_zlim(zlim)

In [None]:
plt.figure(figsize=(5, 12))

# feeding_conditions = ['JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Fresh food', 'JG08 Tubing']
feeding_conditions = ['JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Fresh food', None]
# feeding_conditions = ['JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', None, None]
# feeding_conditions = [None, None, None, 'JG08 Fresh food', None]

channels = ['I2', 'RN', 'BN2', 'BN3']

# xlabels = ['Duration (s)']*4
# xlims = [duration_range]*4
# ylabels = [c+' mean rectified voltage (μV)' if c is not None else None for c in channels]
# ylims = [voltage_mean_range]*4

xlabels = [c+' burst duration (s)' if c is not None else None for c in channels]
xlims = [[0,10], [0,20], [0,15], [0,15]]
ylabels = [c+' burst mean rectified voltage (μV)' if c is not None else None for c in channels]
ylims = [[0,6], [0,5], [0,10], [0,12]]

for i in range(len(ylabels)):
    if ylabels[i] is not None:
        ax = plt.subplot(len(ylabels), 1, i+1)
        scatter2d(ax, feeding_conditions, xlabels[i], xlims[i], ylabels[i], ylims[i])
        ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

# feeding_conditions = ['JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Fresh food', 'JG08 Tubing']
feeding_conditions = ['JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Fresh food', None]
# feeding_conditions = ['JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', None, None]
# feeding_conditions = [None, None, None, 'JG08 Fresh food', None]

# xlabel, xlim = 'Duration (s)', duration_range
# ylabel, ylim = 'BN2 mean rectified voltage (μV)', voltage_mean_range
# zlabel, zlim = 'BN3 mean rectified voltage (μV)', voltage_mean_range

xlabel, xlim = 'RN burst mean rectified voltage (μV)', voltage_mean_range
ylabel, ylim = 'BN2 burst mean rectified voltage (μV)', voltage_mean_range
zlabel, zlim = 'BN3 burst mean rectified voltage (μV)', voltage_mean_range

scatter3d(ax, feeding_conditions, xlabel, xlim, ylabel, ylim, zlabel, zlim)
ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 12))

feeding_conditions = ['JG07 Regular nori', 'JG07 Tape nori']

channels = ['I2', 'RN', 'BN2', 'BN3']

xlabels = ['Duration (s)']*4
xlims = [duration_range]*4
ylabels = [c+' mean rectified voltage (μV)' if c is not None else None for c in channels]
ylims = [voltage_mean_range]*4

for i in range(len(ylabels)):
    if ylabels[i] is not None:
        ax = plt.subplot(len(ylabels), 1, i+1)
        scatter2d(ax, feeding_conditions, xlabels[i], xlims[i], ylabels[i], ylims[i])
        ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

feeding_conditions = ['JG07 Regular nori', 'JG07 Tape nori']

xlabel, xlim = 'Duration (s)', duration_range
ylabel, ylim = 'BN2 mean rectified voltage (μV)', voltage_mean_range
zlabel, zlim = 'BN3 mean rectified voltage (μV)', voltage_mean_range

scatter3d(ax, feeding_conditions, xlabel, xlim, ylabel, ylim, zlabel, zlim)
ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 12))

feeding_conditions = ['JG05 Regular nori', 'JG05 Tape nori']

channels = ['I2', None, 'BN2', 'BN3']

xlabels = ['Duration (s)']*4
xlims = [duration_range]*4
ylabels = [c+' mean rectified voltage (μV)' if c is not None else None for c in channels]
ylims = [voltage_mean_range]*4

for i in range(len(ylabels)):
    if ylabels[i] is not None:
        ax = plt.subplot(len(ylabels), 1, i+1)
        scatter2d(ax, feeding_conditions, xlabels[i], xlims[i], ylabels[i], ylims[i])
        ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

feeding_conditions = ['JG05 Regular nori', 'JG05 Tape nori']

xlabel, xlim = 'Duration (s)', duration_range
ylabel, ylim = 'BN2 mean rectified voltage (μV)', voltage_mean_range
zlabel, zlim = 'BN3 mean rectified voltage (μV)', voltage_mean_range

scatter3d(ax, feeding_conditions, xlabel, xlim, ylabel, ylim, zlabel, zlim)
ax.legend()
plt.tight_layout()

## Force plots

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

feeding_conditions = ['JG07 Tape nori', 'JG08 Tape nori']

xlabel, xlim = 'Duration (s)', duration_range
ylabel, ylim = 'Force peak (mN)', force_peak_range
zlabel, zlim = 'Force mean (mN)', force_mean_range

# xlabel, xlim = 'Duration (s)', duration_range
# ylabel, ylim = 'BN2 mean rectified voltage (μV)', voltage_mean_range
# zlabel, zlim = 'Force mean (mN)', force_mean_range

# xlabel, xlim = 'Duration (s)', duration_range
# ylabel, ylim = 'BN2 mean rectified voltage (μV)', voltage_mean_range
# zlabel, zlim = 'Force peak (mN)', force_peak_range

scatter3d(ax, feeding_conditions, xlabel, xlim, ylabel, ylim, zlabel, zlim)
ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

feeding_conditions = ['JG07 Tape nori', 'JG08 Tape nori']

xlabel, xlim = 'BN2 burst duration (s)', [0,8]
ylabel, ylim = 'BN2 burst mean rectified voltage (μV)', [0,6]
# zlabel, zlim = 'Force mean (mN)', force_mean_range
# zlabel, zlim = 'Force peak (mN)', force_peak_range
# zlabel, zlim = 'Force mean following BN2 burst (mN)', [0,150]
# zlabel, zlim = 'Force peak following BN2 burst (mN)', [0,250]
zlabel, zlim = 'Force increase following BN2 burst (mN)', [0,200]
# zlabel, zlim = 'Force slope mean following BN2 burst (mN/s)', [0,50]
# zlabel, zlim = 'Force slope peak following BN2 burst (mN/s)', [0,700]

scatter3d(ax, feeding_conditions, xlabel, xlim, ylabel, ylim, zlabel, zlim)
ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

feeding_conditions = ['JG07 Tape nori', 'JG08 Tape nori']
# feeding_conditions = ['JG08 Tape nori']

xlabel, xlim = 'BN2 burst duration (s)', [0,4]#[0,8]
ylabel, ylim = 'BN2 burst mean rectified voltage (μV)', [1.5,4.5]#[0,6]
# ylabel, ylim = 'BN3 burst mean rectified voltage (μV)', [0,10]
# ylabel, ylim = 'RN burst mean rectified voltage (μV)', [1,3]
# ylabel, ylim = 'RN burst duration (s)', [0,10]
# ylabel, ylim = 'BN2 burst mean rectified voltage (μV)', [1.5,4.5]#[0,6]
# zlabel, zlim = 'Force duration (s)', [0,10]
# zlabel, zlim = 'Force large hump duration following BN2 burst (s)', [3,7]
# zlabel, zlim = 'Force mean (mN)', force_mean_range
# zlabel, zlim = 'Force peak (mN)', force_peak_range
# zlabel, zlim = 'Force mean following BN2 burst (mN)', [0,150]
# zlabel, zlim = 'Force peak following BN2 burst (mN)', [0,250]
# zlabel, zlim = 'Force increase following BN2 burst (mN)', [0,200]
# zlabel, zlim = 'Force slope mean following BN2 burst (mN/s)', [0,50]
# zlabel, zlim = 'Force slope peak following BN2 burst (mN/s)', [0,700]
zlabel, zlim = 'Force slope mean during initial rise (mN/s)', [0,120]
# zlabel, zlim = 'RN burst mean rectified voltage (μV)', [1,3]

scatter3d(ax, feeding_conditions, xlabel, xlim, ylabel, ylim, zlabel, zlim, trend=True, remove_outliers=True)
ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca(projection='3d')

# feeding_conditions = ['JG07 Tape nori', 'JG08 Tape nori']
feeding_conditions = ['JG08 Tape nori']

xlabel, xlim = 'BN2 burst duration (s)', [2,4]#[0,4]#[0,8]
ylabel, ylim = 'BN2 burst mean rectified voltage (μV)', [1.5,4.5]#[0,6]
# ylabel, ylim = 'BN3 burst mean rectified voltage (μV)', [0,10]
# ylabel, ylim = 'RN burst mean rectified voltage (μV)', [1,3]
# ylabel, ylim = 'RN burst duration (s)', [0,10]
# ylabel, ylim = 'BN2 burst mean rectified voltage (μV)', [1.5,4.5]#[0,6]
# zlabel, zlim = 'Force duration (s)', [0,10]
zlabel, zlim = 'Force large hump duration following BN2 burst (s)', [3,7]
# zlabel, zlim = 'Force mean (mN)', force_mean_range
# zlabel, zlim = 'Force peak (mN)', force_peak_range
# zlabel, zlim = 'Force mean following BN2 burst (mN)', [0,150]
# zlabel, zlim = 'Force peak following BN2 burst (mN)', [0,250]
# zlabel, zlim = 'Force increase following BN2 burst (mN)', [0,200]
# zlabel, zlim = 'Force slope mean following BN2 burst (mN/s)', [0,50]
# zlabel, zlim = 'Force slope peak following BN2 burst (mN/s)', [0,700]
# zlabel, zlim = 'Force slope mean during initial rise (mN/s)', [0,120]
# zlabel, zlim = 'RN burst mean rectified voltage (μV)', [1,3]

scatter3d(ax, feeding_conditions, xlabel, xlim, ylabel, ylim, zlabel, zlim, trend=True, remove_outliers=True)
ax.legend()
plt.tight_layout()

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.gca()

feeding_conditions = ['JG07 Tape nori', 'JG08 Tape nori']

# xlabel, xlim = 'BN2 burst duration (s)', [0,8]
xlabel, xlim = 'BN2 burst mean rectified voltage (μV)', [1.5,4.5]#[0,6]
# ylabel, ylim = 'Force mean (mN)', force_mean_range
# ylabel, ylim = 'Force peak (mN)', force_peak_range
# ylabel, ylim = 'Force mean following BN2 burst (mN)', [0,150]
# ylabel, ylim = 'Force peak following BN2 burst (mN)', [0,250]
# ylabel, ylim = 'Force increase following BN2 burst (mN)', [0,200]
# ylabel, ylim = 'Force slope mean following BN2 burst (mN/s)', [0,50]
# ylabel, ylim = 'Force slope peak following BN2 burst (mN/s)', [0,700]
ylabel, ylim = 'Force slope mean during initial rise (mN/s)', [0,120]

scatter2d(ax, feeding_conditions, xlabel, xlim, ylabel, ylim, trend=True)
ax.legend()
sns.despine(ax=ax, offset=20, trim=True)
plt.tight_layout()

## ShowCASE 2019

In [None]:
with sns.plotting_context('poster'):
    
    plt.figure(figsize=(7, 7))
    ax = plt.gca()

    feeding_conditions = ['JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', 'JG08 Fresh food', None]
    outfile = 'JG08-BN2-dur-vs-mean-with-fresh-food.pdf'
    
#     feeding_conditions = ['JG08 Regular nori', 'JG08 Tape nori', 'JG08 Two-ply nori', None, None]
#     outfile = 'JG08-BN2-dur-vs-mean-without-fresh-food.pdf'

    xlabel, xlim = 'BN2 burst mean rectified voltage (μV)', [0,10]
    ylabel, ylim = 'BN2 burst duration (s)', [0,6]

    scatter2d(ax, feeding_conditions, xlabel, xlim, ylabel, ylim)
    ax.legend()
    sns.despine(ax=ax, offset=20, trim=True)
    plt.tight_layout()

#     plt.gcf().savefig(outfile)

In [None]:
with sns.plotting_context('poster'):
    
    plt.figure(figsize=(7, 7))
    ax = plt.gca()

    feeding_conditions = ['JG07 Tape nori', 'JG08 Tape nori']
    outfile = 'JG08-JG07-BN2-vs-force-on-tape-nori.pdf'

    xlabel, xlim = 'BN2 burst mean rectified voltage (μV)', [1.5,4.5]#[0,6]
    ylabel, ylim = 'Force increase following BN2 burst (mN)', [0,200]

    scatter2d(ax, feeding_conditions, xlabel, xlim, ylabel, ylim, trend=True)
    ax.legend()
    sns.despine(ax=ax, offset=20, trim=True)
    plt.tight_layout()

#     plt.gcf().savefig(outfile)