# Import Packages

In [None]:
################################################################################
# NUMPY
# conda install numpy

import numpy as np

################################################################################
# MATPLOTLIB
# conda install matplotlib

import matplotlib.pyplot as plt

################################################################################
# SEABORN
# conda install seaborn

import seaborn as sns

################################################################################
# PYLTTB - Time Series Downsampling Using Largest-Triangle-Three-Buckets
# pip install pylttb

from pylttb import lttb

################################################################################
# NEO
# pip install git+https://github.com/NeuralEnsemble/python-neo.git
# - AxoGraph support requires axographio to be installed: pip install axographio

# import neo

################################################################################
# QUANTITIES
# conda install quantities

import quantities as pq
pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

################################################################################
# ELEPHANT
# pip install elephant

import elephant

################################################################################
# PANDAS
# conda install pandas

import pandas as pd

################################################################################
# STATSMODELS
# conda install statsmodels

import statsmodels.api as sm

################################################################################
# SPM1D - One-Dimensional Statistical Parametric Mapping
# pip install spm1d

# import spm1d

################################################################################
# EPHYVIEWER
# pip install git+https://github.com/jpgill86/ephyviewer.git@experimental
# - requires PyAV: conda install -c conda-forge av

# import ephyviewer

################################################################################
# ParseMetadata
# - requires ipywidgets: conda install ipywidgets
# - requires yaml:       conda install pyyaml

from ParseMetadata import LoadMetadata

################################################################################
# ImportData

from ImportData import LoadAndPrepareData

################################################################################
# NeoUtilities

from NeoUtilities import NeoAnalogSignalDerivative, NeoAnalogSignalRAUC, CausalAlphaKernel

################################################################################
# NeoToEphyviewerBridge

# from NeoToEphyviewerBridge import NeoSegmentToEphyviewerSources#, PlotExampleWithEphyviewer

# IPython Magics

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

# Data Parameters

In [None]:
# specify the data sets to analyze
data_sets = [
    '2018-06-21_IN-VIVO_JG-08 002',
#     '2018-06-24_IN-VIVO_JG-08 001',
]

# load the metadata containing file paths
all_metadata = LoadMetadata()

# store metadata in a dictionary that we will add to later
data = {}
for data_set_name in data_sets:
    data[data_set_name] = {}
    data[data_set_name]['metadata'] = all_metadata[data_set_name]

In [None]:
# select which swallow sequences to use

data['2018-06-21_IN-VIVO_JG-08 002']['time_windows_to_keep'] = [
    [-np.inf, np.inf], # keep everything
#     [659, 726.1], # tension maximized and no perturbation
#     [666.95, 726.1], # tension maximized and no perturbation, and extra long large hump excluded
]

# data['2018-06-24_IN-VIVO_JG-08 001']['time_windows_to_keep'] = [
# #     [-np.inf, np.inf], # keep everything
#     [2244.7, 2259.9], [2269.5, 2355.95], # tension maximized and no perturbation
# #     [2244.7, 2259.9], [2269.5, 2290.2], [2307, 2355.95], # tension maximized and no perturbation, and extra long large hump excluded
# ]

# Import the Data

In [None]:
for data_set_name, d in data.items():

    # read in the data
    blk = LoadAndPrepareData(d['metadata'])

    # grab the force vs time data and rescale to mN
    signalNameToAxoGraphIndex = {sig.name:i for i, sig in enumerate(blk.segments[0].analogsignals)}
    force_idx = signalNameToAxoGraphIndex.get('Force', None)
    assert(force_idx is not None)
    d['force_sig'] = blk.segments[0].analogsignals[force_idx].rescale('mN')

    # grab the spike trains
    spike_trains = {}
    for st in blk.segments[0].spiketrains:
        spike_trains[st.name] = st
    d['spike_trains'] = spike_trains

    # grab the sampling period
    d['sampling_period'] = blk.segments[0].analogsignals[0].sampling_period

#     display(blk)

In [None]:
for data_set_name, d in data.items():
    
    # grab the output force timing data
    epochs_df = pd.read_csv(d['metadata']['epoch_encoder_file'])
    
    # compute end times
    epochs_df = epochs_df.assign(**{
        'end': lambda e: e['time'] + e['duration'],
    })
    
    # copy middle times (end of large hump and start of small hump) into 'force' epochs
    for i, epoch in epochs_df[epochs_df['label'] == 'force'].iterrows():
        for j, subepoch in epochs_df[epochs_df['label'] == 'large hump'].iterrows():
            if subepoch['time'] >= epoch['time']-1e-7 and subepoch['end'] <= epoch['end']+1e-7:
                epochs_df.loc[i, 'middle'] = subepoch['end']
    
    # find max forces in each epoch
    for i, epoch in epochs_df[epochs_df['label'] == 'force'].iterrows():
        epochs_df.loc[i, 'large max'] = max(d['force_sig'].time_slice(epoch['time']  *pq.s, epoch['middle']*pq.s).magnitude)[0]
        epochs_df.loc[i, 'small max'] = max(d['force_sig'].time_slice(epoch['middle']*pq.s, epoch['end']   *pq.s).magnitude)[0]
    
    # drop all but 'force' rows
    epochs_df = epochs_df[epochs_df['label'] == 'force']
    
    # keep only epochs that are entirely inside the time windows
    epochs_df = epochs_df[np.any(list(map(lambda t: (t[0] <= epochs_df['time']) & (epochs_df['end'] <= t[1]), d['time_windows_to_keep'])), axis=0)]
    
    # colors
    epochs_df = epochs_df.assign(colormap_arg = np.linspace(0, 1, len(epochs_df)))
    
    d['epochs_df'] = epochs_df
    
#     print(data_set_name)
#     display(epochs_df)

# Plots

In [None]:
# color map
cm = plt.cm.cool
# cm = plt.cm.brg
# cm = plt.cm.RdBu

##### Figure 1: Plot forces across real time

In [None]:
# plt.figure(1, figsize=(9,4))
# for i, data_set_name in enumerate(data_sets):
#     d = data[data_set_name]
#     plt.subplot(1, len(data), i+1)
#     plt.title(data_set_name)
#     plt.ylabel('Force (mN)')
#     plt.xlabel('Original chart time (s)')
#     for j, epoch in d['epochs_df'].iterrows():
#         epoch_force_sig = d['force_sig'].time_slice(epoch['time']*pq.s, epoch['end']*pq.s)
#         plt.plot(epoch_force_sig.times, epoch_force_sig.as_array(), color=cm(epoch['colormap_arg']))
# plt.tight_layout()

##### Figure 2: Plot forces, spike trains, and firing rate models

In [None]:
# spike_labels = d['spike_trains'].keys()
spike_labels = [
#     'I2',
#     'B8a/b',
#     'B3 (50-100 uV)',
#     '? (45-50 uV)',
#     'B6/B9 ? (26-45 uV)',
    'B38 ? (17-26 uV)',
#     '? (15-17 uV)',
#     'B4/B5',
]

plt.figure(2, figsize=(9,2+2*len(spike_labels)))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]

    # === FORCE ===
    ax = plt.subplot(len(spike_labels)+1, len(data), 2*i+1)
    plt.title(data_set_name)
    plt.ylabel('Force (mN)')
    plt.xlabel('Time (s)')
    plt.xlim(min(d['epochs_df']['time']), max(d['epochs_df']['end']))

    for j, epoch in d['epochs_df'].iterrows():
        epoch_force_sig = d['force_sig'].time_slice(epoch['time']*pq.s, epoch['end']*pq.s)
        plt.plot(epoch_force_sig.times, epoch_force_sig.as_array(), color=cm(epoch['colormap_arg']))

    # === RASTER PLOTS + RATE MODELS ===
    for j, spike_label in enumerate(spike_labels):
        st = d['spike_trains'][spike_label]
        st = st.time_slice(
            min(d['epochs_df']['time'])*pq.s - 5*pq.s,
            max(d['epochs_df']['end'])*pq.s  + 5*pq.s
        ) # drop spikes outside plot range except for a few sec margin so beginning and final firing rates are accurate

        plt.subplot(len(spike_labels)+1, len(data), 2*i+2+j, sharex=ax)
        plt.ylabel(spike_label + '\n(rate model)')
        plt.xlabel('Time (s)')
        plt.xlim(min(d['epochs_df']['time']), max(d['epochs_df']['end']))
        plt.ylim([-2, 40])

        # raster plot
        plt.eventplot(positions=st, lineoffsets=-1, colors='red')

        # spike train convolution
        kernels = [
            CausalAlphaKernel(0.03*np.sqrt(2)*pq.s), # match my old poster's synapse model
#             CausalAlphaKernel(0.2*pq.s),
#             elephant.kernels.AlphaKernel(0.03*np.sqrt(2)*pq.s),
#             elephant.kernels.AlphaKernel(0.2*pq.s),
#             elephant.kernels.EpanechnikovLikeKernel(0.2*pq.s),
#             elephant.kernels.ExponentialKernel(0.2*pq.s),
            elephant.kernels.GaussianKernel(0.2*pq.s),
#             elephant.kernels.LaplacianKernel(0.2*pq.s),
#             elephant.kernels.RectangularKernel(0.2*pq.s),
#             elephant.kernels.TriangularKernel(0.2*pq.s)
        ]
        for kernel in kernels:
            rate = elephant.statistics.instantaneous_rate(
                spiketrain=st, sampling_period=d['sampling_period'], kernel=kernel)
            plt.plot(rate.times.rescale('s'), rate)

        # instantaneous firing frequency step plot
#         plt.plot(st[:-1], 1/elephant.statistics.isi(st), drawstyle='steps-post')

plt.tight_layout()