In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os

import spykshrk.realtime.simulator.nspike_data as nspike_data

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)


idx = pd.IndexSlice

In [2]:
# Load merged rec HDF store based on config

config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_full.json'
config = json.load(open(config_file, 'r'))

hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

store = pd.HDFStore(hdf_file, mode='r')

In [3]:
# Get table with decode for each spike and generate decode bin mask

dec_bin_size = 6000     # Decode bin size in samples (usually 30kHz)

spike_decode = store['rec_3']
dec_bins = np.floor((spike_decode['timestamp'] - spike_decode['timestamp'][0])/dec_bin_size).astype('int')
spike_decode['dec_bin'] = dec_bins

In [4]:
# Loop through each bin and generate normalized posterior estimate of location

dec_bin_ids = np.unique(dec_bins)
dec_est = np.zeros([dec_bin_ids[-1]+1, 150])
dec_est_bin_time = np.floor(spike_decode['timestamp'][0] / dec_bin_size) * dec_bin_size

start_bin_time = np.floor(spike_decode['timestamp'][0] / dec_bin_size) * dec_bin_size
dec_bin_times = np.arange(start_bin_time, start_bin_time + dec_bin_size * len(dec_est), dec_bin_size)

for bin_id in dec_bin_ids:
    spikes_in_bin = spike_decode[spike_decode['dec_bin'] == bin_id]
    dec_in_bin = np.prod(spikes_in_bin.loc[:, 'x0':'x149'].values+0.000001, axis=0)
    dec_in_bin = dec_in_bin / np.sum(dec_in_bin)
    dec_est[bin_id, :] = dec_in_bin

In [5]:
# Calculate MAP of decoder posterior and scale to real position units

dec_est_map = np.argmax(dec_est, axis = 1) * ((config['encoder']['position']['upper'] - 
                                               config['encoder']['position']['lower']) / 
                                              config['encoder']['position']['bins'] + 
                                              config['encoder']['position']['lower'])

dec_est_pos = pd.DataFrame({'est_pos': dec_est_map}, index=pd.Index(data=dec_bin_times, name='timestamp'))

In [6]:
# Get real position

nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Transform position into simpler table with only linear position
pos_data_time = pos_data.loc[(4,1), 'time']

pos_data_linpos = pos_data.loc[(4,1), 'lin_dist_well']
pos_data_linpos.loc[:, 'seg_idx'] = pos_data['seg_idx', 0].values
pos_data_linpos['timestamps'] = pos_data_time*30000
pos_data_linpos = pos_data_linpos.set_index('timestamps')

In [75]:
# Separate estimated and real position into separate arms of track, then convert position for both
# to be "well centric", distance measured from the well the real position is closest to.

def conv_center_pos(pos):
    if pos < 150:
        return pos
    elif (pos >= 150) and (pos < 300):
        return 251-pos+68
    elif (pos >= 300) and (pos < 450):
        return 404-pos+65
def conv_left_pos(pos):
    if pos < 150:
        return 68-pos+101
    elif (pos >= 150) and (pos < 300):
        return pos-150
    elif (pos >= 300) and (pos < 450):
        return 404-pos+101
def conv_right_pos(pos):
    if pos < 150:
        return 68-pos+104
    elif (pos >= 150) and (pos < 300):
        return 251-pos-104
    elif (pos >= 300) and (pos < 450):
        return pos-300

# Reindex and join real position (linpos) to the decode estimated position table
linpos_reindexed = pos_data_linpos.reindex(dec_est_pos.index, method='bfill')
dec_est_and_linpos = dec_est_pos.join(linpos_reindexed)

# Separate out each arm's position
center_dec_est_merged = dec_est_and_linpos[dec_est_and_linpos['seg_idx'] == 1]
left_dec_est_merged = dec_est_and_linpos[(dec_est_and_linpos['seg_idx'] == 2) | 
                                         (dec_est_and_linpos['seg_idx'] == 3)]
right_dec_est_merged = dec_est_and_linpos[(dec_est_and_linpos['seg_idx'] == 4) | 
                                          (dec_est_and_linpos['seg_idx'] == 5)]

# Apply "closest well centric" tranform to each arm's data
center_dec_est = pd.DataFrame()
center_dec_est.loc[:,'est_pos'] = center_dec_est_merged['est_pos'].map(conv_center_pos)
center_dec_est.loc[:,'real_pos'] = center_dec_est_merged['well_center']

left_dec_est = pd.DataFrame()
left_dec_est.loc[:,'est_pos'] = left_dec_est_merged['est_pos'].map(conv_left_pos)
left_dec_est.loc[:,'real_pos'] = left_dec_est_merged['well_left']

right_dec_est = pd.DataFrame()
right_dec_est.loc[:,'est_pos'] = right_dec_est_merged['est_pos'].map(conv_right_pos)
right_dec_est.loc[:,'real_pos'] = right_dec_est_merged['well_right']

# Calculate error in estimated position and the errors used to draw one sided error bars

center_dec_est.loc[:,'error'] = center_dec_est['real_pos'] - center_dec_est['est_pos']
center_dec_est.loc[:,'plt_error_up'] = center_dec_est['error']
center_dec_est.loc[center_dec_est['error'] < 0,'plt_error_up'] = 0 
center_dec_est.loc[:,'plt_error_down'] = center_dec_est['error']
center_dec_est.loc[center_dec_est['error'] > 0,'plt_error_down'] = 0
center_dec_est.loc[:,'plt_error_down'] = np.abs(center_dec_est['plt_error_down'])

left_dec_est.loc[:,'error'] = left_dec_est['real_pos'] - left_dec_est['est_pos']
left_dec_est.loc[:,'plt_error_up'] = left_dec_est['error']
left_dec_est.loc[left_dec_est['error'] < 0,'plt_error_up'] = 0 
left_dec_est.loc[:,'plt_error_down'] = left_dec_est['error']
left_dec_est.loc[left_dec_est['error'] > 0,'plt_error_down'] = 0
left_dec_est.loc[:,'plt_error_down'] = np.abs(left_dec_est['plt_error_down'])

right_dec_est.loc[:,'error'] = right_dec_est['real_pos'] - right_dec_est['est_pos']
right_dec_est.loc[:,'plt_error_up'] = right_dec_est['error']
right_dec_est.loc[right_dec_est['error'] < 0,'plt_error_up'] = 0 
right_dec_est.loc[:,'plt_error_down'] = right_dec_est['error']
right_dec_est.loc[right_dec_est['error'] > 0,'plt_error_down'] = 0
right_dec_est.loc[:,'plt_error_down'] = np.abs(right_dec_est['plt_error_down'])

In [76]:
center_dec_est.loc[82116000.0]

In [79]:

right_dec_est

In [78]:
print('median error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.median(center_dec_est['error']),
                                                                       np.median(left_dec_est['error']),
                                                                       np.median(right_dec_est['error'])))

print('mean error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.mean(center_dec_est['error']),
                                                                       np.mean(left_dec_est['error']),
                                                                       np.mean(right_dec_est['error'])))

In [81]:

plt.figure(figsize=[200,10])
plt.errorbar(x=center_dec_est.index/30000, y=center_dec_est['est_pos'], 
             yerr=[center_dec_est['plt_error_down'], center_dec_est['plt_error_up']], fmt='*')
plt.errorbar(x=left_dec_est.index/30000, y=left_dec_est['est_pos'], 
             yerr=[left_dec_est['plt_error_down'], left_dec_est['plt_error_up']], fmt='*')
plt.errorbar(x=right_dec_est.index/30000, y=right_dec_est['est_pos'], 
             yerr=[right_dec_est['plt_error_down'], right_dec_est['plt_error_up']], fmt='*')
# plt.plot(center_dec_est.index/30000, center_dec_est['real_pos'], '*')


plt.show()