In [15]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import json
import os
import scipy.signal

from spykshrk.realtime.decoder_process import PointProcessDecoder

import spykshrk.realtime.simulator.nspike_data as nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary
from spykshrk.franklab.pp_decoder.pp_clusterless import calc_learned_state_trans_mat, calc_simple_trans_mat, \
                                                        calc_uniform_trans_mat, \
                                                        calc_observation_intensity, calc_likelihood, \
                                                        calc_occupancy, calc_prob_no_spike, \
                                                        calc_posterior, plot_decode_2d

from spykshrk.franklab.pp_decoder.decode_error import bin_pos_data, convert_pos_for_notebook, calc_error_table, \
                                                      plot_arms_error

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)


idx = pd.IndexSlice


In [16]:
%%time
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv_300samp/bond.config.json'
config = json.load(open(config_file, 'r'))

hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

store = pd.HDFStore(hdf_file, mode='r')
spike_decode = store['rec_3']
stim_lockout = store['rec_11']

In [17]:
%%time
stim_lockout_ranges = stim_lockout.pivot(index='lockout_num',columns='lockout_state', values='timestamp')
stim_lockout_ranges = stim_lockout_ranges.reindex(columns=[1,0])


In [18]:
%%time
# Get table with decode for each spike and generate decode bin mask

dec_bin_size = 300     # Decode bin size in samples (usually 30kHz)

arm_coordinates = [[0, 69], [150, 150+102], [300, 300+104]]

dec_bins = np.floor((spike_decode['timestamp'] - spike_decode['timestamp'][0])/dec_bin_size).astype('int')
spike_decode['dec_bin'] = dec_bins


pos_upper = config['encoder']['position']['upper']
pos_lower = config['encoder']['position']['lower']
pos_num_bins = config['encoder']['position']['bins']
pos_bin_delta = ((pos_upper - pos_lower) / pos_num_bins)

x_bins = np.linspace(0, pos_bin_delta*(pos_num_bins-1), pos_num_bins)
x_bin_edges = np.linspace(0, pos_bin_delta*(pos_num_bins), pos_num_bins+1)

pos_kernel = gaussian(x_bins, x_bins[int(len(x_bins)/2)], 3)
 

In [19]:
# Get real position

nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data


In [20]:
# Transform position into simpler table with only linear position
pos_data_time = pos_data.loc[:, 'time']

pos_data_linpos = pos_data.loc[:,'lin_dist_well']
pos_data_linpos.loc[:, 'lin_vel_center'] = pos_data.loc[:,('lin_vel', 'well_center')]
pos_data_linpos.loc[:, 'seg_idx'] = pos_data.loc[:,('seg_idx', 0)]
pos_data_linpos.loc[:,'timestamps'] = pos_data_time*30000
pos_data_linpos = pos_data_linpos.set_index('timestamps')

# Convert real pos to realtime system linear map (single linear coordinate)

center_pos_flat = pos_data_linpos[pos_data_linpos['seg_idx'] == 1]['well_center']
left_pos_flat = pos_data_linpos[(pos_data_linpos['seg_idx'] == 2) | 
                                (pos_data_linpos['seg_idx'] == 3)]['well_left'] + 150
right_pos_flat = pos_data_linpos[(pos_data_linpos['seg_idx'] == 4) | 
                                 (pos_data_linpos['seg_idx'] == 5)]['well_right'] + 300

center_pos_flat.name = 'linpos_flat'
left_pos_flat.name = 'linpos_flat'
right_pos_flat.name = 'linpos_flat'

linpos_flat = pd.concat([center_pos_flat, left_pos_flat, right_pos_flat])
linpos_flat = linpos_flat.sort_index()

In [21]:
PointProcessDecoder._create_transition_matrix(arange)

In [22]:
linpos_flat

In [23]:
%%time

# Calculate transition matrix
learned_trans_mat = calc_learned_state_trans_mat(linpos_flat, x_bins, arm_coordinates, gauss_smooth_std=3, 
                                                 uniform_offset_gain=0.001)
transition_mat = calc_simple_trans_mat(x_bins)
uniform_transition_mat = calc_uniform_trans_mat(x_bins)
# Loop through each bin and generate the observation distribution from spikes in bin
dec_bin_times, dec_est, bin_num_spikes, firing_rate  = calc_observation_intensity(spike_decode, 
                                                                                  dec_bin_size,
                                                                                  x_bins, 
                                                                                  pos_kernel,
                                                                                  arm_coordinates)

#Precompute prob of no spike from firing rate
occupancy = calc_occupancy(linpos_flat, x_bin_edges, pos_kernel)
prob_no_spike = calc_prob_no_spike(firing_rate, dec_bin_size, occupancy)

# Compute the likelihood of each bin
likelihoods = calc_likelihood(dec_est, bin_num_spikes, prob_no_spike, pos_bin_delta)

# Iteratively calculate posterior
posteriors = calc_posterior(likelihoods, learned_trans_mat, pos_num_bins, pos_bin_delta)

In [24]:
%%time
# Run spykshrk.realtime version of point process decoding

pp_decoder = PointProcessDecoder(pos_range=[config['encoder']['position']['lower'],
                                            config['encoder']['position']['upper']],
                                 pos_bins=config['encoder']['position']['bins'],
                                 time_bin_size=config['pp_decoder']['bin_size'])

pp_decoder.select_ntrodes(config['simulator']['nspike_animal_info']['tetrodes'])

num_time_bins = spike_decode.loc[:,'dec_bin'].max()

groups = spike_decode.groupby('dec_bin')

last_bin_id = 0

spykshrk_posteriors = np.zeros([num_time_bins+1, pos_num_bins])

for bin_id, spikes_in_bin in groups:
    if last_bin_id <= bin_id - 1:
        for bin_no_spk_id in range(last_bin_id + 1, bin_id):
            posterior = pp_decoder.increment_no_spike_bin()
            spykshrk_posteriors[bin_no_spk_id, :] = posterior
        
    for ntrode_id, dec in zip(spikes_in_bin.loc[:, 'ntrode_id'].values, 
                   spikes_in_bin.loc[:, 'x0': 'x{:d}'.format(pos_num_bins-1)].values):
        pp_decoder.add_observation(ntrode_id, dec)
        
    posterior = pp_decoder.increment_bin()
    spykshrk_posteriors[bin_id, :] = posterior
    last_bin_id = bin_id
    

In [60]:
plt_ranges = [[2461 + 800, 2461+900]]
plt_ranges = [[2690, 2700]]
for plt_range in plt_ranges:
    
    plt.figure(figsize=[400,20])
    plt.subplot(2,1,1)
    plot_decode_2d(spykshrk_posteriors, dec_bin_times, stim_lockout_ranges, linpos_flat, plt_range, 1.0)
    plt.subplot(2,1,2)
    plot_decode_2d(posteriors, dec_bin_times, stim_lockout_ranges, linpos_flat, plt_range, 1.0)


plt.show()

In [26]:
dec_est_map = x_bins[np.argmax(posteriors, axis=1)]
dec_est_pos = pd.DataFrame({'est_pos': dec_est_map}, index=pd.Index(data=dec_bin_times, name='timestamp'))

pos_data_bins = bin_pos_data(pos_data_linpos, dec_bin_size)

center_dec_error, left_dec_error, right_dec_error = calc_error_table(pos_data_bins, dec_est_pos,
                                                                     arm_coordinates, 2)

In [27]:
print('median error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.median(center_dec_error['abs_error']),
                                                                       np.median(left_dec_error['abs_error']),
                                                                       np.median(right_dec_error['abs_error'])))

print('mean error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.mean(center_dec_error['abs_error']),
                                                                       np.mean(left_dec_error['abs_error']),
                                                                       np.mean(right_dec_error['abs_error'])))

In [63]:
spyk_dec_est_pos.iloc[(spyk_dec_est_pos['est_pos'] > 404).nonzero()]

In [47]:
spykshrk_trans_mat = PointProcessDecoder._create_transition_matrix(1, 150)
plt.imshow(spykshrk_trans_mat)
plt.show()

In [54]:
spyk_dec_est_pos.query('(est_pos > 256) and (est_pos < 390)')

In [59]:
80736600.0 / 30000

In [61]:
spyk_dec_est_map = x_bins[np.argmax(spykshrk_posteriors, axis=1)]
spyk_dec_est_pos = pd.DataFrame({'est_pos': spyk_dec_est_map}, 
                                index=pd.Index(data=dec_bin_times, name='timestamp'))

pos_data_bins = bin_pos_data(pos_data_linpos, dec_bin_size)

spyk_center_dec_error, spyk_left_dec_error, spyk_right_dec_error = calc_error_table(pos_data_bins,
                                                                                    spyk_dec_est_pos,
                                                                                    [[0, 70], 
                                                                                     [149, 256], 
                                                                                     [290, 407]],
                                                                                    0)

In [62]:
print('median error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.median(spyk_center_dec_error['abs_error']),
                                                                       np.median(spyk_left_dec_error['abs_error']),
                                                                       np.median(spyk_right_dec_error['abs_error'])))

print('mean error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.mean(spyk_center_dec_error['abs_error']),
                                                                     np.mean(spyk_left_dec_error['abs_error']),
                                                                     np.mean(spyk_right_dec_error['abs_error'])))

In [None]:
plot_arms_error(center_dec_error, left_dec_error, right_dec_error, [2350, 3400])
plt.show()