In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os

import spykshrk.realtime.simulator.nspike_data as nspike_data


#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)


idx = pd.IndexSlice

In [2]:
# Load merged rec HDF store based on config

#config_file = '/opt/data36/daliu/realtime/spykshrk/test/test_animal.config.json'
config_file = '/opt/data36/daliu/realtime/spykshrk/ripple_dec/bond.config.json'
config = json.load(open(config_file, 'r'))

hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

store = pd.HDFStore(hdf_file, mode='r')

pos_num_bins = config['encoder']['position']['bins']

decoder_df = store['rec_4']

In [10]:
np.mean(decoder_df['real_pos_time'] - decoder_df['timestamp'])

In [4]:
decoder_df

In [11]:
# Get real position

nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Transform position into simpler table with only linear position
pos_data_time = pos_data.loc[:, 'time']

pos_data_linpos = pos_data.loc[:,'lin_dist_well']
pos_data_linpos.loc[:, 'lin_vel_center'] = pos_data.loc[:,('lin_vel', 'well_center')]
pos_data_linpos.loc[:, 'seg_idx'] = pos_data.loc[:,('seg_idx', 0)]
pos_data_linpos.loc[:,'timestamps'] = pos_data_time*30000
pos_data_linpos = pos_data_linpos.set_index('timestamps')

# Convert real pos to realtime system linear map (single linear coordinate)

center_pos_flat = pos_data_linpos[pos_data_linpos['seg_idx'] == 1]['well_center']
left_pos_flat = pos_data_linpos[(pos_data_linpos['seg_idx'] == 2) | 
                                (pos_data_linpos['seg_idx'] == 3)]['well_left'] + 150
right_pos_flat = pos_data_linpos[(pos_data_linpos['seg_idx'] == 4) | 
                                 (pos_data_linpos['seg_idx'] == 5)]['well_right'] + 300

center_pos_flat.name = 'linpos_flat'
left_pos_flat.name = 'linpos_flat'
right_pos_flat.name = 'linpos_flat'

linpos_flat = pd.concat([center_pos_flat, left_pos_flat, right_pos_flat])
linpos_flat = linpos_flat.sort_index()


In [12]:
stim_lockout = store['rec_11']
stim_lockout_ranges = stim_lockout.pivot(index='lockout_num',columns='lockout_state', values='timestamp')
stim_lockout_ranges = stim_lockout_ranges.reindex(columns=[1,0])

In [13]:
def plot_decode_2d(dec_est, dec_bin_times, stim_lockout_ranges, linpos_flat, plt_range):
    stim_lockout_ranges_sec = stim_lockout_ranges/30000
    stim_lockout_range_sec_sub = stim_lockout_ranges_sec[(stim_lockout_ranges_sec[1] > plt_range[0]) &
                                                         (stim_lockout_ranges_sec[0] < plt_range[1])]
    
    plt.imshow(dec_est[(dec_bin_times > plt_range[0]*30000) & (dec_bin_times < plt_range[1]*30000)].transpose(), 
               extent=[plt_range[0], plt_range[1], 0, 450], origin='lower', aspect='auto', cmap='hot', zorder=0)

    plt.colorbar()

    # Plot linear position
    linpos_index_s = linpos_flat.index / 30000
    index_mask = (linpos_index_s > plt_range[0]) & (linpos_index_s < plt_range[1])

    plt.plot(linpos_index_s[index_mask],
             linpos_flat.values[index_mask], 'c.', zorder=1, markersize=5)

    
    plt.plot(stim_lockout_range_sec_sub.values.transpose(), np.tile([[440], [440]], [1, len(stim_lockout_range_sec_sub)]), 'c-*' )

    for stim_lockout in stim_lockout_range_sec_sub.values:
        plt.axvspan(stim_lockout[0], stim_lockout[1], facecolor='#AAAAAA', alpha=0.3)

    plt.plot(plt_range, [74, 74], '--', color='gray')
    plt.plot(plt_range, [148, 148], '--', color='gray')
    plt.plot(plt_range, [256, 256], '--', color='gray')
    plt.plot(plt_range, [298, 298], '--', color='gray')
    plt.plot(plt_range, [407, 407], '--', color='gray')


In [16]:
linpos_flat

In [23]:
plt_ranges = [[2461, 3404]]
for plt_range in plt_ranges:
    plt.figure(figsize=(400,10))
    plot_decode_2d(decoder_df.loc[:,'x0':'x449'].values,
                   decoder_df['timestamp'], stim_lockout_ranges, 
                   decoder_df.set_index('real_pos_time')['real_pos'], plt_range)
plt.show()