In [1]:
# This snippet of code properly adds the working source root path to python's path
# so you no longer have to install spykshrk through setuptools
import sys, os
root_depth = 2
notebook_dir = globals()['_dh'][0]
root_path = os.path.abspath(os.path.join(notebook_dir, '../'*root_depth))
# Add to python's path
try:
    while True:
        sys.path.remove(root_path)
except ValueError:
    # no more root paths
    pass
sys.path.append(root_path)
# Alternatively set root path as current working directory
#os.chdir(root_path)

In [2]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib
import matplotlib.pyplot as plt
import json
import os
import scipy.signal
import holoviews as hv

import warnings

from spykshrk.realtime.decoder_process import PointProcessDecoder

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                              LinearPosition, StimLockout, Posteriors, FlatLinearPosition
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer, DecodeErrorVisualizer

from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError
    
#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 180)

 
idx = pd.IndexSlice
matplotlib.rcParams.update({'font.size': 28})

hv.extension('matplotlib')
hv.extension('bokeh')

In [3]:
from holoviews import Store
from bokeh.models.arrow_heads import TeeHead
Store.add_style_opts(hv.ErrorBars, ['upper_head', 'lower_head'], backend='bokeh')

In [4]:
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv_300samp/bond.config.json'
config = json.load(open(config_file, 'r'))

day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]
time_bin_size = config['pp_decoder']['bin_size']

# Main hdf5 data source file name
hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Open data file
store = pd.HDFStore(hdf_file, mode='r')

# Encapsulate Spike Observation panda table in container
observ_obj = SpikeObservation.from_realtime(store['rec_3'], day=day, epoch=epoch, enc_settings=encode_settings)

realtime_posteriors = Posteriors.from_realtime(store['rec_4'], day=day, epoch=epoch, 
                                               enc_settings=encode_settings)

# Grab stimulation lockout times
stim_lockout = StimLockout.from_realtime(store['rec_11'], enc_settings=encode_settings)


# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

In [8]:
#%%time
# Run spykshrk.realtime version of point process decoding
observ_obj.update_observations_bins(300, inplace=True)

# Create and setup online point process decoder
pp_decoder = PointProcessDecoder(pos_range=[config['encoder']['position']['lower'],
                                            config['encoder']['position']['upper']],
                                 pos_bins=config['encoder']['position']['bins'],
                                 time_bin_size=config['pp_decoder']['bin_size'],
                                 arm_coor=config['encoder']['position']['arm_pos'],
                                 uniform_gain=config['pp_decoder']['trans_mat_uniform_gain'])

pp_decoder.select_ntrodes(config['simulator']['nspike_animal_info']['tetrodes'])

observ_obj.update_observations_bins(time_bin_size)

num_time_bins = observ_obj['dec_bin'].max()

# Group by bin
groups = observ_obj.groupby('dec_bin')

last_bin_id = 0
bin_timestamps = []
spykshrk_posteriors = np.zeros([num_time_bins+1, config['encoder']['position']['bins']])

for bin_id, spikes_in_bin in groups:
    bin_timestamps.append(spikes_in_bin['dec_bin_start'].iloc[0])
    if last_bin_id <= bin_id - 1:
        # increment bins with no spikes
        for bin_no_spk_id in range(last_bin_id + 1, bin_id):
            bin_timestamps.append(bin_timestamps[-1] + time_bin_size)
            post = pp_decoder.increment_no_spike_bin()
            spykshrk_posteriors[bin_no_spk_id, :] = post
        
    # Add 
    for elec_grp_id, dec in zip(spikes_in_bin.loc[:, 'elec_grp_id'].values, 
                   spikes_in_bin.loc[:, 'x000': 'x{:03d}'.
                                     format(config['encoder']['position']['bins']-1)].values):
        pp_decoder.add_observation(elec_grp_id, dec)
        
    post = pp_decoder.increment_bin()
    spykshrk_posteriors[bin_id, :] = post
    last_bin_id = bin_id
    
spykshrk_posteriors = Posteriors.from_numpy(spykshrk_posteriors, day=day, epoch=epoch, 
                                            timestamps=np.array(bin_timestamps),
                                            times=np.array(bin_timestamps)/30000, columns=encode_settings.pos_col_names,
                                            enc_settings=encode_settings)

In [9]:
%%output backend='bokeh' size=400 holomap='scrubber'
%%opts RGB { +framewise} [height=100 width=250 colorbar=True]
%%opts Points {+framewise} [height=100 width=250] (marker='o' size=4 alpha=0.5)


## Plot posteriors

plt_ranges = [[2461 + 250, 2461 + 400]]

spyk_dec_viz = DecodeVisualizer(spykshrk_posteriors, linpos=lin_obj, 
                           enc_settings=encode_settings)

realtime_dec_viz = DecodeVisualizer(realtime_posteriors, linpos=lin_obj, 
                           enc_settings=encode_settings)

plt1 = spyk_dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY(), plt_range=10, slide=10)
plt2 = realtime_dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY(), plt_range=10, slide=10)

(plt1 + plt2).cols(1)

In [24]:
np.insert(np.diff(dec_est_pos.index.get_level_values('time')), 0, 0)

In [25]:
dec_est_pos['est_pos'].diff()/np.insert(np.diff(dec_est_pos.index.get_level_values('time')), 0, 0)

In [26]:
dec_est_pos = spykshrk_posteriors.get_distribution_view().idxmax(axis=1).apply(lambda x: int(x[1:])).to_frame()
dec_est_pos.columns = ['est_pos']
dec_est_pos['linvel_flat'] = dec_est_pos['est_pos'].diff()/np.insert(np.diff(dec_est_pos.index.get_level_values('time')), 0, 0)

dec_est_pos = FlatLinearPosition.create_default(dec_est_pos, sampling_rate=encode_settings.sampling_rate,
                                                arm_coord=encode_settings.arm_coordinates,
                                                parent=spykshrk_posteriors)

resamp_lin_obj = lin_obj.get_resampled(time_bin_size).get_pd_no_multiindex()

error_obj = LinearDecodeError()

error_table = error_obj.calc_error_table(resamp_lin_obj, dec_est_pos,
                                         encode_settings.arm_coordinates, 2)

print(error_table.loc[:, idx[:, 'abs_error']].median())
print(error_table.loc[:, idx[:, 'abs_error']].mean())

In [27]:
%%opts Points {+framewise} [height=500 width=1000] (color=Cycle(values=['#FF0099', '#99FF00', '#5555FF']))
%%opts ErrorBars {+framewise} (line_color=Cycle(values=['#FF0099', '#99FF00', '#5555FF']) alpha=0.5 line_width=1 upper_head=TeeHead(size=0) lower_head=TeeHead(size=0))
%%output holomap='scrubber'
#warnings.filterwarnings(action='')

dec_viz = DecodeErrorVisualizer(error_table)

dmap = dec_viz.plot_arms_error_dmap(20,50)

dmap

In [28]:
abs_error_table = error_table.loc[:, idx[:, 'abs_error']]
abs_error_comb = (abs_error_table[('center', 'abs_error')].
                  combine_first(abs_error_table[('left', 'abs_error')]).
                  combine_first(abs_error_table[('right', 'abs_error')]))

In [29]:
fig, ax = plt.subplots(figsize=(20,10))
abs_all_error = np.abs(abs_error_comb)
ax.hist(abs_all_error, range(200))
ax.text(0.8, 0.6,  "Mean error: {:.01f} cm\nMedian error: {:.01f} cm".format(np.mean(abs_all_error), 
                                                                             np.median(abs_all_error)),
        transform=ax.transAxes, horizontalalignment='right', bbox={'facecolor': 'white', 'pad':20})
plt.xlabel("Decode error (cm)")
plt.ylabel("Number of bins")
plt.xlim([0,200])
plt.title('Decoding error with 10 ms bins and >2 cm/s', fontdict={'fontweight':'bold'})
plt.show()