In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os

import spykshrk.realtime.simulator.nspike_data as nspike_data

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)


idx = pd.IndexSlice

In [2]:
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/ripple_dec/bond.config.json'
config = json.load(open(config_file, 'r'))

hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

store = pd.HDFStore(hdf_file, mode='r')
spike_decode = store['rec_3']
stim_lockout = store['rec_11']

In [3]:
stim_lockout_ranges = stim_lockout.pivot(index='lockout_num',columns='lockout_state', values='timestamp')
stim_lockout_ranges = stim_lockout_ranges.reindex(columns=[1,0])


In [4]:
# Get table with decode for each spike and generate decode bin mask

dec_bin_size = 750     # Decode bin size in samples (usually 30kHz)

dec_bins = np.floor((spike_decode['timestamp'] - spike_decode['timestamp'][0])/dec_bin_size).astype('int')
spike_decode['dec_bin'] = dec_bins

In [5]:
# Loop through each bin and generate normalized posterior estimate of location

pos_bin_delta = ((config['encoder']['position']['upper'] - config['encoder']['position']['lower']) / 
                 config['encoder']['position']['bins'])
pos_num_bins = config['encoder']['position']['bins']

dec_bin_ids = np.unique(dec_bins)
dec_est = np.zeros([dec_bin_ids[-1]+1, pos_num_bins])
dec_est_bin_time = np.floor(spike_decode['timestamp'][0] / dec_bin_size) * dec_bin_size

start_bin_time = np.floor(spike_decode['timestamp'][0] / dec_bin_size) * dec_bin_size
dec_bin_times = np.arange(start_bin_time, start_bin_time + dec_bin_size * len(dec_est), dec_bin_size)

for bin_id in dec_bin_ids:
    spikes_in_bin = spike_decode[spike_decode['dec_bin'] == bin_id]
    dec_in_bin = np.ones(pos_num_bins)
    for dec in spikes_in_bin.loc[:, 'x0':'x{:d}'.format(pos_num_bins-1)].values:
        dec_in_bin *= dec               
        dec_in_bin = dec_in_bin / (np.sum(dec_in_bin) * pos_bin_delta)


    dec_est[bin_id, :] = dec_in_bin

In [6]:
# Get real position

nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data


In [7]:
# Transform position into simpler table with only linear position
pos_data_time = pos_data.loc[:, 'time']

pos_data_linpos = pos_data.loc[:,'lin_dist_well']
pos_data_linpos.loc[:, 'lin_vel_center'] = pos_data.loc[:,('lin_vel', 'well_center')]
pos_data_linpos.loc[:, 'seg_idx'] = pos_data.loc[:,('seg_idx', 0)]
pos_data_linpos.loc[:,'timestamps'] = pos_data_time*30000
pos_data_linpos = pos_data_linpos.set_index('timestamps')

In [8]:
# Convert real pos to realtime system linear map (single linear coordinate)

center_pos_flat = pos_data_linpos[pos_data_linpos['seg_idx'] == 1]['well_center']
left_pos_flat = pos_data_linpos[(pos_data_linpos['seg_idx'] == 2) | 
                                (pos_data_linpos['seg_idx'] == 3)]['well_left'] + 150
right_pos_flat = pos_data_linpos[(pos_data_linpos['seg_idx'] == 4) | 
                                 (pos_data_linpos['seg_idx'] == 5)]['well_right'] + 300

center_pos_flat.name = 'linpos_flat'
left_pos_flat.name = 'linpos_flat'
right_pos_flat.name = 'linpos_flat'

linpos_flat = pd.concat([center_pos_flat, left_pos_flat, right_pos_flat])
linpos_flat = linpos_flat.sort_index()

In [16]:
spike_decode['timestamp']/30000

In [22]:

plt_ranges = [[5200, 5400]]
             
for plt_range in plt_ranges:
    stim_lockout_ranges_sec = stim_lockout_ranges/30000
    stim_lockout_range_sec_sub = stim_lockout_ranges_sec[(stim_lockout_ranges_sec[1] > plt_range[0]) & (stim_lockout_ranges_sec[0] < plt_range[1])]

    plt.figure(figsize=[400,15])


    plt.imshow(dec_est[(dec_bin_times > plt_range[0]*30000) & (dec_bin_times < plt_range[1]*30000)].transpose(), 
               extent=[plt_range[0], plt_range[1], 0, 450], origin='lower', aspect='auto', cmap='hot', zorder=0)

    linpos_index_s = linpos_flat.index / 30000
    index_mask = (linpos_index_s > plt_range[0]) & (linpos_index_s < plt_range[1])

    plt.plot(linpos_index_s[index_mask],
             linpos_flat.values[index_mask], 'c.', zorder=1, markersize=5)

    # plt.plot(stim_lockout_range_sec_sub.values.transpose(), np.tile([[440], [440]], [1, len(stim_lockout_range_sec_sub)]), 'c-*' )

    for stim_lockout in stim_lockout_range_sec_sub.values:
        plt.axvspan(stim_lockout[0], stim_lockout[1], facecolor='#AAAAAA', alpha=0.3)

    plt.plot(plt_range, [74, 74], '--', color='gray')
    plt.plot(plt_range, [148, 148], '--', color='gray')
    plt.plot(plt_range, [256, 256], '--', color='gray')
    plt.plot(plt_range, [298, 298], '--', color='gray')
    plt.plot(plt_range, [407, 407], '--', color='gray')

plt.colorbar()
plt.show()