In [1]:
# This snippet of code properly adds the working source root path to python's path
# so you no longer have to install spykshrk through setuptools
import sys, os 
root_depth = 2
notebook_dir = globals()['_dh'][0]
root_path = os.path.abspath(os.path.join(notebook_dir, '../'*root_depth))
# Add to python's path
try:
    while True:
        sys.path.remove(root_path)
except ValueError:
    # no more root paths
    pass
sys.path.append(root_path)
# Alternatively set root path as current working directory
#os.chdir(root_path)

In [2]:
import numpy as np
import scipy as sp
import pandas as pd
import functools
import itertools
import holoviews as hv
from holoviews import dim, opts
import enum
import IPython
import warnings
import torch

from IPython.utils.text import columnize
from IPython.lib.pretty import pprint, pretty
from IPython.display import display_html

# Spykshrk modules for data analysis
from spykshrk.franklab.data_containers import FlatLinearPosition, SpikeFeatures, \
        EncodeSettings, pos_col_format, SpikeObservation
from spykshrk.franklab.generator.place.data_generator import UnitNormalGenerator, TetrodeUniformUnitNormalGenerator
from spykshrk.franklab.pp_decoder.util import normal_pdf_int_lookup, gaussian
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPEncoder, OfflinePPDecoder
from spykshrk.franklab.pp_decoder.wtrack_mapping import WtrackLinposDecomposer, WtrackLinposRecomposer
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer
from spykshrk.franklab.visualization import LinPosVisualizer, TetrodeVisualizer
from spykshrk.franklab.pp_decoder.util import apply_no_anim_boundary
from spykshrk.util import AttrDict



hv.extension('bokeh')
hv.extension('matplotlib')

hv.renderer('bokeh').theme = "dark_minimal"
import matplotlib.pyplot as plt
plt.style.use('dark_background')
plt.rcParams['figure.facecolor'] = 'black'

pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)



In [3]:
class InconsistentDataWarning(Warning):
    pass

In [4]:
# Encoding and decoding settings for both simulator and algorithm

@enum.unique
class Arm(enum.Enum):
    center = 1
    left = 2
    right = 3

@enum.unique
class Direction(enum.Enum):
    forward = 1
    reverse = 2
    
encode_settings = AttrDict({'sampling_rate': 1000,
                            'pos_bins': np.arange(0,800,1),
                            'pos_bin_edges': np.arange(0,800.1,1),
                            'pos_bin_delta': 1,
                            'pos_kernel': sp.stats.norm.pdf(np.arange(0,800,1), 400, 3),
                            'pos_kernel_std': 3, 
                            'mark_kernel_std': int(20), 
                            'pos_num_bins': 800,
                            'pos_col_names': [pos_col_format(ii, 800) for ii in range(800)],
                            'arm_coordinates': [[0, 100], [100, 250], [250,400], [400, 500], [500, 650], [650, 800]],
                            'wtrack_arm_coordinates': AttrDict(center=AttrDict(forward=AttrDict(x1=0, x2=100, len=100), 
                                                                               reverse=AttrDict(x1=400, x2=500, len=100)),
                                                               left=AttrDict(forward=AttrDict(x1=100, x2=250, len=150),
                                                                             reverse=AttrDict(x1=250, x2=400, len=150)),
                                                               right=AttrDict(forward=AttrDict(x1=500, x2=650, len=150),
                                                                              reverse=AttrDict(x1=650, x2=800, len=150))),
                            'vel': 3,
                            'spk_amp': 60})

decode_settings = AttrDict({'trans_smooth_std': 5,
                            'trans_uniform_gain': 0.001,
                            'time_bin_size': 10})
                            

In [5]:
def wtrack_simulate_arm_const_speed(arm_range, trial_time, trial_len, sampling_rate):
    pos_arm = np.arange(arm_range.x1, arm_range.x2, 
                        arm_range.len/sampling_rate/(trial_time*arm_range.len/(trial_len)))
    return pos_arm

def wtrack_simulate_trial_const_speed(arm1_range, arm1_enum, arm2_range, arm2_enum, trial_time, sampling_rate):
    pos_trial = np.array([])
    pos_trial_enum = np.array([])

    pos_arm = wtrack_simulate_arm_const_speed(arm1_range, trial_time, 
                                              arm1_range.len + arm2_range.len, sampling_rate)
    pos_trial = np.append(pos_trial, pos_arm)
    pos_trial_enum = np.append(pos_trial_enum, np.tile(arm1_enum, len(pos_arm)))

    pos_arm = wtrack_simulate_arm_const_speed(arm2_range, trial_time, 
                                              arm1_range.len + arm2_range.len, sampling_rate)
    pos_trial = np.append(pos_trial, pos_arm)
    pos_trial_enum = np.append(pos_trial_enum, np.tile(arm2_enum, len(pos_arm)))
    return pos_trial, pos_trial_enum

def wtrack_simulate_run_const_speed(arm1_dir, arm1_enum, arm2_dir, arm2_enum, trial_time, sampling_rate):
    pos_run = np.array([])
    pos_run_enum = np.array([])
    pos_run_dir_enum = np.array([])
    
    pos_trial, pos_trial_enum = wtrack_simulate_trial_const_speed(arm1_dir.forward, arm1_enum,
                                                                  arm2_dir.forward, arm2_enum, 
                                                                  trial_time, sampling_rate)
    pos_run = np.append(pos_run, pos_trial)
    pos_run_enum = np.append(pos_run_enum, pos_trial_enum)
    pos_run_dir_enum = np.append(pos_run_dir_enum, np.tile(Direction.forward, len(pos_trial)))

    pos_trial, pos_trial_enum = wtrack_simulate_trial_const_speed(arm2_dir.reverse, arm2_enum,
                                                                  arm1_dir.reverse, arm1_enum,
                                                                  trial_time, sampling_rate)
    pos_run = np.append(pos_run, pos_trial)
    pos_run_enum = np.append(pos_run_enum, pos_trial_enum)
    pos_run_dir_enum = np.append(pos_run_dir_enum, np.tile(Direction.reverse, len(pos_trial)))

    return pos_run, pos_run_enum, pos_run_dir_enum

def wtrack_simulate_series_const_speed(w_coor, trial_time, sampling_rate):
    pos_series = np.array([])
    pos_series_enum = np.array([])
    pos_series_dir_enum = np.array([])
    
    pos_run, pos_run_enum, pos_run_dir_enum = wtrack_simulate_run_const_speed(w_coor.center, Arm.center, 
                                                                              w_coor.left, Arm.left, 
                                                                              trial_time, sampling_rate)
    pos_series = np.append(pos_series, pos_run)
    pos_series_enum = np.append(pos_series_enum, pos_run_enum)
    pos_series_dir_enum = np.append(pos_series_dir_enum, pos_run_dir_enum)

    pos_run, pos_run_enum, pos_run_dir_enum = wtrack_simulate_run_const_speed(w_coor.center, Arm.center,
                                                                              w_coor.right, Arm.right,
                                                                              trial_time, sampling_rate)
    pos_series = np.append(pos_series, pos_run)
    pos_series_enum = np.append(pos_series_enum, pos_run_enum)
    pos_series_dir_enum = np.append(pos_series_dir_enum, pos_run_dir_enum)    
    
    return pos_series, pos_series_enum, pos_series_dir_enum
    

In [6]:
trial_time = 10

pos_epoch = np.array([])

pos_series, pos_series_enum, pos_series_dir_enum = wtrack_simulate_series_const_speed(encode_settings.wtrack_arm_coordinates, 
                                                                                      trial_time, encode_settings.sampling_rate)

pos_epoch = np.append(pos_epoch, pos_series)

pos_timestamp = np.arange(0, len(pos_epoch), 1)

pos_time = np.arange(0, pos_epoch.size/encode_settings.sampling_rate, 1/encode_settings.sampling_rate)

pos_vel = np.ones(len(pos_epoch)) * (pos_epoch[1] - pos_epoch[0]) / (pos_time[1] - pos_time[0])

# Duplicate
num_dup = 2
for ii in range(num_dup):
    pos_epoch = np.append(pos_epoch, pos_epoch)
    pos_timestamp = np.arange(0, len(pos_epoch), 1)
    pos_time = np.arange(0, pos_epoch.size/encode_settings.sampling_rate, 1/encode_settings.sampling_rate)
    pos_vel = np.append(pos_vel, pos_vel)
    pos_series_enum = np.append(pos_series_enum, pos_series_enum)
    pos_series_dir_enum = np.append(pos_series_dir_enum, pos_series_dir_enum)


In [7]:
linpos_flat = FlatLinearPosition.from_numpy_single_epoch(1, 1, pos_timestamp, pos_epoch, pos_vel, 
                                                         encode_settings.sampling_rate, 
                                                         encode_settings.wtrack_arm_coordinates)

linpos_flat['arm'] = pd.Series(index=linpos_flat.index, data=pos_series_enum, dtype='category')
linpos_flat['direction'] = pd.Series(index=linpos_flat.index, data=pos_series_dir_enum, dtype='category')

In [8]:
arm_colormap = dict(center='darkorange', left='pink', right='cyan')

direction_hatchmap = dict(forward='right_diagonal_line', reverse='left_diagonal_line')
mpl_direction_hatchmap = dict(forward='/', reverse='\\')

def wtrack_linear_plot_hook(plot, element, arm, direction):
    if hv.Store.current_backend == 'bokeh':
        plot.handles['glyph'].fill_color = arm_colormap[arm]
        plot.handles['glyph'].hatch_pattern = direction_hatchmap[direction]
        plot.handles['glyph'].line_color = None
        plot.handles['glyph'].hatch_color = 'grey'
    elif hv.Store.current_backend == 'matplotlib':
        element.opts(hatch=mpl_direction_hatchmap[direction], facecolor=arm_colormap[arm], clone=False)

def wtrack_linear_plot_init_hook(plot, element, arm, direction):
    if hv.Store.current_backend == 'bokeh':
        pass
    elif hv.Store.current_backend == 'matplotlib':
        plot_kwargs = plot.style.kwargs
        plot_kwargs['facecolor'] = color=arm_colormap[arm]
        plot.style = opts.Polygons(**plot_kwargs, hatch=mpl_direction_hatchmap[direction])

def pos_hook(plot, element):
    if hv.Store.current_backend == 'bokeh':
        pass
    elif hv.Store.current_backend == 'matplotlib':
        pass
    
def pos_init_hook(plot, element):
    if hv.Store.current_backend == 'bokeh':
        pass
    elif hv.Store.current_backend == 'matplotlib':
        pass
        
def wtrack_linear_plot_polygons(arm, direction, pos_time, w_coor):
    time_range = (pos_time[0], pos_time[-1])
    y_range = (w_coor[arm][direction].x1, w_coor[arm][direction].x2)
    time_total = time_range[1] - time_range[0]
    time_center = time_total/2 + time_range[0]
    y_total = y_range[1] - y_range[0]
    y_center = y_total/2 + y_range[0]
    
    box = hv.Box(time_center, y_center, (time_total, y_total))
    
    init_hooks = [functools.partial(wtrack_linear_plot_init_hook, arm=arm, direction=direction)]
    
    hooks = [functools.partial(wtrack_linear_plot_hook, arm=arm, direction=direction)]
    
    if hv.Store.current_backend == 'bokeh':
        poly = hv.Polygons(box).opts(hooks=hooks)
    elif hv.Store.current_backend == 'matplotlib':
        poly = hv.Polygons(box).opts(initial_hooks=init_hooks)
    return poly

def plot_position(time, pos, color='royalblue', fig_size=400, frame_width=800, aspect=3):
    if hv.Store.current_backend == 'bokeh':
        return hv.Scatter((time, pos)).opts(color=color, frame_width=frame_width, aspect=aspect)
    elif hv.Store.current_backend == 'matplotlib':
        return hv.Scatter((time, pos)).opts(color=color, fig_size=fig_size, aspect=aspect)
        


In [9]:
hv.output(backend='matplotlib')
plt.style.use('dark_background')
plt.rcParams['figure.facecolor'] = 'black'
#hv.output(backend='matplotlib', fig='pdf')

w_coor = encode_settings.wtrack_arm_coordinates
wtrack_linear_plot_polygons(arm='center', direction='forward', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons(arm='left', direction='forward', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons( arm='left', direction='reverse', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons(arm='center', direction='reverse', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons(arm='right', direction='forward', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons(arm='right', direction='reverse', pos_time=pos_time, w_coor=w_coor) * \
    plot_position(pos_time, pos_epoch)
#hv.Scatter((pos_time, pos_epoch)).opts(color='royalblue', fig_size=400, aspect=2)

In [10]:
wtrack_decomposed = WtrackLinposDecomposer(linpos_flat, encode_settings)

In [11]:
encode_settings_decomp = AttrDict(encode_settings)
encode_settings_decomp.pos_num_bins = wtrack_decomposed.armcoord_cw_num_bins
encode_settings_decomp.pos_col_names=pos_col_format(range(encode_settings_decomp.pos_num_bins),
                                                    encode_settings_decomp.pos_num_bins)
encode_settings_decomp.pos_bins = np.arange(0, encode_settings_decomp.pos_num_bins)
encode_settings_decomp.pos_bin_edges = np.arange(0, encode_settings_decomp.pos_num_bins + 0.0001, 1)
encode_settings_decomp.pos_kernel_std = 3
encode_settings_decomp.pos_kernel = sp.stats.norm.pdf(np.arange(0, encode_settings_decomp.pos_num_bins, 1), 
                                                      encode_settings_decomp.pos_num_bins/2, 
                                                      encode_settings_decomp.pos_kernel_std)
encode_settings_decomp.pos_col_names = pos_col_format(range(encode_settings_decomp.pos_num_bins), 
                                                      encode_settings_decomp.pos_num_bins) 
encode_settings_decomp.arm_coordinates = wtrack_decomposed.simple_armcoord_cw
encode_settings_decomp.wtrack_decomp_arm_coordinates = wtrack_decomposed.wtrack_armcoord_cw


In [12]:
all_decon_range = np.array([(order.x1, order.x2) for arm in wtrack_decomposed.armcoord_cw.values()
                            for direct in arm.values() for order in [direct.prev, direct.main, direct.next]])
decon_min = all_decon_range.min()
decon_max = all_decon_range.max()

In [13]:
hv.output(backend='bokeh', size=150)

hv.Points(wtrack_decomposed.decomp_linpos['linpos_cw'], kdims=[('samples', 'Samples'), 
                                                               ('linpos_cw', 'Position (cm)')]).opts(aspect=2, ylim=(decon_min, decon_max))

In [14]:
tet_gen = TetrodeUniformUnitNormalGenerator(sampling_rate=encode_settings.sampling_rate,
                                            num_marks=4,
                                            num_units=200,
                                            mark_mean_range=(60, 200),
                                            mark_cov_range=(20, 40),
                                            firing_rate_range=(5, 40),
                                            pos_field_range=wtrack_decomposed.simple_main_armcoord_cw,
                                            pos_field_bins=wtrack_decomposed.simple_main_armcoord_bins_cw,
                                            pos_field_var_range=(2,10))

In [15]:
spk_amps, unit_spks = tet_gen.simulate_tetrode_over_pos(wtrack_decomposed.decomp_linpos, col_name='linpos_cw')

In [16]:
tet_viz = TetrodeVisualizer(spk_amps, wtrack_decomposed.decomp_linpos, unit_spks)
tet_viz.plot_color_3d_dynamic('linpos_cw', 'c00', 'c01')

In [17]:
%%time
#%%prun -r -s cumulative

# Setup encoding model and estimate the position distribution of each spike being encoded
encoder_cw = OfflinePPEncoder(linflat=wtrack_decomposed.decomp_linpos, enc_spk_amp=spk_amps, dec_spk_amp=spk_amps, 
                              encode_settings=encode_settings_decomp, decode_settings=decode_settings,
                              linflat_col_name='linpos_cw', chunk_size=15000, cuda=True)
encoder_ccw = OfflinePPEncoder(linflat=wtrack_decomposed.decomp_linpos, enc_spk_amp=spk_amps, dec_spk_amp=spk_amps, 
                               encode_settings=encode_settings_decomp, decode_settings=decode_settings,
                               linflat_col_name='linpos_ccw', chunk_size=15000, cuda=True)
observ_cw = encoder_cw.run_encoder()
observ_ccw = encoder_ccw.run_encoder()


In [18]:
hv.output(backend='matplotlib')

# Setup plot to visualize estimated position distribution
sel_distrib = observ_cw.loc[:, pos_col_format(0,encode_settings_decomp.pos_num_bins):         
                            pos_col_format(encode_settings_decomp.pos_num_bins-1,
                                           encode_settings_decomp.pos_num_bins)]
    
sel_pos = observ_cw.loc[:, 'position']
max_prob = sel_distrib.max().max()
    
def plot_observ(ind):
    plot_list = []
    for ii in range(5):
        plot_list.append(hv.Curve(sel_distrib.iloc[ind+ii], extents=(0, 0, 2050, max_prob)))
        plot_list.append(hv.Points((sel_pos.iloc[ind+ii], [0.005])))
    return hv.Overlay(plot_list).opts(fig_size=400, aspect=3)
        
dmap = hv.DynamicMap(plot_observ, kdims=['ind'])
dmap.redim.values(ind=list(range(0, len(observ_cw)-5, 5)))

In [19]:
def decomposed_pos_remap_to_wtrack(pos, decomposed_armcoord, wtrack_armcoord):
    wtrack_pos = np.empty(pos.shape)
    wtrack_pos.fill(np.nan)
    for arm_k, arm_v in decomposed_armcoord.items():
        for direct_k, direct_v in arm_v.items():
            dir_pos_ind = (pos >= direct_v.main.x1) & (pos <= direct_v.main.x2)
            wtrack_dir_pos = pos[dir_pos_ind] - direct_v.main.x1 + wtrack_armcoord[arm_k][direct_k].x1
            wtrack_pos[dir_pos_ind] = wtrack_dir_pos
    
    if np.isnan(wtrack_pos).any():
        warnings.warn('Position in decomposed main coordinate system does not match wtrack coordinates.', InconsistentDataWarning)
        
    return wtrack_pos

In [20]:
wtrack_recomposed = WtrackLinposRecomposer(encoder_cw, encoder_ccw, wtrack_decomposed, encode_settings)

In [21]:
hv.output(backend='bokeh')
hv.Image(np.flip(wtrack_recomposed.trans_mat, axis=0), bounds=(0, 0, 800, 800)).opts(invert_yaxis=True)


In [22]:
hv.Image(np.flip(np.triu(wtrack_recomposed.trans_mat), axis=0), bounds=(0, 0, 800, 800)).opts(invert_yaxis=True)


In [23]:
hv.output(backend='matplotlib')

# Setup plot to visualize estimated position distribution
sel_distrib = wtrack_recomposed.observ.loc[:, pos_col_format(0,encode_settings.pos_num_bins):         
                                           pos_col_format(encode_settings.pos_num_bins-1,
                                                          encode_settings.pos_num_bins)]
    
sel_pos = wtrack_recomposed.observ.loc[:, 'position']
max_prob = sel_distrib.max().max()

def plot_observ(ind):
    plot_list = []
    for ii in range(5):
        plot_list.append(hv.Curve(sel_distrib.iloc[ind+ii], extents=(0, 0, 800, max_prob)))
        plot_list.append(hv.Points((sel_pos.iloc[ind+ii], [0.005])))
    return hv.Overlay(plot_list).opts(fig_size=400, aspect=3)
        
dmap = hv.DynamicMap(plot_observ, kdims=['ind'])
dmap.redim.values(ind=list(range(0, len(sel_distrib), 5)))

In [24]:
%%time
# Run PP decoding algorithm
time_bin_size = 10

decoder = OfflinePPDecoder(observ_obj=wtrack_recomposed.observ, trans_mat=wtrack_recomposed.trans_mat, 
                           prob_no_spike=wtrack_recomposed.prob_no_spike,
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           time_bin_size=time_bin_size)

posteriors = decoder.run_decoder()

In [25]:
#%%output backend='bokeh' size=300 holomap='scrubber'
#%%opts RGB { +framewise} [height=100 width=350]
#%%opts Points (marker='o' color='#AAAAFF' size=1 alpha=0.05)
#%%opts Polygons (color='grey', alpha=0.3 fill_color='grey' fill_alpha=0.3)
#%%opts Image {+framewise}
hv.output(backend='bokeh')

dec_viz = DecodeVisualizer(posteriors.fillna(0), linpos=linpos_flat, enc_settings=encode_settings, heatmap_max=0.15)

dmap = dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY())

dmap.opts(aspect=3, frame_width=600)

dmap