In [1]:
# This snippet of code properly adds the working source root path to python's path
# so you no longer have to install spykshrk through setuptools
import sys, os 
root_depth = 2
notebook_dir = globals()['_dh'][0]
root_path = os.path.abspath(os.path.join(notebook_dir, '../'*root_depth))
# Add to python's path
try:
    while True:
        sys.path.remove(root_path)
except ValueError:
    # no more root paths
    pass
sys.path.append(root_path)
# Alternatively set root path as current working directory
#os.chdir(root_path)


In [2]:
import numpy as np
import scipy as sp
import pandas as pd
import functools
import itertools
import holoviews as hv
from holoviews import dim, opts
import enum
import IPython
import warnings
import torch
import collections

from IPython.utils.text import columnize
from IPython.lib.pretty import pprint, pretty
from IPython.display import display_html

# Spykshrk modules for data analysis
from spykshrk.franklab.data_containers import FlatLinearPosition, SpikeFeatures, \
        EncodeSettings, pos_col_format, SpikeObservation, Posteriors
from spykshrk.franklab.generator.place.spike_generator import UnitNormalGenerator, TetrodeUniformUnitNormalGenerator
from spykshrk.franklab.generator.place.pos_generator import WtrackPosConstSimulator
from spykshrk.franklab.pp_decoder.util import normal_pdf_int_lookup, gaussian
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPEncoder, OfflinePPLikelihood, OfflinePPPosterior, OfflinePPDecoder, OfflinePPIndicatorPosterior
from spykshrk.franklab.pp_decoder.wtrack_mapping import WtrackLinposDecomposer, WtrackLinposRecomposer
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer, WtrackLinposVisualizer, DecodeStepVisualizer, MultiDecodeStepVisualizer
from spykshrk.franklab.visualization import LinPosVisualizer, TetrodeVisualizer
from spykshrk.franklab.pp_decoder.util import apply_no_anim_boundary
from spykshrk.franklab.wtrack import WtrackArm, Direction, Order, Rotation
from spykshrk.util import AttrDict, AttrDictEnum

hv.extension('bokeh')
hv.extension('matplotlib')

hv.renderer('bokeh').theme = "dark_minimal"
import matplotlib.pyplot as plt
plt.style.use('dark_background')
plt.rcParams['figure.facecolor'] = 'black'
plt.rcParams['figure.constrained_layout.use'] = False
plt.rcParams['xtick.major.pad'] = 3.5
plt.rcParams['ytick.major.pad'] = 10
plt.rcParams['axes.formatter.limits'] = [-5,5]

pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)

plain_formatter = get_ipython().display_formatter.formatters['text/plain']
plain_formatter.max_width=160
plain_formatter.max_seq_length=20

np.set_printoptions(precision=4, linewidth=80, threshold=20, edgeitems=5)

In [3]:
class InconsistentDataWarning(Warning):
    pass

In [4]:
# Encoding and decoding settings for both simulator and algorithm

encode_settings = AttrDict({'sampling_rate': 1000,
                            'pos_sampling_rate': 30,
                            'pos_bins': np.arange(0,800,1),
                            'pos_bin_edges': np.arange(0,800.1,1),
                            'pos_bin_delta': 1,
                            'pos_kernel': sp.stats.norm.pdf(np.arange(0,800,1), 400, 2),
                            'pos_kernel_std': 2, 
                            'mark_kernel_std': int(5), 
                            'pos_num_bins': 800,
                            'pos_col_names': [pos_col_format(ii, 800) for ii in range(800)],
                            'arm_coordinates': [[0, 100], [100, 250], [250,400], [400, 500], [500, 650], [650, 800]],
                            'wtrack_arm_coordinates': AttrDictEnum({WtrackArm.center: 
                                                                    AttrDictEnum({Direction.outbound: AttrDictEnum(x1=0, x2=100, len=100), 
                                                                                  Direction.inbound: AttrDictEnum(x1=400, x2=500, len=100)}),
                                                                    WtrackArm.left: 
                                                                    AttrDictEnum({Direction.outbound: AttrDictEnum(x1=100, x2=250, len=150),
                                                                                  Direction.inbound: AttrDictEnum(x1=250, x2=400, len=150)}),
                                                                    WtrackArm.right: 
                                                                    AttrDictEnum({Direction.outbound: AttrDictEnum(x1=500, x2=650, len=150),
                                                                                  Direction.inbound: AttrDictEnum(x1=650, x2=800, len=150)})}),
                            'vel': 3,
                            'spk_amp': 60})

decode_settings = AttrDict({'trans_smooth_std': 5,
                            'trans_uniform_gain': 0.001,
                            'time_bin_size': 10})
                            

In [5]:
wtrack_pos_const_sim = WtrackPosConstSimulator(trial_time=10, num_series=3, encode_settings=encode_settings)
linpos_flat = wtrack_pos_const_sim.linpos_flat

In [6]:
hv.output(backend='matplotlib')
plt.style.use('dark_background')
plt.rcParams['figure.facecolor'] = 'black'

wtrack_viz = WtrackLinposVisualizer(linpos_flat, encode_settings)
wtrack_viz.plot

In [7]:
wtrack_decomposed = WtrackLinposDecomposer(linpos_flat, encode_settings)

In [8]:
wtrack_decomposed.decomp_linpos

In [9]:
tet_gen = TetrodeUniformUnitNormalGenerator(sampling_rate=encode_settings.sampling_rate,
                                            num_marks=4,
                                            num_units=200,
                                            mark_mean_range=(60, 600),
                                            mark_cov_range=(50, 120),
                                            firing_rate_range=(10, 60),
                                            pos_field_range=wtrack_decomposed.simple_main_armcoord_cw,
                                            pos_field_bins=wtrack_decomposed.simple_main_armcoord_bins_cw,
                                            pos_field_var_range=(2,10))

In [10]:
resample_decomp_linpos = wtrack_decomposed.decomp_linpos.get_resampled(1)

spk_amps, unit_spks = tet_gen.simulate_tetrode_over_pos(resample_decomp_linpos, col_name='linpos_cw')
spk_amps_prev, unit_spks_prev = tet_gen.simulate_tetrode_over_pos(resample_decomp_linpos, col_name='linpos_cw_prev')
spk_amps_next, unit_spks_next = tet_gen.simulate_tetrode_over_pos(resample_decomp_linpos, col_name='linpos_cw_next')

In [11]:
spk_amps = spk_amps.append([spk_amps_prev, spk_amps_next]).sort_index(level='timestamp')

In [12]:

tet_viz = TetrodeVisualizer(spk_amps, linpos_flat.get_irregular_resampled(spk_amps), unit_spks)
#tet_viz.plot_color_3d_dynamic('linpos_cw', 'c00', 'c01')
tet_viz.plot_color_3d_dynamic('linpos_flat', 'c00', 'c01')

In [13]:
spks_linpos = linpos_flat.get_irregular_resampled(spk_amps).join(spk_amps)
spk_linpos_outbound = spks_linpos.query('direction==@Direction.outbound', engine='python')
spk_linpos_inbound = spks_linpos.query('direction==@Direction.inbound', engine='python')

In [14]:
class WtrackPPInOutEncoder:
    
    def __init__(self, wtrack_decomposed, enc_spk_amps, dec_spk_amps, encode_settings, decode_settings, chunk_size, cuda, norm):
        self.wtrack_decomposed = wtrack_decomposed
        self.enc_spk_amps = enc_spk_amps
        self.dec_spk_amps = dec_spk_amps
        self.encode_settings = encode_settings
        self.decode_settings = decode_settings
        self.chunk_size = chunk_size
        self.cuda = cuda
        self.norm = norm
        
        self.recomposed_encoder_in = None
        self.recomposed_encoder_out = None
        
        self.linpos_outbound = self.wtrack_decomposed.decomp_linpos.query('direction==@Direction.outbound', engine='python')
        self.linpos_inbound = self.wtrack_decomposed.decomp_linpos.query('direction==@Direction.inbound', engine='python')
        self.spks_linpos = self.wtrack_decomposed.decomp_linpos.get_irregular_resampled(self.enc_spk_amps).join(self.enc_spk_amps)
        self.spks_linpos_outbound = self.spks_linpos.query('direction==@Direction.outbound', engine='python')
        self.spks_linpos_outbound = self.spks_linpos_outbound.drop(columns=self.wtrack_decomposed.decomp_linpos.columns)
        self.spks_linpos_inbound = self.spks_linpos.query('direction==@Direction.inbound', engine='python')
        self.spks_linpos_inbound = self.spks_linpos_inbound.drop(columns=self.wtrack_decomposed.decomp_linpos.columns)
        
    def run_encoder(self):
        self.recomposed_encoder_out = WtrackPPInOutEncoder.create_recomposed_encoder(decomp_linpos = self.linpos_outbound, 
                                                                                     enc_spk_amp = self.spks_linpos_outbound,
                                                                                     dec_spk_amp = self.dec_spk_amps,
                                                                                     wtrack_decomposed = self.wtrack_decomposed,
                                                                                     encode_settings = self.encode_settings,
                                                                                     decode_settings = self.decode_settings,
                                                                                     chunk_size = self.chunk_size,
                                                                                     cuda = self.cuda,
                                                                                     norm = self.norm)
        
        self.recomposed_encoder_in = WtrackPPInOutEncoder.create_recomposed_encoder(decomp_linpos = self.linpos_inbound, 
                                                                                    enc_spk_amp = self.spks_linpos_inbound,
                                                                                    dec_spk_amp = self.dec_spk_amps,
                                                                                    wtrack_decomposed = self.wtrack_decomposed,
                                                                                    encode_settings = self.encode_settings,
                                                                                    decode_settings = self.decode_settings,
                                                                                    chunk_size = self.chunk_size,
                                                                                    cuda = self.cuda,
                                                                                    norm = self.norm)
        
    @staticmethod
    def create_recomposed_encoder(decomp_linpos, enc_spk_amp, dec_spk_amp, wtrack_decomposed, encode_settings, decode_settings, chunk_size=10000, cuda=True, norm=True):
        encoder_cw = OfflinePPEncoder(linpos=decomp_linpos, enc_spk_amp=enc_spk_amp, dec_spk_amp=dec_spk_amp, 
                                      encode_settings=wtrack_decomposed.encode_settings_decomp, decode_settings=decode_settings,
                                      linpos_col_name='linpos_cw', chunk_size=chunk_size, cuda=cuda, norm=norm)
        encoder_ccw = OfflinePPEncoder(linpos=decomp_linpos, enc_spk_amp=enc_spk_amp, dec_spk_amp=dec_spk_amp, 
                                       encode_settings=wtrack_decomposed.encode_settings_decomp, decode_settings=decode_settings,
                                       linpos_col_name  ='linpos_ccw', chunk_size=chunk_size, cuda=cuda, norm=norm)
        observ_cw = encoder_cw.run_encoder()
        observ_ccw = encoder_ccw.run_encoder()
        
        recomposed = WtrackLinposRecomposer(encoder_cw, encoder_ccw, wtrack_decomposed, encode_settings)
        
        return recomposed
        

In [15]:
%%time

multi_encoder = WtrackPPInOutEncoder(wtrack_decomposed, spk_amps, spk_amps, encode_settings, decode_settings, 15000, cuda=True, norm=False)


In [16]:
%%time

multi_encoder.run_encoder()

In [17]:
hv.output(backend='bokeh', size=90)

#from matplotlib.ticker import ScalarFormatter
#xfmt = ScalarFormatter()
#xfmt.set_powerlimits((0,0))
#xfmt.set_scientific(True)

if hv.Store.current_backend == 'bokeh':
    import bokeh.models.formatters
    yfmt = bokeh.models.formatters.BasicTickFormatter(precision=1, power_limit_high=0, power_limit_low=0)
elif hv.Store.current_backend == 'matplotlib':
    yfmt = None

observ_ex = multi_encoder.recomposed_encoder_in.encoder_cw.observ_obj
encode_settings_decomp = multi_encoder.wtrack_decomposed.encode_settings_decomp

# Setup plot to visualize estimated position distribution
sel_distrib = observ_ex.loc[:, pos_col_format(0,encode_settings_decomp.pos_num_bins):         
                            pos_col_format(encode_settings_decomp.pos_num_bins-1,
                                           encode_settings_decomp.pos_num_bins)]
    
sel_pos = observ_ex.loc[:, 'position']
max_prob = sel_distrib.max().max()

num_plot = 20

def plot_observ(ind):
    plot_list = []
    max_dist = sel_distrib.iloc[ind:ind+num_plot].max().max()
    for ii in range(num_plot):
        sel = sel_distrib.iloc[ind+ii]
        plot_list.append(hv.Curve(sel).options(framewise=True, xlim=(0, None), ylim=(0, max_dist), yformatter=yfmt))
        #plot_list.append(hv.Curve(sel, extents=(0,0,encode_settings_decomp.pos_num_bins,max_dist)).opts(ylim=(0, max_dist)))
        plot_list.append(hv.Points((sel_pos.iloc[ind+ii], [max_dist/10])))
    return hv.Overlay(plot_list).opts(aspect=3)
        
dmap = hv.DynamicMap(plot_observ, kdims=['ind'])
dmap.redim.values(ind=list(range(0, len(observ_ex)-num_plot, num_plot)))

In [18]:
class ForwardReverseDecoder:
    
    def __init__(self, observ, trans_mat, prob_no_spike, encode_settings, decode_settings, time_bin_size):
        self.observ = observ
        self.trans_mat = trans_mat
        self.prob_no_spike = prob_no_spike
        self.encode_settings = encode_settings
        self.decode_settings = decode_settings
        self.time_bin_size = time_bin_size
        
        self.posterior_forward = None
        self.posterior_reverse = None
        
        self.trans_forward = np.triu(self.trans_mat)
        self.trans_forward /= self.trans_forward.sum()
        self.trans_reverse = np.tril(self.trans_mat)
        self.trans_reverse /= self.trans_reverse.sum()
        
        self.decoder_forward = OfflinePPDecoder(observ_obj=self.observ, trans_mat=self.trans_forward, 
                                                prob_no_spike=self.prob_no_spike, encode_settings=self.encode_settings,
                                                decode_settings=self.decode_settings, time_bin_size=self.time_bin_size)
        self.decoder_reverse = OfflinePPDecoder(observ_obj=self.observ, trans_mat=self.trans_reverse, 
                                                prob_no_spike=self.prob_no_spike, encode_settings=self.encode_settings,
                                                decode_settings=self.decode_settings, time_bin_size=self.time_bin_size)
    
    def run_decoder(self):
        self.posterior_forward = self.decoder_forward.run_decoder()
        self.posterior_reverse = self.decoder_reverse.run_decoder()


In [19]:
time_bin_size = 10

decoder = ForwardReverseDecoder(observ=multi_encoder.recomposed_encoder_in.observ, trans_mat=multi_encoder.recomposed_encoder_in.trans_mat,
                                prob_no_spike=multi_encoder.recomposed_encoder_in.prob_no_spike, encode_settings=multi_encoder.encode_settings,
                                decode_settings=decode_settings, time_bin_size=time_bin_size)
decoder.run_decoder()

In [20]:
class WtrackPPDecoder:
    def __init__(self, in_out_encoder, chunk_size=None, cuda=None):
        self.in_out_encoder = in_out_encoder
        self.wtrack_decomposed = self.in_out_encoder.wtrack_decomposed
        self.encode_settings = self.in_out_encoder.encode_settings
        self.decode_settings = self.in_out_encoder.decode_settings
        if chunk_size:
            self.chunk_size = chunk_size
        else:
            self.chunk_size = self.in_out_encoder.chunk_size
        if cuda:
            self.cuda = cuda
        else:
            self.cuda = self.in_out_encoder.cuda
        
        self.encoder_in = self.in_out_encoder.recomposed_encoder_in
        self.encoder_out = self.in_out_encoder.recomposed_encoder_out
        
        self.trans_in_forward = np.triu(self.encoder_in.trans_mat)
        self.trans_in_forward /= self.trans_in_forward.sum()
        self.trans_in_reverse = np.tril(self.encoder_in.trans_mat)
        self.trans_in_reverse /= self.trans_in_reverse.sum()
        self.trans_out_forward = np.triu(self.encoder_out.trans_mat)
        self.trans_out_forward /= self.trans_out_forward.sum()
        self.trans_out_reverse = np.tril(self.encoder_out.trans_mat)
        self.trans_out_reverse /= self.trans_out_reverse.sum()
        
        self.like_cls_in = OfflinePPLikelihood(observ=self.encoder_in.observ, trans_mat=self.encoder_in.trans_mat, prob_no_spike=self.encoder_in.prob_no_spike,
                                               encode_settings=encode_settings, decode_settings=decode_settings, 
                                               time_bin_size=time_bin_size, dtype=np.float32)
        self.like_in = None
        self.like_cls_out = OfflinePPLikelihood(observ=self.encoder_out.observ, trans_mat=self.encoder_out.trans_mat, prob_no_spike=self.encoder_out.prob_no_spike,
                                                encode_settings=encode_settings, decode_settings=decode_settings, 
                                                time_bin_size=time_bin_size, dtype=np.float32)
        self.like_out = None
        
        self.indicator_states = AttrDict({})
        
    def calc(self):
        self._calc_like()
        self._prepare_post()
        self._calc_post()
        
    def _calc_like(self):
        self.like_in = self.like_cls_in.calc()
        self.like_out = self.like_cls_out.calc()
        
    def _prepare_post(self):
        self.indicator_states[0] = AttrDict({'name': 'inbound, forward', 'observ': self.encoder_in.observ, 'linpos': self.encoder_in.linpos, 
                                             'likelihoods': self.like_in, 'trans_mat': self.trans_in_forward,
                                             'prob_no_spike': self.encoder_in.prob_no_spike})
        self.indicator_states[1] = AttrDict({'name': 'inbound, reverse', 'observ': self.encoder_in.observ, 'linpos': self.encoder_in.linpos, 
                                             'likelihoods': self.like_in, 'trans_mat': self.trans_in_reverse, 
                                             'prob_no_spike': self.encoder_in.prob_no_spike})
        self.indicator_states[2] = AttrDict({'name': 'outbound, forward', 'observ': self.encoder_out.observ, 'linpos': self.encoder_out.linpos, 
                                             'likelihoods': self.like_out, 'trans_mat': self.trans_out_forward,
                                             'prob_no_spike': self.encoder_out.prob_no_spike})
        self.indicator_states[3] = AttrDict({'name': 'outbound, reverse', 'observ': self.encoder_out.observ, 'linpos': self.encoder_out.linpos, 
                                             'likelihoods': self.like_out, 'trans_mat': self.trans_out_reverse, 
                                             'prob_no_spike': self.encoder_out.prob_no_spike})
        
    def _calc_post(self):
        self.post_indicator_cls = OfflinePPIndicatorPosterior(indicator_states=self.indicator_states, encode_settings=self.encode_settings, 
                                                              decode_settings=self.decode_settings, cuda=False, dtype=np.float32)
        self.post_indicator_cls.calc()

In [21]:
wtrack_decoder = WtrackPPDecoder(multi_encoder)
wtrack_decoder._calc_like()

In [22]:
wtrack_decoder._prepare_post()
wtrack_decoder._calc_post()

In [23]:
wtrack_decoder.post_indicator_cls.indicator_posts[3].get_posterior().get_distribution_view().sum(axis=1)

In [24]:
post_plot = wtrack_decoder.post_indicator_cls.indicator_posts[3].get_posterior()

hv.output(backend='bokeh', size=200)

dec_viz = DecodeVisualizer(post_plot, linpos=linpos_flat, enc_settings=encode_settings, heatmap_max=0.15)

dmap = dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY())

dmap.opts(aspect=3, frame_width=600)

dmap

In [25]:
hv.output(backend='bokeh', size=16)
multi_dec_viz = MultiDecodeStepVisualizer(indicator_states=wtrack_decoder.indicator_states, encode_settings=wtrack_decoder.encode_settings, decode_settings=wtrack_decoder.decode_settings)
multi_dec_viz.plot_all()

In [29]:
hv.output(backend='bokeh', size=30)

decoder_in = wtrack_decoder.indicator_states[0]
step_viz = DecodeStepVisualizer(decoder_in.observ, decoder_in.likelihoods, decoder_in.posteriors, multi_encoder.linpos_inbound, encode_settings, decode_settings)

step_viz.plot_all()