In [1]:
cd /home/daliu/Src/spykshrk_realtime/

In [2]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
import json
import os
import scipy.signal
import functools

import math

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.pp_decoder.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                                         LinearPosition, StimLockout, Posteriors, \
                                                         FlatLinearPosition, SpikeWaves, SpikeFeatures, \
                                                         pos_col_format

from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer
from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError
from spykshrk.franklab.pp_decoder.util import normal_pdf_int_lookup

import dask
import dask.dataframe as dd
import dask.array as da

import pickle
        
%load_ext Cython

%matplotlib inline

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 80)

idx = pd.IndexSlice

matplotlib.rcParams.update({'font.size': 14})



In [3]:
try:
    cluster.close()
    client.close()
except:
    print("No cluster or client")
    
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=15)
client = Client(cluster)

In [4]:
%%time
# Load merged rec HDF store based on config

#config_file = '/opt/data36/daliu/realtime/spykshrk/ripple_dec/bond.config.json'
#config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv/bond.config.json'
config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_single.json'
config = json.load(open(config_file, 'r'))

day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]
time_bin_size = config['pp_decoder']['bin_size']

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

spk = nspike_data.SpkDataStream(nspike_anim)
spk_data = SpikeWaves(spk.data)

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

In [6]:
%%time
spk_amp = spk_data.max(axis=1)
spk_amp = spk_amp.to_frame().pivot_table(index=['day','epoch','elec_grp','timestamp','time'], 
                                         columns='channel', values=0)
spk_amp= SpikeFeatures(spk_amp)
spk_amp_thresh = spk_amp.get_above_threshold(60)


In [7]:
%%time

def compute_observ_tet(dec_spk, enc_spk, tet_lin_pos, encode_settings):
    
    pos_distrib_tet = sp.stats.norm.pdf(np.expand_dims(encode_settings.pos_bins, 0),
                                        np.expand_dims(tet_lin_pos['linpos_flat'],1), 
                                        encode_settings.pos_kernel_std)

    mark_contrib = normal_pdf_int_lookup(np.expand_dims(dec_spk, 1), 
                                         np.expand_dims(enc_spk,0), 
                                         encode_settings.mark_kernel_std)

    all_contrib = np.prod(mark_contrib, axis=2)

    observ = np.matmul(all_contrib, pos_distrib_tet)
    
    # normalize each row
    observ = observ / observ.sum(axis=1)[:, np.newaxis]
    
    ret_df = pd.DataFrame(observ, index=dec_spk.index, 
                          columns=[pos_col_format(pos_ii, observ.shape[1]) 
                                   for pos_ii in range(observ.shape[1])])
    return ret_df


grp = spk_amp_thresh.groupby('elec_grp')
observations = {}
task = []
chunksize = 2000
for tet_id, spk_tet in grp:
    spk_tet.index = spk_tet.index.droplevel('elec_grp')
    tet_lin_pos = (lin_obj.get_irregular_resampled(spk_tet.index.get_level_values('timestamp'))
                   .get_mapped_single_axis())
    
    # Velocity threshold on spikes and position
    tet_lin_pos_thresh = tet_lin_pos.get_above_velocity(10.)
    spk_tet_thresh = spk_tet.reindex(tet_lin_pos_thresh.index)
    
    # Decode from all spikes
    dask_spk_tet = dd.from_pandas(spk_tet.get_simple_index(), chunksize=chunksize)
    
    df_meta = pd.DataFrame([], columns=[pos_col_format(ii, encode_settings.pos_num_bins) 
                                        for ii in range(encode_settings.pos_num_bins)])
    
    # Setup decode of all spikes from encoding of velocity threshold spikes
    task.append(dask_spk_tet.map_partitions(functools.partial(compute_observ_tet, enc_spk=spk_tet_thresh,
                                                              tet_lin_pos=tet_lin_pos_thresh,
                                                              encode_settings=encode_settings), 
                                            meta=df_meta))
    
results = dask.compute(*task)


In [8]:
%%time
tet_ids = np.unique(spk_amp.index.get_level_values('elec_grp'))
observ_tet_list = []
grp = spk_amp_thresh.groupby('elec_grp')
for tet_ii, (tet_id, grp_spk) in enumerate(grp):
    tet_result = results[tet_ii]
    tet_result.set_index(grp_spk.index, inplace=True)
    observ_tet_list.append(tet_result)

observ = pd.concat(observ_tet_list)
observ_obj = SpikeObservation.from_df(observ.sort_index(level=['day', 'epoch', 
                                                               'timestamp', 'elec_grp']), )

observ_obj['elec_grp'] = observ_obj.index.get_level_values('elec_grp')
observ_obj.index = observ_obj.index.droplevel('elec_grp')

observ_obj['position'] = (lin_obj.get_irregular_resampled(observ_obj.index.get_level_values('timestamp')).
                          get_mapped_single_axis()['linpos_flat'])

In [9]:
%%time
# Run PP decoding algorithm
time_bin_size = 300

decoder = OfflinePPDecoder(lin_obj=lin_obj, observ_obj=observ_obj,
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           which_trans_mat='learned', time_bin_size=time_bin_size)

posteriors = decoder.run_decoder()

In [10]:
plt.plot(observ.loc[:,'x000':'x449'].iloc[0:3].values.T)
plt.show()

In [11]:
data = observ.loc[:,'x000':'x449'].iloc[0:3].values.T
data_norm = data/data.sum(axis=0)

plt.plot(data_norm)
plt.show()

In [12]:
plt.plot(decoder.binned_observ.loc[:,'x000':'x449'].iloc[0:2].values.T)
plt.show()

In [13]:
for id, firing in decoder.firing_rate.items():
    plt.figure()
    plt.plot(firing)

plt.show()

In [14]:
## Plot posteriors
plt_ranges = [[2461, 2641]]
#plt_ranges = [[2461, 3405]]
#plt_ranges = [[2930, 3000]]
#plt_ranges = [[3295, 3325]]

for plt_range in plt_ranges:
    
    fig, ax = plt.subplots(figsize=[400,10])
    DecodeVisualizer.plot_decode_image(posteriors, plt_range, encode_settings, x_tick=10)
    print(ax)
    DecodeVisualizer.plot_linear_pos(lin_obj, plt_range)
    #DecodeVisualizer.plot_stim_lockout(ax, stim_lockout, plt_range, encode_settings.arm_coordinates[2][1] + 10)
    
    #plt.xlim(plt_range)
    
plt.show()