In [1]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import json
import os
import scipy.signal

import spykshrk.realtime.simulator.nspike_data as nspike_data

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)


idx = pd.IndexSlice

def gaussian(x, mu, sig):
    return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))

def normal2D(x, y, sig):
    return np.exp(-(np.power(x, 2.) + np.power(y, 2.)) / (2 * np.power(sig, 2.)))


def apply_no_anim_boundary(x_bins, bounds, image):
    # no-animal boundary
    boundary_ind = np.searchsorted(x_bins, bounds, side='right')
    boundary_ind = np.reshape(boundary_ind, [3,2])

    for bounds in boundary_ind:
        if image.ndim == 1:
            image[bounds[0]:bounds[1]] = 0
        elif image.ndim == 2:
            image[bounds[0]:bounds[1], :] = 0
            image[:, bounds[0]:bounds[1]] = 0
    return image

In [2]:
t = 1 - np.abs(np.linspace(-1, 1, 21))
kernel = t.reshape(21, 1) * t.reshape(1, 21)
kernel /= kernel.sum()

In [3]:
%%time
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/ripple_dec/bond.config.json'
config = json.load(open(config_file, 'r'))

hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

store = pd.HDFStore(hdf_file, mode='r')
spike_decode = store['rec_3']
stim_lockout = store['rec_11']

In [4]:
%%time
stim_lockout_ranges = stim_lockout.pivot(index='lockout_num',columns='lockout_state', values='timestamp')
stim_lockout_ranges = stim_lockout_ranges.reindex(columns=[1,0])


In [5]:
%%time
# Get table with decode for each spike and generate decode bin mask

dec_bin_size = 300     # Decode bin size in samples (usually 30kHz)

x_no_anim_bounds = [69, 150, 150+102, 300, 300+104, 450]

dec_bins = np.floor((spike_decode['timestamp'] - spike_decode['timestamp'][0])/dec_bin_size).astype('int')
spike_decode['dec_bin'] = dec_bins


pos_upper = config['encoder']['position']['upper']
pos_lower = config['encoder']['position']['lower']
pos_num_bins = config['encoder']['position']['bins']
pos_bin_delta = ((pos_upper - pos_lower) / pos_num_bins)

x_bins = np.linspace(0, pos_bin_delta*(pos_num_bins-1), pos_num_bins)
x_bin_edges = np.linspace(0, pos_bin_delta*(pos_num_bins), pos_num_bins+1)

pos_kernel = gaussian(x_bins, x_bins[int(len(x_bins)/2)], 3)


In [6]:
# Get real position

nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data


In [7]:
# Transform position into simpler table with only linear position
pos_data_time = pos_data.loc[:, 'time']

pos_data_linpos = pos_data.loc[:,'lin_dist_well']
pos_data_linpos.loc[:, 'lin_vel_center'] = pos_data.loc[:,('lin_vel', 'well_center')]
pos_data_linpos.loc[:, 'seg_idx'] = pos_data.loc[:,('seg_idx', 0)]
pos_data_linpos.loc[:,'timestamps'] = pos_data_time*30000
pos_data_linpos = pos_data_linpos.set_index('timestamps')

In [8]:
# Convert real pos to realtime system linear map (single linear coordinate)

center_pos_flat = pos_data_linpos[pos_data_linpos['seg_idx'] == 1]['well_center']
left_pos_flat = pos_data_linpos[(pos_data_linpos['seg_idx'] == 2) | 
                                (pos_data_linpos['seg_idx'] == 3)]['well_left'] + 150
right_pos_flat = pos_data_linpos[(pos_data_linpos['seg_idx'] == 4) | 
                                 (pos_data_linpos['seg_idx'] == 5)]['well_right'] + 300

center_pos_flat.name = 'linpos_flat'
left_pos_flat.name = 'linpos_flat'
right_pos_flat.name = 'linpos_flat'

linpos_flat = pd.concat([center_pos_flat, left_pos_flat, right_pos_flat])
linpos_flat = linpos_flat.sort_index()


In [88]:
# Calculate State Transition Matrix

# Smoothing kernel for learned pos transition matrix
xv, yv = np.meshgrid(np.arange(-20,21), np.arange(-20,21))
kernel = normal2D(xv, yv, 1)
kernel /= kernel.sum()

linpos_state = linpos_flat
linpos_ind = np.searchsorted(x_bins, linpos_state, side='right') - 1

# Create learned pos transition matrix
learned_trans_mat = np.zeros([pos_num_bins, pos_num_bins])
for first_pos_ind, second_pos_ind in zip(linpos_ind[:-1], linpos_ind[1:]):
    
    learned_trans_mat[first_pos_ind, second_pos_ind] += 1


# normalize
learned_trans_mat = learned_trans_mat/( learned_trans_mat.sum(axis=0)[None,:])
learned_trans_mat[np.isnan(learned_trans_mat)] = 0

# smooth
learned_trans_mat = sp.signal.convolve2d(learned_trans_mat, kernel, mode='same')
learned_trans_mat = apply_no_anim_boundary(x_bins, x_no_anim_bounds, learned_trans_mat)

# uniform offset
uniform_gain = 0.01
uniform_dist = np.ones(learned_trans_mat.shape)

# no-animal boundary
uniform_dist = apply_no_anim_boundary(x_bins, x_no_anim_bounds, uniform_dist)

#normalize uniform offset
uniform_dist = uniform_dist/( uniform_dist.sum(axis=0)[None,:])
uniform_dist[np.isnan(uniform_dist)] = 0
    
# apply uniform offset
learned_trans_mat = learned_trans_mat * (1 - uniform_gain) + uniform_dist * uniform_gain 

# renormalize
learned_trans_mat = learned_trans_mat/( learned_trans_mat.sum(axis=0)[None,:])
learned_trans_mat[np.isnan(learned_trans_mat)] = 0

plt.figure(figsize=[20,20])
plt.imshow(learned_trans_mat, cmap='hot')
plt.colorbar()

plt.show()

In [10]:
%%time
# Loop through each bin and generate the observation distribution from spikes in bin

dec_bin_ids = np.unique(dec_bins)
dec_est = np.zeros([dec_bin_ids[-1]+1, pos_num_bins])

start_bin_time = np.floor(spike_decode['timestamp'][0] / dec_bin_size) * dec_bin_size
dec_bin_times = np.arange(start_bin_time, start_bin_time + dec_bin_size * len(dec_est), dec_bin_size)

# initialize conditional intensity function
firing_rate = {ntrode_id: np.zeros(pos_num_bins) for ntrode_id in spike_decode['ntrode_id'].unique()}

groups = spike_decode.groupby('dec_bin')
bin_num_spikes = [0] * len(dec_est)

for bin_id, spikes_in_bin in groups:
    dec_in_bin = np.ones(pos_num_bins)
    
    
    bin_num_spikes[bin_id] = len(spikes_in_bin)
    
    # Count spikes for occupancy firing rate (conditional intensity function)
    for ntrode_id, pos in  spikes_in_bin.loc[:, ('ntrode_id', 'position')].values:
        firing_rate[ntrode_id][np.searchsorted(x_bins, pos, side='right') - 1] += 1
    
    for dec_ii, dec in enumerate(spikes_in_bin.loc[:, 'x0':'x{:d}'.format(pos_num_bins-1)].values):
        smooth_dec = np.convolve(dec, pos_kernel, mode='same')
        dec_in_bin = dec_in_bin * smooth_dec
        dec_in_bin = dec_in_bin / (np.sum(dec_in_bin) * pos_bin_delta)

    dec_est[bin_id, :] = dec_in_bin
    
    
# Smooth and normalize firing rate (conditional intensity function)
for fr_key in firing_rate.keys():
    firing_rate[fr_key] = np.convolve(firing_rate[fr_key], pos_kernel, mode='same')

    firing_rate[fr_key] = apply_no_anim_boundary(x_bins, x_no_anim_bounds, firing_rate[fr_key])
    
    firing_rate[fr_key] = firing_rate[fr_key] / (firing_rate[fr_key].sum() * pos_bin_delta)

In [64]:
#Precompute prob of no spike from firing rate

occupancy, occ_bin_edges = np.histogram(linpos_flat, bins=x_bin_edges, normed=True)

occupancy = np.convolve(occupancy, pos_kernel, mode='same')

prob_no_spike = {}
for tet_id, tet_fr in firing_rate.items():
    prob_no_spike[tet_id] = np.exp(-dec_bin_size/30000 * tet_fr / occupancy)


In [65]:
# Compute the likelihood of each bin

likelihoods = np.ones(dec_est.shape)


for num_spikes, (dec_ind, dec_est_bin) in zip(bin_num_spikes, enumerate(dec_est)):
    if num_spikes > 0:
        likelihoods[dec_ind, :] = dec_est_bin
        
        for prob_no in prob_no_spike.values():
            likelihoods[dec_ind, :] *= prob_no
    else:
        
        for prob_no in prob_no_spike.values():
            likelihoods[dec_ind, :] *= prob_no
    
    # Normalize
    likelihoods[dec_ind, :] = likelihoods[dec_ind, :] / (likelihoods[dec_ind, :].sum() * pos_bin_delta)
    

In [82]:
%%time
# Compute artificial gaussian state transition matrix

last_posterior = np.ones(pos_num_bins)

# Setup transition matrix
transition_mat = np.ones([pos_num_bins, pos_num_bins])
for bin_ii in range(pos_num_bins):
    transition_mat[bin_ii, :] = gaussian(x_bins, x_bins[bin_ii], 3)

# uniform offset
uniform_gain = 0.01
uniform_dist = np.ones(transition_mat.shape)

# no-animal boundary
boundary_ind = np.searchsorted(x_bins, [69, 150, 150+102, 300, 300+104, 450], side='right')
boundary_ind = np.reshape(boundary_ind, [3,2])

#for bounds in boundary_ind:
#    uniform_dist[bounds[0]:bounds[1], :] = 0
#    uniform_dist[:, bounds[0]:bounds[1]] = 0
#    transition_mat[bounds[0]:bounds[1], :] = 0
#    transition_mat[:, bounds[0]:bounds[1]] = 0

# normalize transition matrix
transition_mat = transition_mat/( transition_mat.sum(axis=0)[None,:])
    
#normalize uniform offset
uniform_dist = uniform_dist/( uniform_dist.sum(axis=0)[None,:])
    
# apply uniform offset
transition_mat = transition_mat * (1 - uniform_gain) + uniform_dist * uniform_gain 

    
plt.figure(figsize=[10,10])
plt.imshow(transition_mat, cmap='hot')
plt.colorbar()
plt.show()

In [67]:
# Iteratively calculate posterior

posteriors = np.zeros(dec_est.shape)
    
for like_ii, like in enumerate(likelihoods):
    posteriors[like_ii, :] = like * (learned_trans_mat * last_posterior).sum(axis=1)
    posteriors[like_ii, :] = posteriors[like_ii, :] / (posteriors[like_ii, :].sum() * pos_bin_delta)
    last_posterior = posteriors[like_ii, :]

In [68]:
def plot_decode_2d(dec_est, dec_bin_times, stim_lockout_ranges, linpos_flat, plt_range):
    stim_lockout_ranges_sec = stim_lockout_ranges/30000
    stim_lockout_range_sec_sub = stim_lockout_ranges_sec[(stim_lockout_ranges_sec[1] > plt_range[0]) & (stim_lockout_ranges_sec[0] < plt_range[1])]
    
    plt.imshow(dec_est[(dec_bin_times > plt_range[0]*30000) & (dec_bin_times < plt_range[1]*30000)].transpose(), 
               extent=[plt_range[0], plt_range[1], 0, 450], origin='lower', aspect='auto', cmap='hot', zorder=0)

    plt.colorbar()

    # Plot linear position
    linpos_index_s = linpos_flat.index / 30000
    index_mask = (linpos_index_s > plt_range[0]) & (linpos_index_s < plt_range[1])

    plt.plot(linpos_index_s[index_mask],
             linpos_flat.values[index_mask], 'c.', zorder=1, markersize=5)

    
    plt.plot(stim_lockout_range_sec_sub.values.transpose(), np.tile([[440], [440]], [1, len(stim_lockout_range_sec_sub)]), 'c-*' )

    for stim_lockout in stim_lockout_range_sec_sub.values:
        plt.axvspan(stim_lockout[0], stim_lockout[1], facecolor='#AAAAAA', alpha=0.3)

    plt.plot(plt_range, [74, 74], '--', color='gray')
    plt.plot(plt_range, [148, 148], '--', color='gray')
    plt.plot(plt_range, [256, 256], '--', color='gray')
    plt.plot(plt_range, [298, 298], '--', color='gray')
    plt.plot(plt_range, [407, 407], '--', color='gray')


In [84]:
plt_ranges = [[2930, 3050]]
             
for plt_range in plt_ranges:
    
    plt.figure(figsize=[400,20])
    plt.subplot(2,1,1)
    plot_decode_2d(dec_est, dec_bin_times, stim_lockout_ranges, linpos_flat, plt_range)
    plt.subplot(2,1,2)
    plot_decode_2d(posteriors, dec_bin_times, stim_lockout_ranges, linpos_flat, plt_range)


plt.show()

In [70]:
likelihood_test = np.ones(len(x_bins))

plt.figure( figsize=(40,20))
for tet_ii, tet_fr in enumerate(firing_rate.values()):
    likelihood_test *= np.exp(-dec_bin_size/30000 * (tet_fr / occupancy))
    #for bound in boundary_ind:
    #    likelihood_test[bound[0]:bound[1]] = 0
    
    #likelihood_test = likelihood_test / (likelihood_test.sum() * pos_bin_delta)
    
    
    plt.subplot(9,2,tet_ii*2 + 1)
    plt.plot( likelihood_test)
    plt.subplot(9,2, tet_ii*2+2)
    plt.plot(x_bins, tet_fr/occupancy)

likelihood_test[np.isinf(likelihood_test)] = 0

likelihood_test = np.convolve(likelihood_test, pos_kernel, mode='same')
plt.subplot(9,2, 17)
plt.plot(x_bins, likelihood_test)
    
plt.show()

In [71]:

plt_ii = 79953
num_spikes = bin_num_spikes[plt_ii]
plt.figure(figsize=[10*num_spikes,5])
plt_spk_dec = spike_decode.query('dec_bin == {}'.format(plt_ii))

spk_dec_single_agg = np.ones(len(x_bins))
for subplt_ii, spk_dec_single in enumerate(plt_spk_dec.loc[:, 'x0':'x{:d}'.format(pos_num_bins-1)].values):
    plt.subplot(1,6,subplt_ii+1)
    plt.plot(x_bins, spk_dec_single)
    spk_dec_single_agg *= spk_dec_single + 0.001
    
plt.subplot(1,6,6)
plt.plot(x_bins, spk_dec_single_agg)
    
plt.figure(figsize=[30,5])
print("ii {} had {} spikes".format(plt_ii, num_spikes))
plt.subplot(1,4,1)
plt.plot(x_bins, dec_est[plt_ii,:])
plt.title('aggregate observations')
plt.subplot(1,4,2)
plt.plot(x_bins, likelihoods[plt_ii,:])
plt.title('likelihoods')
plt.subplot(1,4,3)
plt.plot(x_bins, posteriors[plt_ii-1,:])
plt.title('last_posterior')
plt.subplot(1,4,4)
plt.plot(x_bins, posteriors[plt_ii,:])
plt.title('posteriors')
plt.show()

In [72]:
np.searchsorted(dec_bin_times, 3260*30000)
np.sum(dec_est[79900]) * 3
(np.sum(posteriors, axis=1) * 3).all()

bin_num_spikes[79900]

In [73]:
#plt.plot(dec_est[79900])
plt.plot(posteriors[79900,:])

plt.show()


In [74]:
linpos_2cm = linpos_flat[pos_data_linpos['lin_vel_center'] > 5]
linpos_ind = np.searchsorted(x_bins, linpos_2cm, side='right') - 1
print(len(linpos_ind[1:]))

pos_ind_bootstrap = np.random.randint(0, len(linpos_2cm)-10, 1000000)

learned_trans_mat_fast = np.zeros([pos_num_bins, pos_num_bins])
for first_pos in pos_ind_bootstrap:
    learned_trans_mat_fast[linpos_ind[first_pos], linpos_ind[first_pos+10]] += 1
    
plt.figure(figsize=[10,10])
plt.imshow(learned_trans_mat_fast, cmap='hot')
plt.colorbar()
plt.show()

In [75]:
boundary_ind_mat = np.reshape(boundary_ind, [3,2])
select_valid_bound = []
for bounds in boundary_ind_mat:
    select_valid_bound += range(bounds[0], bounds[1]+1)

learned_trans_mat[0:22, :]


In [76]:
plt.plot(occupancy)
plt.show()

In [77]:
likelihood_test = np.ones(len(x_bins))

plt.figure( figsize=(40,20))
for tet_ii, tet_fr in enumerate(firing_rate.values()):
    likelihood_test *= np.exp(-dec_bin_size/30000/10 * (tet_fr + 1))
    for bound in boundary_ind:
        likelihood_test[bound[0]:bound[1]] = 0
    
    #likelihood_test = likelihood_test / (likelihood_test.sum() * pos_bin_delta)
    
    plt.subplot(8,2,tet_ii*2 + 1)
    plt.plot( likelihood_test)
    plt.subplot(8,2, tet_ii*2+2)
    plt.plot(x_bins, tet_fr)
    
plt.show()

In [78]:
spike_decode

In [79]:
for tet_fr in firing_rate.values():
    plt.figure()
    plt.plot(x_bins, ( tet_fr))

plt.show()

In [80]:
plt_ii = 79900
num_spikes = bin_num_spikes[plt_ii]
plt.figure(figsize=[10*num_spikes,5])
plt_spk_dec = spike_decode.query('dec_bin == {}'.format(plt_ii))

for subplt_ii, spk_dec_single in enumerate(plt_spk_dec.loc[:, 'x0':'x149'].values):
    plt.subplot(1,4,subplt_ii+1)
    plt.plot(x_bins, spk_dec_single)
    
plt.figure(figsize=[30,5])
print("ii {} had {} spikes".format(plt_ii, num_spikes))
plt.subplot(1,4,1)
plt.plot(x_bins, dec_est[plt_ii,:])
plt.title('aggregate observations')
plt.subplot(1,4,2)
plt.plot(x_bins, likelihoods[plt_ii,:])
plt.title('likelihoods')
plt.subplot(1,4,3)
plt.plot(x_bins, posteriors[plt_ii-1,:])
plt.title('last_posterior')
plt.subplot(1,4,4)
plt.plot(x_bins, posteriors[plt_ii,:])
plt.title('posteriors')
plt.show()

In [None]:
spikes_in_bin.loc[:, ['trode_id', 'position']]

In [None]:
#plt.plot(firing_rate)
plt.plot(np.exp(-dec_bin_size/30000 * firing_rate))
plt.show()
