# Import Packages

In [None]:
# add the directory containing modules to the path
import sys
sys.path.append('modules')

In [None]:
################################################################################
# NUMPY
# conda install numpy

import numpy as np

################################################################################
# MATPLOTLIB
# conda install matplotlib

import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D

################################################################################
# SEABORN
# conda install seaborn

import seaborn as sns

################################################################################
# PYLTTB - Time Series Downsampling Using Largest-Triangle-Three-Buckets
# pip install pylttb

# from pylttb import lttb

################################################################################
# NEO
# pip install neo>=0.7.1
# - AxoGraph support requires axographio to be installed: pip install axographio

# import neo

################################################################################
# QUANTITIES
# conda install quantities

import quantities as pq
pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

################################################################################
# ELEPHANT
# pip install git+https://github.com/NeuralEnsemble/elephant.git@master

import elephant

################################################################################
# PANDAS
# conda install pandas

# import pandas as pd

################################################################################
# STATSMODELS
# conda install statsmodels

import statsmodels.api as sm

################################################################################
# SPM1D - One-Dimensional Statistical Parametric Mapping
# pip install spm1d

# import spm1d

################################################################################
# EPHYVIEWER
# pip install git+https://github.com/jpgill86/ephyviewer.git@experimental
# - requires PyAV: conda install -c conda-forge av

# import ephyviewer

################################################################################
# ParseMetadata
# - requires ipywidgets: conda install ipywidgets
# - requires yaml:       conda install pyyaml

from ParseMetadata import LoadMetadata

################################################################################
# ImportData

from ImportData import LoadAndPrepareData

################################################################################
# NeoUtilities

from NeoUtilities import CausalAlphaKernel

################################################################################
# EphyviewerConfigurator

# from EphyviewerConfigurator import EphyviewerConfigurator

################################################################################
# NeoToEphyviewerBridge

# from NeoToEphyviewerBridge import NeoSegmentToEphyviewerSources#, PlotExampleWithEphyviewer

# IPython Magics

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

# Data Parameters

In [None]:
# specify the data sets to analyze
data_sets = [
    'IN VIVO / JG08 / 2018-06-21 / 002',
    'IN VIVO / JG08 / 2018-06-24 / 001',
]

# load the metadata containing file paths
all_metadata = LoadMetadata(local_data_root='../data')

# store metadata in a dictionary that we will add to later
data = {}
for data_set_name in data_sets:
    data[data_set_name] = {}
    data[data_set_name]['metadata'] = all_metadata[data_set_name]

In [None]:
# select which swallow sequences to use

data['IN VIVO / JG08 / 2018-06-21 / 002']['time_windows_to_keep'] = [
#     [-np.inf, np.inf], # keep everything
    [659, 726.1], # tension maximized and no perturbation
#     [666.95, 726.1], # tension maximized and no perturbation, and extra long large hump excluded
#     [666.95, 705], # sequence of 5 very stereotyped swallows
]

data['IN VIVO / JG08 / 2018-06-24 / 001']['time_windows_to_keep'] = [
#     [-np.inf, np.inf], # keep everything
#     [2244.7, 2259.9], [2269.5, 2355.95], # tension maximized and no perturbation
#     [2244.7, 2259.9], [2269.5, 2290.2], [2307, 2355.95], # tension maximized and no perturbation, and extra long large hump excluded
    
#     [3932, 3990],
    [2232, 2356], [2743, 2940], [3095, 3140], [3385, 3425], [3570, 3594], [3923, 3990]
]

# Import and Process the Data

In [None]:
for data_set_name, d in data.items():

    # read in the data
    blk, _, epoch_encoder_dataframe, _ = LoadAndPrepareData(d['metadata'])
    signalNameToIndex = {sig.name:i for i, sig in enumerate(blk.segments[0].analogsignals)}

    # grab the force vs time data and rescale to mN
    d['force_sig'] = blk.segments[0].analogsignals[signalNameToIndex['Force']].rescale('mN')

    # apply a super-low-pass filter to force signal
    d['smoothed_force_sig'] = elephant.signal_processing.butter(  # may raise a FutureWarning
        signal = d['force_sig'],
        lowpass_freq = 0.5*pq.Hz,
    )

    # calculate the derivative of the force vs time data and smooth it
    d['dforce/dt'] = elephant.signal_processing.butter(  # may raise a FutureWarning
        signal = elephant.signal_processing.derivative(d['force_sig']),
        lowpass_freq = 2*pq.Hz,
    ).rescale('mN/s')

    # grab the voltage vs time data and rescale to uV
    d['i2_sig']  = blk.segments[0].analogsignals[signalNameToIndex['I2']].rescale('uV')
    d['rn_sig']  = blk.segments[0].analogsignals[signalNameToIndex['RN']].rescale('uV')
    d['bn2_sig'] = blk.segments[0].analogsignals[signalNameToIndex['BN2']].rescale('uV')
    d['bn3_sig'] = blk.segments[0].analogsignals[signalNameToIndex['BN3']].rescale('uV')

    # grab the spike trains
    spike_trains = {}
    for st in blk.segments[0].spiketrains:
        spike_trains[st.name] = st
    d['spike_trains'] = spike_trains

    # grab the sampling period
    d['sampling_period'] = blk.segments[0].analogsignals[0].sampling_period

    # keep only epochs that are entirely inside the time windows
    epochs_df = epoch_encoder_dataframe.copy()
    epochs_df = epochs_df[np.any(list(map(lambda t: (t[0] <= epochs_df['Start (s)']) & (epochs_df['End (s)'] <= t[1]), d['time_windows_to_keep'])), axis=0)]

    # copy middle times (end of large hump and start of small hump) into 'force' epochs
    for i, epoch in epochs_df[epochs_df['Label'] == 'force'].iterrows():
        for j, subepoch in epochs_df[epochs_df['Label'] == 'large hump'].iterrows():
            if subepoch['Start (s)'] >= epoch['Start (s)']-1e-7 and subepoch['End (s)'] <= epoch['End (s)']+1e-7:
                epochs_df.loc[i, 'Middle (s)'] = subepoch['End (s)']

    # drop all but 'force' rows
    epochs_df = epochs_df[epochs_df['Label'] == 'force']

    # find max forces in each epoch
    for i, epoch in epochs_df.iterrows():
        epochs_df.loc[i,                'max'] = max(         d['force_sig'].time_slice(epoch['Start (s)'] *pq.s, epoch['End (s)']   *pq.s).magnitude)[0]
        epochs_df.loc[i,          'large max'] = max(         d['force_sig'].time_slice(epoch['Start (s)'] *pq.s, epoch['Middle (s)']*pq.s).magnitude)[0] if not np.isnan(epoch.get('Middle (s)', np.nan)) else np.nan
        epochs_df.loc[i,          'small max'] = max(         d['force_sig'].time_slice(epoch['Middle (s)']*pq.s, epoch['End (s)']   *pq.s).magnitude)[0] if not np.isnan(epoch.get('Middle (s)', np.nan)) else np.nan
        epochs_df.loc[i, 'smoothed large max'] = max(d['smoothed_force_sig'].time_slice(epoch['Start (s)'] *pq.s, epoch['Middle (s)']*pq.s).magnitude)[0] if not np.isnan(epoch.get('Middle (s)', np.nan)) else np.nan
        epochs_df.loc[i, 'smoothed small max'] = max(d['smoothed_force_sig'].time_slice(epoch['Middle (s)']*pq.s, epoch['End (s)']   *pq.s).magnitude)[0] if not np.isnan(epoch.get('Middle (s)', np.nan)) else np.nan

    # find rectified area under the curve (RAUC) in each epoch
    for i, epoch in epochs_df.iterrows():
        epochs_df.loc[i,            'force RAUC'] = elephant.signal_processing.rauc(d['force_sig'].time_slice(epoch['Start (s)'] *pq.s, epoch['End (s)']   *pq.s))                   .rescale('mN*s')
        epochs_df.loc[i, 'large hump force RAUC'] = elephant.signal_processing.rauc(d['force_sig'].time_slice(epoch['Start (s)'] *pq.s, epoch['Middle (s)']*pq.s))                   .rescale('mN*s') if not np.isnan(epoch.get('Middle (s)', np.nan)) else np.nan
        epochs_df.loc[i, 'small hump force RAUC'] = elephant.signal_processing.rauc(d['force_sig'].time_slice(epoch['Middle (s)']*pq.s, epoch['End (s)']   *pq.s))                   .rescale('mN*s') if not np.isnan(epoch.get('Middle (s)', np.nan)) else np.nan
        epochs_df.loc[i,               'I2 RAUC'] = elephant.signal_processing.rauc(d['i2_sig']   .time_slice(epoch['Start (s)'] *pq.s, epoch['End (s)']   *pq.s), baseline = 'mean').rescale('uV*s')
        epochs_df.loc[i,               'RN RAUC'] = elephant.signal_processing.rauc(d['rn_sig']   .time_slice(epoch['Start (s)'] *pq.s, epoch['End (s)']   *pq.s), baseline = 'mean').rescale('uV*s')
        epochs_df.loc[i,              'BN2 RAUC'] = elephant.signal_processing.rauc(d['bn2_sig']  .time_slice(epoch['Start (s)'] *pq.s, epoch['End (s)']   *pq.s), baseline = 'mean').rescale('uV*s')
        epochs_df.loc[i,              'BN3 RAUC'] = elephant.signal_processing.rauc(d['bn3_sig']  .time_slice(epoch['Start (s)'] *pq.s, epoch['End (s)']   *pq.s), baseline = 'mean').rescale('uV*s')

    # colors
    epochs_df = epochs_df.assign(colormap_arg = np.linspace(0, 1, len(epochs_df)))

    d['epochs_df'] = epochs_df

# Plots

In [None]:
# color map
cm = plt.cm.cool
# cm = plt.cm.brg
# cm = plt.cm.RdBu

sns.set(
#     context = 'poster',
    style = 'ticks',
    font_scale = 1,
    font = 'Palatino Linotype',
)

##### Figure 1: Plot forces across real time

In [None]:
# plt.figure(1, figsize=(9,3))
# for i, data_set_name in enumerate(data_sets):
#     d = data[data_set_name]
#     plt.subplot(1, len(data), i+1)
#     plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
#     plt.ylabel('Force (mN)')
#     plt.xlabel('Original chart time (s)')
#     for j, epoch in d['epochs_df'].iterrows():
#         epoch_force_sig = d['force_sig'].time_slice(epoch['Start (s)']*pq.s, epoch['End (s)']*pq.s)
#         plt.plot(epoch_force_sig.times, epoch_force_sig.as_array(), color=cm(epoch['colormap_arg']))
#     sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
# plt.tight_layout()

##### Figure 2: Plot forces, spike trains, and firing rate models

In [None]:
n_plot_cols = len(data)
n_plot_rows = 2 + max(len(d['spike_trains']) for k,d in data.items())
plt.figure(2, figsize=(9,2*n_plot_rows))

for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    t_min = min(d['epochs_df']['Start (s)'])
    t_max = max(d['epochs_df']['End (s)'])

    # === FORCE ===
    ax = plt.subplot(n_plot_rows, n_plot_cols, i+1)
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.ylabel('Force (mN)')
    plt.xlim(t_min, t_max)
    plt.ylim([-10, 400])
    
#     sns.despine(ax=plt.gca(), offset=10, trim=True, bottom=True)
#     plt.gca().xaxis.set_visible(False)

    for j, epoch in d['epochs_df'].iterrows():
        epoch_force_sig = d['force_sig'].time_slice(epoch['Start (s)']*pq.s, epoch['End (s)']*pq.s)
        plt.plot(epoch_force_sig.times, epoch_force_sig.as_array(), color=cm(epoch['colormap_arg']))
        plt.text(np.mean([epoch['Start (s)'], epoch['End (s)']]), epoch['max'], '{:.0f}'.format(epoch['force RAUC']), fontsize=8, ha='center')
        if not np.isnan(epoch.get('Middle (s)', np.nan)):
            plt.text(np.mean([epoch['Start (s)'],  epoch['Middle (s)']]), epoch['smoothed small max'], '{:.0f}'.format(epoch['large hump force RAUC']), fontsize=8, ha='center')
            plt.text(np.mean([epoch['Middle (s)'], epoch['End (s)']]),    epoch['smoothed small max'], '{:.0f}'.format(epoch['small hump force RAUC']), fontsize=8, ha='center')

    # === D(FORCE)/DT ===
    plt.subplot(n_plot_rows, n_plot_cols, (1)*n_plot_cols+i+1, sharex=ax)
    plt.axhline(0, color='gray', linewidth=0.5)
    dfdt = d['dforce/dt'].time_slice(t_min*pq.s, t_max*pq.s)
    plt.plot(dfdt.times, dfdt.as_array())
    plt.ylabel('d(Force)/dt (mN/s)')
    plt.ylim([-400, 400])
    
#     sns.despine(ax=plt.gca(), offset=10, trim=True, bottom=True)
#     plt.gca().xaxis.set_visible(False)

    # === RASTER PLOTS + RATE MODELS ===
    spike_labels = d['spike_trains'].keys()
#     spike_labels = [
#     #     'I2',
#     #     'B8a/b',
#         'B3 (50-100 uV)',
#     #     '? (45-50 uV)',
#     #     'B6/B9 ? (26-45 uV)',
#         'B38 ? (17-26 uV)',
#     #     '? (15-17 uV)',
#     #     'B4/B5',
#     ]

    for j, spike_label in enumerate(spike_labels):
        st = d['spike_trains'][spike_label]
        st = st.time_slice(
            t_min*pq.s - 5*pq.s,
            t_max*pq.s + 5*pq.s
        ) # drop spikes outside plot range except for a few sec margin so beginning and final firing rates are accurate

        plt.subplot(n_plot_rows, n_plot_cols, (2+j)*n_plot_cols+i+1, sharex=ax)
        plt.ylabel(spike_label + '\n(rate model)')
        plt.xlim(t_min, t_max)
        plt.ylim([-2, 40])
        
        if j == len(spike_labels)-1:
#             sns.despine(ax=plt.gca(), offset=10, trim=True)
            plt.xlabel('Time (s)')
#         else:
#             sns.despine(ax=plt.gca(), offset=10, trim=True, bottom=True)
#             plt.gca().xaxis.set_visible(False)

        # raster plot
        plt.eventplot(positions=st, lineoffsets=-1, colors='red')

        # spike train convolution
        kernels = [
#             CausalAlphaKernel(0.03*np.sqrt(2)*pq.s), # match my old poster's synapse model
            CausalAlphaKernel(0.2*pq.s),
#             elephant.kernels.AlphaKernel(0.03*np.sqrt(2)*pq.s),
#             elephant.kernels.AlphaKernel(0.2*pq.s),
#             elephant.kernels.EpanechnikovLikeKernel(0.2*pq.s),
#             elephant.kernels.ExponentialKernel(0.2*pq.s),
#             elephant.kernels.GaussianKernel(0.2*pq.s),
#             elephant.kernels.LaplacianKernel(0.2*pq.s),
#             elephant.kernels.RectangularKernel(0.2*pq.s),
#             elephant.kernels.TriangularKernel(0.2*pq.s)
        ]
        for kernel in kernels:
            rate = elephant.statistics.instantaneous_rate(
                spiketrain=st, sampling_period=d['sampling_period'], kernel=kernel)
            plt.plot(rate.times.rescale('s'), rate)

        # instantaneous firing frequency step plot
#         plt.plot(st[:-1], 1/elephant.statistics.isi(st), drawstyle='steps-post')

        for k, epoch in d['epochs_df'].iterrows():
            plt.text(np.mean([epoch['Start (s)'], epoch['End (s)']]), 20, st.time_slice(epoch['Start (s)'], epoch['End (s)']).size, fontsize=8, ha='center')

plt.tight_layout()

##### Figure 3: Plot number of spikes vs force RAUC

In [None]:
plt.figure(3, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
#     rauc_label = 'small hump force RAUC'
    rauc_label = 'large hump force RAUC'
#     rauc_label = 'force RAUC'
    y = d['epochs_df'][rauc_label]

    spike_labels = d['spike_trains'].keys()
#     spike_labels = [
#     #     'I2',
#         'B8a/b',
#         'B3 (50-100 uV)',
#     #     '? (45-50 uV)',
#         'B6/B9 ? (26-45 uV)',
#         'B38 ? (17-26 uV)',
#     #     '? (15-17 uV)',
#         'B4/B5',
#     ]

    legend_text = []
    for j, spike_label in enumerate(spike_labels):
        x = []
        for k, epoch in d['epochs_df'].iterrows():
            st = d['spike_trains'][spike_label].time_slice(epoch['Start (s)'], epoch['End (s)'])
            x.append(st.size)
        
        model = sm.OLS(y, sm.add_constant(x)).fit()
        legend_text.append('{}, R$^2$ = {:.2f}, p = {:.3f}'.format(spike_label, model.rsquared, model.pvalues[1]))
        
#         plt.scatter(x, y)
#         line_plot_x = np.linspace(min(x),max(x),100)
#         plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
        sns.regplot(x=x, y=y, ci=None, truncate=True)

    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('Number of spikes in swallow motor pattern')
#     plt.ylabel('Integrated force (mN$\cdot$s)')
    plt.ylabel(rauc_label + ' (mN$\cdot$s)')
    plt.legend(legend_text, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()

##### Figure 4: Plot BN2 RAUC vs force RAUC

In [None]:
plt.figure(4, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
    x = d['epochs_df']['BN2 RAUC']
    y = d['epochs_df']['force RAUC']
    
    plt.xlim([0, 25])
    plt.ylim([0, 1200])
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    legend_text = ['R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(x))]
    
#     plt.scatter(x, y)
#     line_plot_x = np.linspace(plt.gca().get_xlim()[0],plt.gca().get_xlim()[1],100)
#     plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
    sns.regplot(x=x, y=y, ci=False)
    
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('BN2 RAUC (integrated rectified voltage on buccal nerve 2) ($\mu$V$\cdot$s)')
    plt.ylabel('Force RAUC (integrated force) (mN$\cdot$s)')
    plt.legend(legend_text)#, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()

##### Figure 5: Plot BN3 RAUC vs force RAUC

In [None]:
plt.figure(5, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
    x = d['epochs_df']['BN3 RAUC']
    y = d['epochs_df']['force RAUC']
    
    plt.xlim([0, 50])
    plt.ylim([0, 1200])
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    legend_text = ['R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(x))]
    
#     plt.scatter(x, y)
#     line_plot_x = np.linspace(plt.gca().get_xlim()[0],plt.gca().get_xlim()[1],100)
#     plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
    sns.regplot(x=x, y=y, ci=False)
    
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('BN3 RAUC (integrated rectified voltage on buccal nerve 3) ($\mu$V$\cdot$s)')
    plt.ylabel('Force RAUC (integrated force) (mN$\cdot$s)')
    plt.legend(legend_text)#, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()

##### Figure 6: Plot RN RAUC vs force RAUC

In [None]:
plt.figure(6, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
    x = d['epochs_df']['RN RAUC']
    y = d['epochs_df']['force RAUC']
    
    plt.xlim([0, 25])
    plt.ylim([0, 1200])
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    legend_text = ['R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(x))]
    
#     plt.scatter(x, y)
#     line_plot_x = np.linspace(plt.gca().get_xlim()[0],plt.gca().get_xlim()[1],100)
#     plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
    sns.regplot(x=x, y=y, ci=False)
    
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('RN RAUC (integrated rectified voltage on radular nerve) ($\mu$V$\cdot$s)')
    plt.ylabel('Force RAUC (integrated force) (mN$\cdot$s)')
    plt.legend(legend_text)#, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()

#### Figure 7: Plot BN2 RAUC vs RN RAUC

In [None]:
plt.figure(7, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
    x = d['epochs_df']['BN2 RAUC']
    y = d['epochs_df']['RN RAUC']
    
    plt.xlim([0, 25])
    plt.ylim([0, 25])
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    legend_text = ['R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(x))]
    
#     plt.scatter(x, y)
#     line_plot_x = np.linspace(plt.gca().get_xlim()[0],plt.gca().get_xlim()[1],100)
#     plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
    sns.regplot(x=x, y=y, ci=False)
    
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('BN2 RAUC (integrated rectified voltage on buccal nerve 2) ($\mu$V$\cdot$s)')
    plt.ylabel('RN RAUC (integrated rectified voltage on radular nerve) ($\mu$V$\cdot$s)')
    plt.legend(legend_text)#, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()

#### Figure 8: Plot BN2 RAUC vs BN3 RAUC

In [None]:
plt.figure(8, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
    x = d['epochs_df']['BN2 RAUC']
    y = d['epochs_df']['BN3 RAUC']
    
    plt.xlim([0, 25])
    plt.ylim([0, 50])
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    legend_text = ['R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(x))]
    
#     plt.scatter(x, y)
#     line_plot_x = np.linspace(plt.gca().get_xlim()[0],plt.gca().get_xlim()[1],100)
#     plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
    sns.regplot(x=x, y=y, ci=False)
    
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('BN2 RAUC (integrated rectified voltage on buccal nerve 2) ($\mu$V$\cdot$s)')
    plt.ylabel('BN3 RAUC (integrated rectified voltage on buccal nerve 3) ($\mu$V$\cdot$s)')
    plt.legend(legend_text)#, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()

#### Figure 9: Plot RN RAUC vs BN3 RAUC

In [None]:
plt.figure(9, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
    x = d['epochs_df']['RN RAUC']
    y = d['epochs_df']['BN3 RAUC']
    
    plt.xlim([0, 25])
    plt.ylim([0, 50])
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    legend_text = ['R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(x))]
    
#     plt.scatter(x, y)
#     line_plot_x = np.linspace(plt.gca().get_xlim()[0],plt.gca().get_xlim()[1],100)
#     plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
    sns.regplot(x=x, y=y, ci=False)
    
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('RN RAUC (integrated rectified voltage on radular nerve) ($\mu$V$\cdot$s)')
    plt.ylabel('BN3 RAUC (integrated rectified voltage on buccal nerve 3) ($\mu$V$\cdot$s)')
    plt.legend(legend_text)#, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()