# Import Packages

In [None]:
################################################################################
# NUMPY
# conda install numpy

import numpy as np

################################################################################
# MATPLOTLIB
# conda install matplotlib

import matplotlib.pyplot as plt

################################################################################
# NEO
# pip install git+https://github.com/NeuralEnsemble/python-neo.git
# - AxoGraph support requires axographio to be installed: pip install axographio

# import neo

################################################################################
# QUANTITIES
# conda install quantities

import quantities as pq
pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

################################################################################
# ELEPHANT
# pip install elephant

import elephant

################################################################################
# PANDAS
# conda install pandas

import pandas as pd

################################################################################
# SPM1D - One-Dimensional Statistical Parametric Mapping
# pip install spm1d

import spm1d

################################################################################
# EPHYVIEWER
# pip install git+https://github.com/jpgill86/ephyviewer.git@experimental
# - requires PyAV: conda install -c conda-forge av

# import ephyviewer

################################################################################
# ParseMetadata
# - requires ipywidgets: conda install ipywidgets
# - requires yaml:       conda install pyyaml

from ParseMetadata import LoadMetadata

################################################################################
# ImportData

from ImportData import LoadAndPrepareData

################################################################################
# NeoUtilities

# from NeoUtilities import NeoAnalogSignalRAUC

################################################################################
# NeoToEphyviewerBridge

# from NeoToEphyviewerBridge import NeoSegmentToEphyviewerSources#, PlotExampleWithEphyviewer

# IPython Magics

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

# Data Parameters

In [None]:
# specify the data sets to analyze
data_sets = [
    '2018-06-21_IN-VIVO_JG-08 002',
    '2018-06-24_IN-VIVO_JG-08 001',
]

# load the metadata containing file paths
all_metadata = LoadMetadata()

# store metadata in a dictionary that we will add to later
data = {}
for data_set_name in data_sets:
    data[data_set_name] = {}
    data[data_set_name]['metadata'] = all_metadata[data_set_name]

In [None]:
# select which swallow sequences to use

data['2018-06-21_IN-VIVO_JG-08 002']['time_windows_to_keep'] = [
#     [-np.inf, np.inf], # keep everything
    [659, 726.1], # tension maximized and no perturbation
#     [666.95, 726.1], # tension maximized and no perturbation, and extra long large hump excluded
]

data['2018-06-24_IN-VIVO_JG-08 001']['time_windows_to_keep'] = [
#     [-np.inf, np.inf], # keep everything
    [2244.7, 2259.9], [2269.5, 2355.95], # tension maximized and no perturbation
#     [2244.7, 2259.9], [2269.5, 2290.2], [2307, 2355.95], # tension maximized and no perturbation, and extra long large hump excluded
]

# Import the Data

In [None]:
for data_set_name, d in data.items():

    # read in the data
    blk = LoadAndPrepareData(d['metadata'])

    # grab the force vs time data and rescale to mN
    signalNameToAxoGraphIndex = {sig.name:i for i, sig in enumerate(blk.segments[0].analogsignals)}
    force_idx = signalNameToAxoGraphIndex.get('Force', None)
    assert(force_idx is not None)
    d['force_sig'] = blk.segments[0].analogsignals[force_idx].rescale('mN')

    # apply a super-low-pass filter to force signal
    d['smoothed_force_sig'] = elephant.signal_processing.butter(  # may raise a FutureWarning
        signal = d['force_sig'],
        lowpass_freq = 0.5*pq.Hz,
    )

#     display(blk)

In [None]:
for data_set_name, d in data.items():
    
    # grab the output force timing data
    epochs_df = pd.read_csv(d['metadata']['epoch_encoder_file'])
    
    # compute end times
    epochs_df = epochs_df.assign(**{
        'end': lambda e: e['time'] + e['duration'],
    })
    
    # copy middle times (end of large hump and start of small hump) into 'force' epochs
    for i, epoch in epochs_df[epochs_df['label'] == 'force'].iterrows():
        for j, subepoch in epochs_df[epochs_df['label'] == 'large hump'].iterrows():
            if subepoch['time'] >= epoch['time']-1e-7 and subepoch['end'] <= epoch['end']+1e-7:
                epochs_df.loc[i, 'middle'] = subepoch['end']
    
    # find max forces in each epoch
    for i, epoch in epochs_df[epochs_df['label'] == 'force'].iterrows():
        epochs_df.loc[i,          'large max'] = max(         d['force_sig'].time_slice(epoch['time']  *pq.s, epoch['middle']*pq.s).magnitude)[0]
        epochs_df.loc[i,          'small max'] = max(         d['force_sig'].time_slice(epoch['middle']*pq.s, epoch['end']   *pq.s).magnitude)[0]
        epochs_df.loc[i, 'smoothed large max'] = max(d['smoothed_force_sig'].time_slice(epoch['time']  *pq.s, epoch['middle']*pq.s).magnitude)[0]
        epochs_df.loc[i, 'smoothed small max'] = max(d['smoothed_force_sig'].time_slice(epoch['middle']*pq.s, epoch['end']   *pq.s).magnitude)[0]
    
    # drop all but 'force' rows
    epochs_df = epochs_df[epochs_df['label'] == 'force']
    
    # keep only epochs that are entirely inside the time windows
    epochs_df = epochs_df[np.any(list(map(lambda t: (t[0] <= epochs_df['time']) & (epochs_df['end'] <= t[1]), d['time_windows_to_keep'])), axis=0)]
    
    # colors
    epochs_df = epochs_df.assign(colormap_arg = np.linspace(0, 1, len(epochs_df)))
    
    d['epochs_df'] = epochs_df
    
#     print(data_set_name)
#     display(epochs_df)

# Plots

In [None]:
# color map
cm = plt.cm.cool
# cm = plt.cm.brg
# cm = plt.cm.RdBu

# fig size in inches
figsize = (9,3)

##### Figure 1: Plot forces across real time

In [None]:
plt.figure(1, figsize=figsize)
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    plt.title(data_set_name)
    plt.ylabel('Force (mN)')
    plt.xlabel('Original chart time (s)')
    for j, epoch in d['epochs_df'].iterrows():
        epoch_force_sig = d['force_sig'].time_slice(epoch['time']*pq.s, epoch['end']*pq.s)
        plt.plot(epoch_force_sig.times, epoch_force_sig.as_array(), color=cm(epoch['colormap_arg']))
plt.tight_layout()

##### Figure 2: Plot forces with aligned start times

In [None]:
plt.figure(2, figsize=figsize)
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    plt.title(data_set_name)
    plt.ylabel('Force (mN)')
    plt.xlabel('Zeroed time (s)')
    for j, epoch in d['epochs_df'].iterrows():
        epoch_force_sig = d['force_sig'].time_slice(epoch['time']*pq.s, epoch['end']*pq.s)
        epoch_force_sig.t_start = 0*pq.s
        plt.plot(epoch_force_sig.times, epoch_force_sig.as_array(), color=cm(epoch['colormap_arg']))

##### Figure 3: Plot forces with uniformaly normalized time

In [None]:
plt.figure(3, figsize=figsize)
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    plt.title(data_set_name)
    plt.ylabel('Force (mN)')
    plt.xlabel('Uniformly normalized time')
    for j, epoch in d['epochs_df'].iterrows():
        epoch_force_sig = d['force_sig'].time_slice(epoch['time']*pq.s, epoch['end']*pq.s)
        epoch_force_sig.t_start = 0*pq.s
        epoch_force_sig.sampling_period = 1/len(epoch_force_sig.times)*pq.s
        plt.plot(epoch_force_sig.times, epoch_force_sig.as_array(), color=cm(epoch['colormap_arg']))

##### Figure 4: Plot normalized forces with uniformaly normalized time

In [None]:
plt.figure(4, figsize=figsize)
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    plt.title(data_set_name)
    plt.ylabel('Normalized force')
    plt.xlabel('Uniformly normalized time')
    for j, epoch in d['epochs_df'].iterrows():
        epoch_force_sig = d['force_sig'].time_slice(epoch['time']*pq.s, epoch['end']*pq.s) / epoch['smoothed large max']
        epoch_force_sig.t_start = 0*pq.s
        epoch_force_sig.sampling_period = 1/len(epoch_force_sig.times)*pq.s
        plt.plot(epoch_force_sig.times, epoch_force_sig.as_array(), color=cm(epoch['colormap_arg']))

##### Figure 5: Plot forces with time normalized separately for large and small humps

In [None]:
plt.figure(5, figsize=figsize)
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    plt.title(data_set_name)
    plt.ylabel('Force (mN)')
    plt.xlabel('Phase-dependent normalized time')
    for j, epoch in d['epochs_df'].iterrows():
        epoch_large_force_sig = d['force_sig'].time_slice(epoch['time']*pq.s, epoch['middle']*pq.s)
        epoch_small_force_sig = d['force_sig'].time_slice(epoch['middle']*pq.s, epoch['end']*pq.s)
        epoch_large_force_sig.t_start = 0*pq.s
        epoch_small_force_sig.t_start = 1*pq.s
        epoch_large_force_sig.sampling_period = 1/len(epoch_large_force_sig.times)*pq.s
        epoch_small_force_sig.sampling_period = 1/len(epoch_small_force_sig.times)*pq.s
        plt.plot(np.concatenate([epoch_large_force_sig.times, epoch_small_force_sig.times]), np.concatenate([epoch_large_force_sig.as_array(), epoch_small_force_sig.as_array()]), color=cm(epoch['colormap_arg']))

##### Figure 6: Plot normalized forces with time normalized separately for large and small humps

In [None]:
plt.figure(6, figsize=figsize)
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    plt.title(data_set_name)
    plt.ylabel('Normalized force')
    plt.xlabel('Phase-dependent normalized time')
    for j, epoch in d['epochs_df'].iterrows():
        epoch_large_force_sig = d['force_sig'].time_slice(epoch['time']*pq.s, epoch['middle']*pq.s) / epoch['smoothed large max']
        epoch_small_force_sig = d['force_sig'].time_slice(epoch['middle']*pq.s, epoch['end']*pq.s)  / epoch['smoothed large max']
        epoch_large_force_sig.t_start = 0*pq.s
        epoch_small_force_sig.t_start = 1*pq.s
        epoch_large_force_sig.sampling_period = 1/len(epoch_large_force_sig.times)*pq.s
        epoch_small_force_sig.sampling_period = 1/len(epoch_small_force_sig.times)*pq.s
        plt.plot(np.concatenate([epoch_large_force_sig.times, epoch_small_force_sig.times]), np.concatenate([epoch_large_force_sig.as_array(), epoch_small_force_sig.as_array()]), color=cm(epoch['colormap_arg']))

##### Figure 7: Reproduce Fig. 5 using interpolated data

In [None]:
# interpolate and resample the signals at regular intervals in preparation for averaging
for data_set_name, d in data.items():

    epoch_large_force_values = []
    epoch_small_force_values = []
    for i, epoch in d['epochs_df'].iterrows():
        epoch_large_force_sig = d['force_sig'].time_slice(epoch['time']*pq.s, epoch['middle']*pq.s)
        epoch_small_force_sig = d['force_sig'].time_slice(epoch['middle']*pq.s, epoch['end']*pq.s)
        epoch_large_force_values.append(epoch_large_force_sig.as_array().flatten())
        epoch_small_force_values.append(epoch_small_force_sig.as_array().flatten())

    # number of points per curve before interpolation
#     print([x.shape for x in epoch_large_force_values])
#     print([x.shape for x in epoch_small_force_values])

    # linear interpolation and resampling to n_samples data points
    n_samples = 1000
    epoch_large_force_values = spm1d.util.interp(epoch_large_force_values, Q=n_samples)
    epoch_small_force_values = spm1d.util.interp(epoch_small_force_values, Q=n_samples)

    # number of points per curve after interpolation
#     print([x.shape for x in epoch_large_force_values])
#     print([x.shape for x in epoch_small_force_values])

    # combine large and small into unified time series
    epoch_all_force_values = np.concatenate([epoch_large_force_values, epoch_small_force_values], axis=1)
#     print(epoch_all_force_values.shape)

    # evenly spaced time values from 0 to 2
    times = np.linspace(0, 2, 2*n_samples)
    
    d['resampled_epoch_all_force_values'] = epoch_all_force_values
    d['resampled_times'] = times

In [None]:
plt.figure(7, figsize=figsize)
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    plt.title(data_set_name)
    plt.ylabel('Normalized force')
    plt.xlabel('Phase-dependent normalized time')
    for epoch, colormap_arg in zip(d['resampled_epoch_all_force_values'], d['epochs_df']['colormap_arg']):
        plt.plot(d['resampled_times'], epoch, color=cm(colormap_arg))

##### Figure 8: Plot mean and standard deviation cloud

In [None]:
plt.figure(8, figsize=figsize)
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.subplot(1, len(data), i+1)
    plt.title(data_set_name)
    plt.ylabel('Normalized force')
    plt.xlabel('Phase-dependent normalized time')
    spm1d.plot.plot_mean_sd(d['resampled_epoch_all_force_values'])#, x=d['resampled_times'])

#     # verify that spm1d.plot.plot_mean_sd does what I think it does
#     mean = np.mean(d['resampled_epoch_all_force_values'], axis=0)
#     std = np.std(d['resampled_epoch_all_force_values'], axis=0)
#     plt.plot(mean, 'w:')
#     plt.plot(mean+std, 'r:')
#     plt.plot(mean-std, 'b:')

##### Figure 9: Overlay mean and standard deviation cloud

In [None]:
plt.figure(9, figsize=tuple(np.array(figsize)*(0.5,1)))
for i, data_set_name in enumerate(data_sets):
    d = data[data_set_name]
    plt.ylabel('Normalized force')
    plt.xlabel('Phase-dependent normalized time')
    spm1d.plot.plot_mean_sd(d['resampled_epoch_all_force_values'])#, x=d['resampled_times'])
    
#     # verify that spm1d.plot.plot_mean_sd does what I think it does
#     mean = np.mean(d['resampled_epoch_all_force_values'], axis=0)
#     std = np.std(d['resampled_epoch_all_force_values'], axis=0)
#     plt.plot(mean, 'w:')
#     plt.plot(mean+std, 'r:')
#     plt.plot(mean-std, 'b:')