In [1]:
cd /home/daliu/Src/spykshrk_realtime/

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from functools import partial
import multiprocessing as mp
import pickle
import functools

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.pp_decoder.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                                         LinearPosition, StimLockout, Posteriors
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer

from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError
#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 120)

idx = pd.IndexSlice

In [3]:
%%time
# Load merged rec HDF store based on config

config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv_300samp/bond.config.json'
#config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_single.json'
#config_file = '/opt/data36/daliu/realtime/spykshrk/test/test.config.json'


config = json.load(open(config_file, 'r'))
day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]

# Main hdf5 data source file name
hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Open data file
store = pd.HDFStore(hdf_file, mode='r')

# Encapsulate Spike Observation panda table in container
observ_obj = SpikeObservation.create_default(store['rec_3'], day=day, epoch=epoch)

# Grab stimulation lockout times
stim_lockout = StimLockout.create_default(store['rec_11'])


# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

In [4]:
# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

In [12]:
observ_obj

In [None]:
%%prun -r
# Run PP decoding algorithm

time_bin_size = 300

decoder = OfflinePPDecoder(lin_obj=lin_obj, observ_obj=observ_obj,
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           which_trans_mat='uniform', time_bin_size=time_bin_size, parallel=True, bin_per_pool=1)

posteriors = decoder.run_decoder()

In [18]:
posteriors

In [None]:
p = _

In [None]:
p.sort_stats('cumtime')
p.print_stats(40)

In [None]:
p.print_callees('view.py:514')

In [None]:
p.sort_stats('cumtime')
p.print_stats(40)

In [None]:
p.sort_stats('cumulative')
p.print_callees('_calc_observation')

In [19]:
dec_est_pos = posteriors.get_distribution_view().idxmax(axis=1).apply(lambda x: int(x[1:])).to_frame()
dec_est_pos.columns = ['est_pos']

resamp_lin_obj = lin_obj.get_resampled(time_bin_size).get_pd_no_multiindex()

dec_error = LinearDecodeError()

dec_error = dec_error.calc_error_table(resamp_lin_obj, dec_est_pos,
                                       encode_settings.arm_coordinates, 2)

print("Median:")
print(dec_error.loc[:, idx[:, 'abs_error']].median())
print("Mean:")
print(dec_error.loc[:, idx[:, 'abs_error']].mean())

In [20]:
error_bars = dec_error.loc[:,idx[:, ['plt_error_up', 'plt_error_down']]]. \
        reindex(columns=pd.MultiIndex.from_product([['center','left','right'],['plt_error_up','plt_error_down']]))

error_bars = np.reshape(error_bars.values, [len(error_bars),3,2])
error_bars = error_bars.transpose(1,2,0)

dec_error.loc[:, idx[:, 'real_pos']].plot(figsize=[100,10], style='o', yerr=error_bars)

plt.show()

In [19]:

observ_obj.update_observations_bins(300)
observ_obj.update_parallel_bins(30000)


In [33]:
import dask
import dask.dataframe as dd
import numpy as np

dask.set_options(get=dask.multiprocessing.get)

observ_obj.update_observations_bins(300)
observ_obj.update_parallel_bins(30000)

observ_dask = dd.from_pandas(pd.DataFrame(observ_obj.drop(['time', 'timestamp'], axis=1).reset_index()), npartitions=1)

In [34]:
def step1(table):
    return table.mean()

task = observ_dask.groupby('dec_bin').apply(step1, meta=observ_dask)

def step2(table):
    return table / 2

task = task.apply(step2, axis=1, meta=observ_dask)

In [35]:
%%time
results = task.compute()

In [29]:
results

In [25]:
type(results)