In [1]:
cd /home/daliu/Src/spykshrk_realtime/

In [94]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
import json
import os
import scipy.signal
import dask
import dask.dataframe as dd

import holoviews as hv
from holoviews.operation.datashader import aggregate, shade, datashade, dynspread
from holoviews.operation import decimate

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.pp_decoder.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                                         LinearPosition, StimLockout, Posteriors, \
                                                         FlatLinearPosition, pos_col_format

from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer, DecodeErrorVisualizer
from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError

        
%load_ext Cython

%matplotlib inline

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 80)

idx = pd.IndexSlice

#matplotlib.rcParams.update({'font.size': 18})

hv.extension('matplotlib')
hv.extension('bokeh')
hv.Store.renderers['bokeh'].webgl = False

In [3]:
from holoviews import Store
from bokeh.models.arrow_heads import TeeHead
Store.add_style_opts(hv.ErrorBars, ['upper_head', 'lower_head'], backend='bokeh')

In [4]:
try:
    cluster.close()
    client.close()
except:
    print("No cluster or client")
    
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=20, threads_per_worker=2,)
client = Client(cluster)

In [5]:
%%time
# Load merged rec HDF store based on config

#config_file = '/opt/data36/daliu/realtime/spykshrk/ripple_dec/bond.config.json'
#config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv/bond.config.json'
config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_single.json'
config = json.load(open(config_file, 'r'))

day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]
time_bin_size = config['pp_decoder']['bin_size']

# Main hdf5 data source file name
hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Open data file
store = pd.HDFStore(hdf_file, mode='r')

# Encapsulate Spike Observation panda table in container
observ_obj = SpikeObservation.from_realtime(store['rec_3'], day=day, epoch=epoch, enc_settings=encode_settings)

# Grab stimulation lockout times
stim_lockout = StimLockout.from_realtime(store['rec_11'], enc_settings=encode_settings)

# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

In [6]:
observ_obj

In [7]:
#%pdb

In [8]:
%%prun -r -s cumulative

# Run PP decoding algorithm
time_bin_size = 30

decoder = OfflinePPDecoder(lin_obj=lin_obj, observ_obj=observ_obj,
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           which_trans_mat='learned', time_bin_size=time_bin_size)

posteriors = decoder.run_decoder()

In [9]:
#%%prun -r -s cumulative
like = decoder._calc_observation_single_bin(observ_obj.groupby('parallel_bin').get_group(0),
                                     observ_obj['elec_grp_id'].unique(), decoder.prob_no_spike, 30, encode_settings)



In [10]:
## Plot posteriors
#plt_ranges = [[2461, 2641]]
#plt_ranges = [[2461, 3405]]
#plt_ranges = [[2930, 3100]]
#plt_ranges = [[3295, 3325]]
plt_ranges = [None]
#plt_ranges = [[0,600]]
dec_viz = DecodeVisualizer(posteriors.get_relative_index(), linpos=lin_obj.get_relative_index(), 
                           enc_settings=encode_settings)
for plt_range in plt_ranges:
    
    fig, ax = plt.subplots(figsize=[200,10])
    dec_viz.plot_decode_image(plt_range=plt_range, x_tick=10)
    dec_viz.plot_linear_pos(plt_range=plt_range)
    #DecodeVisualizer.plot_stim_lockout(ax, stim_lockout, plt_range, encode_settings.arm_coordinates[2][1] + 10)
    
    #plt.xlim(plt_range)
    
plt.show()

In [11]:
## Plot posteriors
#plt_ranges = [[3260, 3280]]
#plt_ranges = [[3175, 3200]]
#plt_ranges = [[3102, 3106]]
plt_ranges = [[2943, 2950]]

dec_viz = DecodeVisualizer(posteriors, linpos=lin_obj, 
                           enc_settings=encode_settings)

for plt_range in plt_ranges:
    
    fig, ax = plt.subplots(figsize=[50,10])
    dec_viz.plot_decode_image(plt_range, x_tick=1)
    dec_viz.plot_linear_pos(plt_range)
    dec_viz.plot_stim_lockout(ax, stim_lockout, plt_range, encode_settings.arm_coordinates[2][1] + 10)
    
    plt.xlim(plt_range)
    
plt.show()

In [12]:
lin_flat = lin_obj.get_mapped_single_axis()
lin_flat

In [218]:
%%opts Image {+framewise} [height=100 width=250 aspect=3 colorbar=True] (cmap='hot')
%%opts Points {+framewise} (marker='o' size=2 alpha=0.5)
%%opts RGB [height=100 width=250]
%%output backend='matplotlib' size=400 holomap='scrubber'

def plot_post(time, plt_range=10, lookahead=5, lookbehind=5):
    
    behind_time = max(time-lookbehind, posteriors.index.get_level_values('time')[0])
    img_sel = posteriors.get_distribution_view().query('time > {} and time < {}'.
                                                       format(behind_time, time+plt_range+lookahead)).values.T
    img_sel = np.flip(img_sel, axis=0)
    linpos_sel = lin_flat.query('time > {} and time < {}'.
                                format(behind_time, time+plt_range+lookahead))['linpos_flat'].values
    linpos_sel_time = (lin_flat.query('time > {} and time < {}'.
                                     format(behind_time, time+plt_range+lookahead))['linpos_flat'].
                       index.get_level_values('time'))

    img = hv.Image(img_sel, bounds=(behind_time, 0, time+plt_range+lookahead, 450),
                   kdims=['sec', 'linpos'], vdims=['probability'], extents=(time, None, time + plt_range, None))
    img = img.redim(probability={'range':(0,0.5)})

    pos = hv.Points((linpos_sel_time, linpos_sel), kdims=['sec', 'linpos'], extents=(time, None, time + plt_range, None))
    #img.opts(plot={'apply_range': (time, None, time + plt_range, None)})
    #pos.opts(plot={'apply_range': (time, None, time + plt_range, None)})

    over = img * pos
    #over_plt = renderer.get_plot(over)
    return over

dmap = hv.DynamicMap(plot_post, kdims=[hv.Dimension('time', values=np.arange(posteriors.index.get_level_values('time')[0],
                                                                      posteriors.index.get_level_values('time')[-1],
                                                                      10))])

dmap

In [216]:
dec_est_pos = posteriors.get_distribution_view().idxmax(axis=1).apply(lambda x: int(x[1:])).to_frame()
dec_est_pos.columns = ['est_pos']

dec_est_pos = FlatLinearPosition.create_default(dec_est_pos, encode_settings.sampling_rate,
                                                encode_settings.arm_coordinates, parent=posteriors)

resamp_lin_obj = lin_obj.get_resampled(time_bin_size).get_pd_no_multiindex()

error_obj = LinearDecodeError()

error_table = error_obj.calc_error_table(resamp_lin_obj, dec_est_pos,
                                         encode_settings.arm_coordinates, 2)

print(error_table.loc[:, idx[:, 'abs_error']].median())0
print(error_table.loc[:, idx[:, 'abs_error']].mean())

In [217]:
%%opts Points {+framewise} [height=500 width=1000] (color=Cycle(values=['#FF0099', '#99FF00', '#5555FF']))
%%opts ErrorBars {+framewise} [height=500 width=1000] (line_color=Cycle(values=['#FF0099', '#99FF00', '#5555FF']) alpha=0.5 line_width=1 upper_head=TeeHead(size=0) lower_head=TeeHead(size=0))
%%output holomap='scrubber'

err_viz = DecodeErrorVisualizer(error_table)

dmap = err_viz.plot_arms_error_dmap(slide_interval=10, plot_interval=20)

dmap

In [None]:
%%opts Points {+framewise} [height=500 width=1000] (color=Cycle(values=['#FF0099', '#99FF00', '#5555FF']))
%%opts ErrorBars {+framewise} (line_color=Cycle(values=['#FF0099', '#99FF00', '#5555FF']) alpha=0.5 line_width=1 upper_head=TeeHead(size=0) lower_head=TeeHead(size=0))
%%output holomap='scrubber'

err_viz.plot_arms_error

In [133]:
tab = hv.Table(pd.DataFrame([[1,2,3]], columns=['a','b','c']))

tab['a']+tab['b']