In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os

import spykshrk.realtime.simulator.nspike_data as nspike_data

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)


idx = pd.IndexSlice

In [2]:
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/bond_param/01/bond.config.json'
config = json.load(open(config_file, 'r'))

hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

store = pd.HDFStore(hdf_file, mode='r')
spike_decode = store['rec_3']

In [3]:
# Get table with decode for each spike and generate decode bin mask

dec_bin_size = 6000     # Decode bin size in samples (usually 30kHz)

dec_bins = np.floor((spike_decode['timestamp'] - spike_decode['timestamp'][0])/dec_bin_size).astype('int')
dec_bin_ids = np.unique(dec_bins)

spike_decode['dec_bin'] = dec_bins

pos_bin_delta = ((config['encoder']['position']['upper'] - config['encoder']['position']['lower']) / 
                 config['encoder']['position']['bins'])
pos_num_bins = config['encoder']['position']['bins']



In [19]:
dec_bins

In [22]:
# Loop through each bin and generate normalized posterior estimate of location

def decode_from_spikes(dec_bin_ids, spike_decode, dec_bin_size, pos_bin_delta, pos_num_bins):

    dec_est = np.zeros([dec_bin_ids[-1]+1, pos_num_bins])
    dec_est_bin_time = np.floor(spike_decode['timestamp'][0] / dec_bin_size) * dec_bin_size

    start_bin_time = np.floor(spike_decode['timestamp'][0] / dec_bin_size) * dec_bin_size
    dec_bin_times = np.arange(start_bin_time, start_bin_time + dec_bin_size * len(dec_est), dec_bin_size)

    for bin_id in dec_bin_ids:
        spikes_in_bin = spike_decode[spike_decode['dec_bin'] == bin_id]
        dec_in_bin = np.ones(pos_num_bins)
        for dec in spikes_in_bin.loc[:, 'x0':'x{:d}'.format(pos_num_bins-1)].values:
            dec_in_bin *= dec
            dec_in_bin = dec_in_bin / (np.sum(dec_in_bin) * pos_bin_delta)


        dec_est[bin_id, :] = dec_in_bin
    return dec_est

dec_est = decode_from_spikes(dec_bin_ids, spike_decode, dec_bin_size, pos_bin_delta, pos_num_bins)

In [26]:
# Calculate MAP of decoder posterior and scale to real position units

dec_est_map = np.argmax(dec_est, axis = 1) * pos_bin_delta

dec_est_pos = pd.DataFrame({'est_pos': dec_est_map}, index=pd.Index(data=dec_bin_times, name='timestamp'))

In [6]:
# Get real position

nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data


In [27]:
# Transform position into simpler table with only linear position

def convert_pos_for_notebook(pos_data): 
    pos_data_time = pos_data.loc[:, 'time']

    pos_data_notebook = pos_data.loc[:,'lin_dist_well']
    pos_data_notebook.loc[:, 'lin_vel_center'] = pos_data.loc[:,('lin_vel', 'well_center')]
    pos_data_notebook.loc[:, 'seg_idx'] = pos_data.loc[:,('seg_idx', 0)]
    pos_data_notebook.loc[:,'timestamps'] = pos_data_time*30000
    pos_data_notebook = pos_data_notebook.set_index('timestamps')
    
    return pos_data_notebook

pos_data_notebook = convert_pos_for_notebook(pos_data)

In [8]:
# Bin position in the same way as spike dec results

pos_bin_ids = np.floor((pos_data_notebook.index - pos_data_notebook.index[0])/dec_bin_size).astype('int')
pos_data_notebook['bin'] = pos_bin_ids
pos_bin_ids_unique = np.unique(pos_bin_ids)

start_bin_time = np.floor(pos_data_notebook.index[0] / dec_bin_size) * dec_bin_size

pos_bin_times = (pos_bin_ids_unique * dec_bin_size + start_bin_time)


pos_data_bins = pd.DataFrame()

for ind, bin_id in enumerate(pos_bin_ids_unique): 
    pos_in_bin = pos_data_notebook[pos_data_notebook['bin'] == bin_id]
    pos_bin_mean = pos_in_bin.mean()
    pos_bin_mean.name = pos_bin_times[ind]
    
    pos_data_bins = pos_data_bins.append(pos_bin_mean)

In [17]:
pos_data_bins

In [9]:
# Separate estimated and real position into separate arms of track, then convert position for both
# to be "well centric", distance measured from the well the real position is closest to.

# Center end is ~68
# Left end is ~101
# Right end is ~104
def conv_center_pos(pos):
    if pos < 150:
        return pos
    elif (pos >= 150) and (pos < 300):
        return 251-pos+68
    elif (pos >= 300) and (pos < 450):
        return 404-pos+65
def conv_left_pos(pos):
    if pos < 150:
        return 68-pos+101
    elif (pos >= 150) and (pos < 300):
        return pos-150
    elif (pos >= 300) and (pos < 450):
        return 404-pos+101
def conv_right_pos(pos):
    if pos < 150:
        return 68-pos+104
    elif (pos >= 150) and (pos < 300):
        return 251-pos+104
    elif (pos >= 300) and (pos < 450):
        return pos-300

# Reindex and join real position (linpos) to the decode estimated position table
linpos_reindexed = pos_data_bins.reindex(dec_est_pos.index, method='bfill')
dec_est_and_linpos = dec_est_pos.join(linpos_reindexed)

# Select rows only when velocity meets criterion
dec_est_and_linpos = dec_est_and_linpos[np.abs(dec_est_and_linpos['lin_vel_center']) >= 0]

# Separate out each arm's position
center_dec_est_merged = dec_est_and_linpos[dec_est_and_linpos['seg_idx'] == 1]
left_dec_est_merged = dec_est_and_linpos[(dec_est_and_linpos['seg_idx'] == 2) | 
                                         (dec_est_and_linpos['seg_idx'] == 3)]
right_dec_est_merged = dec_est_and_linpos[(dec_est_and_linpos['seg_idx'] == 4) | 
                                          (dec_est_and_linpos['seg_idx'] == 5)]

# Apply "closest well centric" tranform to each arm's data
center_dec_est = pd.DataFrame()
center_dec_est.loc[:,'est_pos'] = center_dec_est_merged['est_pos'].map(conv_center_pos)
center_dec_est.loc[:,'real_pos'] = center_dec_est_merged['well_center']

left_dec_est = pd.DataFrame()
left_dec_est.loc[:,'est_pos'] = left_dec_est_merged['est_pos'].map(conv_left_pos)
left_dec_est.loc[:,'real_pos'] = left_dec_est_merged['well_left']

right_dec_est = pd.DataFrame()
right_dec_est.loc[:,'est_pos'] = right_dec_est_merged['est_pos'].map(conv_right_pos)
right_dec_est.loc[:,'real_pos'] = right_dec_est_merged['well_right']

# Calculate error in estimated position and the errors used to draw one sided error bars

center_dec_est.loc[:,'error'] = center_dec_est['real_pos'] - center_dec_est['est_pos']
center_dec_est.loc[:,'abs_error'] = np.abs(center_dec_est['error'])
center_dec_est.loc[:,'plt_error_up'] = center_dec_est['error']
center_dec_est.loc[center_dec_est['error'] < 0,'plt_error_up'] = 0 
center_dec_est.loc[:,'plt_error_down'] = center_dec_est['error']
center_dec_est.loc[center_dec_est['error'] > 0,'plt_error_down'] = 0
center_dec_est.loc[:,'plt_error_down'] = np.abs(center_dec_est['plt_error_down'])

left_dec_est.loc[:,'error'] = left_dec_est['real_pos'] - left_dec_est['est_pos']
left_dec_est.loc[:,'abs_error'] = np.abs(left_dec_est['error'])
left_dec_est.loc[:,'plt_error_up'] = left_dec_est['error']
left_dec_est.loc[left_dec_est['error'] < 0,'plt_error_up'] = 0 
left_dec_est.loc[:,'plt_error_down'] = left_dec_est['error']
left_dec_est.loc[left_dec_est['error'] > 0,'plt_error_down'] = 0
left_dec_est.loc[:,'plt_error_down'] = np.abs(left_dec_est['plt_error_down'])

right_dec_est.loc[:,'error'] = right_dec_est['real_pos'] - right_dec_est['est_pos']
right_dec_est.loc[:,'abs_error'] = np.abs(right_dec_est['error'])
right_dec_est.loc[:,'plt_error_up'] = right_dec_est['error']
right_dec_est.loc[right_dec_est['error'] < 0,'plt_error_up'] = 0 
right_dec_est.loc[:,'plt_error_down'] = right_dec_est['error']
right_dec_est.loc[right_dec_est['error'] > 0,'plt_error_down'] = 0
right_dec_est.loc[:,'plt_error_down'] = np.abs(right_dec_est['plt_error_down'])

In [10]:
print('median error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.median(center_dec_est['abs_error']),
                                                                       np.median(left_dec_est['abs_error']),
                                                                       np.median(right_dec_est['abs_error'])))

print('mean error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.mean(center_dec_est['abs_error']),
                                                                       np.mean(left_dec_est['abs_error']),
                                                                       np.mean(right_dec_est['abs_error'])))

In [11]:
plt_ranges = [[2350, 3400], [4560, 5550]]
for plt_range in plt_ranges:
    center_plt_ind = (center_dec_est.index/30000 >= plt_range[0]) & (center_dec_est.index/30000 <= plt_range[1])
    left_plt_ind = (left_dec_est.index/30000 >= plt_range[0]) & (left_dec_est.index/30000 <= plt_range[1])
    right_plt_ind = (right_dec_est.index/30000 >= plt_range[0]) & (right_dec_est.index/30000 <= plt_range[1])

    plt.figure(figsize=[400,10])
    plt.errorbar(x=center_dec_est.index[center_plt_ind]/30000,
                 y=center_dec_est['real_pos'][center_plt_ind], 
                 yerr=[center_dec_est['plt_error_up'][center_plt_ind],
                       center_dec_est['plt_error_down'][center_plt_ind]], fmt='*')

    plt.errorbar(x=left_dec_est.index[left_plt_ind]/30000, 
                 y=left_dec_est['real_pos'][left_plt_ind], 
                 yerr=[left_dec_est['plt_error_up'][left_plt_ind],
                       left_dec_est['plt_error_down'][left_plt_ind]], fmt='*')

    plt.errorbar(x=right_dec_est.index[right_plt_ind]/30000,
                 y=right_dec_est['real_pos'][right_plt_ind], 
                 yerr=[right_dec_est['plt_error_up'][right_plt_ind],
                       right_dec_est['plt_error_down'][right_plt_ind]], fmt='*')
    # plt.plot(center_dec_est.index/30000, center_dec_est['real_pos'], '*')

plt.show()

In [12]:
plt_ranges = [[4560, 5000], [5000, 5550]]
for plt_range in plt_ranges:
    center_plt_ind = (center_dec_est.index/30000 >= plt_range[0]) & (center_dec_est.index/30000 <= plt_range[1])
    left_plt_ind = (left_dec_est.index/30000 >= plt_range[0]) & (left_dec_est.index/30000 <= plt_range[1])
    right_plt_ind = (right_dec_est.index/30000 >= plt_range[0]) & (right_dec_est.index/30000 <= plt_range[1])

    plt.figure(figsize=[400,10])
    plt.errorbar(x=center_dec_est.index[center_plt_ind]/30000,
                 y=center_dec_est['real_pos'][center_plt_ind], 
                 yerr=[center_dec_est['plt_error_up'][center_plt_ind],
                       center_dec_est['plt_error_down'][center_plt_ind]], fmt='*')

    plt.errorbar(x=left_dec_est.index[left_plt_ind]/30000, 
                 y=left_dec_est['real_pos'][left_plt_ind], 
                 yerr=[left_dec_est['plt_error_up'][left_plt_ind],
                       left_dec_est['plt_error_down'][left_plt_ind]], fmt='*')

    plt.errorbar(x=right_dec_est.index[right_plt_ind]/30000,
                 y=right_dec_est['real_pos'][right_plt_ind], 
                 yerr=[right_dec_est['plt_error_up'][right_plt_ind],
                       right_dec_est['plt_error_down'][right_plt_ind]], fmt='*')
    # plt.plot(center_dec_est.index/30000, center_dec_est['real_pos'], '*')

plt.show()

In [13]:


center_pos_flat = pos_data_notebook[pos_data_notebook['seg_idx'] == 1]['well_center']
left_pos_flat = pos_data_notebook[(pos_data_notebook['seg_idx'] == 2) | 
                                (pos_data_notebook['seg_idx'] == 3)]['well_left'] + 150
right_pos_flat = pos_data_notebook[(pos_data_notebook['seg_idx'] == 4) | 
                                 (pos_data_notebook['seg_idx'] == 5)]['well_right'] + 300

center_pos_flat.name = 'linpos_flat'
left_pos_flat.name = 'linpos_flat'
right_pos_flat.name = 'linpos_flat'

linpos_flat = pd.concat([center_pos_flat, left_pos_flat, right_pos_flat])
linpos_flat = linpos_flat.sort_index()
linpos_flat

In [18]:
plt.figure(figsize=[400,15])
plt_range = [5300, 5450]
plt.imshow(dec_est[(dec_bin_times > plt_range[0]*30000) & (dec_bin_times < plt_range[1]*30000)].transpose(), 
           extent=[plt_range[0], plt_range[1], 0, 450], origin='lower', aspect='auto', cmap='hot', zorder=0)

linpos_index_s = linpos_flat.index / 30000
index_mask = (linpos_index_s > plt_range[0]) & (linpos_index_s < plt_range[1])

plt.plot(linpos_index_s[index_mask],
         linpos_flat.values[index_mask], 'c.', zorder=1, markersize=5)

plt.plot(plt_range, [74, 74], '--', color='gray')
plt.plot(plt_range, [148, 148], '--', color='gray')
plt.plot(plt_range, [256, 256], '--', color='gray')
plt.plot(plt_range, [298, 298], '--', color='gray')
plt.plot(plt_range, [407, 407], '--', color='gray')

plt.colorbar()
plt.show()


In [15]:
linpos_index_s[index_mask]
linpos_flat[index_mask].values