In [1]:
cd /home/daliu/Src/spykshrk_realtime/

In [2]:
import pandas as pd
import numpy as np
import scipy as sp
import math

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
import json
import os
import scipy.signal
import functools
import holoviews as hv

from spykshrk.util import AttrDict
import spykshrk.franklab.filterframework_util as ff_util

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas, \
                                                normal_pdf_int_lookup
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder, OfflinePPEncoder
from spykshrk.franklab.data_containers import DataFrameClass, EncodeSettings, DecodeSettings, SpikeObservation, \
                                              LinearPosition, StimLockout, Posteriors, \
                                              FlatLinearPosition, SpikeWaves, SpikeFeatures, \
                                              pos_col_format, DayEpochTimeSeries

from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer
from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError

import dask
import dask.dataframe as dd
import dask.array as da
import multiprocessing

import cloudpickle
        
%load_ext Cython

%matplotlib inline

hv.extension('matplotlib')
hv.extension('bokeh')
#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 6)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 80)

idx = pd.IndexSlice

matplotlib.rcParams.update({'font.size': 14})



In [3]:
from holoviews import Store
from bokeh.models.arrow_heads import TeeHead
Store.add_style_opts(hv.Curve, ['linestyle'], backend='matplotlib')


In [4]:
"""try:
    cluster.close()
    client.close()
except:
    print("No cluster or client running")
    
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=20, threads_per_worker=2)
client = Client(cluster)

min_worker_memory = np.inf
for w in cluster.workers:
    min_worker_memory = min(min_worker_memory, w.memory_limit)
"""

dask.set_options(get=dask.multiprocessing.get, pool=multiprocessing.pool.Pool(20))
min_worker_memory = 10e9

In [5]:
# Load merged rec HDF store based on config

#config_file = '/opt/data36/daliu/realtime/spykshrk/ripple_dec/bond.config.json'
#config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv/bond.config.json'
config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_single.json'
config = json.load(open(config_file, 'r'))

config['simulator']['nspike_animal_info']['tetrodes'] = [1, 2, 4, 5, 7, 10, 11, 12, 13, 14, 17, 18,
                                                         19, 20, 22, 23, 27, 29]

day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]
time_bin_size = config['pp_decoder']['bin_size']

# Change config
config['encoder']['position_kernel']['std'] = 1
config['pp_decoder']['trans_mat_smoother_std'] = 2
config['pp_decoder']['trans_mat_uniform_gain'] = 0.01

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

spk = nspike_data.SpkDataStream(nspike_anim)
spk_data = SpikeWaves.from_df(spk.data, encode_settings)

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)
linflat_obj = lin_obj.get_mapped_single_axis()

ripcons = nspike_data.RipplesConsData(nspike_anim)
ripdata = ripcons.data_obj

In [6]:
spk_amp = spk_data.max(axis=1)
spk_amp = spk_amp.to_frame().pivot_table(index=['day','epoch','elec_grp_id','timestamp','time'], 
                                         columns='channel', values=0)
spk_amp = SpikeFeatures.create_default(df=spk_amp, sampling_rate=30000)
spk_amp_thresh = spk_amp.get_above_threshold(60)


In [7]:
%%time
#%%prun -r -s cumulative

encoder = OfflinePPEncoder(linflat=linflat_obj, spk_amp=spk_amp_thresh, speed_thresh=2,
                           encode_settings=encode_settings, dask_worker_memory=min_worker_memory)
#task = encoder.setup_encoder_dask()
results = encoder.run_encoder()

In [8]:
encode_prof = _

In [9]:
#%%time
tet_ids = np.unique(spk_amp.index.get_level_values('elec_grp_id'))
observ_tet_list = []
grp = spk_amp_thresh.groupby('elec_grp_id')
for tet_ii, (tet_id, grp_spk) in enumerate(grp):
    tet_result = results[tet_ii]
    tet_result.set_index(grp_spk.index, inplace=True)
    observ_tet_list.append(tet_result)

observ = pd.concat(observ_tet_list)
observ_obj = SpikeObservation.create_default(observ.sort_index(level=['day', 'epoch', 
                                                                      'timestamp', 'elec_grp_id']), 
                                             encode_settings.sampling_rate )

observ_obj['elec_grp_id'] = observ_obj.index.get_level_values('elec_grp_id')
observ_obj.index = observ_obj.index.droplevel('elec_grp_id')

observ_obj['position'] = (lin_obj.get_irregular_resampled(observ_obj.index.get_level_values('timestamp')).
                          get_mapped_single_axis()['linpos_flat'])

In [10]:
observ_obj

In [11]:
%%time
# Run PP decoding algorithm
time_bin_size = 30

decoder = OfflinePPDecoder(lin_obj=lin_obj, observ_obj=observ_obj,
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           which_trans_mat='learned', time_bin_size=time_bin_size)

posteriors = decoder.run_decoder()

In [12]:
import os
os.path.join('/analysis', 'decode/clusterless/offline/posterior', 'run1')

In [13]:
posteriors._to_hdf_store('/opt/data36/daliu/pyBond/analysis/bond_decode.h5','/analysis', 
                         'decode/clusterless/offline/posterior', 'learned_trans_mat', overwrite=True)

In [14]:
test1 = Posteriors._from_hdf_store('/opt/data36/daliu/pyBond/analysis/bond_decode.h5','/analysis',
                                   'decode/clusterless/offline/posterior', 'run1')

In [15]:
posteriors.memory_usage()[0]/1e6

In [16]:
hv.Image(decoder.trans_mat)

In [17]:
%%output backend='matplotlib' size=300
%%opts Points (s=200 marker='^' )
%%opts Curve [aspect=3]
%%opts Text (text_align='left')

sel_distrib = observ_obj.loc[:, pos_col_format(0,encode_settings.pos_num_bins):         
                             pos_col_format(encode_settings.pos_num_bins-1,
                                            encode_settings.pos_num_bins)]
    
sel_pos = observ_obj.loc[:, 'position']

max_prob = sel_distrib.max().max()/2

def plot_observ(big_bin, small_bin):
    bin_id = small_bin + 10000 * big_bin
    spks_in_bin = sel_distrib.loc[observ_obj['dec_bin'] == bin_id, :]
    pos_in_bin = sel_pos.loc[observ_obj['dec_bin'] == bin_id, :]
    
    num_spks = len(spks_in_bin)
    plot_list = []
    if num_spks == 0:
        plot_list.append(hv.Curve((0,[max_prob-0.01]), 
                                   extents=(0, 0, encode_settings.pos_bins[-1], max_prob)))
    for spk_observ, pos_observ in zip(spks_in_bin.values, pos_in_bin.values):
        plot_list.append(hv.Curve(spk_observ, 
                                  extents=(0, 0, encode_settings.pos_bins[-1], max_prob)))

        plot_list.append(hv.Points((pos_observ, [max_prob-0.01])))
    return hv.Overlay(plot_list) * hv.Text(50,max_prob-0.05, "num_spks: {num_spks}\n"
                                           "Timestamp: {timestamp}\nTime: {time}".
                                           format(num_spks=num_spks, timestamp=time_bin_size*bin_id,
                                                  time=time_bin_size*bin_id/30000))

#Ind = Stream.define('stuff', ind=0)

dmap = hv.DynamicMap(plot_observ, kdims=['big_bin', 'small_bin'], label="test")
#dmap = hv.DynamicMap(plot_observ, kdims=
#                     [hv.Dimension('bin_id', range=(0, observ_obj['dec_bin'].iloc[-1]), step=1)])
#dmap = hv.DynamicMap(plot_observ, kdims=
#                     [hv.Dimension('bin_id', values=observ_obj['dec_bin'].unique())])

#dmap.redim.values(bin_id=range(0, observ_obj['dec_bin'].iloc[-1]))
dmap.redim.range(small_bin=(0, 1000), big_bin=(0, observ_obj['dec_bin'].iloc[-1]/1000 + 1))
#dmap.redim.range(bin_id=(0, observ_obj['dec_bin'].iloc[-1]))
#dmap.redim.values(bin_id=[4,5])


In [18]:
%%output backend='bokeh' size=400 holomap='scrubber'
%%opts RGB { +framewise} [height=100 width=250 aspect=2 colorbar=True]
%%opts Points [height=100 width=250 aspect=2 ] (marker='o' color='#AAAAFF' size=2 alpha=0.7)
%%opts Polygons (color='grey', alpha=0.5 fill_color='grey' fill_alpha=0.5)
#%%opts Image {+framewise}
dec_viz = DecodeVisualizer(posteriors, linpos=lin_obj, riptimes=ripdata, enc_settings=encode_settings)

dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY(), plt_range=1, slide=1, values=ripdata['starttime']-.5)


In [19]:
%%opts NdLayout [shared_axes=False]
%%output size=100

dmap = dec_viz.plot_ripple_dynamic()

plot_list = []
plt_grp_size = 12
plt_grps = range(math.ceil(ripdata.get_num_events()/plt_grp_size))
plt_range_low = np.array(plt_grps) * plt_grp_size
plt_range_high = np.append(plt_range_low[0:-1] + plt_grp_size, ripdata.get_num_events())

for plt_grp, ind_low, ind_high in zip(plt_grps, plt_range_low, plt_range_high):
    plot_list.append(hv.NdLayout(dmap[set(range(ind_low, ind_high))]).cols(3))


#for plt_grp in plt_grps
#hv.NdLayout(dmap[set(range(ripdata.get_num_events()))]).cols(3)

In [20]:
%%opts Image {+axiswise} [height=300 width=300 aspect=3]
%%opts Curve {+axiswise} [aspect=2] (line_dash='dashed' color='#AAAAAA' linestyle='--' alpha=0.5)
%%opts Points {+axiswise} [aspect=2] (marker='*' size=14)
%%opts NdLayout {+axiswise}
%%output backend='matplotlib' size=600

event_ids = ripdata.find_events([2585.42, 2791, 2938.2, 3180.2, 3263.40, 3337.4])
plt = hv.Layout()
for id in event_ids:
    plt += dec_viz.plot_ripple_all(id)

plt.cols(1)

In [21]:
%%opts Image {+axiswise} [height=300 width=300 aspect=1]
%%opts Curve.arm_bound {+axiswise} [aspect=1] (line_dash='dashed' color='#AAAAAA' linestyle='--' alpha=0.5)
%%opts Points {+axiswise} [aspect=1] (marker='*' size=14)
%%opts NdLayout {+axiswise}
%%output backend='matplotlib' size=200

dec_viz.plot_ripple_all(2)

In [22]:
%%opts Image {+axiswise} [height=300 width=300 aspect=1]
%%opts Curve {+axiswise} [aspect=1] (line_dash='dashed' color='#AAAAAA' linestyle='--' alpha=0.5)
%%opts Points {+axiswise} [aspect=1] (marker='*' size=14)
%%opts NdLayout {+axiswise}
%%output backend='matplotlib' size=200

dec_viz = DecodeVisualizer(posteriors, linpos=lin_obj, riptimes=ripdata.get_above_maxthresh(5), enc_settings=encode_settings)

rip_plots = dec_viz.plot_ripple_grid(2)
for plt_grp in rip_plots:
    display(plt_grp)

In [23]:
%%output size=300
dec_viz.plot_ripple_all(242)

In [24]:
np.append(plt_range_high, [270])