In [16]:
# This snippet of code properly adds the working source root path to python's path
# so you no longer have to install spykshrk through setuptools
import sys, os
root_depth = 2
notebook_dir = globals()['_dh'][0]
root_path = os.path.abspath(os.path.join(notebook_dir, '../'*root_depth))
# Add to python's path
try:
    while True:
        sys.path.remove(root_path)
except ValueError:
    # no more root paths
    pass
sys.path.append(root_path)
# Alternatively set root path as current working directory
#os.chdir(root_path)

In [17]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from functools import partial
import multiprocessing as mp
import pickle
import functools
import holoviews as hv

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPEncoder, OfflinePPDecoder
from spykshrk.franklab.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                              LinearPosition, StimLockout, Posteriors
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer, DecodeErrorVisualizer

from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError
#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 120)
hv.extension('matplotlib')
hv.extension('bokeh')
idx = pd.IndexSlice

In [18]:
from holoviews import Store
from bokeh.models.arrow_heads import TeeHead
Store.add_style_opts(hv.ErrorBars, ['upper_head', 'lower_head'], backend='bokeh')

In [21]:
%%time
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv_300samp/bond.config.json'
#config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_single.json'
#config_file = '/opt/data36/daliu/realtime/spykshrk/test/test.config.json'


config = json.load(open(config_file, 'r'))
day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]

# Main hdf5 data source file name
hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Open data file
store = pd.HDFStore(hdf_file, mode='r')

# Encapsulate Spike Observation panda table in container
observ_obj = SpikeObservation.from_realtime(store['rec_3'], day=day, epoch=epoch, enc_settings=encode_settings)

# Grab stimulation lockout times
stim_lockout = StimLockout.from_realtime(store['rec_11'], enc_settings=encode_settings)


# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

In [22]:
# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)
linflat_obj = lin_obj.get_mapped_single_axis()

In [26]:
#%%prun -r
# Run PP decoding algorithm

time_bin_size = 300

decoder = OfflinePPDecoder(observ_obj=observ_obj, trans_mat=OfflinePPEncoder.calc_uniform_trans_mat(encode_settings),
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           time_bin_size=time_bin_size)

posteriors = decoder.run_decoder()

In [24]:
dec_est_pos = posteriors.get_distribution_view().idxmax(axis=1).apply(lambda x: int(x[1:])).to_frame()
dec_est_pos.columns = ['est_pos']

resamp_lin_obj = lin_obj.get_resampled(time_bin_size).get_pd_no_multiindex()

dec_error = LinearDecodeError()

dec_error = dec_error.calc_error_table(resamp_lin_obj, dec_est_pos,
                                       encode_settings.arm_coordinates, 2)

print("Median:")
print(dec_error.loc[:, idx[:, 'abs_error']].median())
print("Mean:")
print(dec_error.loc[:, idx[:, 'abs_error']].mean())

In [25]:
%%opts ErrorBars {+framewise} [height=500 width=1000] (line_color=Cycle(values=['#FF0099', '#99FF00', '#5555FF']) line_width=1 upper_head=TeeHead(size=0) lower_head=TeeHead(size=0))
%%opts Points {+framewise} [height=500 width=1000] (color=Cycle(values=['#FF0099', '#99FF00', '#5555FF']))
%%output holomap='scrubber'

#warnings.filterwarnings(action='')

dec_viz = DecodeErrorVisualizer(dec_error)

dmap = dec_viz.plot_arms_error_dmap(20,50)

dmap