In [1]:
cd /home/daliu/Src/spykshrk_realtime/

In [38]:
%pdb

In [2]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
import json
import os
import scipy.signal
import dask
import dask.dataframe as dd

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.pp_decoder.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                                         LinearPosition, StimLockout, Posteriors, \
                                                         FlatLinearPosition, pos_col_format

from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer
from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError

        
%load_ext Cython

%matplotlib inline

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 80)

idx = pd.IndexSlice

matplotlib.rcParams.update({'font.size': 28})


In [3]:
try:
    cluster.close()
    client.close()
except:
    print("No cluster or client")
    
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=15)
client = Client(cluster)

In [4]:
%%time
# Load merged rec HDF store based on config

#config_file = '/opt/data36/daliu/realtime/spykshrk/ripple_dec/bond.config.json'
#config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv/bond.config.json'
config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_single.json'
config = json.load(open(config_file, 'r'))

day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]
time_bin_size = config['pp_decoder']['bin_size']

# Main hdf5 data source file name
hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Open data file
store = pd.HDFStore(hdf_file, mode='r')

# Encapsulate Spike Observation panda table in container
observ_obj = SpikeObservation.create_default(store['rec_3'], day=day, epoch=epoch)

# Grab stimulation lockout times
stim_lockout = StimLockout.create_default(store['rec_11'])

# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

In [None]:
%%time
# Run PP decoding algorithm
time_bin_size = 300

decoder = OfflinePPDecoder(lin_obj=lin_obj, observ_obj=observ_obj,
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           which_trans_mat='learned', time_bin_size=time_bin_size)

posteriors = decoder.run_decoder()

In [7]:
observ_obj

In [16]:
{key : 'f8' for key in [pos_col_format(ii, 100) for ii in range(450)]}

In [14]:
decoder.binned_observ

In [9]:
## Plot posteriors
plt_ranges = [[2461, 2641]]
#plt_ranges = [[2461, 3405]]
#plt_ranges = [[2930, 3000]]
#plt_ranges = [[3295, 3325]]

for plt_range in plt_ranges:
    
    fig, ax = plt.subplots(figsize=[200,10])
    DecodeVisualizer.plot_decode_image(posteriors, plt_range, encode_settings, x_tick=10)
    print(ax)
    DecodeVisualizer.plot_linear_pos(lin_obj, plt_range)
    DecodeVisualizer.plot_stim_lockout(ax, stim_lockout, plt_range, encode_settings.arm_coordinates[2][1] + 10)
    
    #plt.xlim(plt_range)
    
plt.show()

In [10]:
## Plot posteriors
#plt_ranges = [[3260, 3280]]
#plt_ranges = [[3175, 3200]]
#plt_ranges = [[3102, 3106]]
plt_ranges = [[2943, 2950]]
             
for plt_range in plt_ranges:
    
    fig, ax = plt.subplots(figsize=[50,10])
    DecodeVisualizer.plot_decode_image(posteriors, plt_range, encode_settings, x_tick=1)
    DecodeVisualizer.plot_linear_pos(lin_obj, plt_range)
    DecodeVisualizer.plot_stim_lockout(ax, stim_lockout, plt_range, encode_settings.arm_coordinates[2][1] + 10)
    
    plt.xlim(plt_range)
    
plt.show()

In [12]:
dec_est_pos = posteriors.get_distribution_view().idxmax(axis=1).apply(lambda x: int(x[1:])).to_frame()
dec_est_pos.columns = ['est_pos']

dec_est_pos = FlatLinearPosition.create_default(dec_est_pos, parent=posteriors)

resamp_lin_obj = lin_obj.get_resampled(time_bin_size).get_pd_no_multiindex()

error_obj = LinearDecodeError()

error_table = error_obj.calc_error_table(resamp_lin_obj, dec_est_pos,
                                         encode_settings.arm_coordinates, 2)

print(error_table.loc[:, idx[:, 'abs_error']].median())
print(error_table.loc[:, idx[:, 'abs_error']].mean())

In [13]:
error_bars = error_table.loc[:,idx[:, ['plt_error_up', 'plt_error_down']]]. \
        reindex(columns=pd.MultiIndex.from_product([['center','left','right'],['plt_error_up','plt_error_down']]))

error_bars = np.reshape(error_bars.values, [len(error_bars),3,2])
error_bars = error_bars.transpose(1,2,0)

error_table.loc[:, idx[:, 'real_pos']].plot(figsize=[100,10], style='o', yerr=error_bars)
plt.show()

In [None]:

pd.set_option('display.max_colwidth',2)

df = pd.DataFrame(np.zeros([10,10]), 
             index=pd.MultiIndex.from_arrays([range(10), range(0, -10, -1)],
                                             names=['electrode group id', 'short']))
df.style.set_table_styles([dict(selector='th', props=[('max-width', '50px')])])