In [1]:
cd /home/daliu/Src/spykshrk_realtime/

In [54]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
import json
import os
import scipy.signal
import dask
import dask.dataframe as dd

import holoviews as hv
from holoviews.operation.datashader import aggregate, shade, datashade, dynspread, regrid
from holoviews.operation import decimate

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.pp_decoder.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                                         LinearPosition, StimLockout, Posteriors, \
                                                         FlatLinearPosition, pos_col_format

from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer, DecodeErrorVisualizer
from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError

        
%load_ext Cython

%matplotlib inline

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 80)

idx = pd.IndexSlice

#matplotlib.rcParams.update({'font.size': 18})

hv.extension('matplotlib')
hv.extension('bokeh')
hv.Store.renderers['bokeh'].webgl = False

In [3]:
from holoviews import Store
from bokeh.models.arrow_heads import TeeHead
Store.add_style_opts(hv.ErrorBars, ['upper_head', 'lower_head'], backend='bokeh')
Store.add_style_opts(hv.ErrorBars, ['ecolor'], backend='matplotlib')


In [4]:
try:
    cluster.close()
    client.close()
except:
    print("No cluster or client")
    
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=20, threads_per_worker=2,)
client = Client(cluster)

In [5]:
%%time
# Load merged rec HDF store based on config

#config_file = '/opt/data36/daliu/realtime/spykshrk/ripple_dec/bond.config.json'
#config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv/bond.config.json'
config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_single.json'
config = json.load(open(config_file, 'r'))

day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]
time_bin_size = config['pp_decoder']['bin_size']

# Main hdf5 data source file name
hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Open data file
store = pd.HDFStore(hdf_file, mode='r')

# Encapsulate Spike Observation panda table in container
observ_obj = SpikeObservation.from_realtime(store['rec_3'], day=day, epoch=epoch, enc_settings=encode_settings)

# Grab stimulation lockout times
stim_lockout = StimLockout.from_realtime(store['rec_11'], enc_settings=encode_settings)

# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

In [6]:
observ_obj

In [7]:
#%pdb

In [8]:
%%prun -r -s cumulative

# Run PP decoding algorithm
time_bin_size = 30

decoder = OfflinePPDecoder(lin_obj=lin_obj, observ_obj=observ_obj,
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           which_trans_mat='learned', time_bin_size=time_bin_size)

posteriors = decoder.run_decoder()

In [52]:
%%output backend='bokeh' size=400 holomap='scrubber'
%%opts RGB { +framewise} [height=100 width=250 colorbar=True]
%%opts Image {+framewise} [height=100 width=250 aspect=3 colorbar=True] (cmap='hot')
#%%opts Points {+framewise} [height=100 width=250] (marker='o' size=5)

import functools



dec_viz = DecodeVisualizer(posteriors, linpos=lin_obj, 
                      
     enc_settings=encode_settings)

dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY(), plt_range=10, slide=10)



In [55]:
img = posteriors.get_distribution_view().iloc[0:1000].values.T
Y, X = (np.mgrid[0:100, 0:100]-50.)/20.

img = hv.Image(np.sin(X**2+Y**2))

def func(x_range, y_range):
    hv_img1 = hv.Image(img, extents=(-0.3,-0.3,0.3,0.3))
    re1 = shade(regrid(hv_img1, dynamic=False, x_range=x_range, y_range=y_range), dynamic=False)
    re1.extents=(-0.3,-0.3,0.3,0.3)
    return re1

hv_img2 = hv.Image(img, bounds=(-0.5,-0.5,0.5,0.5), extents=(-0.2, -0.2, 0.2, 0.2))
hv_img2.redim(x={'range': (-0.2,0.2)})
#rangexy = hv.streams.RangeXY(source=hv_img1)
re2 = regrid(hv_img2)

hv.DynamicMap(func, streams=[hv.streams.RangeXY()]) + re2

In [None]:
dec_est_pos = posteriors.get_distribution_view().idxmax(axis=1).apply(lambda x: int(x[1:])).to_frame()
dec_est_pos.columns = ['est_pos']

dec_est_pos = FlatLinearPosition.create_default(dec_est_pos, encode_settings.sampling_rate,
                                                encode_settings.arm_coordinates, parent=posteriors)

resamp_lin_obj = lin_obj.get_resampled(time_bin_size).get_pd_no_multiindex()

error_obj = LinearDecodeError()

error_table = error_obj.calc_error_table(resamp_lin_obj, dec_est_pos,
                                         encode_settings.arm_coordinates, 2)

print(error_table.loc[:, idx[:, 'abs_error']].median())
print(error_table.loc[:, idx[:, 'abs_error']].mean())

In [None]:
%%output backend='bokeh' size=400 holomap='scrubber'
%%opts Points {+framewise} [height=100 width=250 aspect=2] (color=Cycle(values=['#FF0099', '#99FF00', '#5555FF']))
%%opts ErrorBars {+framewise} [height=100 width=250 aspect=2 ] (ecolor=Cycle(values=['#FF0099', '#99FF00', '#5555FF']) alpha=0.5 line_width=1 upper_head=TeeHead(size=0) lower_head=TeeHead(size=0))

err_viz = DecodeErrorVisualizer(error_table)

dmap = err_viz.plot_arms_error_dmap(slide_interval=10, plot_interval=10)

dmap