# Import Packages

In [None]:
# add the directory containing modules to the path
import sys
sys.path.append('../..')

In [None]:
################################################################################
# NUMPY
# conda install numpy

import numpy as np

################################################################################
# SCIPY
# conda install scipy

# import scipy as sp

################################################################################
# MATPLOTLIB
# conda install matplotlib

import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import MultipleLocator

################################################################################
# SEABORN
# conda install seaborn

import seaborn as sns

################################################################################
# NEO
# pip install neo>=0.7.1
# - AxoGraph support requires axographio to be installed: pip install axographio

# import neo

################################################################################
# QUANTITIES
# conda install quantities

import quantities as pq
pq.markup.config.use_unicode = True  # allow symbols like mu for micro in output
pq.mN = pq.UnitQuantity('millinewton', pq.N/1e3, symbol = 'mN');  # define millinewton

################################################################################
# ELEPHANT
# pip install elephant>=0.6.2

import elephant

################################################################################
# PANDAS
# conda install pandas

# import pandas as pd

################################################################################
# STATSMODELS
# conda install statsmodels

# import statsmodels.api as sm

################################################################################
# SPM1D - One-Dimensional Statistical Parametric Mapping
# pip install spm1d

# import spm1d

################################################################################
# EPHYVIEWER
# pip install git+https://github.com/jpgill86/ephyviewer.git@experimental
# - requires PyAV: conda install -c conda-forge av

# import ephyviewer

################################################################################
# ParseMetadata
# - requires ipywidgets: conda install ipywidgets
# - requires yaml:       conda install pyyaml

from analysis.modules.ParseMetadata import LoadMetadata

################################################################################
# ImportData

from analysis.modules.ImportData import LoadAndPrepareData

################################################################################
# NeoUtilities
# - requires pylttb: pip install pylttb

from analysis.modules.NeoUtilities import DownsampleNeoSignal#, BehaviorsDataFrame, NeoEpochToDataFrame, CausalAlphaKernel

################################################################################
# EphyviewerConfigurator

# from analysis.modules.EphyviewerConfigurator import EphyviewerConfigurator

# IPython Magics

In [None]:
# make figures interactive and open in a separate window
# %matplotlib qt

# make figures interactive and inline
%matplotlib notebook

# make figures non-interactive and inline
# %matplotlib inline

# Data Parameters

In [None]:
# # Jade's colors
# i2_color  = '#0072BD' # MATLAB blue
# rn_color  = '#D95319' # MATLAB orange
# bn2_color = '#EDB120' # MATLAB yellow
# bn3_color = '#7E2F8E' # MATLAB purple
# force_color = '#666666' # dark gray

In [None]:
# BRAIN Initiative Grant 2019 colors
i2_color  = '#FFAF14' # orange
rn_color  = '#00CC64' # green
bn2_color = '#4A72FF' # blue
bn3_color = '#7E2F8E' # purple
force_color = '#666666' # dark gray

In [None]:
# FRESH FOOD

outfile_basename = 'traces-fresh-food'
data_set_name = 'IN VIVO / JG08 / 2018-06-25 / 001'
t_start, t_stop = [1334, 1364] * pq.s # 30 sec
plots = [
    {'channel': 'I2',    'units': 'uV', 'ylim': [ -20,  20], 'decimation_factor':  10, 'color': i2_color},
    {'channel': 'RN',    'units': 'uV', 'ylim': [-150, 150], 'decimation_factor':  10, 'color': rn_color},
    {'channel': 'BN2',   'units': 'uV', 'ylim': [-180, 180], 'decimation_factor':  10, 'color': bn2_color},
    {'channel': 'BN3',   'units': 'uV', 'ylim': [-120, 120], 'decimation_factor':  10, 'color': bn3_color},
    {'channel': 'Force', 'units': 'mN', 'ylim': [ -10, 300], 'decimation_factor': 100, 'color': force_color},
]

In [None]:
# TWO-PLY NORI

outfile_basename = 'traces-two-ply-nori'
data_set_name = 'IN VIVO / JG08 / 2018-06-25 / 001'
t_start, t_stop = [3262, 3292] * pq.s # 30 sec
plots = [
    {'channel': 'I2',    'units': 'uV', 'ylim': [ -20,  20], 'decimation_factor':  10, 'color': i2_color},
    {'channel': 'RN',    'units': 'uV', 'ylim': [-150, 150], 'decimation_factor':  10, 'color': rn_color},
    {'channel': 'BN2',   'units': 'uV', 'ylim': [-180, 180], 'decimation_factor':  10, 'color': bn2_color},
    {'channel': 'BN3',   'units': 'uV', 'ylim': [-120, 120], 'decimation_factor':  10, 'color': bn3_color},
    {'channel': 'Force', 'units': 'mN', 'ylim': [ -10, 300], 'decimation_factor': 100, 'color': force_color},
]

In [None]:
# TAPE NORI

outfile_basename = 'traces-tape-nori'
data_set_name = 'IN VIVO / JG08 / 2018-06-21 / 002'
t_start, t_stop = [175, 205] * pq.s # 30 sec
plots = [
    {'channel': 'I2',    'units': 'uV', 'ylim': [ -30,  30], 'decimation_factor':  10, 'color': i2_color},
    {'channel': 'RN',    'units': 'uV', 'ylim': [ -60,  60], 'decimation_factor':  10, 'color': rn_color},
    {'channel': 'BN2',   'units': 'uV', 'ylim': [-120, 120], 'decimation_factor':  10, 'color': bn2_color},
    {'channel': 'BN3',   'units': 'uV', 'ylim': [-150, 150], 'decimation_factor':  10, 'color': bn3_color},
    {'channel': 'Force', 'units': 'mN', 'ylim': [ -10, 300], 'decimation_factor': 100, 'color': force_color},
]

In [None]:
# REGULAR NORI

outfile_basename = 'traces-regular-nori'
data_set_name = 'IN VIVO / JG08 / 2018-06-21 / 001'
t_start, t_stop = [2468, 2498] * pq.s # 30 sec
plots = [
    {'channel': 'I2',    'units': 'uV', 'ylim': [ -30,  30], 'decimation_factor':  10, 'color': i2_color},
    {'channel': 'RN',    'units': 'uV', 'ylim': [ -60,  60], 'decimation_factor':  10, 'color': rn_color},
    {'channel': 'BN2',   'units': 'uV', 'ylim': [-120, 120], 'decimation_factor':  10, 'color': bn2_color},
    {'channel': 'BN3',   'units': 'uV', 'ylim': [-150, 150], 'decimation_factor':  10, 'color': bn3_color},
    {'channel': 'Force', 'units': 'mN', 'ylim': [ -10, 300], 'decimation_factor': 100, 'color': force_color},
]

# Import and Process the Data

In [None]:
# load the metadata containing file paths
all_metadata = LoadMetadata(file='../../data/metadata.yml')

metadata = all_metadata[data_set_name]

blk = LoadAndPrepareData(metadata)

# Plot

In [None]:
sns.set(
#     context = 'poster',
    style = 'ticks',
    font_scale = 1,
    font = 'Palatino Linotype',
)

# thickness of lines in points
linewidth = 1

# specify the spacing of tick marks on the time axis in seconds
#   - major ticks are labeled
#   - minor ticks are not labeled
majorticks = 5
minorticks = 1

# specify the horizontal positioning of the y-axis labels
#   - using one value for all labels ensures that they are aligned
#   - this may need to be adjusted depending on the size of the tick mark labels
ylabel_offset = -0.06

# specify the positioning of the edges of the plots and the space between them
# layout_settings = dict(
#     left   = 0.045,
#     right  = 0.99,
#     top    = 0.97,
#     bottom = 0.1,
#     hspace = 0.2,
# )



def prettyplot(blk, t_start, t_stop, plots, outfile=None):
    
    signalNameToIndex = {sig.name:i for i, sig in enumerate(blk.segments[0].analogsignals)}
    
#     with sns.plotting_context('poster', font_scale=0.5):
    with sns.plotting_context('notebook', font_scale=1):
        
        plt.figure(figsize=(14, 7))

        num_subplots = len(plots)
        for i, p in enumerate(plots):

            # switch to the appropriate subplot in the figure
            if i==0:
                ax = plt.subplot(num_subplots, 1, i+1)
            else:
                plt.subplot(num_subplots, 1, i+1, sharex=ax)

            # select and rescale a channel for the subplot
            sig = blk.segments[0].analogsignals[signalNameToIndex[p['channel']]]
            sig = sig.time_slice(t_start, t_stop)
            sig = sig.rescale(p['units'])

            # downsample the data
            sig_downsampled = DownsampleNeoSignal(sig, p.get('decimation_factor', 1))

            # specify the x- and y-data for the subplot
            plt.plot(
                sig_downsampled.times,
                sig_downsampled.as_quantity(),
                linewidth=linewidth,
                color=p.get('color', 'k'),
            )

            # specify the y-axis label
            plt.ylabel(p.get('ylabel', sig.name+' ('+sig.units.dimensionality.string+')'))

            # position the y-axis label so that all subplot y-axis labels are aligned
            plt.gca().yaxis.set_label_coords(ylabel_offset, 0.5)

            # specify the plot range
            plt.xlim([t_start, t_stop])
            plt.ylim(p['ylim'])

            if i == num_subplots-1:
                # turn on minor (frequent and unlabeled) ticks for the bottom x-axis
                plt.gca().xaxis.set_minor_locator(MultipleLocator(minorticks))

                # turn on major (infrequent and labeled) ticks for the bottom x-axis
                plt.gca().xaxis.set_major_locator(MultipleLocator(majorticks))

                # disable scientific notation for major tick labels
                # plt.gca().xaxis.get_major_formatter().set_useOffset(False) # not necessary?

                # specify the bottom x-axis label
                plt.xlabel('Time ('+sig.times.units.dimensionality.string+')')

                # offset axes from plot
                sns.despine(ax=plt.gca(), offset=10)#, trim=True)
            else:
                # offset axes and remove x-axis
                sns.despine(ax=plt.gca(), offset=10, trim=True, bottom=True)
                plt.gca().xaxis.set_visible(False)

        # adjust the white space around and between the subplots
        # plt.subplots_adjust(**layout_settings)
        plt.gcf().tight_layout()

        if outfile_basename is not None:
            # specify file metadata (applicable only for PDF)
            metadata = dict(
                Subject = 'Data file: '  + blk.file_origin + '\n' +
                          'Start time: ' + str(t_start)    + '\n' +
                          'End time: '   + str(t_stop),
            )
            
            # output formats
            formats = ['pdf', 'svg', 'png']
            
            # resolution (applicable only for PNG)
            dpi = 300

            # write the figure to files
            for ext in formats:
                plt.gcf().savefig(outfile_basename+'.'+ext, metadata=metadata, dpi=dpi)

__IMPORTANT:__ Remember that the large number of points present in the vector graphics output formats (SVG, PDF) of these traces tend to bring most programs, including poster printer software, to a grinding halt. It is recommended for most applications, especially poster printing, to use a high resolution PNG or, if you prefer vector fonts, a combination of rasterized traces and vector labels constructed in Inkscape or Illustrator from a combination of PNG and SVG files.

In [None]:
prettyplot(blk, t_start, t_stop, plots, outfile_basename)