In [1]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import json
import os
import scipy.signal

from spykshrk.realtime.decoder_process import PointProcessDecoder

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.pp_clusterless import plot_decode_2d

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas

from spykshrk.franklab.pp_decoder.decode_error import bin_pos_data, calc_error_table, \
                                                      plot_arms_error

from spykshrk.franklab.pp_decoder.data_containers import LinearPositionContainer, SpikeObservation
    
#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)

 
idx = pd.IndexSlice


In [2]:
from spykshrk.realtime.simulator import nspike_data

In [3]:
%%time
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv_300samp/bond.config.json'
config = json.load(open(config_file, 'r'))

hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

store = pd.HDFStore(hdf_file, mode='r')
spike_obs = SpikeObservation(store['rec_3'])
decoder_df = store['rec_4']
stim_lockout = store['rec_11']

In [4]:
spike_decode

In [None]:
%%time
stim_lockout_ranges = stim_lockout.pivot(index='lockout_num',columns='lockout_state', values='timestamp')
stim_lockout_ranges = stim_lockout_ranges.reindex(columns=[1,0])


In [None]:
%%time
# Get table with decode for each spike and generate decode bin mask

dec_bin_size = 300     # Decode bin size in samples (usually 30kHz)

arm_coordinates = [[0, 69], [150, 150+102], [300, 300+104]]

spike_decode = spike_obs.get_observations_binned(dec_bin_size)

pos_upper = config['encoder']['position']['upper']
pos_lower = config['encoder']['position']['lower']
pos_num_bins = config['encoder']['position']['bins']
pos_bin_delta = ((pos_upper - pos_lower) / pos_num_bins)

x_bins = np.linspace(0, pos_bin_delta*(pos_num_bins-1), pos_num_bins)
x_bin_edges = np.linspace(0, pos_bin_delta*(pos_num_bins), pos_num_bins+1)

pos_kernel = gaussian(x_bins, x_bins[int(len(x_bins)/2)], 3)


In [None]:
# Get real position

nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

lin_obj = LinearPositionContainer(pos_data)

linpos_flat = lin_obj.get_mapped_single_axis()


In [None]:
%%time
# Run spykshrk.realtime version of point process decoding

pp_decoder = PointProcessDecoder(pos_range=[config['encoder']['position']['lower'],
                                            config['encoder']['position']['upper']],
                                 pos_bins=config['encoder']['position']['bins'],
                                 time_bin_size=config['pp_decoder']['bin_size'],
                                 arm_coor=config['pp_decoder']['arm_pos'],
                                 uniform_gain=config['pp_decoder']['trans_mat_uniform_gain'])

pp_decoder.select_ntrodes(config['simulator']['nspike_animal_info']['tetrodes'])

num_time_bins = spike_decode.loc[:,'dec_bin'].max()

groups = spike_decode.groupby('dec_bin')

last_bin_id = 0

spykshrk_posteriors = np.zeros([num_time_bins+1, pos_num_bins])

dec_bin_times = []

for bin_id, spikes_in_bin in groups:
    if last_bin_id <= bin_id - 1:
        for bin_no_spk_id in range(last_bin_id + 1, bin_id):
            # No spike in bin
            
            dec_bin_times.append(start_timestamp + bin_no_spk_id * dec_bin_size)
            post = pp_decoder.increment_no_spike_bin()
            spykshrk_posteriors[bin_no_spk_id, :] = post
        
    for ntrode_id, dec in zip(spikes_in_bin.loc[:, 'ntrode_id'].values, 
                              spikes_in_bin.loc[:, 'x0': 'x{:d}'.format(pos_num_bins-1)].values):
        pp_decoder.add_observation(ntrode_id, dec)
        
    post = pp_decoder.increment_bin()
    dec_bin_times.append(spikes_in_bin['dec_bin_start'].values[0])
    spykshrk_posteriors[bin_id, :] = post
    last_bin_id = bin_id
    

In [None]:

# plt_ranges = [[2461, 3404]]
plt_ranges = [[2461 + 700, 2461+900]]

for plt_range in plt_ranges:
    plt.figure(figsize=(400,20))
    plt.subplot(2,1,1)
    plot_decode_2d(decoder_df['timestamp'], 
                   decoder_df.loc[:,'x0':'x449'].values,
                   stim_lockout_ranges, 
                   decoder_df.set_index('real_pos_time')['real_pos'], plt_range, 1.0)
    
    plt.subplot(2,1,2)
    plot_decode_2d(np.array(dec_bin_times),
                   spykshrk_posteriors,
                   stim_lockout_ranges, 
                   linpos_flat.reset_index(level=['day', 'epoch'], drop=True),
                   plt_range, 1.0)

    
plt.show()

In [None]:
dec_est_map = x_bins[np.argmax(decoder_df.loc[:,'x0':'x449'].values, axis=1)]
dec_est_pos = pd.DataFrame({'est_pos': dec_est_map}, index=pd.Index(data=decoder_df['timestamp'],
                                                                    name='timestamp'))

pos_data_linpos = lin_obj.get_pd_no_multiindex()

pos_data_bins = bin_pos_data(pos_data_linpos, dec_bin_size)

center_dec_error, left_dec_error, right_dec_error = calc_error_table(pos_data_bins, dec_est_pos,
                                                                     arm_coordinates, 2)

print('median error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.median(center_dec_error['abs_error']),
                                                                       np.median(left_dec_error['abs_error']),
                                                                       np.median(right_dec_error['abs_error'])))

print('mean error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.mean(center_dec_error['abs_error']),
                                                                       np.mean(left_dec_error['abs_error']),
                                                                       np.mean(right_dec_error['abs_error'])))

In [None]:
dec_est_map = x_bins[np.argmax(spykshrk_posteriors, axis=1)]
dec_est_pos = pd.DataFrame({'est_pos': dec_est_map}, index=pd.Index(data=np.array(dec_bin_times),
                                                                    name='timestamp'))

pos_data_bins = bin_pos_data(pos_data_linpos, dec_bin_size)

center_dec_error, left_dec_error, right_dec_error = calc_error_table(pos_data_bins, dec_est_pos,
                                                                     arm_coordinates, 2)

print('median error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.median(center_dec_error['abs_error']),
                                                                       np.median(left_dec_error['abs_error']),
                                                                       np.median(right_dec_error['abs_error'])))

print('mean error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.mean(center_dec_error['abs_error']),
                                                                       np.mean(left_dec_error['abs_error']),
                                                                       np.mean(right_dec_error['abs_error'])))