In [7]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from functools import partial

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.pp_decoder.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                                         LinearPosition, StimLockout, Posteriors
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer

from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError, \
                                                      plot_arms_error

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)

idx = pd.IndexSlice

In [8]:
%%time
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv_300samp/bond.config.json'
#config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_single.json'

config = json.load(open(config_file, 'r'))
day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]

# Main hdf5 data source file name
hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Open data file
store = pd.HDFStore(hdf_file, mode='r')

# Encapsulate Spike Observation panda table in container
observ_obj = SpikeObservation.create_default(store['rec_3'], day=day, epoch=epoch)

# Grab stimulation lockout times
stim_lockout = StimLockout.create_default(store['rec_11'])


# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

In [14]:
%%time
# Run PP decoding algorithm

time_bin_size = 300

decoder = OfflinePPDecoder(lin_obj=lin_obj, observ_obj=observ_obj,
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           which_trans_mat='learned')

posteriors = decoder.run_decoder()

In [15]:
dec_est_pos = posteriors.get_distribution_view().idxmax(axis=1).apply(lambda x: int(x[1:])).to_frame()
dec_est_pos.columns = ['est_pos']

resamp_lin_obj = lin_obj.get_resampled(time_bin_size).get_pd_no_multiindex()

dec_error = LinearDecodeError()

center_dec_error, left_dec_error, right_dec_error = dec_error.calc_error_table(resamp_lin_obj, dec_est_pos,
                                                                     encode_settings.arm_coordinates, 2)

print('median error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.median(center_dec_error['abs_error']),
                                                                       np.median(left_dec_error['abs_error']),
                                                                       np.median(right_dec_error['abs_error'])))

print('mean error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.mean(center_dec_error['abs_error']),
                                                                       np.mean(left_dec_error['abs_error']),
                                                                       np.mean(right_dec_error['abs_error'])))

In [20]:
plt_ranges = [[2350, 3400], [4560, 5550]]
plt.figure(figsize=[400,10])
for plt_range in plt_ranges:
    plot_arms_error(center_dec_error, left_dec_error, right_dec_error, plt_range)

    
plt.show()