In [1]:
# This snippet of code properly adds the working source root path to python's path
# so you no longer have to install spykshrk through setuptools
import sys, os 
root_depth = 2
notebook_dir = globals()['_dh'][0]
root_path = os.path.abspath(os.path.join(notebook_dir, '../'*root_depth))
# Add to python's path
try:
    while True:
        sys.path.remove(root_path)
except ValueError:
    # no more root paths
    pass
sys.path.append(root_path)
# Alternatively set root path as current working directory
#os.chdir(root_path)

In [2]:
import numpy as np
import scipy as sp
import pandas as pd
import functools
import itertools
import holoviews as hv
from holoviews import dim, opts
import enum
import IPython
import warnings
import torch

from IPython.utils.text import columnize
from IPython.lib.pretty import pprint, pretty
from IPython.display import display_html

# Spykshrk modules for data analysis
from spykshrk.franklab.data_containers import FlatLinearPosition, SpikeFeatures, \
        EncodeSettings, pos_col_format, SpikeObservation
from spykshrk.franklab.generator.place.spike_generator import UnitNormalGenerator, TetrodeUniformUnitNormalGenerator
from spykshrk.franklab.generator.place.pos_generator import WtrackPosConstSimulator
from spykshrk.franklab.pp_decoder.util import normal_pdf_int_lookup, gaussian
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPEncoder, OfflinePPDecoder
from spykshrk.franklab.pp_decoder.wtrack_mapping import WtrackLinposDecomposer, WtrackLinposRecomposer
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer, WtrackLinposVisualizer
from spykshrk.franklab.visualization import LinPosVisualizer, TetrodeVisualizer
from spykshrk.franklab.pp_decoder.util import apply_no_anim_boundary
from spykshrk.util import AttrDict, AttrDictEnum
from spykshrk.franklab.wtrack import WtrackArm, Direction, Order, Rotation


hv.extension('bokeh')
hv.extension('matplotlib')

hv.renderer('bokeh').theme = "dark_minimal"
import matplotlib.pyplot as plt
plt.style.use('dark_background')
plt.rcParams['figure.facecolor'] = 'black'

pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)

plain_formatter = get_ipython().display_formatter.formatters['text/plain']
plain_formatter.max_width=160

np.set_printoptions(precision=4, linewidth=120)

In [3]:
class InconsistentDataWarning(Warning):
    pass

In [4]:
# Encoding and decoding settings for both simulator and algorithm

encode_settings = AttrDict({'sampling_rate': 1000,
                            'pos_bins': np.arange(0,800,1),
                            'pos_bin_edges': np.arange(0,800.1,1),
                            'pos_bin_delta': 1,
                            'pos_kernel': sp.stats.norm.pdf(np.arange(0,800,1), 400, 3),
                            'pos_kernel_std': 3, 
                            'mark_kernel_std': int(20), 
                            'pos_num_bins': 800,
                            'pos_col_names': [pos_col_format(ii, 800) for ii in range(800)],
                            'arm_coordinates': [[0, 100], [100, 250], [250,400], [400, 500], [500, 650], [650, 800]],
                            'wtrack_arm_coordinates': AttrDictEnum({WtrackArm.center: 
                                                                    AttrDictEnum({Direction.outbound: AttrDictEnum(x1=0, x2=100, len=100), 
                                                                                  Direction.inbound: AttrDictEnum(x1=400, x2=500, len=100)}),
                                                                    WtrackArm.left: 
                                                                    AttrDictEnum({Direction.outbound: AttrDictEnum(x1=100, x2=250, len=150),
                                                                                  Direction.inbound: AttrDictEnum(x1=250, x2=400, len=150)}),
                                                                    WtrackArm.right: 
                                                                    AttrDictEnum({Direction.outbound: AttrDictEnum(x1=500, x2=650, len=150),
                                                                                  Direction.inbound: AttrDictEnum(x1=650, x2=800, len=150)})}),
                            'vel': 3,
                            'spk_amp': 60})

decode_settings = AttrDict({'trans_smooth_std': 5,
                            'trans_uniform_gain': 0.001,
                            'time_bin_size': 10})
                            

In [5]:
wtrack_pos_const_sim = WtrackPosConstSimulator(trial_time=10, num_series=3, encode_settings=encode_settings)
linpos_flat = wtrack_pos_const_sim.linpos_flat

In [6]:
hv.output(backend='matplotlib')
plt.style.use('dark_background')
plt.rcParams['figure.facecolor'] = 'black'

wtrack_viz = WtrackLinposVisualizer(linpos_flat, encode_settings)
wtrack_viz.plot

In [7]:
wtrack_decomposed = WtrackLinposDecomposer(linpos_flat, encode_settings)

In [8]:
hv.output(backend='matplotlib')
hv.Scatter(wtrack_decomposed.decomp_linpos['linpos_cw']).opts(aspect=1, fig_size=400)

In [9]:
wtrack_decomposed.decomp_linpos

In [10]:
encode_settings_decomp = AttrDict(encode_settings)
encode_settings_decomp.pos_num_bins = wtrack_decomposed.armcoord_cw_num_bins
encode_settings_decomp.pos_col_names=pos_col_format(range(encode_settings_decomp.pos_num_bins),
                                                    encode_settings_decomp.pos_num_bins)
encode_settings_decomp.pos_bins = np.arange(0, encode_settings_decomp.pos_num_bins)
encode_settings_decomp.pos_bin_edges = np.arange(0, encode_settings_decomp.pos_num_bins + 0.0001, 1)
encode_settings_decomp.pos_kernel_std = 3
encode_settings_decomp.pos_kernel = sp.stats.norm.pdf(np.arange(0, encode_settings_decomp.pos_num_bins, 1), 
                                                      encode_settings_decomp.pos_num_bins/2, 
                                                      encode_settings_decomp.pos_kernel_std)
encode_settings_decomp.pos_col_names = pos_col_format(range(encode_settings_decomp.pos_num_bins), 
                                                      encode_settings_decomp.pos_num_bins) 
encode_settings_decomp.arm_coordinates = wtrack_decomposed.simple_armcoord_cw
encode_settings_decomp.wtrack_decomp_arm_coordinates = wtrack_decomposed.wtrack_armcoord_cw


In [11]:
tet_gen = TetrodeUniformUnitNormalGenerator(sampling_rate=encode_settings.sampling_rate,
                                            num_marks=4,
                                            num_units=200,
                                            mark_mean_range=(60, 200),
                                            mark_cov_range=(20, 40),
                                            firing_rate_range=(5, 40),
                                            pos_field_range=wtrack_decomposed.simple_main_armcoord_cw,
                                            pos_field_bins=wtrack_decomposed.simple_main_armcoord_bins_cw,
                                            pos_field_var_range=(2,10))

In [12]:
spk_amps, unit_spks = tet_gen.simulate_tetrode_over_pos(wtrack_decomposed.decomp_linpos, col_name='linpos_cw')

In [13]:
tet_viz = TetrodeVisualizer(spk_amps, wtrack_decomposed.decomp_linpos, unit_spks)
tet_viz.plot_color_3d_dynamic('linpos_cw', 'c00', 'c01')

In [14]:
%%time
#%%prun -r -s cumulative

# Setup encoding model and estimate the position distribution of each spike being encoded
encoder_cw = OfflinePPEncoder(linflat=wtrack_decomposed.decomp_linpos, enc_spk_amp=spk_amps, dec_spk_amp=spk_amps, 
                              encode_settings=encode_settings_decomp, decode_settings=decode_settings,
                              linflat_col_name='linpos_cw', chunk_size=15000, cuda=True)
encoder_ccw = OfflinePPEncoder(linflat=wtrack_decomposed.decomp_linpos, enc_spk_amp=spk_amps, dec_spk_amp=spk_amps, 
                               encode_settings=encode_settings_decomp, decode_settings=decode_settings,
                               linflat_col_name='linpos_ccw', chunk_size=15000, cuda=True)
observ_cw = encoder_cw.run_encoder()
observ_ccw = encoder_ccw.run_encoder()


In [15]:
hv.output(backend='matplotlib')

# Setup plot to visualize estimated position distribution
sel_distrib = observ_cw.loc[:, pos_col_format(0,encode_settings_decomp.pos_num_bins):         
                            pos_col_format(encode_settings_decomp.pos_num_bins-1,
                                           encode_settings_decomp.pos_num_bins)]
    
sel_pos = observ_cw.loc[:, 'position']
max_prob = sel_distrib.max().max()
    
def plot_observ(ind):
    plot_list = []
    for ii in range(5):
        plot_list.append(hv.Curve(sel_distrib.iloc[ind+ii], extents=(0, 0, 2050, max_prob)))
        plot_list.append(hv.Points((sel_pos.iloc[ind+ii], [0.005])))
    return hv.Overlay(plot_list).opts(fig_size=400, aspect=3)
        
dmap = hv.DynamicMap(plot_observ, kdims=['ind'])
dmap.redim.values(ind=list(range(0, len(observ_cw)-5, 5)))

In [None]:
def decomposed_pos_remap_to_wtrack(pos, decomposed_armcoord, wtrack_armcoord):
    wtrack_pos = np.empty(pos.shape)
    wtrack_pos.fill(np.nan)
    for arm_k, arm_v in decomposed_armcoord.items():
        for direct_k, direct_v in arm_v.items():
            dir_pos_ind = (pos >= direct_v.main.x1) & (pos <= direct_v.main.x2)
            wtrack_dir_pos = pos[dir_pos_ind] - direct_v.main.x1 + wtrack_armcoord[arm_k][direct_k].x1
            wtrack_pos[dir_pos_ind] = wtrack_dir_pos
    
    if np.isnan(wtrack_pos).any():
        warnings.warn('Position in decomposed main coordinate system does not match wtrack coordinates.', InconsistentDataWarning)
        
    return wtrack_pos

In [None]:
wtrack_recomposed = WtrackLinposRecomposer(encoder_cw, encoder_ccw, wtrack_decomposed, encode_settings)

In [None]:
hv.output(backend='bokeh')
hv.Image(np.flip(wtrack_recomposed.trans_mat, axis=0), bounds=(0, 0, 800, 800)).opts(invert_yaxis=True)


In [None]:
hv.Image(np.flip(np.triu(wtrack_recomposed.trans_mat), axis=0), bounds=(0, 0, 800, 800)).opts(invert_yaxis=True)


In [None]:
hv.output(backend='matplotlib')

# Setup plot to visualize estimated position distribution
sel_distrib = wtrack_recomposed.observ.loc[:, pos_col_format(0,encode_settings.pos_num_bins):         
                                           pos_col_format(encode_settings.pos_num_bins-1,
                                                          encode_settings.pos_num_bins)]
    
sel_pos = wtrack_recomposed.observ.loc[:, 'position']
max_prob = sel_distrib.max().max()

def plot_observ(ind):
    plot_list = []
    for ii in range(5):
        plot_list.append(hv.Curve(sel_distrib.iloc[ind+ii], extents=(0, 0, 800, max_prob)))
        plot_list.append(hv.Points((sel_pos.iloc[ind+ii], [0.005])))
    return hv.Overlay(plot_list).opts(fig_size=400, aspect=3)
        
dmap = hv.DynamicMap(plot_observ, kdims=['ind'])
dmap.redim.values(ind=list(range(0, len(sel_distrib), 5)))

In [None]:
%%time
# Run PP decoding algorithm
time_bin_size = 10

decoder = OfflinePPDecoder(observ_obj=wtrack_recomposed.observ, trans_mat=wtrack_recomposed.trans_mat, 
                           prob_no_spike=wtrack_recomposed.prob_no_spike,
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           time_bin_size=time_bin_size)

posteriors = decoder.run_decoder()

In [None]:
#%%output backend='bokeh' size=300 holomap='scrubber'
#%%opts RGB { +framewise} [height=100 width=350]
#%%opts Points (marker='o' color='#AAAAFF' size=1 alpha=0.05)
#%%opts Polygons (color='grey', alpha=0.3 fill_color='grey' fill_alpha=0.3)
#%%opts Image {+framewise}
hv.output(backend='bokeh')

dec_viz = DecodeVisualizer(posteriors.fillna(0), linpos=linpos_flat, enc_settings=encode_settings, heatmap_max=0.15)

dmap = dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY())

dmap.opts(aspect=3, frame_width=600)

dmap