In [1]:
import numpy as np
import scipy as sp
import scipy.stats
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import holoviews as hv

import json
import functools

import dask
import dask.dataframe as dd
import dask.array as da

from spykshrk.franklab.pp_decoder.data_containers import FlatLinearPosition, SpikeFeatures, \
        EncodeSettings, pos_col_format, SpikeObservation
from spykshrk.franklab.pp_decoder.util import normal_pdf_int_lookup, gaussian
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPDecoder
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer

hv.extension('matplotlib')
hv.extension('bokeh')
pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)

In [2]:
#%pdb

In [3]:
try:
    cluster.close()
    client.close()
except:
    print("No cluster or client")
    
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=15)
client = Client(cluster)

In [4]:
class UnitGenerator:
    def __init__(self, elec_grp_id, mark_mean, mark_cov, pos_mean, pos_var, peak_fr, sampling_rate):
        self.elec_grp_id = elec_grp_id
        self.mark_mean = mark_mean
        self.mark_cov = mark_cov
        self.pos_mean = pos_mean
        self.pos_var = pos_var
        self.rv_marks = sp.stats.multivariate_normal(mean=mark_mean, cov=np.diag(mark_cov))
        self.rv_pos = sp.stats.norm(loc=pos_mean, scale=pos_var)
        self.peak_fr = peak_fr
        self.sampling_rate = sampling_rate
        
    def simulate_spikes_over_pos(self, linpos_flat):
        prob_field = self.rv_pos.pdf(linpos_flat['linpos_flat'].values)/self.rv_pos.pdf(self.pos_mean)
        
        spike_train = sp.stats.bernoulli(p=self.peak_fr/self.sampling_rate * prob_field).rvs()
        
        
        marks = np.atleast_2d(self.rv_marks.rvs(sum(spike_train))).astype('i4')
        
        sample_num = np.nonzero(spike_train)[0]

        time_ind = linpos_flat.index[sample_num]
        ind_levels = time_ind.levels.copy()
        ind_levels.append([self.elec_grp_id])
        ind_labels = time_ind.labels.copy()
        ind_labels.append([0]*len(time_ind))
        ind_names = time_ind.names.copy()
        ind_names.append('elec_grp_id')
        
        new_ind = pd.MultiIndex(levels=ind_levels, labels=ind_labels, names=ind_names)
        new_ind = new_ind.reorder_levels(['day', 'epoch', 'elec_grp_id', 'timestamp', 'time'])
        #new_ind = new_ind.sortlevel(['day', 'epoch', 'elec_grp', 'timestamp', 'time'])

        
        spk_amp = SpikeFeatures(marks, columns=['c00', 'c01', 'c02', 'c03'],
                                index=new_ind)
                                                             
        mark_linpos = linpos_flat.iloc[sample_num]
        mark_linpos['elec_grp_id'] = self.elec_grp_id
        mark_linpos.set_index('elec_grp_id', append=True, inplace=True)
        mark_linpos = mark_linpos.reorder_levels(['day','epoch','elec_grp_id','timestamp','time'])
        
        return spk_amp, mark_linpos, prob_field
    
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

In [5]:
encode_settings = AttrDict({'sampling_rate': 1000,
                            'pos_bins': np.arange(0,100,1),
                            'pos_bin_edges': np.arange(0,100.1,1),
                            'pos_bin_delta': 1,
                            'pos_kernel': sp.stats.norm.pdf(np.arange(0,100,1), 50, 1),
                            'pos_kernel_std': 1, 
                            'mark_kernel_std': 20, 
                            'pos_num_bins': 100,
                            'pos_col_names': [pos_col_format(ii, 100) for ii in range(100)],
                            'arm_coordinates': [[0,100]]})

decode_settings = AttrDict({'trans_smooth_std': 5,
                            'trans_uniform_gain': 0.001,
                            'time_bin_size': 10})
                            

sampling_rate = 1000

pos_time = np.arange(0,100000,1)
pos_run = 50*np.cos(pos_time[0:17272]/(500*np.pi))+50
pos_run = np.append(pos_run, ([pos_run[-1]]*39478))
pos_run = np.append(pos_run, 50*np.cos(pos_time[56750:100000]/(500*np.pi))+50)
pos_vel = np.concatenate([[0], np.diff(pos_run) * sampling_rate])

linpos_flat_obj = FlatLinearPosition.from_numpy_single_epoch(1, 1, pos_time, pos_run, pos_vel, sampling_rate)

mark_mean_range = [20,1000]
mark_cov_range = [200,500]
num_marks = 4
num_units = 50
firing_rate_range = [20,50]

pos_field_range = [0, 100]
pos_field_var_range = [5,10]

unit_mean = np.random.randint(*mark_mean_range, [num_units, num_marks])
unit_cov = np.random.randint(*mark_cov_range, [num_units, num_marks])

#unit_pos_mean = np.random.randint(*pos_field_range, [num_units])
#unit_pos_var = np.random.randint(*pos_field_var_range, [num_units])

unit_pos_mean = np.linspace(*pos_field_range, num_units)
unit_pos_var = np.array([pos_field_range[1]/num_units*2]*num_units)

unit_fr = np.random.randint(*firing_rate_range, [num_units])



units = {}
unit_spks = {}
spk_amps = pd.DataFrame()
for unit_ii in range(num_units):
    units[unit_ii] = UnitGenerator(elec_grp_id=1,
                                   mark_mean=unit_mean[unit_ii,:], mark_cov=unit_cov[unit_ii,:], 
                                   pos_mean=unit_pos_mean[unit_ii], pos_var=unit_pos_var[unit_ii], 
                                   peak_fr=unit_fr[unit_ii], sampling_rate=sampling_rate)

    unit_marks, mark_pos, field = units[unit_ii].simulate_spikes_over_pos(linpos_flat_obj)
    unit_spks[unit_ii] = unit_marks.merge(mark_pos, how='outer', left_index=True, right_index=True)
    
    spk_amps = spk_amps.append(unit_marks)
spk_amps.sort_index(level='timestamp', inplace=True)

spk_amps.drop_duplicates(inplace=True)


In [6]:
%%output size=200 backend='matplotlib'
%%opts Points [aspect=2] (marker='.')

hv.Points(pos_run)

In [7]:
%%output backend='matplotlib'
%opts Scatter3D {+framewise}
%opts Overlay {+framewise}

from holoviews.streams import Stream, param


def mark_plots(elevation, azimuth):
    %%output backend='matplotlib'
    scatter = [hv.Scatter3D(mark_pos.loc[:,['linpos_flat','c01','c02']])
               for elec_id, mark_pos in unit_spks.items()]
    overlay = hv.Overlay(scatter)
    overlay = overlay.opts({'Scatter3D':{'plot': {'fig_size':400, 'azimuth': int(azimuth), 
                                                  'elevation': int(elevation)},
                                         'norm': {'framewise':True}}})
    return overlay


#holo = hv.HoloMap({(e,a): mark_plots(e,a) for e in range(0, 181, 20)
#                   for a in range(-90,91,20)}, kdims=['e','a'])
#holo

dmap = hv.DynamicMap(callback=mark_plots, kdims=['elevation', 'azimuth'], cache_size=1)
dmap = dmap.redim.values(elevation=range(0,181,5),
                         azimuth=range(-90,91,5)).opts(norm=dict(framewise=True))
dmap

In [8]:
%%time

def compute_observ_tet(dec_spk, enc_spk, tet_lin_pos, occupancy, encode_settings):
    
    pos_distrib_tet = sp.stats.norm.pdf(np.expand_dims(encode_settings.pos_bins, 0),
                                        np.expand_dims(tet_lin_pos['linpos_flat'],1), 
                                        encode_settings.pos_kernel_std)

    mark_contrib = normal_pdf_int_lookup(np.expand_dims(dec_spk, 1), 
                                         np.expand_dims(enc_spk,0), 
                                         encode_settings.mark_kernel_std)

    all_contrib = np.prod(mark_contrib, axis=2)

    observ = np.matmul(all_contrib, pos_distrib_tet)
    
    # occupancy normalize
    observ = observ / occupancy
    
    # normalize each row
    observ = observ / observ.sum(axis=1)[:, np.newaxis]
    
    ret_df = pd.DataFrame(observ, index=dec_spk.index, 
                          columns=[pos_col_format(pos_ii, observ.shape[1]) 
                                   for pos_ii in range(observ.shape[1])])
    return ret_df


occ, _ = np.histogram(a=linpos_flat_obj['linpos_flat'], bins=encode_settings.pos_bin_edges,normed=True)
occ = np.convolve(occ, encode_settings.pos_kernel)[int(len(occ)/2):int(len(occ)*3/2)]

grp = spk_amps.groupby('elec_grp_id')
observations = {}
task = []
chunksize = 100
for tet_id, spk_tet in grp:
    spk_tet.index = spk_tet.index.droplevel('elec_grp_id')
    tet_lin_pos = linpos_flat_obj.reindex(spk_tet.index)
    
    # Velocity threshold on spikes and position
    #tet_lin_pos_thresh = tet_lin_pos.get_above_velocity(10.)
    #spk_tet_thresh = spk_tet.reindex(tet_lin_pos_thresh.index)
    tet_lin_pos_thresh = tet_lin_pos
    spk_tet_thresh = spk_tet
    
    # Decode from all spikes
    dask_spk_tet = dd.from_pandas(spk_tet.get_simple_index(), chunksize=chunksize)
    
    df_meta = pd.DataFrame([], columns=[pos_col_format(ii, encode_settings.pos_num_bins) 
                                        for ii in range(encode_settings.pos_num_bins)])
    
    # Setup decode of all spikes from encoding of velocity threshold spikes
    task.append(dask_spk_tet.map_partitions(functools.partial(compute_observ_tet, enc_spk=spk_tet_thresh,
                                                              tet_lin_pos=tet_lin_pos_thresh,
                                                              occupancy=occ,
                                                              encode_settings=encode_settings), 
                                            meta=df_meta))
    
results = dask.compute(*task)


In [9]:
%%output backend='bokeh'
%%opts Histogram [height=300, width=600]
hist1 = hv.Histogram((occ, encode_settings.pos_bin_edges))

hist2 = hv.Histogram(np.histogram(a=linpos_flat_obj['linpos_flat'], 
                                  bins=encode_settings.pos_bin_edges,normed=True))

layout = hv.Layout(hist1 + hist2).cols(1)
layout

In [10]:
len(encode_settings.pos_bin_edges)

In [11]:
%%time
tet_ids = np.unique(spk_amps.index.get_level_values('elec_grp_id'))
observ_tet_list = []
grp = spk_amps.groupby('elec_grp_id')
for tet_ii, (tet_id, grp_spk) in enumerate(grp):
    tet_result = results[tet_ii]
    tet_result.set_index(grp_spk.index, inplace=True)
    observ_tet_list.append(tet_result)

observ = pd.concat(observ_tet_list)
observ_obj = SpikeObservation(observ.sort_index(level=['day', 'epoch', 
                                                               'timestamp', 'elec_grp_id']), )

observ_obj['elec_grp_id'] = observ_obj.index.get_level_values('elec_grp_id')
observ_obj.index = observ_obj.index.droplevel('elec_grp_id')

observ_obj['position'] = (linpos_flat_obj['linpos_flat'])

In [13]:
%%output backend='matplotlib' size=200
%%opts Points (s=400 marker='x')

sel_distrib = observ_obj.loc[:, pos_col_format(0,encode_settings.pos_num_bins):         
                                 pos_col_format(encode_settings.pos_num_bins-1,
                                                encode_settings.pos_num_bins)]
    
sel_pos = observ_obj.loc[:, 'position']

max = sel_distrib.max().max()
    
def plot_observ(ind):
    
        
    plot_list = []
    for ii in range(5):
        plot_list.append(hv.Curve(sel_distrib.iloc[ind+ii], extents=(0, 0, 100, max)))
        plot_list.append(hv.Points((sel_pos.iloc[ind+ii], [0.005])))
    return hv.Overlay(plot_list)
        
#Ind = Stream.define('stuff', ind=0)

dmap = hv.DynamicMap(plot_observ, kdims=['ind'])
dmap.redim.values(ind=list(range(0, len(observ_obj)-5, 5)))

In [14]:
%%time
# Run PP decoding algorithm
time_bin_size = 30

decoder = OfflinePPDecoder(lin_obj=linpos_flat_obj, observ_obj=observ_obj,
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           which_trans_mat='learned', time_bin_size=time_bin_size)

posteriors = decoder.run_decoder()

In [15]:
%%output backend='bokeh' size=400 holomap='scrubber'
%%opts RGB { +framewise} [height=100 width=250 colorbar=True]
%%opts Points {+framewise} [height=100 width=250] (marker='o' size=4 alpha=0.5)

dec_viz = DecodeVisualizer(posteriors, linpos=linpos_flat_obj, enc_settings=encode_settings)

dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY(), plt_range=100, slide=10)
