In [227]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os

import spykshrk.realtime.simulator.nspike_data as nspike_data

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)


idx = pd.IndexSlice

In [2]:
# Load merged rec HDF store based on config

config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_full.json'
config = json.load(open(config_file, 'r'))

hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

store = pd.HDFStore(hdf_file, mode='r')

In [3]:
# Get table with decode for each spike and generate decode bin mask

dec_bin_size = 6000     # Decode bin size in samples (usually 30kHz)

spike_decode = store['rec_3']
dec_bins = np.floor((spike_decode['timestamp'] - spike_decode['timestamp'][0])/dec_bin_size).astype('int')
spike_decode['dec_bin'] = dec_bins

In [5]:
# Loop through each bin and generate normalized posterior estimate of location

dec_bin_ids = np.unique(dec_bins)
dec_est = np.zeros([dec_bin_ids[-1]+1, 150])
dec_est_bin_time = np.floor(spike_decode['timestamp'][0] / dec_bin_size) * dec_bin_size

start_bin_time = np.floor(spike_decode['timestamp'][0] / dec_bin_size) * dec_bin_size
dec_bin_times = np.arange(start_bin_time, start_bin_time + dec_bin_size * len(dec_est), dec_bin_size)

for bin_id in dec_bin_ids:
    spikes_in_bin = spike_decode[spike_decode['dec_bin'] == bin_id]
    dec_in_bin = np.prod(spikes_in_bin.loc[:, 'x0':'x149'].values+0.0001, axis=0)
    dec_in_bin = dec_in_bin / np.sum(dec_in_bin)
    dec_est[bin_id, :] = dec_in_bin

In [169]:
# Calculate MAP of decoder posterior and scale to real position units

dec_est_map = np.argmax(dec_est, axis = 1) * ((config['encoder']['position']['upper'] - 
                                               config['encoder']['position']['lower']) / 
                                              config['encoder']['position']['bins'] + 
                                              config['encoder']['position']['lower'])

dec_est_pos = pd.DataFrame({'est_pos': dec_est_map}, index=pd.Index(data=dec_bin_times, name='timestamp'))

In [171]:
# Get real position

nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Transform position into simpler table with only linear position
pos_data_time = pos_data.loc[(4,1), 'time']

pos_data_linpos = pos_data.loc[(4,1), 'lin_dist_well']
pos_data_linpos['timestamps'] = pos_data_time*30000
pos_data_linpos = pos_data_linpos.set_index('timestamps')

In [225]:
# Break trajectory into arms and attach real position

center_dec_est = dec_est_pos[dec_est_pos['est_pos'] < 150]
left_dec_est = dec_est_pos[(dec_est_pos['est_pos'] > 150) & (dec_est_pos['est_pos'] < 300)]
right_dec_est = dec_est_pos[dec_est_pos['est_pos'] > 300]

pos_data_indexed_center = pos_data_linpos.reindex(center_dec_est.index, method='bfill')
center_dec_est = center_dec_est.join(pos_data_indexed_center['well_center'])
center_dec_est = center_dec_est.rename(columns={'well_center': 'real_pos'})

pos_data_indexed_left = pos_data_linpos.reindex(left_dec_est.index, method='bfill')
left_dec_est = left_dec_est.join(pos_data_indexed_left['well_left'])
left_dec_est = left_dec_est.rename(columns={'well_left': 'real_pos'})

pos_data_indexed_right = pos_data_linpos.reindex(right_dec_est.index, method='bfill')
right_dec_est = right_dec_est.join(pos_data_indexed_right['well_right'])
right_dec_est = right_dec_est.rename(columns={'well_right': 'real_pos'})


In [228]:
center_dec_est

In [233]:

plt.figure(figsize=[40,10])
plt.plot(center_dec_est.index/30000, center_dec_est['est_pos'], '*')
plt.plot(center_dec_est.index/30000, center_dec_est['real_pos'], '*')

plt.show()