In [1]:
cd /home/daliu/Src/spykshrk_realtime/

In [30]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
import json
import os
import scipy.signal
import functools

import math

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.pp_decoder.util import gaussian, normal2D, apply_no_anim_boundary, simplify_pos_pandas
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.pp_decoder.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                                         LinearPosition, StimLockout, Posteriors, \
                                                         FlatLinearPosition, SpikeWaves, SpikeFeatures, \
                                                         pos_col_format

from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer
from spykshrk.franklab.pp_decoder.decode_error import LinearDecodeError
from spykshrk.franklab.pp_decoder.util import normal_pdf_int_lookup

import dask
import dask.dataframe as dd
import dask.array as da

import pickle
        
%load_ext Cython

%matplotlib inline

#pd.set_option('float_format', '{:,.2f}'.format)
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)
#pd.set_option('display.width', 80)

idx = pd.IndexSlice

matplotlib.rcParams.update({'font.size': 14})



In [3]:
try:
    cluster.close()
    client.close()
except:
    print("No cluster or client")

In [4]:
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=15)
client = Client(cluster)

In [5]:
%%time
# Load merged rec HDF store based on config

#config_file = '/opt/data36/daliu/realtime/spykshrk/ripple_dec/bond.config.json'
#config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv/bond.config.json'
config_file = '/home/daliu/Src/spykshrk_realtime/config/bond_single.json'
config = json.load(open(config_file, 'r'))

day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]
time_bin_size = config['pp_decoder']['bin_size']

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

spk = nspike_data.SpkDataStream(nspike_anim)
spk_data = SpikeWaves(spk.data)

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

In [None]:
spk.data

In [6]:
%%time
spk_amp = spk_data.max(axis=1)
spk_amp = spk_amp.to_frame().pivot_table(index=['day','epoch','electrode_group_id','timestamp','time'], 
                                         columns='channel', values=0)
spk_amp= SpikeFeatures(spk_amp)
spk_amp_thresh = spk_amp.get_above_threshold(60)


In [None]:
%%time

import time

def compute_observ_tet(tet_id, chunk_ii, lin_obj, dec_spk_data,
                       enc_spk_index, enc_spk_data, encode_settings):
    
    print("Computing {}: {}".format(tet_id, chunk_ii))
    tet_lin_pos = (lin_obj.get_irregular_resampled(enc_spk_index.get_level_values('timestamp'))
                   .get_mapped_single_axis())
    pos_distrib_tet = sp.stats.norm.pdf(np.expand_dims(encode_settings.pos_bins, 0),
                                        np.expand_dims(tet_lin_pos['linpos_flat'],1), 
                                        encode_settings.pos_kernel_std)

    mark_contrib = normal_pdf_int_lookup(np.expand_dims(dec_spk_data, 1), 
                                         np.expand_dims(enc_spk_data,0), 
                                         encode_settings.mark_kernel_std)

    all_contrib = np.prod(mark_contrib, axis=2)

    print("Done {}".format(tet_id))

    return np.matmul(all_contrib, pos_distrib_tet)  


grp = spk_amp_thresh.groupby('tet')
observations = {}
task = []
chunking = 1000
for tet_id, spk_tet in grp:
    for chunk_ii in range(int(math.ceil(len(spk_tet)/chunking))):
        if (chunk_ii + 1) * chunking > len(spk_tet):
            chunk_start = chunk_ii * chunking
            chunk_end = len(spk_tet)
        else:
            chunk_start = chunk_ii * chunking
            chunk_end = (chunk_ii + 1) * chunking
            
        task.append(dask.delayed(compute_observ_tet)
                    (tet_id, chunk_ii, lin_obj, 
                     spk_tet.values[chunk_start:chunk_end, :],
                     spk_tet.index, spk_tet.values,
                     encode_settings))
            
    #compute_observ_tet(lin_obj, spk_tet.index, spk_tet.values, encode_settings)
    #task.append(dask.delayed(compute_observ_tet)(tet_id, lin_obj, spk_tet.index, spk_tet.values, encode_settings))
    
observ_results = dask.compute(*task)



"""    print("Starting tet {}".format(tet_id))
    tet_lin_pos = (lin_obj.get_irregular_resampled(spk_tet.index.get_level_values('timestamp'))
                   .get_mapped_single_axis())
    pos_distrib_tet = sp.stats.norm.pdf(np.expand_dims(encode_settings.pos_bins, 0),
                                        np.expand_dims(tet_lin_pos['linpos_flat'],1), 
                                        encode_settings.pos_kernel_std)
    marks_tet = spk_tet.values
    
    mark_contrib = normal_pdf_int_lookup(np.expand_dims(marks_tet, 0), 
                                         np.expand_dims(marks_tet,1), 
                                         encode_settings.mark_kernel_std)
    
    all_contrib = np.prod(mark_contrib, axis=2)
    
    observations[tet_id] = pd.DataFrame(data=np.matmul(all_contrib, pos_distrib_tet), index=spk_tet.index,
                                        columns=[pos_col_format(x_bin, encode_settings.pos_num_bins) 
                                                 for x_bin in range(encode_settings.pos_num_bins)])
"""     

In [28]:
%%time

def compute_observ_tet(dec_spk, enc_spk, tet_lin_pos, encode_settings):
    
    pos_distrib_tet = sp.stats.norm.pdf(np.expand_dims(encode_settings.pos_bins, 0),
                                        np.expand_dims(tet_lin_pos['linpos_flat'],1), 
                                        encode_settings.pos_kernel_std)

    mark_contrib = normal_pdf_int_lookup(np.expand_dims(dec_spk, 1), 
                                         np.expand_dims(enc_spk,0), 
                                         encode_settings.mark_kernel_std)

    all_contrib = np.prod(mark_contrib, axis=2)

    observ = np.matmul(all_contrib, pos_distrib_tet)
    
    return pd.DataFrame(observ, index=dec_spk.index, 
                        columns=[pos_col_format(pos_ii, observ.shape[1]) 
                                 for pos_ii in range(observ.shape[1])])


grp = spk_amp_thresh.groupby('electrode_group_id')
observations = {}
task = []
chunksize = 2000
for tet_id, spk_tet in grp:
    tet_lin_pos = (lin_obj.get_irregular_resampled(spk_tet.index.get_level_values('timestamp'))
                   .get_mapped_single_axis())
    dask_spk_tet = dd.from_pandas(spk_tet.get_simple_index(), chunksize=chunksize)
    
    
    task.append(dask_spk_tet.map_partitions(functools.partial(compute_observ_tet, enc_spk=spk_tet,
                                                              tet_lin_pos=tet_lin_pos,
                                                              encode_settings=encode_settings), 
                                            meta=pd.DataFrame([], columns=range(450))))
    
results = dask.compute(*task)


In [None]:
tet_ids = np.unique(spk_amp.index.get_level_values('tet'))

In [None]:
observ_tet_list = []
grp = spk_amp_thresh.groupby('tet')
for tet_ii, (tet_id, grp_spk) in enumerate(grp):
    tet_result = results[tet_ii]
    tet_result.set_index(grp_spk.index, inplace=True)
    observ_tet_list.append(tet_result)

observ = pd.concat(observ_tet_list)

In [None]:
observ.sort_index(level=['day', 'epoch', 'timestamp', 'tet'])

In [None]:
results[0].iloc[120:125].T.plot(figsize=[15,10])

In [None]:
%%prun -r -s cumulative
grp = spk_amp_thresh.groupby('tet')
observations = {}
for tet_id, spk_tet in grp:
    print("Starting tet {}".format(tet_id))
    tet_lin_pos = (lin_obj.get_irregular_resampled(spk_tet.index.get_level_values('timestamp'))
                   .get_mapped_single_axis())
    pos_distrib_tet = sp.stats.norm.pdf(np.expand_dims(encode_settings.pos_bins, 0),
                                        np.expand_dims(tet_lin_pos['linpos_flat'],1), 
                                        encode_settings.pos_kernel_std)
    marks_tet = spk_tet.values
    
    encoding_model = sp.stats.norm(loc=marks_tet,
                                      scale=encode_settings.mark_kernel_std)
                                      #x=np.expand_dims(marks_tet, 0)))
    
    #observations[tet_id] = np.matmul(np.squeeze(np.prod(mark_contrib, axis=2)), pos_distrib_test)
                    
    observations[tet_id] = np.zeros([marks_tet.shape[0], len(encode_settings.pos_bins)])
    
    for dec_spk_ii, dec_mark in enumerate(marks_tet):
        mark_contrib = np.prod(encoding_model.pdf(dec_mark), 
                               axis=1)
        observations[tet_id][dec_spk_ii, :] = np.matmul(mark_contrib, pos_distrib_tet)