# Import Packages

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import quantities as pq
import neo
from neurotic._elephant_tools import CausalAlphaKernel, instantaneous_rate

pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

In [None]:
colors = {
    'B38':       '#EFBF46', # yellow
    'I2':        '#DC5151', # red
    'B8a/b':     '#DA8BC3', # pink
    'B6/B9':     '#64B5CD', # light blue
    'B3/B6/B9':  '#5A9BC5', # medium blue
    'B3':        '#4F80BD', # dark blue
    'B4/B5':     '#00A86B', # jade green
    'Force':     '0.7',     # light gray
    'Model':     '0.2',     # dark gray
}

# Load Data

In [None]:
directory = 'spikes-firing-rates-and-forces'

# filename = 'JG07 Tape nori 0.mat'
# filename = 'JG08 Tape nori 0.mat'
filename = 'JG08 Tape nori 1.mat'
# filename = 'JG08 Tape nori 1 superset.mat'  # this file is missing spikes for several swallows
# filename = 'JG08 Tape nori 2.mat'
# filename = 'JG11 Tape nori 0.mat'
# filename = 'JG12 Tape nori 0.mat'
# filename = 'JG12 Tape nori 1.mat'
# filename = 'JG14 Tape nori 0.mat'

file_basename = '.'.join(os.path.basename(filename).split('.')[:-1])

In [None]:
# read the data file containing force and spike trains
reader = neo.io.NeoMatlabIO(os.path.join(directory, filename))
blk = reader.read_block()
seg = blk.segments[0]
sigs = {sig.name:sig for sig in seg.analogsignals}
spiketrains = {st.name:st for st in seg.spiketrains}

# Plot Empirical Force

In [None]:
# plot the swallowing force measured by the force transducer
fig, ax = plt.subplots(1, 1, sharex=True, figsize=(8,4))
ax.plot(sigs['Force'].times.rescale('s'), sigs['Force'].rescale('mN'), c=colors['Force'])
ax.set_xlabel('Time (s)')
ax.set_ylabel('Force (mN)')
ax.set_title(file_basename)
plt.tight_layout()

# Model Parameters

In [None]:
# parameters for constructing the model
# - model force = sum of scaled (weighted) firing rates + offset
# - comment/uncomment an entry in firing_rate_params to exclude/include the unit (I2 muscle or motor neurons)
# - weights can be positive or negative
# - rate constants determine how quickly the effect of a unit builds and decays
# - the model will be plotted below against the empirical force, both normalized by their peak values

offset = 0

# firing_rate_params = {
# #     'I2':    {'weight': -0.002, 'rate_constant': 1},
# #     'B8a/b': {'weight': 0.05,  'rate_constant': 1},
#     'B3':    {'weight': 0.05,  'rate_constant': 1},
#     'B6/B9': {'weight': 0.05,  'rate_constant': 0.5},
#     'B38':   {'weight': 0.025, 'rate_constant': 1},
# #     'B4/B5': {'weight': 0.05,  'rate_constant': 1},
# }

firing_rate_params = {
#     'I2':    {'weight': -0.02, 'rate_constant': 1},
#     'B8a/b': {'weight': 0.05,  'rate_constant': 1},
    'B3':    {'weight': 0.05,  'rate_constant': 1},
    'B6/B9': {'weight': 0.1,  'rate_constant': 0.5},
    'B38':   {'weight': 0.05, 'rate_constant': 1},
#     'B4/B5': {'weight': 0.05,  'rate_constant': 1},
}

# Generate Firing Rate Model

In [None]:
firing_rates = {}
for name, params in firing_rate_params.items():
    weight = params['weight']
    rate_constant = params['rate_constant']

    # convolve the spike train with the kernel
    firing_rates[name] = instantaneous_rate(
        spiketrain=spiketrains[name],
        sampling_period=0.0002*pq.s,  # 5 kHz, same as data acquisition rate
        kernel=CausalAlphaKernel(rate_constant*pq.s),
    )
    firing_rates[name].name = f'{name}\nweight: {weight}\nrate const: {rate_constant} sec'

    # scale the firing rate by its weight
    firing_rates[name] *= weight

# create the model by summing the firing rates and adding the offset
firing_rates['Model'] = None
for name, params in firing_rate_params.items():
    if firing_rates['Model'] is None:
        firing_rates['Model'] = firing_rates[name].copy()
    else:
        firing_rates['Model'] += firing_rates[name]
firing_rates['Model'] += offset*pq.Hz
firing_rates['Model'].name = f'Model = Sum of\nScaled Rates + {offset}'

# Plot Model

In [None]:
# plot each spike train and the scaled (weighted) firing rate
fig, axes = plt.subplots(len(firing_rates)+1, 1, sharex=True, figsize=(8,2*len(firing_rates)))
for i, name in enumerate(firing_rates):
    ax = axes[i]
    if name in spiketrains:
        ax.eventplot(positions=spiketrains[name], lineoffsets=-1, colors=colors[name])
    ax.plot(firing_rates[name].times.rescale('s'), firing_rates[name].rescale('Hz'), c=colors[name])
    ax.set_ylabel(firing_rates[name].name)
    ax.set_ylim(-2, 3)

# plot force and the model, both normalized by their peaks
axes[-1].plot(sigs['Force'].times.rescale('s'), sigs['Force']/sigs['Force'].max(), c=colors['Force'])
axes[-1].plot(firing_rates['Model'].times.rescale('s'), firing_rates['Model']/firing_rates['Model'].max(), c=colors['Model'])
axes[-1].set_ylabel('Model vs. Force\n(both normalized)')

axes[-1].set_xlabel('Time (s)')
axes[0].set_title(file_basename)
plt.tight_layout()

# Plot Model for Grant

In [None]:
# use with JG08 Tape nori 1
time_slices = {
    'I2':    [670.7, 680.83],
    'B8a/b': [673.5, 679.59],
    'B3':    [675.645, 680.83],
    'B6/B9': [674.25, 680.83],
    'B38':   [670.7, 680.83],
    'Model': [672.26, 680.2],
    'Force': [672.26, 680.2],
}

# plot each spike train and the scaled (weighted) firing rate
fig, axes = plt.subplots(2*len(firing_rate_params)+1, 1, sharex=True, figsize=(6,len(firing_rate_params)*(16/17)+1*(20/17)), gridspec_kw={'height_ratios': [3, 1]*len(firing_rate_params) + [5]})
for i, name in enumerate(firing_rate_params):
    ax = axes[2*i]
    fr = firing_rates[name]
    st = spiketrains[name]
    if name in time_slices:
        fr = fr.copy().time_slice(time_slices[name][0]*pq.s, time_slices[name][1]*pq.s)
        st = st.copy().time_slice(time_slices[name][0]*pq.s, time_slices[name][1]*pq.s)
    ax.plot(fr.times.rescale('s'), fr.rescale('Hz'), c=colors[name])
    ax.annotate(name, xy=(0, 0.5), xycoords='axes fraction',
        ha='right', va='center', fontsize='large', color=colors[name], fontfamily='Serif',
    )
#     ax.set_ylim(0, 2.2)
    ax.axis('off')
    
    ax = axes[2*i+1]
    ax.eventplot(positions=st, lineoffsets=-1, colors=colors[name])
    ax.axis('off')

# plot force and the model, both normalized by their peaks
force = sigs['Force'].copy().time_slice(time_slices['Force'][0]*pq.s, time_slices['Force'][1]*pq.s)
model = firing_rates['Model'].time_slice(time_slices['Model'][0]*pq.s, time_slices['Model'][1]*pq.s)
axes[-1].plot(force.times.rescale('s'), force/force.max(), c=colors['Force'])
axes[-1].plot(model.times.rescale('s'), model/model.max(), c=colors['Model'])
axes[-1].annotate('Model\nvs.', xy=(-0.04, 0.6), xycoords='axes fraction',
    ha='center', va='center', fontsize='large', color=colors['Model'], fontfamily='Serif',
)
axes[-1].annotate('Force', xy=(-0.04, 0.35), xycoords='axes fraction',
    ha='center', va='center', fontsize='large', color=colors['Force'], fontfamily='Serif',
)
axes[-1].axis('off')

plt.tight_layout(0)