In [32]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import json
import os
import scipy.signal

import spykshrk.realtime.simulator.nspike_data as nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary
from spykshrk.franklab.pp_decoder.pp_clusterless import calc_learned_state_trans_mat, calc_simple_trans_mat, \
                                                        calc_uniform_trans_mat, \
                                                        calc_observation_intensity, calc_likelihood, \
                                                        calc_occupancy, calc_prob_no_spike, \
                                                        calc_posterior, plot_decode_2d

from spykshrk.franklab.pp_decoder.decode_error import bin_pos_data, convert_pos_for_notebook, calc_error_table, \
                                                      conv_center_pos, conv_left_pos, conv_right_pos, \
                                                      plot_arms_error
                
%load_ext Cython

%matplotlib inline

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)

idx = pd.IndexSlice


In [33]:
%%time
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/ripple_dec/bond.config.json'
config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv/bond.config.json'
config = json.load(open(config_file, 'r'))

hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

store = pd.HDFStore(hdf_file, mode='r')
spike_decode = store['rec_3']
stim_lockout = store['rec_11']

In [34]:
%%time
stim_lockout_ranges = stim_lockout.pivot(index='lockout_num',columns='lockout_state', values='timestamp')
stim_lockout_ranges = stim_lockout_ranges.reindex(columns=[1,0])


In [35]:
%%time
# Get table with decode for each spike and generate decode bin mask

dec_bin_size = 30     # Decode bin size in samples (usually 30kHz)

arm_coordinates = [[0, 69], [150, 150+102], [300, 300+104]]

dec_bins = np.floor((spike_decode['timestamp'] - spike_decode['timestamp'][0])/dec_bin_size).astype('int')
spike_decode['dec_bin'] = dec_bins


pos_upper = config['encoder']['position']['upper']
pos_lower = config['encoder']['position']['lower']
pos_num_bins = config['encoder']['position']['bins']
pos_bin_delta = ((pos_upper - pos_lower) / pos_num_bins)

x_bins = np.linspace(0, pos_bin_delta*(pos_num_bins-1), pos_num_bins)
x_bin_edges = np.linspace(0, pos_bin_delta*(pos_num_bins), pos_num_bins+1)

pos_kernel = gaussian(x_bins, x_bins[int(len(x_bins)/2)], 3)


In [36]:
# Get real position

nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data


In [37]:
# Transform position into simpler table with only linear position
pos_data_time = pos_data.loc[:, 'time']

pos_data_linpos = pos_data.loc[:,'lin_dist_well']
pos_data_linpos.loc[:, 'lin_vel_center'] = pos_data.loc[:,('lin_vel', 'well_center')]
pos_data_linpos.loc[:, 'seg_idx'] = pos_data.loc[:,('seg_idx', 0)]
pos_data_linpos.loc[:,'timestamps'] = pos_data_time*30000
pos_data_linpos = pos_data_linpos.set_index('timestamps')

# Convert real pos to realtime system linear map (single linear coordinate)

center_pos_flat = pos_data_linpos[pos_data_linpos['seg_idx'] == 1]['well_center']
left_pos_flat = pos_data_linpos[(pos_data_linpos['seg_idx'] == 2) | 
                                (pos_data_linpos['seg_idx'] == 3)]['well_left'] + 150
right_pos_flat = pos_data_linpos[(pos_data_linpos['seg_idx'] == 4) | 
                                 (pos_data_linpos['seg_idx'] == 5)]['well_right'] + 300

center_pos_flat.name = 'linpos_flat'
left_pos_flat.name = 'linpos_flat'
right_pos_flat.name = 'linpos_flat'

linpos_flat = pd.concat([center_pos_flat, left_pos_flat, right_pos_flat])
linpos_flat = linpos_flat.sort_index()


In [38]:
%%time

# Calculate State Transition Matrix
learned_trans_mat = calc_learned_state_trans_mat(linpos_flat, x_bins, arm_coordinates, gauss_smooth_std=3,
                                                uniform_offset_gain=0.001)
transition_mat = calc_simple_trans_mat(x_bins)
uniform_mat = calc_uniform_trans_mat(x_bins)

# Loop through each bin and generate the observation distribution from spikes in bin
dec_bin_times, dec_est, bin_num_spikes, firing_rate  = calc_observation_intensity(spike_decode, 
                                                                                  dec_bin_size,
                                                                                  x_bins, 
                                                                                  pos_kernel,
                                                                                  arm_coordinates)

#Precompute prob of no spike from firing rate
occupancy = calc_occupancy(linpos_flat, x_bin_edges, pos_kernel)
prob_no_spike = calc_prob_no_spike(firing_rate, dec_bin_size, occupancy)

# Calculate likelihood
likelihoods = calc_likelihood(dec_est, bin_num_spikes, prob_no_spike, pos_bin_delta)

# Iteratively calculate posterior
posteriors = calc_posterior(likelihoods, learned_trans_mat, pos_num_bins, pos_bin_delta)

In [39]:
plt.figure(figsize=[20, 10])
plt.plot(learned_trans_mat[10,:])
plt.xlim([0,50])
plt.show()

In [64]:
test_trans_mat = calc_learned_state_trans_mat(linpos_flat, x_bins, arm_coordinates, gauss_smooth_std=0.01,
                                                uniform_offset_gain=0.001)
plt.figure(figsize=[20,20])
plt.imshow(test_trans_mat)
plt.colorbar()
plt.show()

In [41]:
dec_bin_size

In [42]:
%%time

dec_est_map = x_bins[np.argmax(posteriors, axis=1)]
dec_est_pos = pd.DataFrame({'est_pos': dec_est_map}, index=pd.Index(data=dec_bin_times, name='timestamp'))

pos_data_bins = bin_pos_data(pos_data_linpos, dec_bin_size)

In [43]:
arm_coordinates

In [44]:
center_dec_error, left_dec_error, right_dec_error = calc_error_table(pos_data_bins, dec_est_pos,
                                                                     arm_coordinates, 0)

In [45]:
plot_arms_error(center_dec_error, left_dec_error, right_dec_error, [2350, 3400])
plt.show()

In [46]:
print('median error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.median(center_dec_error['abs_error']),
                                                                       np.median(left_dec_error['abs_error']),
                                                                       np.median(right_dec_error['abs_error'])))

print('mean error center: {:0.5}, left: {:0.5}, right: {:.5}'.format(np.mean(center_dec_error['abs_error']),
                                                                       np.mean(left_dec_error['abs_error']),
                                                                       np.mean(right_dec_error['abs_error'])))

In [47]:
plt_ranges = [[2461 + 810, 2461+850]]
             
for plt_range in plt_ranges:
    
    plt.figure(figsize=[400,20])
    plt.subplot(2,1,1)
    plot_decode_2d(dec_est, dec_bin_times, stim_lockout_ranges, linpos_flat, plt_range)
    plt.xticks(np.arange(plt_range[0], plt_range[1], 1.0))
    plt.clim(0, 0.1)
    
    plt.subplot(2,1,2)
    plot_decode_2d(posteriors, dec_bin_times, stim_lockout_ranges, linpos_flat, plt_range)
    plt.xticks(np.arange(plt_range[0], plt_range[1], 1.0))

    plt.clim(0, 0.1)
plt.show()

In [48]:
plt_ranges = [[3311, 3312]]
             
for plt_range in plt_ranges:
    
    plt.figure(figsize=[20,10])

    plot_decode_2d(posteriors, dec_bin_times, stim_lockout_ranges, linpos_flat, plt_range, 1.0)


plt.show()

In [49]:
np.nonzero((dec_bin_times > 3265*30000) & (dec_bin_times < 3266*30000) )

In [50]:
dec_bin_times[0]/30000

In [51]:
count = -10

In [52]:
count += 10
plt.figure(figsize=[20,10])
for ii in range(10):
    plt.subplot(2,5,ii+1)
    plt.plot(posteriors[80401+ii+count])
    
plt.show()

In [53]:
plt.figure(figsize=(20,10))
plt.plot(np.convolve(spike_decode.loc[300002, 'x0':'x449'].values, pos_kernel, mode='same'))

plt.show()

In [54]:
plt.figure(figsize=(20,20))

for ii, (tet_id, tet_fr) in enumerate(firing_rate.items()):
    plt.subplot(9,1,ii+1)
    
    plt.plot(tet_fr)

plt.show()

In [55]:
plt.figure(figsize=(20,20))

for ii, (tet_id, tet_prob_no) in enumerate(prob_no_spike.items()):
    plt.subplot(9,1,ii+1)
    
    plt.plot(tet_prob_no)

plt.show()

In [56]:
likelihood_test = np.ones(len(x_bins))

plt.figure( figsize=(40,20))
for tet_ii, tet_fr in enumerate(firing_rate.values()):
    likelihood_test *= np.exp(-dec_bin_size/30000 * (tet_fr / occupancy))
    
    plt.subplot(9,2,tet_ii*2 + 1)
    plt.plot( likelihood_test)
    plt.subplot(9,2, tet_ii*2+2)
    plt.plot(x_bins, tet_fr/occupancy)

likelihood_test[np.isinf(likelihood_test)] = 0

likelihood_test = np.convolve(likelihood_test, pos_kernel, mode='same')
plt.subplot(9,2, 17)
plt.plot(x_bins, likelihood_test)
    
plt.show()

In [57]:

plt_ii = 79953
num_spikes = bin_num_spikes[plt_ii]
plt.figure(figsize=[10*num_spikes,5])
plt_spk_dec = spike_decode.query('dec_bin == {}'.format(plt_ii))

spk_dec_single_agg = np.ones(len(x_bins))
for subplt_ii, spk_dec_single in enumerate(plt_spk_dec.loc[:, 'x0':'x{:d}'.format(pos_num_bins-1)].values):
    plt.subplot(1,16,subplt_ii+1)
    plt.plot(x_bins, spk_dec_single)
    spk_dec_single_agg *= spk_dec_single + 0.001
    
plt.subplot(1,16,6)
plt.plot(x_bins, spk_dec_single_agg)
    
plt.figure(figsize=[30,5])
print("ii {} had {} spikes".format(plt_ii, num_spikes))
plt.subplot(1,4,1)
plt.plot(x_bins, dec_est[plt_ii,:])
plt.title('aggregate observations')
plt.subplot(1,4,2)
plt.plot(x_bins, likelihoods[plt_ii,:])
plt.title('likelihoods')
plt.subplot(1,4,3)
plt.plot(x_bins, posteriors[plt_ii-1,:])
plt.title('last_posterior')
plt.subplot(1,4,4)
plt.plot(x_bins, posteriors[plt_ii,:])
plt.title('posteriors')
plt.show()

In [58]:
np.searchsorted(dec_bin_times, 3260*30000)
np.sum(dec_est[79900]) * 3
(np.sum(posteriors, axis=1) * 3).all()

bin_num_spikes[79900]

In [59]:
#plt.plot(dec_est[79900])
plt.plot(posteriors[79900,:])

plt.show()


In [60]:
plt.plot(occupancy)
plt.show()

In [61]:
for tet_fr in firing_rate.values():
    plt.figure()
    plt.plot(x_bins, ( tet_fr))

plt.show()

In [62]:
plt_ii = 79900
num_spikes = bin_num_spikes[plt_ii]
plt.figure(figsize=[10*num_spikes,5])
plt_spk_dec = spike_decode.query('dec_bin == {}'.format(plt_ii))

for subplt_ii, spk_dec_single in enumerate(plt_spk_dec.loc[:, 'x0':'x449'].values):
    plt.subplot(1,4,subplt_ii+1)
    plt.plot(x_bins, spk_dec_single)
    
plt.figure(figsize=[30,5])
print("ii {} had {} spikes".format(plt_ii, num_spikes))
plt.subplot(1,4,1)
plt.plot(x_bins, dec_est[plt_ii,:])
plt.title('aggregate observations')
plt.subplot(1,4,2)
plt.plot(x_bins, likelihoods[plt_ii,:])
plt.title('likelihoods')
plt.subplot(1,4,3)
plt.plot(x_bins, posteriors[plt_ii-1,:])
plt.title('last_posterior')
plt.subplot(1,4,4)
plt.plot(x_bins, posteriors[plt_ii,:])
plt.title('posteriors')
plt.show()