In [4]:
# This snippet of code properly adds the working source root path to python's path
# so you no longer have to install spykshrk through setuptools
import sys, os
root_depth = 3
notebook_dir = globals()['_dh'][0]
root_path = os.path.abspath(os.path.join(notebook_dir, '../'*root_depth))
# Add to python's path
try:
    while True:
        sys.path.remove(root_path)
except ValueError:
    # no more root paths
    pass
sys.path.append(root_path)
# Alternatively set root path as current working directory
#os.chdir(root_path)

In [5]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
import holoviews as hv

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                              LinearPosition, StimLockout, Posteriors
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer

%load_ext Cython

%matplotlib inline

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 180)

hv.extension('bokeh')

idx = pd.IndexSlice


In [6]:
# Load config file and data

config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_single.json'
config = json.load(open(config_file, 'r'))
day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]

# Main hdf5 data source file name
hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Open data file
store = pd.HDFStore(hdf_file, mode='r')

# Encapsulate Spike Observation panda table in container
observ_obj = SpikeObservation.from_realtime(store['rec_3'], day=day, epoch=epoch, enc_settings=encode_settings)

# Grab stimulation lockout times
stim_lockout = StimLockout.from_realtime(store['rec_11'], enc_settings=encode_settings)

# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

In [7]:
lin_obj._metadata

['kwds', 'history', 'desc', 'user_key', 'sampling_rate', 'arm_coord']

In [8]:
# Linearized position data, example of MultiIndexing pandas table
lin_obj

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,lin_dist_well,lin_dist_well,lin_dist_well,lin_vel,lin_vel,lin_vel,seg_idx
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,well_center,well_left,well_right,well_center,well_left,well_right,seg_idx
day,epoch,timestamp,time,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
4,1,73830339,2461.0,27.8,142.1,144.5,7.5,134.4,136.4,1.0
4,1,73831341,2461.0,26.9,143.0,145.4,7.5,134.4,136.4,1.0
4,1,73832343,2461.1,25.5,144.3,146.8,6.7,134.5,136.6,1.0
4,1,73833342,2461.1,24.6,145.2,147.7,5.8,134.1,136.1,1.0
4,1,73834344,2461.1,23.3,146.6,149.0,4.9,133.2,135.1,1.0
4,1,...,...,...,...,...,...,...,...,...
4,1,102145374,3404.8,7.0,162.8,165.3,-4.0,-128.6,-130.5,1.0
4,1,102146376,3404.9,7.1,162.8,165.2,-4.2,-131.4,-133.4,1.0
4,1,102147378,3404.9,7.5,162.3,164.8,-4.4,-133.7,-135.7,1.0
4,1,102148377,3404.9,7.5,162.3,164.8,-4.5,-135.4,-137.5,1.0


In [9]:
# Up sampling position data to 30 samples/bin, using backfill to interpolate
lin_obj.get_resampled(30)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,lin_dist_well,lin_dist_well,lin_dist_well,lin_vel,lin_vel,lin_vel,seg_idx,bin
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,well_center,well_left,well_right,well_center,well_left,well_right,seg_idx,Unnamed: 11_level_1
day,epoch,timestamp,time,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2
4,1,73830360,2461.0,,,,,,,,0
4,1,73830390,2461.0,,,,,,,,1
4,1,73830420,2461.0,,,,,,,,2
4,1,73830450,2461.0,,,,,,,,3
4,1,73830480,2461.0,,,,,,,,4
4,1,...,...,...,...,...,...,...,...,...,...
4,1,102149250,3405.0,45.5,124.4,126.9,11.3,-11.3,-11.3,1.0,943963
4,1,102149280,3405.0,101.7,68.1,137.6,0.1,-0.1,0.1,2.0,943964
4,1,102149310,3405.0,84.3,85.6,120.1,21.4,-23.5,20.8,2.0,943965
4,1,102149340,3405.0,170.1,203.5,2.2,-0.1,-0.1,0.1,5.0,943966


In [10]:
# Down sampling position data to 30000 samples/bin, dropping data points
lin_obj.get_resampled(30000)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,lin_dist_well,lin_dist_well,lin_dist_well,lin_vel,lin_vel,lin_vel,seg_idx,bin
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,well_center,well_left,well_right,well_center,well_left,well_right,seg_idx,Unnamed: 11_level_1
day,epoch,timestamp,time,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2
4,1,73860000,2462.0,,,,,,,,0
4,1,73890000,2463.0,,,,,,,,1
4,1,73920000,2464.0,,,,,,,,2
4,1,73950000,2465.0,,,,,,,,3
4,1,73980000,2466.0,,,,,,,,4
4,1,...,...,...,...,...,...,...,...,...,...
4,1,102000000,3400.0,,,,,,,,938
4,1,102030000,3401.0,,,,,,,,939
4,1,102060000,3402.0,,,,,,,,940
4,1,102090000,3403.0,,,,,,,,941


In [11]:
# Observation distribution of each spike in a single epoch. This is calculated and cached from 
# an encoding model in the realtime module. Currently this is only valid for a single epoch's data.

observ_obj

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,rec_ind,timestamp,elec_grp_id,position,x000,x001,x002,...,x444,x445,x446,x447,x448,x449,time
day,epoch,timestamp,time,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
4,1,73830048,2461.0,1,73830048,29,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,2461.0
4,1,73830066,2461.0,1,73830066,13,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,2461.0
4,1,73830144,2461.0,2,73830144,14,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,2461.0
4,1,73830192,2461.0,6,73830192,14,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,2461.0
4,1,73830204,2461.0,5,73830204,13,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,2461.0
4,1,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4,1,102149649,3405.0,237333,102149649,11,7.6,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,3405.0
4,1,102149697,3405.0,55281,102149697,12,7.6,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,3405.0
4,1,102149817,3405.0,96729,102149817,17,7.6,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,3405.0
4,1,102149925,3405.0,237337,102149925,11,7.6,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,3405.0


In [12]:
# Assign bins (300 samples == 10ms) to each spike based on its timestamp, bins stored as dec_bin column.

observ_obj.update_observations_bins(300)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,rec_ind,timestamp,elec_grp_id,position,x000,x001,x002,...,x447,x448,x449,time,dec_bin,dec_bin_start,num_missing_bins
day,epoch,timestamp,time,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
4,1,73830048,2461.0,1,73830048,29,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,2461.0,0,73830000,0
4,1,73830066,2461.0,1,73830066,13,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,2461.0,0,73830000,0
4,1,73830144,2461.0,2,73830144,14,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,2461.0,0,73830000,0
4,1,73830192,2461.0,6,73830192,14,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,2461.0,0,73830000,0
4,1,73830204,2461.0,5,73830204,13,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,2461.0,0,73830000,0
4,1,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4,1,102149649,3405.0,237333,102149649,11,7.6,0.0,0.0,0.0,...,0.0,0.0,0.0,3405.0,94398,102149400,0
4,1,102149697,3405.0,55281,102149697,12,7.6,0.0,0.0,0.0,...,0.0,0.0,0.0,3405.0,94398,102149400,0
4,1,102149817,3405.0,96729,102149817,17,7.6,0.0,0.0,0.0,...,0.0,0.0,0.0,3405.0,94399,102149700,0
4,1,102149925,3405.0,237337,102149925,11,7.6,0.0,0.0,0.0,...,0.0,0.0,0.0,3405.0,94399,102149700,0


In [14]:
# For each time bin, compute the product of the distribution stored in columns x0:x449.
# This estimates the probability distribution of position at each time bin. Refer to 
# spykshrk.franklab.pp_decoder.pp_clusterless.OfflinePPDecoder.calc_observation_intensity
# for analysis code that uses groupby.

spike_decode = observ_obj.update_observations_bins(3000)
groups = spike_decode.groupby('dec_bin')

def prod_dist(df):
    norm_prod = np.ones(450)
    for ind, row in zip(df.index, df.loc[:,'x000':'x449'].values):
        norm_prod = norm_prod * row
        norm_prod = norm_prod / norm_prod.sum()
    prod_ser = pd.Series(norm_prod, index=['x{:03d}'.format(bin_id) for bin_id in range(450)])
    prod_ser['day'] = ind[0]
    prod_ser['epoch'] = ind[1]
    prod_ser['timestamp'] = df['dec_bin_start'].iloc[0]
    prod_ser['time'] = df['dec_bin_start'].iloc[0] / 30000

    return prod_ser

observ_binned = groups.apply(prod_dist)

observ_binned.set_index(['day', 'epoch', 'timestamp', 'time'], inplace=True)

post = Posteriors.from_dataframe(observ_binned, enc_settings=encode_settings, dec_settings=decode_settings, history=spike_decode.history)

In [15]:
lin_obj.get_mapped_single_axis()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,linpos_flat,linvel_flat,seg_idx
day,epoch,timestamp,time,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
4,1,73830339,2461.0,27.8,7.5,1.0
4,1,73831341,2461.0,26.9,7.5,1.0
4,1,73832343,2461.1,25.5,6.7,1.0
4,1,73833342,2461.1,24.6,5.8,1.0
4,1,73834344,2461.1,23.3,4.9,1.0
4,1,...,...,...,...,...
4,1,102145374,3404.8,7.0,-4.0,1.0
4,1,102146376,3404.9,7.1,-4.2,1.0
4,1,102147378,3404.9,7.5,-4.4,1.0
4,1,102148377,3404.9,7.5,-4.5,1.0


In [16]:
# Convert linearized position segments onto a single axis to match the decoded position mapping.
# This function uses the query command of Panda dataframes.
# e.g.:
# right_pos_flat = (self.pos_data.query('@self.pos_data.seg_idx.seg_idx == 4 | '
#                                       '@self.pos_data.seg_idx.seg_idx == 5').
#                   loc[:, ('lin_dist_well', 'well_right')]) + self.arm_coord[2][0]

single_axis_lin_pos = lin_obj.get_resampled(3000).get_mapped_single_axis()

In [17]:
# Display the copy history of this LinearPosition object, starting from the first instantiation

single_axis_lin_pos.history

[                            lin_dist_well                          lin_vel                      seg_idx
                               well_center well_left well_right well_center well_left well_right seg_idx
 day epoch timestamp time                                                                               
 4   1     73830339  2,461.0          27.8     142.1      144.5         7.5     134.4      136.4     1.0
           73831341  2,461.0          26.9     143.0      145.4         7.5     134.4      136.4     1.0
           73832343  2,461.1          25.5     144.3      146.8         6.7     134.5      136.6     1.0
           73833342  2,461.1          24.6     145.2      147.7         5.8     134.1      136.1     1.0
           73834344  2,461.1          23.3     146.6      149.0         4.9     133.2      135.1     1.0
 ...                                   ...       ...        ...         ...       ...        ...     ...
           102145374 3,404.8           7.0     162.8   

In [18]:
# Get stim lockout ranges from digital output state

stim_lockout

Unnamed: 0_level_0,timestamp,timestamp,time,time
Unnamed: 0_level_1,on,off,on,off
lockout_num,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
0,73986480.0,73993983.0,2466.2,2466.5
1,74216224.0,74223739.0,2473.9,2474.1
2,74348644.0,74356160.0,2478.3,2478.5
3,74524703.0,74532221.0,2484.2,2484.4
4,74602943.0,74610459.0,2486.8,2487.0
...,...,...,...,...
186,101754259.0,101761761.0,3391.8,3392.1
187,101831984.0,101839499.0,3394.4,3394.6
188,101869202.0,101876720.0,3395.6,3395.9
189,101899082.0,101906584.0,3396.6,3396.9


In [19]:
%%output backend='bokeh' size=400 holomap='scrubber'
%%opts RGB { +framewise} [height=100 width=250 colorbar=True]
%%opts Points {+framewise} [height=100 width=250] (marker='o' size=4 alpha=0.5)

dec_viz = DecodeVisualizer(post, linpos=lin_obj, enc_settings=encode_settings)

dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY(), plt_range=10, slide=10)


Unexpected plot option 'colorbar' for RGB in loaded backend 'bokeh'.

Similar keywords in the currently active 'bokeh' renderer are: ['toolbar', 'bgcolor']

If you believe this keyword is correct, please make sure the backend has been imported or loaded with the hv.extension.

In [20]:
dec_viz

<spykshrk.franklab.pp_decoder.visualization.DecodeVisualizer at 0x7f9ad3a1a198>