In [1]:
# This snippet of code properly adds the working source root path to python's path
# so you no longer have to install spykshrk through setuptools
import sys, os
root_depth = 2
notebook_dir = globals()['_dh'][0]
root_path = os.path.abspath(os.path.join(notebook_dir, '../'*root_depth))
# Add to python's path
try:
    while True:
        sys.path.remove(root_path)
except ValueError:
    # no more root paths
    pass
sys.path.append(root_path)
# Alternatively set root path as current working directory
#os.chdir(root_path)

In [2]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib
import matplotlib.pyplot as plt
import json
import os
import scipy.signal
import holoviews as hv

import warnings

from spykshrk.realtime.decoder_process import PointProcessDecoder

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                              LinearPosition, StimLockout, Posteriors, FlatLinearPosition
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer, DecodeErrorVisualizer

from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError
    
from spykshrk.franklab.franklab_data import FrankAnimalInfo, FrankFilenameParser, FrankDataInfo

    
#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 180)

 
idx = pd.IndexSlice
matplotlib.rcParams.update({'font.size': 28})

hv.extension('matplotlib')
hv.extension('bokeh')

In [3]:
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv_300samp/bond.config.json'
# config_file = '/opt/data36/daliu/realtime/spykshrk/dec_bond04_run1/bond.config.json'
config = json.load(open(config_file, 'r'))

day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]
time_bin_size = config['pp_decoder']['bin_size']

# Main hdf5 data source file name
hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Open data file
store = pd.HDFStore(hdf_file, mode='r')

# Encapsulate Spike Observation panda table in container
observ_obj = SpikeObservation.from_realtime(store['rec_3'], day=day, epoch=epoch, enc_settings=encode_settings)

realtime_posteriors = Posteriors.from_realtime(store['rec_4'], day=day, epoch=epoch, 
                                               enc_settings=encode_settings)

# Grab stimulation lockout times
stim_lockout = StimLockout.from_realtime(store['rec_11'], enc_settings=encode_settings)


# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

ripcons = nspike_data.RipplesConsData(nspike_anim)
ripdata = ripcons.data_obj

In [4]:
%%output backend='bokeh' size=400 holomap='scrubber'
%%opts RGB { +framewise} [height=100 width=250 colorbar=True]
%%opts Points {+framewise} [height=100 width=250] (marker='o' size=4 alpha=0.5)


## Plot posteriors

plt_ranges = [[2461 + 250, 2461 + 400]]

realtime_dec_viz = DecodeVisualizer(realtime_posteriors, linpos=lin_obj, 
                           enc_settings=encode_settings)

plt2 = realtime_dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY(), plt_range=10, slide=10)

plt2

In [5]:
dec_est_pos = realtime_posteriors.get_distribution_view().idxmax(axis=1).apply(lambda x: int(x[1:])).to_frame()
dec_est_pos.columns = ['est_pos']

resamp_lin_obj = lin_obj.get_resampled(time_bin_size).get_pd_no_multiindex()

dec_error = LinearDecodeError()

dec_error = dec_error.calc_error_table(resamp_lin_obj, dec_est_pos,
                                       encode_settings.arm_coordinates, 2)

print("Median:")
print(dec_error.loc[:, idx[:, 'abs_error']].median())
print("Mean:")
print(dec_error.loc[:, idx[:, 'abs_error']].mean())

In [6]:
dec_error

In [7]:
%%opts ErrorBars {+framewise} [height=500 width=1000] (line_color=Cycle(values=['#FF0099', '#99FF00', '#5555FF']) line_width=1 upper_head=TeeHead(size=0) lower_head=TeeHead(size=0))
%%opts Points {+framewise} [height=500 width=1000] (color=Cycle(values=['#FF0099', '#99FF00', '#5555FF']))
%%output holomap='scrubber'

#warnings.filterwarnings(action='')

dec_viz = DecodeErrorVisualizer(dec_error)

dmap = dec_viz.plot_arms_error(2774, 10)

dmap

In [6]:
%%output backend='matplotlib' size=200
%%opts Image {+axiswise} [height=300 width=300 aspect=1]
%%opts RGB {+axiswise}
%%opts Curve {+axiswise} [aspect=1] (line_dash='dashed' color='#AAAAAA' linestyle='--' alpha=0.5)
%%opts Points {+axiswise} [aspect=1] (marker='*' size=14)


dec_viz = DecodeVisualizer(realtime_posteriors, linpos=lin_obj, riptimes=ripdata.get_above_maxthresh(5), enc_settings=encode_settings)

rip_plots = dec_viz.plot_ripple_grid(2)
for plt_grp in rip_plots:
    display(plt_grp)

In [6]:
#offline_posterior = Posteriors._from_hdf_store('/opt/data36/daliu/pyBond/analysis/bond_decode.h5','/analysis',
#                                               'decode/clusterless/offline/posterior', 'learned_trans_mat')

In [7]:
%pdb on

In [7]:
anim = FrankAnimalInfo('/opt/data36/daliu/', 'pyBond')
decode_info = FrankDataInfo(anim, 'decode')
display(decode_info.entries)
offline_posterior = decode_info.load_single_dataset_ind(0)

In [8]:
%%opts Image {+axiswise} [height=300 width=300 aspect=1]
%%opts Curve.arm_bound {+axiswise} [aspect=1] (line_dash='dashed' color='#AAAAAA' linestyle='--' alpha=0.5)
%%opts Points {+axiswise} [aspect=1] (marker='*' size=18)
%%opts NdLayout {+axiswise}
%%output backend='matplotlib' size=200

off_dec_viz = DecodeVisualizer(offline_posterior, linpos=lin_obj, riptimes=ripdata.get_above_maxthresh(5), enc_settings=encode_settings)

online_rip_plots = dec_viz.plot_ripple_grid(1,1)
offline_rip_plots = off_dec_viz.plot_ripple_grid(1,1)

for ii, subplot in enumerate(online_rip_plots):
    display((subplot + offline_rip_plots[ii]).cols(2))
    


In [11]:
%%opts Image {+axiswise} [height=300 width=300 aspect=1]
%%opts Curve.arm_bound {+axiswise} [aspect=1] (line_dash='dashed' color='#AAAAAA' linestyle='--' alpha=0.5)
%%opts Points {+axiswise} [aspect=1] (marker='*' size=14)
%%opts NdLayout {+axiswise}
%%output backend='matplotlib' size=200

dec_viz.plot_ripple_all(8)

In [12]:
rip1 = offline_posterior.query('ripple_grp==201').get_distribution_view()
rip2 = realtime_posteriors.query('ripple_grp==201').get_distribution_view()

print(rip1.index.get_level_values('timestamp')[-1])
print(rip2.index.get_level_values('timestamp')[-1])