In [38]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib
import matplotlib.pyplot as plt
import json
import os
import scipy.signal

from spykshrk.realtime.decoder_process import PointProcessDecoder

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.pp_decoder.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                                         LinearPosition, StimLockout, Posteriors
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer

from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError

    
#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 180)

 
idx = pd.IndexSlice
matplotlib.rcParams.update({'font.size': 28})


In [2]:
cd /home/daliu/Src/spykshrk_realtime/

In [26]:
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv_300samp/bond.config.json'
config = json.load(open(config_file, 'r'))

day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]
time_bin_size = config['pp_decoder']['bin_size']

# Main hdf5 data source file name
hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Open data file
store = pd.HDFStore(hdf_file, mode='r')

# Encapsulate Spike Observation panda table in container
observ_obj = SpikeObservation.create_default(store['rec_3'], day=day, epoch=epoch)

realtime_posteriors = Posteriors.from_realtime(store['rec_4'], day=day, epoch=epoch, 
                                               encode_settings=encode_settings)

# Grab stimulation lockout times
stim_lockout = StimLockout.create_default(store['rec_11'])


# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

In [4]:
%%time
# Run spykshrk.realtime version of point process decoding

# Create and setup online point process decoder
pp_decoder = PointProcessDecoder(pos_range=[config['encoder']['position']['lower'],
                                            config['encoder']['position']['upper']],
                                 pos_bins=config['encoder']['position']['bins'],
                                 time_bin_size=config['pp_decoder']['bin_size'],
                                 arm_coor=config['encoder']['position']['arm_pos'],
                                 uniform_gain=config['pp_decoder']['trans_mat_uniform_gain'])

pp_decoder.select_ntrodes(config['simulator']['nspike_animal_info']['tetrodes'])

observ_obj.update_observations_bins(time_bin_size)

num_time_bins = observ_obj['dec_bin'].max()

# Group by bin
groups = observ_obj.groupby('dec_bin')

last_bin_id = 0
bin_timestamps = []
spykshrk_posteriors = np.zeros([num_time_bins+1, config['encoder']['position']['bins']])

for bin_id, spikes_in_bin in groups:
    bin_timestamps.append(spikes_in_bin['dec_bin_start'].iloc[0])
    if last_bin_id <= bin_id - 1:
        # increment bins with no spikes
        for bin_no_spk_id in range(last_bin_id + 1, bin_id):
            bin_timestamps.append(bin_timestamps[-1] + time_bin_size)
            post = pp_decoder.increment_no_spike_bin()
            spykshrk_posteriors[bin_no_spk_id, :] = post
        
    # Add 
    for ntrode_id, dec in zip(spikes_in_bin.loc[:, 'ntrode_id'].values, 
                   spikes_in_bin.loc[:, 'x000': 'x{:03d}'.
                                     format(config['encoder']['position']['bins']-1)].values):
        pp_decoder.add_observation(ntrode_id, dec)
        
    post = pp_decoder.increment_bin()
    spykshrk_posteriors[bin_id, :] = post
    last_bin_id = bin_id
    
spykshrk_posteriors = Posteriors.from_numpy(spykshrk_posteriors, day=day, epoch=epoch, 
                                            timestamps=np.array(bin_timestamps),
                                            times=np.array(bin_timestamps)/30000, columns=
                                           ['x{:03d}'.format(pos_ind) for pos_ind in 
                                            range(config['encoder']['position']['bins'])])

In [5]:
spykshrk_posteriors

In [6]:
## Plot posteriors

plt_ranges = [[2461 + 250, 2461 + 400]]
             
for plt_range in plt_ranges:
    
    plt.figure(figsize=[200,20])
    plt.subplot(2,1,1)
    ax1 = DecodeVisualizer.plot_decode_image(spykshrk_posteriors, plt_range, encode_settings)
    DecodeVisualizer.plot_linear_pos(lin_obj, plt_range)
    DecodeVisualizer.plot_stim_lockout(stim_lockout, plt_range, encode_settings.arm_coordinates[2][1] + 10)

    plt.subplot(2,1,2)
    ax1 = DecodeVisualizer.plot_decode_image(realtime_posteriors, plt_range, encode_settings)
    DecodeVisualizer.plot_linear_pos(lin_obj, plt_range)
    DecodeVisualizer.plot_stim_lockout(stim_lockout, plt_range, encode_settings.arm_coordinates[2][1] + 10)
    
    plt.xlim(plt_range)
plt.show()

In [104]:
dec_est_pos = realtime_posteriors.get_distribution_view().idxmax(axis=1).apply(lambda x: int(x[1:])).to_frame()
dec_est_pos.columns = ['est_pos']

resamp_lin_obj = lin_obj.get_resampled(time_bin_size).get_pd_no_multiindex()

dec_error = LinearDecodeError()

center_dec_error, left_dec_error, right_dec_error = dec_error.calc_error_table(resamp_lin_obj, dec_est_pos,
                                                                     encode_settings.arm_coordinates, 2)

print('median error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.median(center_dec_error['abs_error']),
                                                                       np.median(left_dec_error['abs_error']),
                                                                       np.median(right_dec_error['abs_error'])))

print('mean error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.mean(center_dec_error['abs_error']),
                                                                       np.mean(left_dec_error['abs_error']),
                                                                       np.mean(right_dec_error['abs_error'])))

In [31]:
plt_ranges = [[2350, 3400]]

plt.figure(figsize=[400,10])
for plt_range in plt_ranges:
    plot_arms_error(center_dec_error, left_dec_error, right_dec_error, plt_range)
    plt.xticks(np.arange(plt_range[0], plt_range[1], 10))
    
plt.show()

In [102]:
plt_range = [3317, 3330]

plt.figure(figsize=[40,10])
plot_arms_error(center_dec_error, left_dec_error, right_dec_error, plt_range)
plt.xticks(np.arange(plt_range[0], plt_range[1], 1))
plt.title('Center to right well trajectory, >2 cm/s in 10 ms bins, error bars show decoding location/error', fontdict={'fontweight':'bold'})

plt.show()

In [101]:
plt_range = [3241, 3257]

plt.figure(figsize=[40,10])
plot_arms_error(center_dec_error, left_dec_error, right_dec_error, plt_range)
plt.xticks(np.arange(plt_range[0], plt_range[1], 1))
plt.title('Center to left well trajectory, >2 cm/s in 10 ms bins, error bars show decoding location/error', fontdict={'fontweight':'bold'})
plt.show()

In [52]:
all_error = pd.concat([center_dec_error, left_dec_error, right_dec_error])

In [94]:
fig, ax = plt.subplots(figsize=(20,10))
abs_all_error = np.abs(all_error['error'])
ax.hist(abs_all_error, range(200))
ax.text(0.8, 0.6,  "Mean error: {:.01f} cm\nMedian error: {:.01f} cm".format(np.mean(abs_all_error), 
                                                                             np.median(abs_all_error)),
        transform=ax.transAxes, horizontalalignment='right', bbox={'facecolor': 'white', 'pad':20})
plt.xlabel("Decode error (cm)")
plt.ylabel("Number of bins")
plt.xlim([0,200])
plt.title('Decoding error with 10 ms bins and >2 cm/s', fontdict={'fontweight':'bold'})
plt.show()

In [68]:
 "Mean error: {:.02f}\nMedian error: {:.02f}".format(np.mean(abs_all_error), 
                                                                    np.median(abs_all_error))