In [1]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import json
import os
import scipy.signal

from spykshrk.realtime.decoder_process import PointProcessDecoder

import spykshrk.realtime.simulator.nspike_data as nspike_data

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)


idx = pd.IndexSlice

def gaussian(x, mu, sig):
    return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))

def normal2D(x, y, sig):
    return np.exp(-(np.power(x, 2.) + np.power(y, 2.)) / (2 * np.power(sig, 2.)))


def apply_no_anim_boundary(x_bins, bounds, image):
    # no-animal boundary
    boundary_ind = np.searchsorted(x_bins, bounds, side='right')
    boundary_ind = np.reshape(boundary_ind, [3,2])

    for bounds in boundary_ind:
        if image.ndim == 1:
            image[bounds[0]:bounds[1]] = 0
        elif image.ndim == 2:
            image[bounds[0]:bounds[1], :] = 0
            image[:, bounds[0]:bounds[1]] = 0
    return image

In [2]:
%%time
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/dec_100uv/bond.config.json'
config = json.load(open(config_file, 'r'))

hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

store = pd.HDFStore(hdf_file, mode='r')
spike_decode = store['rec_3']
stim_lockout = store['rec_11']

In [3]:
%%time
stim_lockout_ranges = stim_lockout.pivot(index='lockout_num',columns='lockout_state', values='timestamp')
stim_lockout_ranges = stim_lockout_ranges.reindex(columns=[1,0])


In [4]:
%%time
# Get table with decode for each spike and generate decode bin mask

dec_bin_size = 3000     # Decode bin size in samples (usually 30kHz)

x_no_anim_bounds = [69, 150, 150+102, 300, 300+104, 450]

dec_bins = np.floor((spike_decode['timestamp'] - spike_decode['timestamp'][0])/dec_bin_size).astype('int')
spike_decode['dec_bin'] = dec_bins


pos_upper = config['encoder']['position']['upper']
pos_lower = config['encoder']['position']['lower']
pos_num_bins = config['encoder']['position']['bins']
pos_bin_delta = ((pos_upper - pos_lower) / pos_num_bins)

x_bins = np.linspace(0, pos_bin_delta*(pos_num_bins-1), pos_num_bins)
x_bin_edges = np.linspace(0, pos_bin_delta*(pos_num_bins), pos_num_bins+1)

pos_kernel = gaussian(x_bins, x_bins[int(len(x_bins)/2)], 3)


In [5]:
# Get real position

nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data


In [6]:
# Transform position into simpler table with only linear position
pos_data_time = pos_data.loc[:, 'time']

pos_data_linpos = pos_data.loc[:,'lin_dist_well']
pos_data_linpos.loc[:, 'lin_vel_center'] = pos_data.loc[:,('lin_vel', 'well_center')]
pos_data_linpos.loc[:, 'seg_idx'] = pos_data.loc[:,('seg_idx', 0)]
pos_data_linpos.loc[:,'timestamps'] = pos_data_time*30000
pos_data_linpos = pos_data_linpos.set_index('timestamps')

In [7]:
# Convert real pos to realtime system linear map (single linear coordinate)

center_pos_flat = pos_data_linpos[pos_data_linpos['seg_idx'] == 1]['well_center']
left_pos_flat = pos_data_linpos[(pos_data_linpos['seg_idx'] == 2) | 
                                (pos_data_linpos['seg_idx'] == 3)]['well_left'] + 150
right_pos_flat = pos_data_linpos[(pos_data_linpos['seg_idx'] == 4) | 
                                 (pos_data_linpos['seg_idx'] == 5)]['well_right'] + 300

center_pos_flat.name = 'linpos_flat'
left_pos_flat.name = 'linpos_flat'
right_pos_flat.name = 'linpos_flat'

linpos_flat = pd.concat([center_pos_flat, left_pos_flat, right_pos_flat])
linpos_flat = linpos_flat.sort_index()


In [8]:
%%time
# Compute artificial gaussian state transition matrix

# Setup transition matrix
transition_mat = np.ones([pos_num_bins, pos_num_bins])
for bin_ii in range(pos_num_bins):
    transition_mat[bin_ii, :] = gaussian(x_bins, x_bins[bin_ii], 3)

# uniform offset
uniform_gain = 0.01
uniform_dist = np.ones(transition_mat.shape)

# normalize transition matrix
transition_mat = transition_mat/( transition_mat.sum(axis=0)[None,:])
    
#normalize uniform offset
uniform_dist = uniform_dist/( uniform_dist.sum(axis=0)[None,:])
    
# apply uniform offset
transition_mat = transition_mat * (1 - uniform_gain) + uniform_dist * uniform_gain 

    
plt.figure(figsize=[10,10])
plt.imshow(transition_mat, cmap='hot')
plt.colorbar()
plt.show()

In [9]:

spike_decode

In [None]:

pp_decoder = PointProcessDecoder(pos_range=[config['encoder']['position']['lower'],
                                                 config['encoder']['position']['upper']],
                                      pos_bins=config['encoder']['position']['bins'],
                                      time_bin_size=config['pp_decoder']['bin_size'])

pp_decoder.select_ntrodes(config['simulator']['nspike_animal_info']['tetrodes'])

num_time_bins = spike_decode.loc[:,'dec_bin'].max()

groups = spike_decode.groupby('dec_bin')

last_bin_id = 0

spykshrk_posteriors = np.zeros([num_time_bins+1, pos_num_bins])

for bin_id, spikes_in_bin in groups:
    if last_bin_id <= bin_id - 1:
        for bin_no_spk_id in range(last_bin_id + 1, bin_id):
            posterior = pp_decoder.increment_no_spike_bin()
            spykshrk_posteriors[bin_no_spk_id, :] = posterior
        
    for ntrode_id, dec in zip(spikes_in_bin.loc[:, 'ntrode_id'].values, 
                   spikes_in_bin.loc[:, 'x0': 'x{:d}'.format(pos_num_bins-1)].values):
        pp_decoder.add_observation(ntrode_id, dec)
        
    posterior = pp_decoder.increment_bin()
    spykshrk_posteriors[bin_id, :] = posterior
    last_bin_id = bin_id
    

In [None]:
%%time
# Loop through each bin and generate the observation distribution from spikes in bin

dec_bin_ids = np.unique(dec_bins)
dec_est = np.zeros([dec_bin_ids[-1]+1, pos_num_bins])

start_bin_time = np.floor(spike_decode['timestamp'][0] / dec_bin_size) * dec_bin_size
dec_bin_times = np.arange(start_bin_time, start_bin_time + dec_bin_size * len(dec_est), dec_bin_size)

# initialize conditional intensity function
firing_rate = {ntrode_id: np.zeros(pos_num_bins) for ntrode_id in spike_decode['ntrode_id'].unique()}

groups = spike_decode.groupby('dec_bin')
bin_num_spikes = [0] * len(dec_est)

for bin_id, spikes_in_bin in groups:
    dec_in_bin = np.ones(pos_num_bins)
    
    
    bin_num_spikes[bin_id] = len(spikes_in_bin)
    
    # Count spikes for occupancy firing rate (conditional intensity function)
    for ntrode_id, pos in  spikes_in_bin.loc[:, ('ntrode_id', 'position')].values:
        firing_rate[ntrode_id][np.searchsorted(x_bins, pos, side='right') - 1] += 1
    
    for dec_ii, dec in enumerate(spikes_in_bin.loc[:, 'x0':'x{:d}'.format(pos_num_bins-1)].values):
        smooth_dec = np.convolve(dec, pos_kernel, mode='same')
        dec_in_bin = dec_in_bin * smooth_dec
        dec_in_bin = dec_in_bin / (np.sum(dec_in_bin) * pos_bin_delta)

    dec_est[bin_id, :] = dec_in_bin
    
    
# Smooth and normalize firing rate (conditional intensity function)
for fr_key in firing_rate.keys():
    firing_rate[fr_key] = np.convolve(firing_rate[fr_key], pos_kernel, mode='same')

    firing_rate[fr_key] = apply_no_anim_boundary(x_bins, x_no_anim_bounds, firing_rate[fr_key])
    
    firing_rate[fr_key] = firing_rate[fr_key] / (firing_rate[fr_key].sum() * pos_bin_delta)

In [None]:
#Precompute prob of no spike from firing rate

occupancy, occ_bin_edges = np.histogram(linpos_flat, bins=x_bin_edges, normed=True)

occupancy = np.convolve(occupancy, pos_kernel, mode='same')

prob_no_spike = {}
for tet_id, tet_fr in firing_rate.items():
    prob_no_spike[tet_id] = np.exp(-dec_bin_size/30000 * tet_fr / occupancy)


In [None]:
# Compute the likelihood of each bin

likelihoods = np.ones(dec_est.shape)


for num_spikes, (dec_ind, dec_est_bin) in zip(bin_num_spikes, enumerate(dec_est)):
    if num_spikes > 0:
        likelihoods[dec_ind, :] = dec_est_bin
        
        for prob_no in prob_no_spike.values():
            likelihoods[dec_ind, :] *= prob_no
    else:
        
        for prob_no in prob_no_spike.values():
            likelihoods[dec_ind, :] *= prob_no
    
    # Normalize
    likelihoods[dec_ind, :] = likelihoods[dec_ind, :] / (likelihoods[dec_ind, :].sum() * pos_bin_delta)
    

In [None]:
# Iteratively calculate posterior
last_posterior = np.ones(pos_num_bins)

posteriors = np.zeros(dec_est.shape)
    
for like_ii, like in enumerate(likelihoods):
    posteriors[like_ii, :] = like * (transition_mat * last_posterior).sum(axis=1)
    posteriors[like_ii, :] = posteriors[like_ii, :] / (posteriors[like_ii, :].sum() * pos_bin_delta)
    last_posterior = posteriors[like_ii, :]

In [None]:
def plot_decode_2d(dec_est, dec_bin_times, stim_lockout_ranges, linpos_flat, plt_range):
    stim_lockout_ranges_sec = stim_lockout_ranges/30000
    stim_lockout_range_sec_sub = stim_lockout_ranges_sec[(stim_lockout_ranges_sec[1] > plt_range[0]) & (stim_lockout_ranges_sec[0] < plt_range[1])]
    
    plt.imshow(dec_est[(dec_bin_times > plt_range[0]*30000) & (dec_bin_times < plt_range[1]*30000)].transpose(), 
               extent=[plt_range[0], plt_range[1], 0, 450], origin='lower', aspect='auto', cmap='hot', zorder=0)

    plt.colorbar()

    # Plot linear position
    linpos_index_s = linpos_flat.index / 30000
    index_mask = (linpos_index_s > plt_range[0]) & (linpos_index_s < plt_range[1])

    plt.plot(linpos_index_s[index_mask],
             linpos_flat.values[index_mask], 'c.', zorder=1, markersize=5)

    
    plt.plot(stim_lockout_range_sec_sub.values.transpose(), np.tile([[440], [440]], [1, len(stim_lockout_range_sec_sub)]), 'c-*' )

    for stim_lockout in stim_lockout_range_sec_sub.values:
        plt.axvspan(stim_lockout[0], stim_lockout[1], facecolor='#AAAAAA', alpha=0.3)

    plt.plot(plt_range, [74, 74], '--', color='gray')
    plt.plot(plt_range, [148, 148], '--', color='gray')
    plt.plot(plt_range, [256, 256], '--', color='gray')
    plt.plot(plt_range, [298, 298], '--', color='gray')
    plt.plot(plt_range, [407, 407], '--', color='gray')


In [None]:
dec_bin_times/30000

In [None]:
plt_ranges = [[2461 + 800, 2461+900]]
             
for plt_range in plt_ranges:
    
    plt.figure(figsize=[100,20])
    plt.subplot(2,1,1)
    plot_decode_2d(spykshrk_posteriors, dec_bin_times, stim_lockout_ranges, linpos_flat, plt_range)
    plt.subplot(2,1,2)
    plot_decode_2d(posteriors, dec_bin_times, stim_lockout_ranges, linpos_flat, plt_range)


plt.show()