In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from functools import partial

import spykshrk.realtime.simulator.nspike_data as nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian
from spykshrk.franklab.pp_decoder.pp_clusterless import calc_observation_intensity
from spykshrk.franklab.pp_decoder.decode_error import bin_pos_data, calc_error_table, \
                                                      conv_center_pos, conv_left_pos, conv_right_pos, \
                                                      plot_arms_error

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)

idx = pd.IndexSlice

In [2]:
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv_300samp/bond.config.json'
config = json.load(open(config_file, 'r'))

hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

store = pd.HDFStore(hdf_file, mode='r')
spike_decode = store['rec_3']

In [3]:
# Get table with decode for each spike and generate decode bin mask

dec_bin_size = 6000     # Decode bin size in samples (usually 30kHz)

# Center end is 69
# Left end is 102
# Right end is 104
arm_coordinates = [[0, 69], [150, 150+102], [300, 300+104]]

dec_bins = np.floor((spike_decode['timestamp'] - spike_decode['timestamp'][0])/dec_bin_size).astype('int')
dec_bin_ids = np.unique(dec_bins)

spike_decode['dec_bin'] = dec_bins

pos_bin_delta = ((config['encoder']['position']['upper'] - config['encoder']['position']['lower']) / 
                 config['encoder']['position']['bins'])
pos_num_bins = config['encoder']['position']['bins']

x_bins = np.linspace(0, pos_bin_delta*(pos_num_bins-1), pos_num_bins)
pos_kernel = gaussian(x_bins, x_bins[int(len(x_bins)/2)], 3)


In [4]:
# Loop through each bin and generate normalized posterior estimate of location

# Loop through each bin and generate the observation distribution from spikes in bin
dec_bin_times, dec_est, bin_num_spikes, firing_rate  = calc_observation_intensity(spike_decode, 
                                                                                  dec_bin_size,
                                                                                  x_bins, 
                                                                                  pos_kernel,
                                                                                  arm_coordinates)

In [5]:
# Calculate MAP of decoder posterior and scale to real position units

dec_est_map = x_bins[np.argmax(dec_est, axis = 1)]

dec_est_pos = pd.DataFrame({'est_pos': dec_est_map}, index=pd.Index(data=dec_bin_times, name='timestamp'))

In [6]:
# Get real position

nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data


In [68]:
# Transform position into simpler table with only linear position

def convert_pos_for_notebook(pos_data): 
    pos_data_time = pos_data.loc[:, 'time']

    pos_data_notebook = pos_data.loc[:,'lin_dist_well']
    pos_data_notebook.loc[:, 'lin_vel_center'] = pos_data.loc[:,('lin_vel', 'well_center')]
    pos_data_notebook.loc[:, 'seg_idx'] = pos_data.loc[:,('seg_idx', 0)]
    pos_data_notebook.loc[:,'timestamps'] = pos_data_time*30000
    pos_data_notebook = pos_data_notebook.set_index('timestamps')
    
    return pos_data_notebook
  
pos_data_notebook = convert_pos_for_notebook(pos_data)

In [78]:
# Bin position in the same way as spike dec results

pos_data_bins = bin_pos_data(pos_data_notebook, dec_bin_size)

In [60]:
pos_data_bins

In [58]:
pos_data_bins

In [13]:
arm_coordinates

In [9]:
# Separate estimated and real position into separate arms of track, then convert position for both
# to be "well centric", distance measured from the well the real position is closest to.

center_dec_error, left_dec_error, right_dec_error = calc_error_table(pos_data_bins, dec_est_pos,
                                                                     arm_coordinates, 0)

In [15]:
print('median error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.median(center_dec_error['abs_error']),
                                                                       np.median(left_dec_error['abs_error']),
                                                                       np.median(right_dec_error['abs_error'])))

print('mean error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.mean(center_dec_error['abs_error']),
                                                                       np.mean(left_dec_error['abs_error']),
                                                                       np.mean(right_dec_error['abs_error'])))

In [57]:
pos_bin_ids = np.floor((pos_data_notebook.index - pos_data_notebook.index[0])/30).astype('int')
pos_data_notebook['bin'] = pos_bin_ids
pos_bin_ids_unique = np.unique(pos_bin_ids)

pos_data_new_times = np.linspace(pos_data_notebook.index[0], pos_data_notebook.index[-1], 
                                 (pos_data_notebook.index[-1] - pos_data_notebook.index[0])/30000 + 1)

pos_data_new_times = np.arange(pos_data_notebook.index[0], pos_data_notebook.index[-1]+1, 30000)

pos_data_bins.reindex(pos_data_new_times, method='nearest')


In [65]:
np.nonzero(pos_data_notebook.index % 30000 == 0)

In [77]:
pos_data_notebook.index[0] + (30000 - pos_data_notebook.index[0] % 30) 

In [47]:
pos_data_new_times[2]

In [31]:
pos_data_notebook.index[-1]

In [14]:
plt_ranges = [[2350, 3400], [4560, 5550]]
for plt_range in plt_ranges:
    plot_arms_error(center_dec_error, left_dec_error, right_dec_error, plt_range)

    
plt.show()