In [36]:
import numpy as np
import scipy as sp
import scipy.stats
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import holoviews as hv

import json
import functools

import dask
import dask.dataframe as dd
import dask.array as da

from spykshrk.franklab.pp_decoder.data_containers import FlatLinearPosition, SpikeFeatures, \
        EncodeSettings, pos_col_format, SpikeObservation
from spykshrk.franklab.pp_decoder.util import normal_pdf_int_lookup

hv.extension('matplotlib')
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)

In [21]:
%pdb

In [3]:
try:
    cluster.close()
    client.close()
except:
    print("No cluster or client")
    
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=15)
client = Client(cluster)

In [41]:
class UnitGenerator:
    def __init__(self, elec_grp, mark_mean, mark_cov, pos_mean, pos_var, peak_fr, sampling_rate):
        self.elec_grp = elec_grp
        self.mark_mean = mark_mean
        self.mark_cov = mark_cov
        self.pos_mean = pos_mean
        self.pos_var = pos_var
        self.rv_marks = sp.stats.multivariate_normal(mean=mark_mean, cov=np.diag(mark_cov))
        self.rv_pos = sp.stats.norm(loc=pos_mean, scale=pos_var)
        self.peak_fr = peak_fr
        self.sampling_rate = sampling_rate
        
    def simulate_spikes_over_pos(self, linpos_flat):
        prob_field = self.rv_pos.pdf(linpos_flat['linpos_flat'].values)/self.rv_pos.pdf(self.pos_mean)
        
        spike_train = sp.stats.bernoulli(p=self.peak_fr/self.sampling_rate * prob_field).rvs()
        
        
        marks = np.atleast_2d(self.rv_marks.rvs(sum(spike_train))).astype('i4')
        
        sample_num = np.nonzero(spike_train)[0]

        time_ind = linpos_flat.index[sample_num]
        ind_levels = time_ind.levels.copy()
        ind_levels.append([self.elec_grp])
        ind_labels = time_ind.labels.copy()
        ind_labels.append([0]*len(time_ind))
        ind_names = time_ind.names.copy()
        ind_names.append('elec_grp')
        
        new_ind = pd.MultiIndex(levels=ind_levels, labels=ind_labels, names=ind_names)
        new_ind = new_ind.reorder_levels(['day', 'epoch', 'elec_grp', 'timestamp', 'time'])
        #new_ind = new_ind.sortlevel(['day', 'epoch', 'elec_grp', 'timestamp', 'time'])

        
        spk_amp = SpikeFeatures(marks, columns=['c00', 'c01', 'c02', 'c03'],
                                index=new_ind)
                                                             
        
        return spk_amp
    
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

In [42]:
encode_settings = AttrDict({'pos_bins': np.arange(0,100,1), 
                            'pos_kernel_std': 10, 
                            'mark_kernel_std': 40, 
                            'pos_num_bins': 100})

sampling_rate = 1000

pos_time = np.arange(0,20000,1)
pos_run = 100*np.cos(pos_time/(500*np.pi))
pos_vel = np.concatenate([[0], np.diff(pos_run) * sampling_rate])

linpos_flat_obj = FlatLinearPosition.from_numpy_single_epoch(1, 1, pos_time, pos_run, pos_vel, sampling_rate)

mark_mean_range = [20,1000]
mark_cov_range = [200,500]
num_marks = 4
num_units = 10
firing_rate_range = [20,100]

pos_range = [0, 100]
pos_var_range = [10,40]

mark_kernel = 40
pos_kernel = 10

unit_mean = np.random.randint(*mark_mean_range, [num_units, num_marks])
unit_cov = np.random.randint(*mark_cov_range, [num_units, num_marks])
unit_pos_mean = np.random.randint(*pos_range, [num_units])
unit_pos_var = np.random.randint(*pos_var_range, [num_units])
unit_fr = np.random.randint(*pos_var_range, [num_units])



units = {}
spk_amps = pd.DataFrame()
for unit_ii in range(num_units):
    units[unit_ii] = UnitGenerator(elec_grp=1,
                                   mark_mean=unit_mean[unit_ii,:], mark_cov=unit_cov[unit_ii,:], 
                                   pos_mean=unit_pos_mean[unit_ii], pos_var=unit_pos_var[unit_ii], 
                                   peak_fr=unit_fr[unit_ii], sampling_rate=sampling_rate)

    unit_marks = units[unit_ii].simulate_spikes_over_pos(linpos_flat_obj)
    spk_amps = spk_amps.append(unit_marks)
spk_amps.sort_index(level='timestamp', inplace=True)

spk_amps.drop_duplicates(inplace=True)


In [32]:
spk_amps

In [43]:
%%time

def compute_observ_tet(dec_spk, enc_spk, tet_lin_pos, encode_settings):
    
    pos_distrib_tet = sp.stats.norm.pdf(np.expand_dims(encode_settings.pos_bins, 0),
                                        np.expand_dims(tet_lin_pos['linpos_flat'],1), 
                                        encode_settings.pos_kernel_std)

    mark_contrib = normal_pdf_int_lookup(np.expand_dims(dec_spk, 1), 
                                         np.expand_dims(enc_spk,0), 
                                         encode_settings.mark_kernel_std)

    all_contrib = np.prod(mark_contrib, axis=2)

    observ = np.matmul(all_contrib, pos_distrib_tet)
    
    # normalize each row
    observ = observ / observ.sum(axis=1)[:, np.newaxis]
    
    ret_df = pd.DataFrame(observ, index=dec_spk.index, 
                          columns=[pos_col_format(pos_ii, observ.shape[1]) 
                                   for pos_ii in range(observ.shape[1])])
    return ret_df


grp = spk_amps.groupby('elec_grp')
observations = {}
task = []
chunksize = 100
for tet_id, spk_tet in grp:
    spk_tet.index = spk_tet.index.droplevel('elec_grp')
    tet_lin_pos = linpos_flat_obj.reindex(spk_tet.index)
    
    # Velocity threshold on spikes and position
    #tet_lin_pos_thresh = tet_lin_pos.get_above_velocity(10.)
    #spk_tet_thresh = spk_tet.reindex(tet_lin_pos_thresh.index)
    tet_lin_pos_thresh = tet_lin_pos
    spk_tet_thresh = spk_tet
    
    # Decode from all spikes
    dask_spk_tet = dd.from_pandas(spk_tet.get_simple_index(), chunksize=chunksize)
    
    df_meta = pd.DataFrame([], columns=[pos_col_format(ii, encode_settings.pos_num_bins) 
                                        for ii in range(encode_settings.pos_num_bins)])
    
    # Setup decode of all spikes from encoding of velocity threshold spikes
    task.append(dask_spk_tet.map_partitions(functools.partial(compute_observ_tet, enc_spk=spk_tet_thresh,
                                                              tet_lin_pos=tet_lin_pos_thresh,
                                                              encode_settings=encode_settings), 
                                            meta=df_meta))
    
results = dask.compute(*task)


In [44]:
%%time
tet_ids = np.unique(spk_amps.index.get_level_values('elec_grp'))
observ_tet_list = []
grp = spk_amps.groupby('elec_grp')
for tet_ii, (tet_id, grp_spk) in enumerate(grp):
    tet_result = results[tet_ii]
    tet_result.set_index(grp_spk.index, inplace=True)
    observ_tet_list.append(tet_result)

observ = pd.concat(observ_tet_list)
observ_obj = SpikeObservation.from_df(observ.sort_index(level=['day', 'epoch', 
                                                               'timestamp', 'elec_grp']), )

observ_obj['elec_grp'] = observ_obj.index.get_level_values('elec_grp')
observ_obj.index = observ_obj.index.droplevel('elec_grp')

observ_obj['position'] = (linpos_flat_obj['linpos_flat'])

In [45]:
observ_obj

In [None]:
%%opts Scatter3D [fig_size=200]

grp = spk_amps.groupby('elec_grp')


scatter = [hv.Scatter3D(elec_spk.values) for elec_id, elec_spk in grp]

overlay = hv.Overlay(scatter)

overlay

In [None]:
hv.Scatter3D(marks[0][:,0:3])
hv.Scatter3D(marks[1][:,0:3])
