In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.pp_decoder.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                                         LinearPositionContainer

%load_ext Cython

%matplotlib inline

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
#pd.set_option('display.width', 180)

idx = pd.IndexSlice


In [17]:
# Load config file and data

config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_single.json'
config = json.load(open(config_file, 'r'))

# Main hdf5 data source file name
hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Open data file
store = pd.HDFStore(hdf_file, mode='r')

# Encapsulate Spike Observation panda table in container
observ_obj = SpikeObservation(store['rec_3'])

# Grab stimulation lockout times
stim_lockout = store['rec_11']

# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPositionContainer(pos_data, encode_settings)

In [14]:
# Linearized position data, example of MultiIndexing pandas table
lin_obj.pos_data

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,time,time,lin_dist_well,lin_dist_well,lin_dist_well,seg_idx,lin_vel,lin_vel,lin_vel
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,time,timestamp,well_center,well_left,well_right,seg_idx,well_center,well_left,well_right
day,epoch,timestamp,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2
4,1,73830339.0,2461.0,73830339.0,27.8,142.1,144.5,1.0,7.5,134.4,136.4
4,1,73831341.0,2461.0,73831341.0,26.9,143.0,145.4,1.0,7.5,134.4,136.4
4,1,73832343.0,2461.1,73832343.0,25.5,144.3,146.8,1.0,6.7,134.5,136.6
4,1,73833342.0,2461.1,73833342.0,24.6,145.2,147.7,1.0,5.8,134.1,136.1
4,1,73834344.0,2461.1,73834344.0,23.3,146.6,149.0,1.0,4.9,133.2,135.1
4,1,...,...,...,...,...,...,...,...,...,...
4,1,102145374.0,3404.8,102145374.0,7.0,162.8,165.3,1.0,-4.0,-128.6,-130.5
4,1,102146376.0,3404.9,102146376.0,7.1,162.8,165.2,1.0,-4.2,-131.4,-133.4
4,1,102147378.0,3404.9,102147378.0,7.5,162.3,164.8,1.0,-4.4,-133.7,-135.7
4,1,102148377.0,3404.9,102148377.0,7.5,162.3,164.8,1.0,-4.5,-135.4,-137.5


In [4]:
# Up sampling position data to 30 samples/bin, using backfill to interpolate
lin_obj.get_resampled(30)

<spykshrk.franklab.pp_decoder.data_containers.LinearPositionContainer at 0x7f00217c4e10>

In [5]:
# Down sampling position data to 30000 samples/bin, dropping data points
lin_obj.get_resampled(30000)

<spykshrk.franklab.pp_decoder.data_containers.LinearPositionContainer at 0x7f0064809e80>

In [18]:
# Observation distribution of each spike in a single epoch. This is calculated and cached from 
# an encoding model in the realtime module. Currently this is only valid for a single epoch's data.

observ_obj.spike_dec

Unnamed: 0_level_0,Unnamed: 1_level_0,position,rec_ind,x000,x001,x002,x003,x004,x005,x006,x007,...,x440,x441,x442,x443,x444,x445,x446,x447,x448,x449
timestamp,ntrode_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
73830048,29,0.0,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
73830066,13,0.0,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
73830144,14,0.0,2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
73830192,14,0.0,6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
73830204,13,0.0,5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
102149649,11,7.6,237333,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
102149697,12,7.6,55281,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
102149817,17,7.6,96729,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
102149925,11,7.6,237337,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [19]:
# Assign bins (300 samples == 10ms) to each spike based on its timestamp, bins stored as dec_bin column.

observ_obj.get_observations_bin_assigned(300)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,position,rec_ind,x000,x001,x002,x003,x004,x005,x006,x007,...,x441,x442,x443,x444,x445,x446,x447,x448,x449,dec_bin_start
timestamp,ntrode_id,dec_bin,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
73830048,29,0,0.0,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,73830000
73830066,13,0,0.0,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,73830000
73830144,14,0,0.0,2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,73830000
73830192,14,0,0.0,6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,73830000
73830204,13,0,0.0,5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,73830000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
102149649,11,94398,7.6,237333,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,102149400
102149697,12,94398,7.6,55281,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,102149400
102149817,17,94399,7.6,96729,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,102149700
102149925,11,94399,7.6,237337,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,102149700


In [10]:
observ_obj.spike_dec.pivot_table(index=['dec_bin', 'dec_bin_start', 'timestamp', 'ntrode_id'])
                                 

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,position,x000,x001,x002,x003,x004,x005,x006,x007,x008,...,x440,x441,x442,x443,x444,x445,x446,x447,x448,x449
dec_bin,dec_bin_start,timestamp,ntrode_id,rec_ind,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
0,73830000,73830048,29,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0,73830000,73830066,13,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0,73830000,73830144,14,2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0,73830000,73830192,14,6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0,73830000,73830204,13,5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
94398,102149400,102149649,11,237333,7.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
94398,102149400,102149697,12,55281,7.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
94399,102149700,102149817,17,96729,7.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
94399,102149700,102149925,11,237337,7.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [11]:
observ_obj.spike_dec.loc[:, ['x000', 'x001']]

Unnamed: 0,x000,x001
0,0.0,0.0
1,0.0,0.0
2,0.0,0.0
3,0.0,0.0
4,0.0,0.0
...,...,...
303106,0.0,0.0
303107,0.0,0.0
303108,0.0,0.0
303109,0.0,0.0


In [None]:
# For each time bin, compute the product of the distribution stored in columns x0:x449.
# This estimates the probability distribution of position at each time bin. Refer to 
# spykshrk.franklab.pp_decoder.pp_clusterless.OfflinePPDecoder.calc_observation_intensity
# for analysis code that uses groupby.

spike_decode = observ_obj.get_observations_bin_assigned(3000)
groups = spike_decode.groupby('dec_bin')

def prod_dist(df):
    norm_prod = np.ones(450)
    for row in df.loc[:,'x0':'x449'].values:
        norm_prod = norm_prod * row
        norm_prod = norm_prod / norm_prod.sum()
    prod_ser = pd.Series(norm_prod, index=['x{}'.format(bin_id) for bin_id in range(450)])
    prod_ser['timestamp'] = df['dec_bin_start'].iloc[0]
    return prod_ser

observ_binned = groups.apply(prod_dist)

observ_binned

In [None]:
# Convert linearized position segments onto a single axis to match the decoded position mapping.
# This function uses the query command of Panda dataframes.
# e.g.:
# right_pos_flat = (self.pos_data.query('@self.pos_data.seg_idx.seg_idx == 4 | '
#                                       '@self.pos_data.seg_idx.seg_idx == 5').
#                   loc[:, ('lin_dist_well', 'well_right')]) + self.arm_coord[2][0]

single_axis_lin_pos = lin_obj.get_resampled(3000).get_mapped_single_axis()

In [None]:
# Convert stim lockout from digital output state into time intervals

stim_lockout_ranges = stim_lockout.pivot(index='lockout_num',columns='lockout_state', values='timestamp')
stim_lockout_ranges = stim_lockout_ranges.reindex(columns=[1,0])


In [None]:
%matplotlib inline
plt.figure(figsize=[20,10])
plt.imshow(observ_binned[3000:4500].T, origin='lower', aspect='auto', cmap='hot', zorder=0)
plt.plot(single_axis_lin_pos.values[3000:4500], '.')
plt.show()