In [1]:
# This snippet of code properly adds the working source root path to python's path
# so you no longer have to install spykshrk through setuptools
import sys, os
root_depth = 2
notebook_dir = globals()['_dh'][0]
root_path = os.path.abspath(os.path.join(notebook_dir, '../'*root_depth))
# Add to python's path
try:
    while True:
        sys.path.remove(root_path)
except ValueError:
    # no more root paths
    pass
sys.path.append(root_path)
# Alternatively set root path as current working directory
#os.chdir(root_path)

In [2]:
import pandas as pd
import numpy as np
import scipy as sp
import math

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
import json
import os
import scipy.signal
import functools
import holoviews as hv

from spykshrk.util import AttrDict
import spykshrk.franklab.filterframework_util as ff_util

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas, \
                                                normal_pdf_int_lookup
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder, OfflinePPEncoder
from spykshrk.franklab.data_containers import DataFrameClass, EncodeSettings, DecodeSettings, SpikeObservation, \
                                              LinearPosition, StimLockout, Posteriors, \
                                              FlatLinearPosition, SpikeWaves, SpikeFeatures, \
                                              pos_col_format, DayEpochTimeSeries

from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer
from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError

from spykshrk.franklab.franklab_data import FrankAnimalInfo, FrankFilenameParser, FrankDataInfo

import dask
import dask.dataframe as dd
import dask.array as da
import multiprocessing

import cloudpickle
        
%load_ext Cython

%matplotlib inline

hv.extension('matplotlib')
hv.extension('bokeh')
#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 80)

idx = pd.IndexSlice

matplotlib.rcParams.update({'font.size': 14})

hv.renderer('bokeh').theme = "dark_minimal"
import matplotlib.pyplot as plt
plt.style.use('dark_background')
plt.rcParams['figure.facecolor'] = 'black'


In [3]:
from holoviews import Store
from bokeh.models.arrow_heads import TeeHead
Store.add_style_opts(hv.Curve, ['linestyle'], backend='matplotlib')


In [4]:
try:
    cluster.close()
    client.close()
except:
    print("No cluster or client running")
    
#from dask.distributed import Client, LocalCluster

#cluster = LocalCluster(n_workers=5, memory_pause_fraction=0.5)
#client = Client(cluster)

#min_worker_memory = np.inf
#for w in cluster.workers:
#    min_worker_memory = min(min_worker_memory, w.memory_limit)


dask.config.set(pool=multiprocessing.pool.ThreadPool(1))
min_worker_memory = 10e9


In [17]:
# Load merged rec HDF store based on config

#config_file = '/opt/data36/daliu/realtime/spykshrk/ripple_dec/bond.config.json'
#config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv/bond.config.json'
config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_single.json'
config = json.load(open(config_file, 'r'))

# Tetrode subset to use for small tests: 5, 11, 12, 14, 19
config['simulator']['nspike_animal_info']['tetrodes'] = [1, 2, 4, 5, 7, 10, 11, 12, 13, 14, 17, 18,
                                                         19, 20, 22, 23, 27, 29]
#config['simulator']['nspike_animal_info']['tetrodes'] = [5, 11, 12, 14, 19]

config['simulator']['nspike_animal_info']['base_dir'] = '/opt/databackup/daliu/other/mkarlsso'

day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]
time_bin_size = config['pp_decoder']['bin_size']

# Change config
sim_num = 3
config['encoder']['position_kernel']['std'] = 1
config['pp_decoder']['trans_mat_smoother_std'] = 2
config['pp_decoder']['trans_mat_uniform_gain'] = 0.01
config['encoder']['mark_kernel']['std'] = 10
config['encoder']['spk_amp'] = 60
config['encoder']['vel'] = 2

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

In [6]:

# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

spk = nspike_data.SpkDataStream(nspike_anim)
spk_data = SpikeWaves.from_df(spk.data, encode_settings)

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)
linflat_obj = lin_obj.get_mapped_single_axis()

ripcons = nspike_data.RipplesConsData(nspike_anim)
ripdata = ripcons.data_obj

In [7]:
lin_obj

In [8]:
spk_amp = spk_data.max(axis=1)
spk_amp = spk_amp.to_frame().pivot_table(index=['day','epoch','elec_grp_id','timestamp','time'], 
                                         columns='channel', values=0)
spk_amp = SpikeFeatures.create_default(df=spk_amp, sampling_rate=30000)
spk_amp_thresh = spk_amp.get_above_threshold(encode_settings.spk_amp)


In [9]:
linflat_spkindex = linflat_obj.get_irregular_resampled(spk_amp_thresh)
linflat_spkindex_encode_velthresh = linflat_spkindex.query('abs(linvel_flat) >= @encode_settings.vel')
linflat_spkindex_decode_velthresh = linflat_spkindex
   
spk_amp_thresh_index_match = spk_amp_thresh

In [10]:
encode_settings.vel

In [11]:
spk_amp_thresh_encode = spk_amp_thresh_index_match.loc[linflat_spkindex_encode_velthresh.index.get_values()]
#spk_amp_thresh_encode.set_index( 'elec_grp_id', append=True, inplace=True)
#spk_amp_thresh_encode = spk_amp_thresh_encode.reorder_levels(['day', 'epoch', 'elec_grp_id' , 'timestamp', 'time'])
spk_amp_thresh_encode.sort_index(inplace=True)

spk_amp_thresh_decode = spk_amp_thresh_index_match.loc[linflat_spkindex_decode_velthresh.index.get_values()]
#spk_amp_thresh_decode.set_index( 'elec_grp_id', append=True, inplace=True)
#spk_amp_thresh_decode = spk_amp_thresh_decode.reorder_levels(['day', 'epoch', 'elec_grp_id' , 'timestamp', 'time'])
spk_amp_thresh_decode.sort_index(inplace=True)

In [15]:
%%time
#%%prun -r -s cumulative

encoder = OfflinePPEncoder(linflat=linflat_obj, enc_spk_amp=spk_amp_thresh_encode, dec_spk_amp=spk_amp_thresh_decode,
                           encode_settings=encode_settings, decode_settings=decode_settings, chunk_size=100000
                           #dask_worker_memory=min_worker_memory)
                           ) 
#task = encoder.setup_encoder_dask()
observ_obj = encoder.run_encoder()

In [13]:
%%time
# Run PP decoding algorithm
time_bin_size = 30

decoder = OfflinePPDecoder(observ_obj=observ_obj, trans_mat=encoder.trans_mat['simple'], 
                           prob_no_spike=encoder.prob_no_spike,
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           time_bin_size=time_bin_size)

posteriors = decoder.run_decoder()

In [14]:
observ_obj

In [None]:
decoder.likelihoods

In [None]:
posteriors

In [None]:
np.nansum(posteriors.get_distribution_view(), axis=1)[40:40]

In [None]:
# check to make sure no bins were skipped
np.nonzero(np.invert(np.diff(posteriors.index.get_level_values('timestamp')) == 30))

In [None]:
posteriors._to_hdf_store('/opt/data36/daliu/pyBond/analysis/bond_decode_example.h5','/analysis', 
                         'example01/bond/decode/clusterless/offline/day04/epoch01/', 'decode_sim'+str(sim_num), overwrite=True)
ripdata._to_hdf_store('/opt/data36/daliu/pyBond/analysis/bond_decode_example.h5','/analysis', 
                      'example01/cons_ripple/day04/epoch01/', 'decode_sim'+str(sim_num), overwrite=True)

In [None]:
frank_anim = FrankAnimalInfo('/opt/data36/daliu/', 'pyBond')
data_info = FrankDataInfo(frank_anim, 'decode_example')
display(frank_anim.data_paths)
display(data_info.entries)

In [None]:
lin_obj_flat_col = lin_obj.copy()

new_col = []
for entry in lin_obj_flat_col.columns:
    if entry[0] == 'lin_dist_well':
        new_col.append('dist_' + entry[1])
    elif entry[0] == 'lin_vel':
        new_col.append('vel_' + entry[1])
    elif entry[0] == 'seg_idx':
        new_col.append(entry[1])
    elif len(entry[0]) > 1 and len(entry[1]) > 1:
        new_col.append(entry[0] + '_' + entry[1])
    else:
        new_col.append(entry)

lin_obj_flat_col.columns = new_col
lin_obj_flat_col

In [None]:
spk_data._to_hdf_store('/opt/data36/daliu/pyBond/processing/bond_processing_example.h5', '/processing',
                      'example01/bond/day04/epoch01/', 'spk_wave', overwrite=True)
spk_amp._to_hdf_store('/opt/data36/daliu/pyBond/processing/bond_processing_example.h5', '/processing',
                      'example01/bond/day04/epoch01/', 'spk_amp', overwrite=True)
lin_obj_flat_col._to_hdf_store('/opt/data36/daliu/pyBond/processing/bond_processing_example.h5', '/processing',
                      'example01/bond/day04/epoch01/', 'lin_pos', overwrite=True)
linflat_obj._to_hdf_store('/opt/data36/daliu/pyBond/processing/bond_processing_example.h5', '/processing',
                      'example01/bond/day04/epoch01/', 'linflat_pos', overwrite=True)


In [None]:
data_info.save_single_data('/processing', 'example01/bond/day04/epoch01', 'spk_amp', spk_amp, overwrite=True)
data_info.save_single_data('/processing', 'example01/bond/day04/epoch01', 'lin_pos', lin_obj, overwrite=True)
data_info.save_single_data('/processing', 'example01/bond/day04/epoch01', 'linflat_pos', linflat_obj, overwrite=True)


In [None]:

#test1 = Posteriors._from_hdf_store('/opt/data36/daliu/pyBond/analysis/bond_decode.h5','/analysis',
#                                   'decode/clusterless/offline/posterior', 'simple_trans_mat')

In [15]:
hv.Image(np.nan_to_num(encoder.trans_mat['simple'], 0))

In [16]:
%%output backend='matplotlib' size=300
%%opts Points (s=200 marker='^' )
%%opts Curve [aspect=3]
%%opts Text (text_align='left')

sel_distrib = observ_obj.loc[:, pos_col_format(0,encode_settings.pos_num_bins):         
                             pos_col_format(encode_settings.pos_num_bins-1,
                                            encode_settings.pos_num_bins)]
    
sel_pos = observ_obj.loc[:, 'position']

max_prob = sel_distrib.max().max()/2

def plot_observ(big_bin, small_bin):
    bin_id = small_bin + 10000 * big_bin
    spks_in_bin = sel_distrib.loc[observ_obj['dec_bin'] == bin_id, :]
    pos_in_bin = sel_pos.loc[observ_obj['dec_bin'] == bin_id, :]
    
    num_spks = len(spks_in_bin)
    plot_list = []
    if num_spks == 0:
        plot_list.append(hv.Curve((0,[max_prob-0.01]), 
                                   extents=(0, 0, encode_settings.pos_bins[-1], max_prob)))
    for spk_observ, pos_observ in zip(spks_in_bin.values, pos_in_bin.values):
        plot_list.append(hv.Curve(spk_observ, 
                                  extents=(0, 0, encode_settings.pos_bins[-1], max_prob)))

        plot_list.append(hv.Points((pos_observ, [max_prob-0.01])))
    return hv.Overlay(plot_list) * hv.Text(50,max_prob-0.05, "num_spks: {num_spks}\n"
                                           "Timestamp: {timestamp}\nTime: {time}".
                                           format(num_spks=num_spks, timestamp=time_bin_size*bin_id,
                                                  time=time_bin_size*bin_id/30000))

#Ind = Stream.define('stuff', ind=0)

dmap = hv.DynamicMap(plot_observ, kdims=['big_bin', 'small_bin'], label="test")
#dmap = hv.DynamicMap(plot_observ, kdims=
#                     [hv.Dimension('bin_id', range=(0, observ_obj['dec_bin'].iloc[-1]), step=1)])
#dmap = hv.DynamicMap(plot_observ, kdims=
#                     [hv.Dimension('bin_id', values=observ_obj['dec_bin'].unique())])

#dmap.redim.values(bin_id=range(0, observ_obj['dec_bin'].iloc[-1]))
dmap.redim.range(small_bin=(0, 1000), big_bin=(0, observ_obj['dec_bin'].iloc[-1]/1000 + 1))
#dmap.redim.range(bin_id=(0, observ_obj['dec_bin'].iloc[-1]))
#dmap.redim.values(bin_id=[4,5])


In [17]:
posteriors.memory_usage().sum()/1e9

In [38]:
%%output backend='bokeh' size=300 holomap='scrubber'
%%opts RGB { +framewise} [height=100 width=350]
%%opts Points (marker='o' color='#AAAAFF' size=1 alpha=0.05)
%%opts Polygons (color='grey', alpha=0.3 fill_color='grey' fill_alpha=0.3)
#%%opts Image {+framewise}
dec_viz = DecodeVisualizer(posteriors.fillna(0), linpos=linflat_obj, riptimes=ripdata, enc_settings=encode_settings, heatmap_max=0.15)

dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY())


In [None]:
%%opts NdLayout [shared_axes=False]
%%output size=100

dmap = dec_viz.plot_ripple_dynamic()

plot_list = []
plt_grp_size = 12
plt_grps = range(math.ceil(ripdata.get_num_events()/plt_grp_size))
plt_range_low = np.array(plt_grps) * plt_grp_size
plt_range_high = np.append(plt_range_low[0:-1] + plt_grp_size, ripdata.get_num_events())

for plt_grp, ind_low, ind_high in zip(plt_grps, plt_range_low, plt_range_high):
    plot_list.append(hv.NdLayout(dmap[set(range(ind_low, ind_high))]).cols(3))


#for plt_grp in plt_grps
#hv.NdLayout(dmap[set(range(ripdata.get_num_events()))]).cols(3)

In [39]:
%%opts Image {+axiswise} [height=300 width=300 aspect=3]
%%opts Curve {+axiswise} [aspect=3] (line_dash='dashed' color='#AAAAAA' linestyle='--' alpha=0.5)
%%opts Points {+axiswise} [aspect=3] (marker='*' size=14)
%%opts NdLayout {+axiswise}
%%output backend='matplotlib' size=600

event_ids = ripdata.find_events([2585.42, 2791, 2938.2, 3180.2, 3263.40, 3337.4])
plt = hv.Layout().opts(shared_axes=False)
for id in event_ids:
    plt += dec_viz.plot_ripple_all(id)

plt.cols(1).opts(shared_axes=False)

In [42]:
%%opts Image {+axiswise} [height=300 width=300 aspect=3]
%%opts Curve {+axiswise} [aspect=3] (line_dash='dashed' color='#AAAAAA' linestyle='--' alpha=0.5)
%%opts Points {+axiswise} [aspect=3] (marker='*' size=14)
%%opts NdLayout {+axiswise}
%%output backend='matplotlib' size=600

event_ids = [22, 70, 73, 99, 123, 143, 161, 174, 180, 199, 201, 229, 245]
for id in event_ids:
    plt += dec_viz.plot_ripple_all(id)

plt.cols(1).opts(shared_axes=False)

In [40]:
%%opts Image {+axiswise} [height=300 width=300 aspect=1]
%%opts Curve.arm_bound {+axiswise} [aspect=1] (line_dash='dashed' color='#AAAAAA' linestyle='--' alpha=0.5)
%%opts Points {+axiswise} [aspect=1] (marker='*' size=14)
%%opts NdLayout {+axiswise}
%%output backend='matplotlib' size=200

dec_viz.plot_ripple_all(2)

In [41]:
%%opts Image {+axiswise} [height=300 width=300 aspect=1]
%%opts Curve.arm_bound {+axiswise} [aspect=1] (line_dash='dashed' color='#AAAAAA' linestyle='--' alpha=0.5)
%%opts Points {+axiswise} [aspect=1] (marker='*' size=18)
%%opts NdLayout {+axiswise}
%%output backend='matplotlib' size=200

dec_viz = DecodeVisualizer(posteriors.fillna(0), linpos=linflat_obj, riptimes=ripdata.get_above_maxthresh(5), enc_settings=encode_settings)

rip_plots = dec_viz.plot_ripple_grid(2)
for plt_grp in rip_plots:
    display(plt_grp.opts(shared_axes=False))

In [None]:
%%output size=300
dec_viz.plot_ripple_all(242)