# Import Packages

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import quantities as pq
import elephant
import statsmodels.api as sm
import neurotic
from neurotic.gui.config import _neo_epoch_to_dataframe
from utils import CausalAlphaKernel

pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

# IPython Magics

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

# Data Parameters

In [None]:
# specify the data sets to analyze
data_sets = [
    'IN VIVO / JG08 / 2018-06-21 / 002',
    'IN VIVO / JG08 / 2018-06-24 / 001',
    'IN VIVO / JG12 / 2019-05-10 / 002',
]

# load the metadata containing file paths
metadata = neurotic.MetadataSelector(file='../../data/metadata.yml')

# store metadata in a dictionary that we will add to later
data = {}
for data_set_name in data_sets:
    metadata.select(data_set_name)
    data[data_set_name] = {}
    data[data_set_name]['metadata'] = metadata.selected_metadata

In [None]:
# select which swallow sequences to use

data['IN VIVO / JG08 / 2018-06-21 / 002']['time_windows_to_keep'] = [
#     [-np.inf, np.inf], # keep everything
    [659, 726.1], # tension maximized and no perturbation
#     [666.95, 726.1], # tension maximized and no perturbation, and extra long large hump excluded
#     [666.95, 705], # sequence of 5 very stereotyped swallows
]

data['IN VIVO / JG08 / 2018-06-24 / 001']['time_windows_to_keep'] = [
#     [-np.inf, np.inf], # keep everything
    [2244.7, 2259.9], [2269.5, 2355.95], # tension maximized and no perturbation
#     [2244.7, 2259.9], [2269.5, 2290.2], [2307, 2355.95], # tension maximized and no perturbation, and extra long large hump excluded
    
#     [3932, 3990],
#     [2232, 2356], [2743, 2940], [3095, 3140], [3385, 3425], [3570, 3594], [3923, 3990]
]

data['IN VIVO / JG12 / 2019-05-10 / 002']['time_windows_to_keep'] = [
#     [-np.inf, np.inf], # keep everything
#     [430, 525],
    [430, 580],
#     [2890, 2946],#[2890, 3085],
]

# Import and Process the Data

In [None]:
for data_set_name, d in data.items():

    # read in the data
    blk = neurotic.load_dataset(d['metadata'])
#     signalNameToIndex = {sig.name:i for i, sig in enumerate(blk.segments[0].analogsignals)}
    signalNameToIndex = {sig.name.replace('-L','').replace('-PROX',''):i for i, sig in enumerate(blk.segments[0].analogsignals)}

    # grab the force vs time data and rescale to mN
    d['force_sig'] = blk.segments[0].analogsignals[signalNameToIndex['Force']].rescale('mN')

    # apply a super-low-pass filter to force signal
    d['smoothed_force_sig'] = elephant.signal_processing.butter(  # may raise a FutureWarning
        signal = d['force_sig'],
        lowpass_freq = 0.5*pq.Hz,
    )

    # calculate the derivative of the force vs time data and smooth it
    d['dforce/dt'] = elephant.signal_processing.butter(  # may raise a FutureWarning
        signal = elephant.signal_processing.derivative(d['force_sig']),
        lowpass_freq = 2*pq.Hz,
    ).rescale('mN/s')

    # grab the voltage vs time data and rescale to uV
    d['i2_sig']  = blk.segments[0].analogsignals[signalNameToIndex['I2']].rescale('uV')
    d['rn_sig']  = blk.segments[0].analogsignals[signalNameToIndex['RN']].rescale('uV')
    d['bn2_sig'] = blk.segments[0].analogsignals[signalNameToIndex['BN2']].rescale('uV')
    d['bn3_sig'] = blk.segments[0].analogsignals[signalNameToIndex['BN3']].rescale('uV')

    # grab the spike trains
    spike_trains = {}
    for st in blk.segments[0].spiketrains:
        spike_trains[st.name] = st
    d['spike_trains'] = spike_trains

    # grab the sampling period
    d['sampling_period'] = blk.segments[0].analogsignals[0].sampling_period

    # keep only epochs that are entirely inside the time windows
    epochs_df = _neo_epoch_to_dataframe(blk.segments[0].epochs)
    epochs_df = epochs_df[np.any(list(map(lambda t: (t[0] <= epochs_df['Start (s)']) & (epochs_df['End (s)'] <= t[1]), d['time_windows_to_keep'])), axis=0)]

    # copy middle times (end of large hump and start of small hump) into 'force' epochs
    for i, epoch in epochs_df[epochs_df['Type'] == 'force'].iterrows():
        for j, subepoch in epochs_df[epochs_df['Type'] == 'large hump'].iterrows():
            if subepoch['Start (s)'] >= epoch['Start (s)']-1e-7 and subepoch['End (s)'] <= epoch['End (s)']+1e-7:
                epochs_df.loc[i, 'Middle (s)'] = subepoch['End (s)']

    # drop all but 'force' rows
    epochs_df = epochs_df[epochs_df['Type'] == 'force']

    # find max forces in each epoch
    for i, epoch in epochs_df.iterrows():
        epochs_df.loc[i,                'max'] = max(         d['force_sig'].time_slice(epoch['Start (s)'] *pq.s, epoch['End (s)']   *pq.s).magnitude)[0]
        epochs_df.loc[i,          'large max'] = max(         d['force_sig'].time_slice(epoch['Start (s)'] *pq.s, epoch['Middle (s)']*pq.s).magnitude)[0] if not np.isnan(epoch.get('Middle (s)', np.nan)) else np.nan
        epochs_df.loc[i,          'small max'] = max(         d['force_sig'].time_slice(epoch['Middle (s)']*pq.s, epoch['End (s)']   *pq.s).magnitude)[0] if not np.isnan(epoch.get('Middle (s)', np.nan)) else np.nan
        epochs_df.loc[i, 'smoothed large max'] = max(d['smoothed_force_sig'].time_slice(epoch['Start (s)'] *pq.s, epoch['Middle (s)']*pq.s).magnitude)[0] if not np.isnan(epoch.get('Middle (s)', np.nan)) else np.nan
        epochs_df.loc[i, 'smoothed small max'] = max(d['smoothed_force_sig'].time_slice(epoch['Middle (s)']*pq.s, epoch['End (s)']   *pq.s).magnitude)[0] if not np.isnan(epoch.get('Middle (s)', np.nan)) else np.nan

    # find rectified area under the curve (RAUC) in each epoch
    for i, epoch in epochs_df.iterrows():
        epochs_df.loc[i,            'force RAUC'] = elephant.signal_processing.rauc(d['force_sig'].time_slice(epoch['Start (s)'] *pq.s, epoch['End (s)']   *pq.s))                   .rescale('mN*s')
        epochs_df.loc[i, 'large hump force RAUC'] = elephant.signal_processing.rauc(d['force_sig'].time_slice(epoch['Start (s)'] *pq.s, epoch['Middle (s)']*pq.s))                   .rescale('mN*s') if not np.isnan(epoch.get('Middle (s)', np.nan)) else np.nan
        epochs_df.loc[i, 'small hump force RAUC'] = elephant.signal_processing.rauc(d['force_sig'].time_slice(epoch['Middle (s)']*pq.s, epoch['End (s)']   *pq.s))                   .rescale('mN*s') if not np.isnan(epoch.get('Middle (s)', np.nan)) else np.nan
        epochs_df.loc[i,               'I2 RAUC'] = elephant.signal_processing.rauc(d['i2_sig']   .time_slice(epoch['Start (s)'] *pq.s, epoch['End (s)']   *pq.s), baseline = 'mean').rescale('uV*s')
        epochs_df.loc[i,               'RN RAUC'] = elephant.signal_processing.rauc(d['rn_sig']   .time_slice(epoch['Start (s)'] *pq.s, epoch['End (s)']   *pq.s), baseline = 'mean').rescale('uV*s')
        epochs_df.loc[i,              'BN2 RAUC'] = elephant.signal_processing.rauc(d['bn2_sig']  .time_slice(epoch['Start (s)'] *pq.s, epoch['End (s)']   *pq.s), baseline = 'mean').rescale('uV*s')
        epochs_df.loc[i,              'BN3 RAUC'] = elephant.signal_processing.rauc(d['bn3_sig']  .time_slice(epoch['Start (s)'] *pq.s, epoch['End (s)']   *pq.s), baseline = 'mean').rescale('uV*s')

    # colors
    epochs_df = epochs_df.assign(colormap_arg = np.linspace(0, 1, len(epochs_df)))

    d['epochs_df'] = epochs_df

# Plots

In [None]:
# color map
cm = plt.cm.cool
# cm = plt.cm.brg
# cm = plt.cm.RdBu

sns.set(
#     context = 'poster',
    style = 'ticks',
    font_scale = 1,
    font = 'Palatino Linotype',
)

##### Figure 1: Plot forces across real time

In [None]:
# plt.figure(1, figsize=(9,3))
# for i, data_set_name in enumerate(data_sets):
#     d = data[data_set_name]
#     plt.subplot(1, len(data), i+1)
#     plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
#     plt.ylabel('Force (mN)')
#     plt.xlabel('Original chart time (s)')
#     for j, epoch in d['epochs_df'].iterrows():
#         epoch_force_sig = d['force_sig'].time_slice(epoch['Start (s)']*pq.s, epoch['End (s)']*pq.s)
#         plt.plot(epoch_force_sig.times, epoch_force_sig.as_array(), color=cm(epoch['colormap_arg']))
#     sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
# plt.tight_layout()

##### Figure 2: Plot forces, spike trains, and firing rate models

In [None]:
n_plot_cols = len(data)
n_plot_rows = 2 + max(len(d['spike_trains']) for k,d in data.items())
plt.figure(2, figsize=(9,2*n_plot_rows))

for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    t_min = min(d['epochs_df']['Start (s)'])
    t_max = max(d['epochs_df']['End (s)'])

    # === FORCE ===
    ax = plt.subplot(n_plot_rows, n_plot_cols, i+1)
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.ylabel('Force (mN)')
    plt.xlim(t_min, t_max)
    plt.ylim([-10, 400])
    
#     sns.despine(ax=plt.gca(), offset=10, trim=True, bottom=True)
#     plt.gca().xaxis.set_visible(False)

    for j, epoch in d['epochs_df'].iterrows():
        epoch_force_sig = d['force_sig'].time_slice(epoch['Start (s)']*pq.s, epoch['End (s)']*pq.s)
        plt.plot(epoch_force_sig.times, epoch_force_sig.as_array(), color=cm(epoch['colormap_arg']))
        plt.text(np.mean([epoch['Start (s)'], epoch['End (s)']]), epoch['max'], '{:.0f}'.format(epoch['force RAUC']), fontsize=8, ha='center')
        if not np.isnan(epoch.get('Middle (s)', np.nan)):
            plt.text(np.mean([epoch['Start (s)'],  epoch['Middle (s)']]), epoch['smoothed small max'], '{:.0f}'.format(epoch['large hump force RAUC']), fontsize=8, ha='center')
            plt.text(np.mean([epoch['Middle (s)'], epoch['End (s)']]),    epoch['smoothed small max'], '{:.0f}'.format(epoch['small hump force RAUC']), fontsize=8, ha='center')

    # === D(FORCE)/DT ===
    plt.subplot(n_plot_rows, n_plot_cols, (1)*n_plot_cols+i+1, sharex=ax)
    plt.axhline(0, color='gray', linewidth=0.5)
    dfdt = d['dforce/dt'].time_slice(t_min*pq.s, t_max*pq.s)
    plt.plot(dfdt.times, dfdt.as_array())
    plt.ylabel('d(Force)/dt (mN/s)')
    plt.ylim([-400, 400])
    
#     sns.despine(ax=plt.gca(), offset=10, trim=True, bottom=True)
#     plt.gca().xaxis.set_visible(False)

    # === RASTER PLOTS + RATE MODELS ===
    spike_labels = d['spike_trains'].keys()

    for j, spike_label in enumerate(spike_labels):
        st = d['spike_trains'][spike_label]
        st = st.time_slice(
            t_min*pq.s - 5*pq.s,
            t_max*pq.s + 5*pq.s
        ) # drop spikes outside plot range except for a few sec margin so beginning and final firing rates are accurate

        plt.subplot(n_plot_rows, n_plot_cols, (2+j)*n_plot_cols+i+1, sharex=ax)
        plt.ylabel(spike_label + '\n(rate model)')
        plt.xlim(t_min, t_max)
        plt.ylim([-2, 40])
        
        if j == len(spike_labels)-1:
#             sns.despine(ax=plt.gca(), offset=10, trim=True)
            plt.xlabel('Time (s)')
#         else:
#             sns.despine(ax=plt.gca(), offset=10, trim=True, bottom=True)
#             plt.gca().xaxis.set_visible(False)

        # raster plot
        plt.eventplot(positions=st, lineoffsets=-1, colors='red')

        # spike train convolution
        kernels = [
#             CausalAlphaKernel(0.03*np.sqrt(2)*pq.s), # match my old poster's synapse model
            CausalAlphaKernel(0.2*pq.s),
#             elephant.kernels.AlphaKernel(0.03*np.sqrt(2)*pq.s),
#             elephant.kernels.AlphaKernel(0.2*pq.s),
#             elephant.kernels.EpanechnikovLikeKernel(0.2*pq.s),
#             elephant.kernels.ExponentialKernel(0.2*pq.s),
#             elephant.kernels.GaussianKernel(0.2*pq.s),
#             elephant.kernels.LaplacianKernel(0.2*pq.s),
#             elephant.kernels.RectangularKernel(0.2*pq.s),
#             elephant.kernels.TriangularKernel(0.2*pq.s)
        ]
        for kernel in kernels:
            rate = elephant.statistics.instantaneous_rate(
                spiketrain=st, sampling_period=d['sampling_period'], kernel=kernel)
            plt.plot(rate.times.rescale('s'), rate)

        # instantaneous firing frequency step plot
#         if st.size > 0:
# #             plt.plot(st[:-1], 1/elephant.statistics.isi(st), drawstyle='steps-post')
#             times = st.times.rescale('s')
#             times = np.concatenate([[t_min], times, [t_max]])*pq.s
#             iff = 1/elephant.statistics.isi(st)
#             iff = np.concatenate([[0], iff.rescale('1/s'), [0, 0]])/pq.s
#             plt.plot(times, iff, drawstyle='steps-post')

        for k, epoch in d['epochs_df'].iterrows():
            plt.text(np.mean([epoch['Start (s)'], epoch['End (s)']]), 20, st.time_slice(epoch['Start (s)'], epoch['End (s)']).size, fontsize=8, ha='center')

plt.tight_layout()

##### MODEL FIT WORK IN PROGRESS -------

In [None]:
data_set_name = 'IN VIVO / JG08 / 2018-06-21 / 002'
# data_set_name = 'IN VIVO / JG08 / 2018-06-24 / 001'
# data_set_name = 'IN VIVO / JG12 / 2019-05-10 / 002'
d = data[data_set_name]

t_min = min(d['epochs_df']['Start (s)'])
t_max = max(d['epochs_df']['End (s)'])

force_sig = d['force_sig'].time_slice(t_min*pq.s, t_max*pq.s)

# need to make sure t_min, t_max are aligned with the sample times
# of force_sig, since they will determine the rate models' times
t_min = force_sig.t_start.rescale('s').magnitude
t_max = force_sig.t_stop.rescale('s').magnitude


baseline = 0
spike_rate_models = {
#     'I2':    {'weight': -0.002,    'rate_constant': 1},
#     'B8a/b': {'weight': 0.05,    'rate_constant': 1},
    'B3':    {'weight': 0.05, 'rate_constant': 1},
    'B6/B9': {'weight': 0.05, 'rate_constant': 0.5},
    'B38':   {'weight': 0.025, 'rate_constant': 1},
}
params = [baseline]
for unit, p in spike_rate_models.items():
    params += [p['weight'], p['rate_constant']]
params = np.array(params)

unit_names = list(spike_rate_models.keys())
# unit_names = ['B3', 'B6/B9', 'B38']

# 0 iters, error = 5438.537110230563
# params = np.array([0, 0.05, 1, 0.05, 0.5, 0.025, 1])

# 1 iters, error = 5436.993468396257
# params = np.array([3.80738436e-08, 5.00000362e-02, 1.00000377e+00, 5.00000362e-02, 5.00003788e-01, 2.49999990e-02, 9.99999962e-01])

# 2 iters, error = 5282.6250908482625
# params = np.array([2.46302317e-04, 5.00897420e-02, 9.99644748e-01, 5.12355946e-02, 4.99630339e-01, 2.48409028e-02, 1.00002662e+00])

# 10 iters, error = 2999.8449262352165
# params = np.array([0.18306482, 0.00436199, 0.99686335, 0.04440692, 0.45378227, 0.        , 1.01851402])

# 33 iters (terminated), error = 2802.0291040955985
# params = np.array([0.18100302, 0.05385427, 0.97966784, 0.03796378, 0.39556   , 0.        , 1.01637518])

baseline, params = params[0], params[1:]
spike_rate_models = {}
for n, (w, r) in zip(unit_names, params.reshape(-1, 2)):
    spike_rate_models[n] = {'weight': w, 'rate_constant': r}

# # 1 iter without mean shift
# spike_rate_models = {
#     'B3':    {'weight': 0.39999997, 'rate_constant': 0.99999993},
#     'B6/B9': {'weight': 0.40000004, 'rate_constant': 0.50000684},
#     'B38':   {'weight': 0.19999999, 'rate_constant': 0.99999993},
# }

# # 3 iters without mean shift
# spike_rate_models = {
#     'B3':    {'weight': 0.39905256, 'rate_constant': 1.23447876},
#     'B6/B9': {'weight': 0.40142116, 'rate_constant': 0.49882414},
#     'B38':   {'weight': 0.19952626, 'rate_constant': 0.99763389},
# }

# # 5 iters without mean shift
# spike_rate_models = {
#     'B3':    {'weight': 0.39850763, 'rate_constant': 1.4001333},
#     'B6/B9': {'weight': 0.40437206, 'rate_constant': 0.43688527},
#     'B38':   {'weight': 0.19504436, 'rate_constant': 1.00480115},
# }

# 7 iters without mean shift
# spike_rate_models = {
#     'B3':    {'weight': 0.39850765, 'rate_constant': 1.40015991},
#     'B6/B9': {'weight': 0.40437244, 'rate_constant': 0.43688048},
#     'B38':   {'weight': 0.19504362, 'rate_constant': 1.00480165},
# }

# 10 iters without mean shift
# spike_rate_models = {
#     'B3':    {'weight': 0.44591693, 'rate_constant': 7.14593357},
#     'B6/B9': {'weight': 0.33406253, 'rate_constant': 0.46480023},
#     'B38':   {'weight': 0.24457259, 'rate_constant': 6.77407843},
# }

for name, params in spike_rate_models.items():
    
    # get the spike train
    st = d['spike_trains'][name]
    st = st.time_slice(
        t_min*pq.s - 5*pq.s,
        t_max*pq.s,
    ) # drop spikes outside plot range except for a few sec before so beginning firing rate is accurate
    params['spike_train'] = st
    
    # convolve the spike train with the kernel
    params['rate'] = elephant.statistics.instantaneous_rate(
        spiketrain=st,
        sampling_period=d['sampling_period'],
#         t_start=st.t_start, # default
        kernel=CausalAlphaKernel(params['rate_constant']*pq.s),
    )


force_ylim = [-10, 500]
rate_ylim = [-2, 2]

n_plot_rows = len(spike_rate_models) + 2
plt.figure(99, figsize=(9,2*n_plot_rows))

# === FORCE ===
ax = plt.subplot(n_plot_rows, 1, 1)
plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
plt.ylabel('Force (mN)')
plt.xlim(t_min, t_max)
plt.ylim(force_ylim)

plt.plot(force_sig.times, force_sig.as_array(), color='C1')

# === RASTER PLOTS + RATE MODELS ===
for j, (name, params) in enumerate(spike_rate_models.items()):

    plt.subplot(n_plot_rows, 1, j+2, sharex=ax)
    plt.ylabel(f"{name}\nweight: {np.round(params['weight'],3)}\nrate const: {np.round(params['rate_constant'],3)} sec")
    plt.xlim(t_min, t_max)
    plt.ylim(rate_ylim)

    # raster plot
    plt.eventplot(positions=params['spike_train'], lineoffsets=-1, colors='red')

    # spike train convolution
    plt.plot(params['rate'].times.rescale('s'), params['rate'] * params['weight'])

    for k, epoch in d['epochs_df'].iterrows():
        label = params['spike_train'].time_slice(epoch['Start (s)'], epoch['End (s)']).size
        plt.text(np.mean([epoch['Start (s)'], epoch['End (s)']]), 1, label, fontsize=8, ha='center')

# === COMBINED MODEL ===

plt.subplot(n_plot_rows, 1, len(spike_rate_models)+2, sharex=ax)
plt.ylabel(f'Sum of rates\nvs norm force\nbaseline: {np.round(baseline,3)}')
plt.xlabel('Time (s)')
plt.xlim(t_min, t_max)
plt.ylim([-0.1, 1.1])

# assert all times are the same so that simple addition of rates can be done
for j, (name, params) in enumerate(spike_rate_models.items()):
    assert np.all(params['rate'].times == spike_rate_models[list(spike_rate_models.keys())[0]]['rate'].times)

rate_sum = None
for name, params in spike_rate_models.items():
    if rate_sum is None:
        rate_sum = params['rate'] * params['weight']
    else:
        rate_sum += params['rate'] * params['weight']

plt.plot(rate_sum.times.rescale('s'), rate_sum.as_array()+baseline)#/rate_sum.max())
# plt.plot(rate_sum.times.rescale('s'), rate_sum.as_array()/rate_sum.max().magnitude+baseline)

force_sig = d['force_sig'].time_slice(t_min*pq.s, t_max*pq.s)
plt.plot(force_sig.times, force_sig.as_array()/force_sig.max(), color='C1')

plt.tight_layout()

In [None]:
# for key in spike_rate_models:
#     for t in spike_rate_models[key]['spike_train'].times:
#         print(f'"{key}",{np.round(t.magnitude,4)}')

In [None]:
def get_time_range(data_set_name):
    d = data[data_set_name]
    t_min = min(d['epochs_df']['Start (s)'])
    t_max = max(d['epochs_df']['End (s)'])

    force_sig = d['force_sig'].time_slice(t_min*pq.s, t_max*pq.s)

    # need to make sure t_min, t_max are aligned with the sample times
    # of force_sig, since they will determine the rate models' times
    t_min = force_sig.t_start.rescale('s').magnitude
    t_max = force_sig.t_stop.rescale('s').magnitude
    
    return t_min, t_max


def get_spike_trains(data_set_name, t_min, t_max, unit_names):
    
    spike_trains = []
    for name in unit_names:
        # get the spike train
        st = d['spike_trains'][name]
        st = st.time_slice(
            t_min*pq.s - 5*pq.s,
            t_max*pq.s + 5*pq.s
        ) # drop spikes outside plot range except for a few sec margin so beginning and final firing rates are accurate
        spike_trains.append(st)
    
    return spike_trains

def get_model_time_series(t_min, t_max, sampling_period, spike_trains, params):
    
    # assume params is a flattened array of ordered pairs (weight, rate_constant) with baseline appended
    baseline, params = params[0], params[1:]
    params = params.reshape(-1, 2)
    
    rate_sum = None
    for st, p in zip(spike_trains, params):
        weight, rate_constant = p
        rate_model = elephant.statistics.instantaneous_rate(
            spiketrain=st,
            sampling_period=sampling_period,
            t_start=st.t_start, # default
            kernel=CausalAlphaKernel(rate_constant*pq.s),
        )
        if rate_sum is None:
            rate_sum = rate_model * weight
        else:
            rate_sum += rate_model * weight
    
    rate_sum = rate_sum.time_slice(t_min*pq.s, t_max*pq.s)
    rate_sum = rate_sum.as_array() + baseline
    
    return rate_sum

def get_sum_of_square_residuals(data_set_name, rate_sum):
    d = data[data_set_name]
    t_min, t_max = get_time_range(data_set_name)
    
    force_sig = d['force_sig'].time_slice(t_min*pq.s, t_max*pq.s)
    
    force_sig = force_sig.as_array()
    force_sig = force_sig/force_sig.max()
    
    return float(((rate_sum-force_sig)**2).sum())

def do_it_efficiently(params, t_min, t_max, sampling_period, spike_trains, force_sig_normalized):
    
    # assume params is a flattened array of ordered pairs (weight, rate_constant) with baseline appended
    baseline, params = params[0], params[1:]
    params = params.reshape(-1, 2)
    
    rate_sum = None
    for st, p in zip(spike_trains, params):
        weight, rate_constant = p
        rate_model = elephant.statistics.instantaneous_rate(
            spiketrain=st,
            sampling_period=sampling_period,
            kernel=CausalAlphaKernel(rate_constant*pq.s),
        )
        if rate_sum is None:
            rate_sum = rate_model * weight
        else:
            rate_sum += rate_model * weight
    
    rate_sum = rate_sum.time_slice(t_min*pq.s, t_max*pq.s)
    
    rate_sum = rate_sum.as_array() + baseline
    
    return float(((rate_sum-force_sig_normalized)**2).sum())

In [None]:
data_set_name = 'IN VIVO / JG08 / 2018-06-21 / 002'
unit_names = ['B3', 'B6/B9', 'B38']
params = np.array([0, 0.05, 1, 0.05, 0.5, 0.025, 1])

d = data[data_set_name]
t_min, t_max = get_time_range(data_set_name)
spike_trains = get_spike_trains(data_set_name, t_min, t_max, unit_names)
rate_sum = get_model_time_series(t_min, t_max, d['sampling_period'], spike_trains, params)
error = get_sum_of_square_residuals(data_set_name, rate_sum)

In [None]:
error

In [None]:
data_set_name = 'IN VIVO / JG08 / 2018-06-21 / 002'
unit_names = ['B3', 'B6/B9', 'B38']
params = np.array([0, 0.05, 1, 0.05, 0.5, 0.025, 1])

d = data[data_set_name]
t_min, t_max = get_time_range(data_set_name)
force_sig = d['force_sig'].time_slice(t_min*pq.s, t_max*pq.s)
spike_trains = get_spike_trains(data_set_name, t_min, t_max, unit_names)

# # THIS IS DIFFERENT FROM WHAT I PLOTTED ABOVE WITH TRUE ZERO
# force_sig = (force_sig-force_sig.mean())/force_sig.std()

force_sig = force_sig.as_array()
force_sig = force_sig/force_sig.max()

error = do_it_efficiently(params, t_min, t_max, d['sampling_period'], spike_trains, force_sig)

In [None]:
error

In [None]:
import scipy

data_set_name = 'IN VIVO / JG08 / 2018-06-21 / 002'
unit_names = ['B3', 'B6/B9', 'B38']

# params = np.array([0.4, 1, 0.4, 0.5, 0.2, 1])
params = np.array([0, 0.05, 1, 0.05, 0.5, 0.025, 1])
lb = np.array([
    0,    # baseline
    0,    # weight 1
    1e-3, # rate constant 1
    0,    # weight 2
    1e-3, # rate constant 2
    0,    # weight 3
    1e-3, # rate constant 3
])
ub = np.array([
    1,   # baseline
    1,   # weight 1
    100, # rate constant 1
    1,   # weight 2
    100, # rate constant 2
    1,   # weight 3
    100, # rate constant 3
])

d = data[data_set_name]
t_min, t_max = get_time_range(data_set_name)
force_sig = d['force_sig'].time_slice(t_min*pq.s, t_max*pq.s)
spike_trains = get_spike_trains(data_set_name, t_min, t_max, unit_names)

# # THIS IS DIFFERENT FROM WHAT I PLOTTED ABOVE WITH TRUE ZERO
# force_sig = (force_sig-force_sig.mean())/force_sig.std()

force_sig = force_sig.as_array()
force_sig = force_sig/force_sig.max()

result = scipy.optimize.minimize(
    do_it_efficiently,
    params,
    (t_min, t_max, d['sampling_period'], spike_trains, force_sig),
    bounds=scipy.optimize.Bounds(lb, ub),
    options={'maxiter': 100},
)

In [None]:
result

In [None]:
print(np.array([params, result.x])) # 100 iters with baseline, err: 

In [None]:
print(np.array([params, result.x])) # 10 iters with baseline, err: 

In [None]:
print(np.array([params, result.x])) # 2 iters with baseline, err: 

In [None]:
print(np.array([params, result.x])) # 1 iter with baseline, err: 

In [None]:
print(np.array([params, result.x])) # 10 iters without mean shift, err: 3572

In [None]:
print(np.array([params, result.x])) # 7 iters without mean shift, err: 5385

In [None]:
print(np.array([params, result.x])) # 5 iters without mean shift, err: 5385

In [None]:
print(np.array([params, result.x])) # 3 iters without mean shift, err: 5430

In [None]:
print(np.array([params, result.x])) # 1 iter without mean shift

In [None]:
print(np.array([params, result.x])) # 10 iters with mean shift

In [None]:
print(np.array([params, result.x])) # 5 iters with mean shift

In [None]:
result

In [None]:
result

In [None]:
plt.figure()
plt.plot(rate_sum.times.rescale('s'), rate_sum)

##### ------- END MODEL FIT WORK IN PROGRESS 

##### Figure 3: Plot number of spikes vs force RAUC

In [None]:
plt.figure(3, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
#     rauc_label = 'small hump force RAUC'
    rauc_label = 'large hump force RAUC'
#     rauc_label = 'force RAUC'
    y = d['epochs_df'][rauc_label]

    spike_labels = d['spike_trains'].keys()
#     spike_labels = [
#     #     'I2',
#         'B8a/b',
#         'B3 (50-100 uV)',
#     #     '? (45-50 uV)',
#         'B6/B9 ? (26-45 uV)',
#         'B38 ? (17-26 uV)',
#     #     '? (15-17 uV)',
#         'B4/B5',
#     ]

    legend_text = []
    for j, spike_label in enumerate(spike_labels):
        x = []
        for k, epoch in d['epochs_df'].iterrows():
            st = d['spike_trains'][spike_label].time_slice(epoch['Start (s)'], epoch['End (s)'])
            x.append(st.size)
        
        model = sm.OLS(y, sm.add_constant(x)).fit()
        legend_text.append('{}, R$^2$ = {:.2f}, p = {:.3f}'.format(spike_label, model.rsquared, model.pvalues[1]))
        
#         plt.scatter(x, y)
#         line_plot_x = np.linspace(min(x),max(x),100)
#         plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
        sns.regplot(x=x, y=y, ci=None, truncate=True)

    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('Number of spikes in swallow motor pattern')
#     plt.ylabel('Integrated force (mN·s)')
    plt.ylabel(rauc_label + ' (mN·s)')
    plt.legend(legend_text, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()

##### Figure 4: Plot BN2 RAUC vs force RAUC

In [None]:
plt.figure(4, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
    x = d['epochs_df']['BN2 RAUC']
    y = d['epochs_df']['force RAUC']
    
    plt.xlim([0, 25])
    plt.ylim([0, 1200])
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    legend_text = ['R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(x))]
    
#     plt.scatter(x, y)
#     line_plot_x = np.linspace(plt.gca().get_xlim()[0],plt.gca().get_xlim()[1],100)
#     plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
    sns.regplot(x=x, y=y, ci=False)
    
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('BN2 RAUC (integrated rectified voltage on buccal nerve 2) (μV·s)')
    plt.ylabel('Force RAUC (integrated force) (mN·s)')
    plt.legend(legend_text)#, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()

##### Figure 5: Plot BN3 RAUC vs force RAUC

In [None]:
plt.figure(5, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
    x = d['epochs_df']['BN3 RAUC']
    y = d['epochs_df']['force RAUC']
    
    plt.xlim([0, 50])
    plt.ylim([0, 1200])
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    legend_text = ['R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(x))]
    
#     plt.scatter(x, y)
#     line_plot_x = np.linspace(plt.gca().get_xlim()[0],plt.gca().get_xlim()[1],100)
#     plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
    sns.regplot(x=x, y=y, ci=False)
    
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('BN3 RAUC (integrated rectified voltage on buccal nerve 3) (μV·s)')
    plt.ylabel('Force RAUC (integrated force) (mN·s)')
    plt.legend(legend_text)#, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()

##### Figure 6: Plot RN RAUC vs force RAUC

In [None]:
plt.figure(6, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
    x = d['epochs_df']['RN RAUC']
    y = d['epochs_df']['force RAUC']
    
    plt.xlim([0, 25])
    plt.ylim([0, 1200])
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    legend_text = ['R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(x))]
    
#     plt.scatter(x, y)
#     line_plot_x = np.linspace(plt.gca().get_xlim()[0],plt.gca().get_xlim()[1],100)
#     plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
    sns.regplot(x=x, y=y, ci=False)
    
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('RN RAUC (integrated rectified voltage on radular nerve) (μV·s)')
    plt.ylabel('Force RAUC (integrated force) (mN·s)')
    plt.legend(legend_text)#, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()

#### Figure 7: Plot BN2 RAUC vs RN RAUC

In [None]:
plt.figure(7, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
    x = d['epochs_df']['BN2 RAUC']
    y = d['epochs_df']['RN RAUC']
    
    plt.xlim([0, 25])
    plt.ylim([0, 25])
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    legend_text = ['R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(x))]
    
#     plt.scatter(x, y)
#     line_plot_x = np.linspace(plt.gca().get_xlim()[0],plt.gca().get_xlim()[1],100)
#     plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
    sns.regplot(x=x, y=y, ci=False)
    
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('BN2 RAUC (integrated rectified voltage on buccal nerve 2) (μV·s)')
    plt.ylabel('RN RAUC (integrated rectified voltage on radular nerve) (μV·s)')
    plt.legend(legend_text)#, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()

#### Figure 8: Plot BN2 RAUC vs BN3 RAUC

In [None]:
plt.figure(8, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
    x = d['epochs_df']['BN2 RAUC']
    y = d['epochs_df']['BN3 RAUC']
    
    plt.xlim([0, 25])
    plt.ylim([0, 50])
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    legend_text = ['R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(x))]
    
#     plt.scatter(x, y)
#     line_plot_x = np.linspace(plt.gca().get_xlim()[0],plt.gca().get_xlim()[1],100)
#     plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
    sns.regplot(x=x, y=y, ci=False)
    
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('BN2 RAUC (integrated rectified voltage on buccal nerve 2) (μV·s)')
    plt.ylabel('BN3 RAUC (integrated rectified voltage on buccal nerve 3) (μV·s)')
    plt.legend(legend_text)#, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()

#### Figure 9: Plot RN RAUC vs BN3 RAUC

In [None]:
plt.figure(9, figsize=(9,6))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    
    x = d['epochs_df']['RN RAUC']
    y = d['epochs_df']['BN3 RAUC']
    
    plt.xlim([0, 25])
    plt.ylim([0, 50])
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    legend_text = ['R$^2$ = {:.2f}, p = {:.3f}, n = {}'.format(model.rsquared, model.pvalues[1], len(x))]
    
#     plt.scatter(x, y)
#     line_plot_x = np.linspace(plt.gca().get_xlim()[0],plt.gca().get_xlim()[1],100)
#     plt.plot(line_plot_x, line_plot_x*model.params[1] + model.params[0])
    sns.regplot(x=x, y=y, ci=False)
    
    plt.title('{}\nt = {}'.format(data_set_name, d['time_windows_to_keep']))
    plt.xlabel('RN RAUC (integrated rectified voltage on radular nerve) (μV·s)')
    plt.ylabel('BN3 RAUC (integrated rectified voltage on buccal nerve 3) (μV·s)')
    plt.legend(legend_text)#, fontsize = 9)
    sns.despine(ax=plt.gca(), offset=10, trim=True) # offset axes from plot
plt.tight_layout()