In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from functools import partial

import spykshrk.realtime.simulator.nspike_data as nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian
from spykshrk.franklab.pp_decoder.pp_clusterless import calc_observation_intensity
from spykshrk.franklab.pp_decoder.decode_error import bin_pos_data, calc_error_table, \
                                                      conv_center_pos, conv_left_pos, conv_right_pos

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)

idx = pd.IndexSlice

In [2]:
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv_300samp/bond.config.json'
config = json.load(open(config_file, 'r'))

hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

store = pd.HDFStore(hdf_file, mode='r')
spike_decode = store['rec_3']

In [3]:
# Get table with decode for each spike and generate decode bin mask

dec_bin_size = 6000     # Decode bin size in samples (usually 30kHz)

# Center end is 69
# Left end is 102
# Right end is 104
arm_coordinates = [[0, 69], [150, 150+102], [300, 300+104]]

dec_bins = np.floor((spike_decode['timestamp'] - spike_decode['timestamp'][0])/dec_bin_size).astype('int')
dec_bin_ids = np.unique(dec_bins)

spike_decode['dec_bin'] = dec_bins

pos_bin_delta = ((config['encoder']['position']['upper'] - config['encoder']['position']['lower']) / 
                 config['encoder']['position']['bins'])
pos_num_bins = config['encoder']['position']['bins']

x_bins = np.linspace(0, pos_bin_delta*(pos_num_bins-1), pos_num_bins)
pos_kernel = gaussian(x_bins, x_bins[int(len(x_bins)/2)], 3)


In [4]:
# Loop through each bin and generate normalized posterior estimate of location

# Loop through each bin and generate the observation distribution from spikes in bin
dec_bin_times, dec_est, bin_num_spikes, firing_rate  = calc_observation_intensity(spike_decode, 
                                                                                  dec_bin_size,
                                                                                  x_bins, 
                                                                                  pos_kernel,
                                                                                  arm_coordinates)

In [5]:
# Calculate MAP of decoder posterior and scale to real position units

dec_est_map = x_bins[np.argmax(dec_est, axis = 1)]

dec_est_pos = pd.DataFrame({'est_pos': dec_est_map}, index=pd.Index(data=dec_bin_times, name='timestamp'))

In [6]:
# Get real position

nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data


In [7]:
# Transform position into simpler table with only linear position

def convert_pos_for_notebook(pos_data): 
    pos_data_time = pos_data.loc[:, 'time']

    pos_data_notebook = pos_data.loc[:,'lin_dist_well']
    pos_data_notebook.loc[:, 'lin_vel_center'] = pos_data.loc[:,('lin_vel', 'well_center')]
    pos_data_notebook.loc[:, 'seg_idx'] = pos_data.loc[:,('seg_idx', 0)]
    pos_data_notebook.loc[:,'timestamps'] = pos_data_time*30000
    pos_data_notebook = pos_data_notebook.set_index('timestamps')
    
    return pos_data_notebook

pos_data_notebook = convert_pos_for_notebook(pos_data)

In [8]:
# Bin position in the same way as spike dec results

pos_data_bins = bin_pos_data(pos_data_notebook, dec_bin_size)

In [9]:
dec_est_pos

In [10]:
# Separate estimated and real position into separate arms of track, then convert position for both
# to be "well centric", distance measured from the well the real position is closest to.

# Reindex and join real position (linpos) to the decode estimated position table
linpos_reindexed = pos_data_bins.reindex(dec_est_pos.index, method='bfill')
dec_est_and_linpos = dec_est_pos.join(linpos_reindexed)

center_dec_est, left_dec_est, right_dec_est = calc_error_table(dec_est_and_linpos, arm_coordinates, 0)

In [11]:
print('median error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.median(center_dec_est['abs_error']),
                                                                       np.median(left_dec_est['abs_error']),
                                                                       np.median(right_dec_est['abs_error'])))

print('mean error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.mean(center_dec_est['abs_error']),
                                                                       np.mean(left_dec_est['abs_error']),
                                                                       np.mean(right_dec_est['abs_error'])))

In [12]:
plt_ranges = [[2350, 3400], [4560, 5550]]
for plt_range in plt_ranges:
    center_plt_ind = (center_dec_est.index/30000 >= plt_range[0]) & (center_dec_est.index/30000 <= plt_range[1])
    left_plt_ind = (left_dec_est.index/30000 >= plt_range[0]) & (left_dec_est.index/30000 <= plt_range[1])
    right_plt_ind = (right_dec_est.index/30000 >= plt_range[0]) & (right_dec_est.index/30000 <= plt_range[1])

    plt.figure(figsize=[400,10])
    plt.errorbar(x=center_dec_est.index[center_plt_ind]/30000,
                 y=center_dec_est['real_pos'][center_plt_ind], 
                 yerr=[center_dec_est['plt_error_up'][center_plt_ind],
                       center_dec_est['plt_error_down'][center_plt_ind]], fmt='*')

    plt.errorbar(x=left_dec_est.index[left_plt_ind]/30000, 
                 y=left_dec_est['real_pos'][left_plt_ind], 
                 yerr=[left_dec_est['plt_error_up'][left_plt_ind],
                       left_dec_est['plt_error_down'][left_plt_ind]], fmt='*')

    plt.errorbar(x=right_dec_est.index[right_plt_ind]/30000,
                 y=right_dec_est['real_pos'][right_plt_ind], 
                 yerr=[right_dec_est['plt_error_up'][right_plt_ind],
                       right_dec_est['plt_error_down'][right_plt_ind]], fmt='*')
    # plt.plot(center_dec_est.index/30000, center_dec_est['real_pos'], '*')

plt.show()

In [13]:
plt_ranges = [[4560, 5000], [5000, 5550]]
for plt_range in plt_ranges:
    center_plt_ind = (center_dec_est.index/30000 >= plt_range[0]) & (center_dec_est.index/30000 <= plt_range[1])
    left_plt_ind = (left_dec_est.index/30000 >= plt_range[0]) & (left_dec_est.index/30000 <= plt_range[1])
    right_plt_ind = (right_dec_est.index/30000 >= plt_range[0]) & (right_dec_est.index/30000 <= plt_range[1])

    plt.figure(figsize=[400,10])
    plt.errorbar(x=center_dec_est.index[center_plt_ind]/30000,
                 y=center_dec_est['real_pos'][center_plt_ind], 
                 yerr=[center_dec_est['plt_error_up'][center_plt_ind],
                       center_dec_est['plt_error_down'][center_plt_ind]], fmt='*')

    plt.errorbar(x=left_dec_est.index[left_plt_ind]/30000, 
                 y=left_dec_est['real_pos'][left_plt_ind], 
                 yerr=[left_dec_est['plt_error_up'][left_plt_ind],
                       left_dec_est['plt_error_down'][left_plt_ind]], fmt='*')

    plt.errorbar(x=right_dec_est.index[right_plt_ind]/30000,
                 y=right_dec_est['real_pos'][right_plt_ind], 
                 yerr=[right_dec_est['plt_error_up'][right_plt_ind],
                       right_dec_est['plt_error_down'][right_plt_ind]], fmt='*')
    # plt.plot(center_dec_est.index/30000, center_dec_est['real_pos'], '*')

plt.show()

In [14]:


center_pos_flat = pos_data_notebook[pos_data_notebook['seg_idx'] == 1]['well_center']
left_pos_flat = pos_data_notebook[(pos_data_notebook['seg_idx'] == 2) | 
                                (pos_data_notebook['seg_idx'] == 3)]['well_left'] + 150
right_pos_flat = pos_data_notebook[(pos_data_notebook['seg_idx'] == 4) | 
                                 (pos_data_notebook['seg_idx'] == 5)]['well_right'] + 300

center_pos_flat.name = 'linpos_flat'
left_pos_flat.name = 'linpos_flat'
right_pos_flat.name = 'linpos_flat'

linpos_flat = pd.concat([center_pos_flat, left_pos_flat, right_pos_flat])
linpos_flat = linpos_flat.sort_index()
linpos_flat

In [15]:
plt.figure(figsize=[400,15])
plt_range = [2461, 3400]
plt.imshow(dec_est[(dec_bin_times > plt_range[0]*30000) & (dec_bin_times < plt_range[1]*30000)].transpose(), 
           extent=[plt_range[0], plt_range[1], 0, 450], origin='lower', aspect='auto', cmap='hot', zorder=0)

linpos_index_s = linpos_flat.index / 30000
index_mask = (linpos_index_s > plt_range[0]) & (linpos_index_s < plt_range[1])

plt.plot(linpos_index_s[index_mask],
         linpos_flat.values[index_mask], 'c.', zorder=1, markersize=5)

plt.plot(plt_range, [74, 74], '--', color='gray')
plt.plot(plt_range, [148, 148], '--', color='gray')
plt.plot(plt_range, [256, 256], '--', color='gray')
plt.plot(plt_range, [298, 298], '--', color='gray')
plt.plot(plt_range, [407, 407], '--', color='gray')

plt.colorbar()
plt.show()


In [16]:
dec_bin_times/30000