In [1]:
# This snippet of code properly adds the working source root path to python's path
# so you no longer have to install spykshrk through setuptools
import sys, os 
root_depth = 2
notebook_dir = globals()['_dh'][0]
root_path = os.path.abspath(os.path.join(notebook_dir, '../'*root_depth))
# Add to python's path
try:
    while True:
        sys.path.remove(root_path)
except ValueError:
    # no more root paths
    pass
sys.path.append(root_path)
# Alternatively set root path as current working directory
#os.chdir(root_path)

In [2]:
import numpy as np
import scipy as sp
import pandas as pd
import functools
import itertools
import holoviews as hv
from holoviews import dim, opts
import enum
import IPython
import warnings
import torch

from IPython.utils.text import columnize
from IPython.lib.pretty import pprint, pretty
from IPython.display import display_html

# Spykshrk modules for data analysis
from spykshrk.franklab.data_containers import FlatLinearPosition, SpikeFeatures, \
        EncodeSettings, pos_col_format, SpikeObservation
from spykshrk.franklab.generator.place.data_generator import UnitNormalGenerator, TetrodeUniformUnitNormalGenerator
from spykshrk.franklab.pp_decoder.util import normal_pdf_int_lookup, gaussian
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPEncoder, OfflinePPDecoder
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer
from spykshrk.franklab.visualization import LinPosVisualizer, TetrodeVisualizer
from spykshrk.franklab.pp_decoder.util import apply_no_anim_boundary
from spykshrk.util import AttrDict



hv.extension('bokeh')
hv.extension('matplotlib')

hv.renderer('bokeh').theme = "dark_minimal"
import matplotlib.pyplot as plt
plt.style.use('dark_background')
plt.rcParams['figure.facecolor'] = 'black'

pd.set_option('display.precision', 4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 15)



In [3]:
class InconsistentDataWarning(Warning):
    pass

In [4]:
# Encoding and decoding settings for both simulator and algorithm

@enum.unique
class Arm(enum.Enum):
    center = 1
    left = 2
    right = 3

@enum.unique
class Direction(enum.Enum):
    forward = 1
    reverse = 2
    
encode_settings = AttrDict({'sampling_rate': 1000,
                            'pos_bins': np.arange(0,800,1),
                            'pos_bin_edges': np.arange(0,800.1,1),
                            'pos_bin_delta': 1,
                            'pos_kernel': sp.stats.norm.pdf(np.arange(0,800,1), 400, 3),
                            'pos_kernel_std': 3, 
                            'mark_kernel_std': int(20), 
                            'pos_num_bins': 800,
                            'pos_col_names': [pos_col_format(ii, 800) for ii in range(800)],
                            'arm_coordinates': [[0, 100], [100, 250], [250,400], [400, 500], [500, 650], [650, 800]],
                            'wtrack_arm_coordinates': AttrDict(center=AttrDict(forward=AttrDict(x1=0, x2=100, len=100), 
                                                                               reverse=AttrDict(x1=400, x2=500, len=100)),
                                                               left=AttrDict(forward=AttrDict(x1=100, x2=250, len=150),
                                                                             reverse=AttrDict(x1=250, x2=400, len=150)),
                                                               right=AttrDict(forward=AttrDict(x1=500, x2=650, len=150),
                                                                              reverse=AttrDict(x1=650, x2=800, len=150))),
                            'vel': 3,
                            'spk_amp': 60})

decode_settings = AttrDict({'trans_smooth_std': 5,
                            'trans_uniform_gain': 0.001,
                            'time_bin_size': 10})
                            

In [5]:
def wtrack_simulate_arm_const_speed(arm_range, trial_time, trial_len, sampling_rate):
    pos_arm = np.arange(arm_range.x1, arm_range.x2, 
                        arm_range.len/sampling_rate/(trial_time*arm_range.len/(trial_len)))
    return pos_arm

def wtrack_simulate_trial_const_speed(arm1_range, arm1_enum, arm2_range, arm2_enum, trial_time, sampling_rate):
    pos_trial = np.array([])
    pos_trial_enum = np.array([])

    pos_arm = wtrack_simulate_arm_const_speed(arm1_range, trial_time, 
                                              arm1_range.len + arm2_range.len, sampling_rate)
    pos_trial = np.append(pos_trial, pos_arm)
    pos_trial_enum = np.append(pos_trial_enum, np.tile(arm1_enum, len(pos_arm)))

    pos_arm = wtrack_simulate_arm_const_speed(arm2_range, trial_time, 
                                              arm1_range.len + arm2_range.len, sampling_rate)
    pos_trial = np.append(pos_trial, pos_arm)
    pos_trial_enum = np.append(pos_trial_enum, np.tile(arm2_enum, len(pos_arm)))
    return pos_trial, pos_trial_enum

def wtrack_simulate_run_const_speed(arm1_dir, arm1_enum, arm2_dir, arm2_enum, trial_time, sampling_rate):
    pos_run = np.array([])
    pos_run_enum = np.array([])
    pos_run_dir_enum = np.array([])
    
    pos_trial, pos_trial_enum = wtrack_simulate_trial_const_speed(arm1_dir.forward, arm1_enum,
                                                                  arm2_dir.forward, arm2_enum, 
                                                                  trial_time, sampling_rate)
    pos_run = np.append(pos_run, pos_trial)
    pos_run_enum = np.append(pos_run_enum, pos_trial_enum)
    pos_run_dir_enum = np.append(pos_run_dir_enum, np.tile(Direction.forward, len(pos_trial)))

    pos_trial, pos_trial_enum = wtrack_simulate_trial_const_speed(arm2_dir.reverse, arm2_enum,
                                                                  arm1_dir.reverse, arm1_enum,
                                                                  trial_time, sampling_rate)
    pos_run = np.append(pos_run, pos_trial)
    pos_run_enum = np.append(pos_run_enum, pos_trial_enum)
    pos_run_dir_enum = np.append(pos_run_dir_enum, np.tile(Direction.reverse, len(pos_trial)))

    return pos_run, pos_run_enum, pos_run_dir_enum

def wtrack_simulate_series_const_speed(w_coor, trial_time, sampling_rate):
    pos_series = np.array([])
    pos_series_enum = np.array([])
    pos_series_dir_enum = np.array([])
    
    pos_run, pos_run_enum, pos_run_dir_enum = wtrack_simulate_run_const_speed(w_coor.center, Arm.center, 
                                                                              w_coor.left, Arm.left, 
                                                                              trial_time, sampling_rate)
    pos_series = np.append(pos_series, pos_run)
    pos_series_enum = np.append(pos_series_enum, pos_run_enum)
    pos_series_dir_enum = np.append(pos_series_dir_enum, pos_run_dir_enum)

    pos_run, pos_run_enum, pos_run_dir_enum = wtrack_simulate_run_const_speed(w_coor.center, Arm.center,
                                                                              w_coor.right, Arm.right,
                                                                              trial_time, sampling_rate)
    pos_series = np.append(pos_series, pos_run)
    pos_series_enum = np.append(pos_series_enum, pos_run_enum)
    pos_series_dir_enum = np.append(pos_series_dir_enum, pos_run_dir_enum)    
    
    return pos_series, pos_series_enum, pos_series_dir_enum
    

In [6]:
trial_time = 10

pos_epoch = np.array([])

pos_series, pos_series_enum, pos_series_dir_enum = wtrack_simulate_series_const_speed(encode_settings.wtrack_arm_coordinates, 
                                                                                      trial_time, encode_settings.sampling_rate)

pos_epoch = np.append(pos_epoch, pos_series)

pos_timestamp = np.arange(0, len(pos_epoch), 1)

pos_time = np.arange(0, pos_epoch.size/encode_settings.sampling_rate, 1/encode_settings.sampling_rate)

pos_vel = np.ones(len(pos_epoch)) * (pos_epoch[1] - pos_epoch[0]) / (pos_time[1] - pos_time[0])

# Duplicate
num_dup = 2
for ii in range(num_dup):
    pos_epoch = np.append(pos_epoch, pos_epoch)
    pos_timestamp = np.arange(0, len(pos_epoch), 1)
    pos_time = np.arange(0, pos_epoch.size/encode_settings.sampling_rate, 1/encode_settings.sampling_rate)
    pos_vel = np.append(pos_vel, pos_vel)
    pos_series_enum = np.append(pos_series_enum, pos_series_enum)
    pos_series_dir_enum = np.append(pos_series_dir_enum, pos_series_dir_enum)


In [7]:
linpos_flat = FlatLinearPosition.from_numpy_single_epoch(1, 1, pos_timestamp, pos_epoch, pos_vel, 
                                                         encode_settings.sampling_rate, 
                                                         encode_settings.wtrack_arm_coordinates)

linpos_flat['arm'] = pd.Series(index=linpos_flat.index, data=pos_series_enum, dtype='category')
linpos_flat['direction'] = pd.Series(index=linpos_flat.index, data=pos_series_dir_enum, dtype='category')

In [8]:
arm_colormap = dict(center='darkorange', left='pink', right='cyan')

direction_hatchmap = dict(forward='right_diagonal_line', reverse='left_diagonal_line')
mpl_direction_hatchmap = dict(forward='/', reverse='\\')

def wtrack_linear_plot_hook(plot, element, arm, direction):
    if hv.Store.current_backend == 'bokeh':
        plot.handles['glyph'].fill_color = arm_colormap[arm]
        plot.handles['glyph'].hatch_pattern = direction_hatchmap[direction]
        plot.handles['glyph'].line_color = None
        plot.handles['glyph'].hatch_color = 'grey'
    elif hv.Store.current_backend == 'matplotlib':
        element.opts(hatch=mpl_direction_hatchmap[direction], facecolor=arm_colormap[arm], clone=False)

def wtrack_linear_plot_init_hook(plot, element, arm, direction):
    if hv.Store.current_backend == 'bokeh':
        pass
    elif hv.Store.current_backend == 'matplotlib':
        plot_kwargs = plot.style.kwargs
        plot_kwargs['facecolor'] = color=arm_colormap[arm]
        plot.style = opts.Polygons(**plot_kwargs, hatch=mpl_direction_hatchmap[direction])

def pos_hook(plot, element):
    if hv.Store.current_backend == 'bokeh':
        pass
    elif hv.Store.current_backend == 'matplotlib':
        pass
    
def pos_init_hook(plot, element):
    if hv.Store.current_backend == 'bokeh':
        pass
    elif hv.Store.current_backend == 'matplotlib':
        pass
        
def wtrack_linear_plot_polygons(arm, direction, pos_time, w_coor):
    time_range = (pos_time[0], pos_time[-1])
    y_range = (w_coor[arm][direction].x1, w_coor[arm][direction].x2)
    time_total = time_range[1] - time_range[0]
    time_center = time_total/2 + time_range[0]
    y_total = y_range[1] - y_range[0]
    y_center = y_total/2 + y_range[0]
    
    box = hv.Box(time_center, y_center, (time_total, y_total))
    
    init_hooks = [functools.partial(wtrack_linear_plot_init_hook, arm=arm, direction=direction)]
    
    hooks = [functools.partial(wtrack_linear_plot_hook, arm=arm, direction=direction)]
    
    if hv.Store.current_backend == 'bokeh':
        poly = hv.Polygons(box).opts(hooks=hooks)
    elif hv.Store.current_backend == 'matplotlib':
        poly = hv.Polygons(box).opts(initial_hooks=init_hooks)
    return poly

def plot_position(time, pos, color='royalblue', fig_size=400, frame_width=800, aspect=3):
    if hv.Store.current_backend == 'bokeh':
        return hv.Scatter((time, pos)).opts(color=color, frame_width=frame_width, aspect=aspect)
    elif hv.Store.current_backend == 'matplotlib':
        return hv.Scatter((time, pos)).opts(color=color, fig_size=fig_size, aspect=aspect)
        


In [9]:
hv.output(backend='matplotlib')
plt.style.use('dark_background')
plt.rcParams['figure.facecolor'] = 'black'
#hv.output(backend='matplotlib', fig='pdf')

w_coor = encode_settings.wtrack_arm_coordinates
wtrack_linear_plot_polygons(arm='center', direction='forward', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons(arm='left', direction='forward', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons( arm='left', direction='reverse', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons(arm='center', direction='reverse', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons(arm='right', direction='forward', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons(arm='right', direction='reverse', pos_time=pos_time, w_coor=w_coor) * \
    plot_position(pos_time, pos_epoch)
#hv.Scatter((pos_time, pos_epoch)).opts(color='royalblue', fig_size=400, aspect=2)

In [10]:
class WtrackTransRecomposer(AttrDict):
    def __init__(self, decomposed_trans_mat):
        self.decomposed_trans_mat = decomposed_trans_mat
        

In [11]:
class WtrackLinposDecomposed(AttrDict):
    def __init__(self, encode_settings, bin_size=1):
        super().__init__()

        self.encode_settings = encode_settings
        self.wtrack_armcoord = self.encode_settings.wtrack_arm_coordinates
        
        self.bin_size = bin_size
        self.segment_order_cw = [('center', 'forward'),
                                 ('left', 'forward'), 
                                 ('left', 'reverse'),
                                 ('right', 'forward'),
                                 ('right', 'reverse'),
                                 ('center', 'reverse')]
        self.segment_order_ccw = [('center', 'forward'),
                                  ('right', 'forward'), 
                                  ('right', 'reverse'),
                                  ('left', 'forward'),
                                  ('left', 'reverse'),
                                  ('center', 'reverse')]
        self.armcoord_cw = self._segment_decomposed_with_buffer(self.segment_order_cw, self.wtrack_armcoord)
        self.armcoord_ccw = self._segment_decomposed_with_buffer(self.segment_order_ccw, self.wtrack_armcoord)
        self.armcoord_cw_num_bins = max([max((buffer.x1, buffer.x2)) for arm in self.armcoord_cw.values() 
                                         for direct in arm.values() for buffer in [direct.prev, direct.main, direct.next]])
        self.armcoord_ccw_num_bins = max([max((buffer.x1, buffer.x2)) for arm in self.armcoord_ccw.values() 
                                          for direct in arm.values() for buffer in [direct.prev, direct.main, direct.next]])
        self.wtrack_armcoord_cw = AttrDict({arm: AttrDict({direct: AttrDict(x1=min(direct_v.prev.x1, direct_v.main.x1, direct_v.next.x1),
                                                                            x2=max(direct_v.prev.x2, direct_v.main.x2, direct_v.next.x2))
                                                           for direct, direct_v in arm_v.items()})
                                            for arm, arm_v in self.armcoord_cw.items()})
        self.wtrack_armcoord_ccw = AttrDict({arm: AttrDict({direct: AttrDict(x1=min(direct_v.prev.x1, direct_v.main.x1, direct_v.next.x1),
                                                                             x2=max(direct_v.prev.x2, direct_v.main.x2, direct_v.next.x2))
                                                            for direct, direct_v in arm_v.items()})
                                             for arm, arm_v in self.armcoord_ccw.items()})
        self.cw_bin_range = [min([direct.x1 for arm in self.wtrack_armcoord_cw.values() for direct in arm.values()]),
                             max([direct.x2 for arm in self.wtrack_armcoord_cw.values() for direct in arm.values()])]
        self.cw_num_bins = self.cw_bin_range[1] - self.cw_bin_range[0]
        self.ccw_bin_range = [min([direct.x1 for arm in self.wtrack_armcoord_ccw.values() for direct in arm.values()]),
                              max([direct.x2 for arm in self.wtrack_armcoord_ccw.values() for direct in arm.values()])]
        self.ccw_num_bins = self.ccw_bin_range[1] - self.ccw_bin_range[0]
        self.wtrack_armcoord_main_cw = AttrDict({arm: AttrDict({direct: direct_v.main for direct, direct_v in arm_v.items()})
                                                 for arm, arm_v in self.armcoord_cw.items()})
        self.wtrack_armcoord_main_ccw = AttrDict({arm: AttrDict({direct: direct_v.main for direct, direct_v in arm_v.items()})
                                                  for arm, arm_v in self.armcoord_ccw.items()})
        self.simple_armcoord_cw, self.simple_armcoord_bins_cw = \
                self._create_wtrack_decomposed_simple_armcoord(self.armcoord_cw, self.bin_size)
        self.simple_armcoord_ccw, self.simple_armcoord_bins_ccw = \
                self._create_wtrack_decomposed_simple_armcoord(self.armcoord_ccw, self.bin_size)
        self.simple_main_armcoord_cw, self.simple_main_armcoord_bins_cw = \
                self._create_wtrack_decomposed_simple_main_armcoord(self.armcoord_cw, self.bin_size)
        self.simple_main_armcoord_ccw, self.simple_main_armcoord_bins_ccw = \
                self._create_wtrack_decomposed_simple_main_armcoord(self.armcoord_ccw, self.bin_size)
        
        rotations = ['cw', 'ccw']
        orders = ['prev', 'next']
        self.sel_data = AttrDict()
        for rot_k in rotations:
            self.sel_data[rot_k] = AttrDict(main=self._sel_main(eval('self.armcoord_'+rot_k), self.wtrack_armcoord, 
                                                                eval('self.'+rot_k+'_num_bins'), self.encode_settings))
            for ord_k in orders:
                self.sel_data[rot_k][ord_k] = self._sel_prev_next(ord_k, eval('self.armcoord_'+rot_k), self.wtrack_armcoord,
                                                                  eval('self.'+rot_k+'_num_bins'), self.encode_settings)
        
    @staticmethod
    def _segment_decomposed_with_buffer(segment_order, wtrack_arm_coord):
        seg_decomposed = AttrDict()
        seg_offset = 0
        for ii, seg in enumerate(segment_order):
            prev_seg = segment_order[ii-1]
            next_seg = segment_order[(ii+1)%len(segment_order)]

            prev_seg_len = wtrack_arm_coord[prev_seg[0]][prev_seg[1]].len
            main_seg_len = wtrack_arm_coord[seg[0]][seg[1]].len
            next_seg_len = wtrack_arm_coord[next_seg[0]][next_seg[1]].len

            prev_seg_start = seg_offset
            main_seg_start = prev_seg_start + prev_seg_len
            next_seg_start = main_seg_start + main_seg_len
            next_seg_end = next_seg_start + next_seg_len
            seg_total_len = prev_seg_len + main_seg_len + next_seg_len
            seg_offset += seg_total_len

            arm_dict = seg_decomposed.setdefault(seg[0], AttrDict())
            dir_dict = arm_dict.setdefault(seg[1], AttrDict(prev=AttrDict(x1=prev_seg_start, x2=main_seg_start, len=prev_seg_len), 
                                                            main=AttrDict(x1=main_seg_start, x2=next_seg_start, len=main_seg_len),
                                                            next=AttrDict(x1=next_seg_start, x2=next_seg_end, len=next_seg_len),
                                                            prev_seg=prev_seg,
                                                            next_seg=next_seg))
            
        return seg_decomposed
     
    @staticmethod
    def _create_wtrack_decomposed_simple_armcoord(decomp_armcoord, bin_size=1):
        simple_armcoord = [(min(direct.prev.x1, direct.main.x1, direct.next.x1),
                            max(direct.prev.x2, direct.main.x2, direct.next.x2))
                           for arm in decomp_armcoord.values() for direct in arm.values()]
        simple_armcoord.sort(key=lambda tup: tup[0])
        simple_armcoord_bins = np.array([])
        for seg in simple_armcoord:
            simple_armcoord_bins = np.append(simple_armcoord_bins, np.arange(seg[0], seg[1], bin_size))
        return simple_armcoord, simple_armcoord_bins

    @staticmethod
    def _create_wtrack_decomposed_simple_main_armcoord(decomp_armcoord, bin_size=1):
        simple_armcoord = [(direct.main.x1, direct.main.x2)
                           for arm in decomp_armcoord.values() for direct in arm.values()]
        simple_armcoord.sort(key=lambda tup: tup[0])
        simple_armcoord_bins = np.array([])
        for seg in simple_armcoord:
            simple_armcoord_bins = np.append(simple_armcoord_bins, np.arange(seg[0], seg[1], bin_size))
        return simple_armcoord, simple_armcoord_bins

    @staticmethod
    def _sel_main(decomp_armcoord, wtrack_armcoord, decomp_num_bins, encode_settings):
        main_sel = AttrDict(decomposed=AttrDict(), wtrack=AttrDict())
        main_sel.decomposed['ind'] = np.concatenate([np.arange(direct_v.main.x1, direct_v.main.x2, encode_settings.pos_bin_delta) 
                                                     for arm_v in decomp_armcoord.values() for direct_v in arm_v.values()])
        main_sel.decomposed['col'] = pos_col_format(main_sel.decomposed['ind'], decomp_num_bins)
        main_sel.wtrack['ind'] =  np.concatenate([np.arange(wtrack_armcoord[arm][direct].x1, 
                                                            wtrack_armcoord[arm][direct].x2, 
                                                             encode_settings.pos_bin_delta) 
                                                  for arm, arm_v in decomp_armcoord.items() 
                                                  for direct, direct_v in arm_v.items()])
        main_sel.wtrack['col'] = pos_col_format(main_sel.wtrack['ind'], decomp_num_bins)
        
        return main_sel

    @staticmethod
    def _sel_prev_next(prev_next, decomp_armcoord, wtrack_armcoord, decomp_num_bins, encode_settings):
        sel = AttrDict(decomposed=AttrDict(), wtrack=AttrDict())
        sel.decomposed['ind'] = np.concatenate([np.arange(direct_v[prev_next].x1, direct_v[prev_next].x2, encode_settings.pos_bin_delta) 
                                                for arm_v in decomp_armcoord.values() for direct_v in arm_v.values()])
        sel.decomposed['col'] = pos_col_format(sel.decomposed['ind'], decomp_num_bins)
        sel.wtrack['ind'] =  np.concatenate([np.arange(wtrack_armcoord[direct_v[prev_next+'_seg'][0]][direct_v[prev_next+'_seg'][1]].x1, 
                                                       wtrack_armcoord[direct_v[prev_next+'_seg'][0]][direct_v[prev_next+'_seg'][1]].x2, 
                                                       encode_settings.pos_bin_delta) 
                                             for arm, arm_v in decomp_armcoord.items() 
                                             for direct, direct_v in arm_v.items()])
        sel.wtrack['col'] = pos_col_format(sel.wtrack['ind'], decomp_num_bins)
        
        return sel

In [12]:
wtrack_decompose = WtrackLinposDecomposed(encode_settings)

In [13]:
encode_settings_decomp = AttrDict(encode_settings)
encode_settings_decomp.pos_num_bins = wtrack_decompose.armcoord_cw_num_bins
encode_settings_decomp.pos_col_names=pos_col_format(range(encode_settings_decomp.pos_num_bins),
                                                    encode_settings_decomp.pos_num_bins)
encode_settings_decomp.pos_bins = np.arange(0, encode_settings_decomp.pos_num_bins)
encode_settings_decomp.pos_bin_edges = np.arange(0, encode_settings_decomp.pos_num_bins + 0.0001, 1)
encode_settings_decomp.pos_kernel_std = 3
encode_settings_decomp.pos_kernel = sp.stats.norm.pdf(np.arange(0, encode_settings_decomp.pos_num_bins, 1), 
                                                      encode_settings_decomp.pos_num_bins/2, 
                                                      encode_settings_decomp.pos_kernel_std)
encode_settings_decomp.pos_col_names = pos_col_format(range(encode_settings_decomp.pos_num_bins), 
                                                      encode_settings_decomp.pos_num_bins) 
encode_settings_decomp.arm_coordinates = wtrack_decompose.simple_armcoord_cw
encode_settings_decomp.wtrack_decomp_arm_coordinates = wtrack_decompose.wtrack_armcoord_cw


In [14]:
def wtrack_remap_to_decomposed(linpos_flat, wtrack_arm_coord, wtrack_decomposed_cw, wtrack_decomposed_ccw):
    linpos_arm_dir = linpos_flat.groupby(['arm', 'direction'])

    decomposed_linpos_cw = pd.DataFrame()
    decomposed_linpos_ccw = pd.DataFrame()
    
    decomposed_linpos = pd.DataFrame()
    
    for entry in linpos_arm_dir:
        key = entry[0]
        table = entry[1]
        arm_coord_range = wtrack_arm_coord[key[0].name][key[1].name]
        decomposed_range_cw = wtrack_decomposed_cw[key[0].name][key[1].name]
        decomposed_range_ccw = wtrack_decomposed_ccw[key[0].name][key[1].name]
        decomposed_table = table.copy()
        decomposed_table.loc[:,'linpos_cw'] = decomposed_table.loc[:, 'linpos_flat'] - arm_coord_range.x1 + decomposed_range_cw.main.x1
        decomposed_table.loc[:,'linpos_ccw'] = decomposed_table.loc[:, 'linpos_flat'] - arm_coord_range.x1 + decomposed_range_ccw.main.x1
        
        decomposed_linpos = decomposed_linpos.append(decomposed_table)
    
    decomposed_linpos.sort_index(inplace=True)
    return decomposed_linpos

In [15]:
decomposed_linpos = wtrack_remap_to_decomposed(linpos_flat, encode_settings.wtrack_arm_coordinates,
                                               wtrack_decompose.armcoord_cw,
                                               wtrack_decompose.armcoord_ccw)

In [16]:
all_decon_range = np.array([(order.x1, order.x2) for arm in wtrack_decompose.armcoord_cw.values()
                            for direct in arm.values() for order in [direct.prev, direct.main, direct.next]])
decon_min = all_decon_range.min()
decon_max = all_decon_range.max()

In [17]:
hv.output(backend='bokeh', size=150)

hv.Points(decomposed_linpos['linpos_cw'], kdims=[('samples', 'Samples'), 
                                                 ('linpos_cw', 'Position (cm)')]).opts(aspect=2, ylim=(decon_min, decon_max))

In [18]:
pprint(wtrack_decompose.simple_main_armcoord_ccw,newline=' ',max_seq_length=100000)

In [19]:
tet_gen = TetrodeUniformUnitNormalGenerator(sampling_rate=encode_settings.sampling_rate,
                                            num_marks=4,
                                            num_units=200,
                                            mark_mean_range=(60, 200),
                                            mark_cov_range=(20, 40),
                                            firing_rate_range=(5, 40),
                                            pos_field_range=wtrack_decompose.simple_main_armcoord_cw,
                                            pos_field_bins=wtrack_decompose.simple_main_armcoord_bins_cw,
                                            pos_field_var_range=(2,10))

In [20]:
spk_amps, unit_spks = tet_gen.simulate_tetrode_over_pos(decomposed_linpos, col_name='linpos_cw')

In [21]:
tet_viz = TetrodeVisualizer(spk_amps, decomposed_linpos, unit_spks)
tet_viz.plot_color_3d_dynamic('linpos_cw', 'c00', 'c01')

In [22]:
spk_amps

In [54]:
torch.cuda.get_device_properties(0)

In [52]:
torch.cuda.

In [56]:

print(torch.cuda.memory_allocated()/(2**30), torch.cuda.memory_cached()/(2**30),
      torch.cuda.max_memory_allocated()/(2**30), torch.cuda.max_memory_cached()/(2**30))
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated()/(2**30), torch.cuda.memory_cached()/(2**30),
      torch.cuda.max_memory_allocated()/(2**30), torch.cuda.max_memory_cached()/(2**30))


In [27]:
%%prun -r -s cumulative
#%%time

# Setup encoding model and estimate the position distribution of each spike being encoded
encoder_cw = OfflinePPEncoder(linflat=decomposed_linpos, enc_spk_amp=spk_amps, dec_spk_amp=spk_amps, 
                              encode_settings=encode_settings_decomp, decode_settings=decode_settings,
                              linflat_col_name='linpos_cw', chunk_size=20000, cuda=True)
encoder_ccw = OfflinePPEncoder(linflat=decomposed_linpos, enc_spk_amp=spk_amps, dec_spk_amp=spk_amps, 
                               encode_settings=encode_settings_decomp, decode_settings=decode_settings,
                               linflat_col_name='linpos_ccw', chunk_size=20000, cuda=True)
observ_cw = encoder_cw.run_encoder()
observ_ccw = encoder_ccw.run_encoder()


In [167]:
hv.output(backend='matplotlib')

# Setup plot to visualize estimated position distribution
sel_distrib = observ_cw.loc[:, pos_col_format(0,encode_settings_decomp.pos_num_bins):         
                            pos_col_format(encode_settings_decomp.pos_num_bins-1,
                                           encode_settings_decomp.pos_num_bins)]
    
sel_pos = observ_cw.loc[:, 'position']
max_prob = sel_distrib.max().max()
    
def plot_observ(ind):
    plot_list = []
    for ii in range(5):
        plot_list.append(hv.Curve(sel_distrib.iloc[ind+ii], extents=(0, 0, 2050, max_prob)))
        plot_list.append(hv.Points((sel_pos.iloc[ind+ii], [0.005])))
    return hv.Overlay(plot_list).opts(fig_size=400, aspect=3)
        
dmap = hv.DynamicMap(plot_observ, kdims=['ind'])
dmap.redim.values(ind=list(range(0, len(observ_cw)-5, 5)))

In [25]:
def decomposed_remap_to_wtrack(pos, decomposed_armcoord, wtrack_armcoord):
    wtrack_pos = np.empty(pos.shape)
    wtrack_pos.fill(np.nan)
    for arm_k, arm_v in decomposed_armcoord.items():
        for direct_k, direct_v in arm_v.items():
            dir_pos_ind = (pos >= direct_v.main.x1) & (pos <= direct_v.main.x2)
            wtrack_dir_pos = pos[dir_pos_ind] - direct_v.main.x1 + wtrack_armcoord[arm_k][direct_k].x1
            wtrack_pos[dir_pos_ind] = wtrack_dir_pos
    
    if np.isnan(wtrack_pos).any():
        warnings.warn('Position in decomposed main coordinate system does not match wtrack coordinates.', InconsistentDataWarning)
        
    return wtrack_pos

In [152]:
class WtrackRecomposer(AttrDict):
    rotations = ['cw', 'ccw']
    orders = ['prev', 'next']
    def __init__(self, encoder_cw, encoder_ccw, wtrack_decomposed, encode_settings):

        self.encoder_cw = encoder_cw
        self.encoder_ccw = encoder_ccw
        self.wtrack_decomposed = wtrack_decomposed
        self.encode_settings = encode_settings
        self.observ = self._wtrack_recompose_observ(encoder_cw.observ_obj, encoder_ccw.observ_obj,
                                                    self.wtrack_decomposed, self.encode_settings)
        self.prob_no_spike = self._wtrack_recompose_prob_no_spike(encoder_cw.prob_no_spike, encoder_ccw.prob_no_spike,
                                                                  self.wtrack_decomposed, self.encode_settings)
        self.trans_mat = self._wtrack_recompose_trans_mat(encoder_cw.trans_mat['learned'], encoder_ccw.trans_mat['learned'],
                                                          self.wtrack_decomposed, self.encode_settings)
    
    @staticmethod
    def _wtrack_recompose_prob_no_spike(prob_no_spike_cw, prob_no_spike_ccw, 
                                        wtrack_decomposed, encoding_settings):
        prob_no_spike = {}
        for tet_id, prob_no_spike_tet in prob_no_spike_cw.items():
            prob_no_spike[tet_id] = np.zeros(encode_settings.pos_num_bins)
            prob_no_spike[tet_id][wtrack_decompose.sel_data['cw']['main']['wtrack']['ind']] = \
                     prob_no_spike_tet[wtrack_decompose.sel_data['cw']['main']['decomposed']['ind']]
        return prob_no_spike
    
    @staticmethod
    def _wtrack_recompose_trans_mat(trans_mat_cw, trans_mat_ccw,
                                    wtrack_decomposed, encode_settings):
        trans_mat = np.zeros([encode_settings.pos_num_bins]*2)
        #trans_mat += WtrackRecomposer._wtrack_recompose_trans_mat_part('cw', 'main', 'main', trans_mat_cw, 
        #                                                               wtrack_decomposed, encode_settings)
        for rot_k in WtrackRecomposer.rotations:
            trans_mat += WtrackRecomposer._wtrack_recompose_trans_mat_part(rot_k, 'main', 'main', eval('trans_mat_'+rot_k), 
                                                                           wtrack_decomposed, encode_settings)
            for ord_k in WtrackRecomposer.orders:
                for order1, order2 in itertools.permutations([ord_k, 'main']):
                    trans_mat += WtrackRecomposer._wtrack_recompose_trans_mat_part(rot_k, order1, order2,
                                                                                   eval('trans_mat_'+rot_k), 
                                                                                   wtrack_decomposed, encode_settings)
        return trans_mat

    @staticmethod
    def _wtrack_recompose_trans_mat_part(rotation, order1, order2, decomp_trans_mat,
                                         wtrack_decomposed, encode_settings):
        #trans_main_ax_mask = np.zeros(trans_mat_cw.shape[0])
        #trans_main_ax_mask[wtrack_decompose.main_col_sel_ind] = 1
        #trans_mat = trans_mat_cw[trans_main_ax_mask[:, np.newaxis] & trans_main_ax_mask].reshape((encode_settings.pos_num_bins,
        #                                                                                          encode_settings.pos_num_bins))
        trans_mat = np.zeros([encode_settings.pos_num_bins]*2)
        trans_main_sel = np.meshgrid(wtrack_decomposed.sel_data[rotation][order1]['decomposed']['ind'], 
                                     wtrack_decomposed.sel_data[rotation][order2]['decomposed']['ind'])
        trans_wtrack_sel = np.meshgrid(wtrack_decomposed.sel_data[rotation][order2]['wtrack']['ind'], 
                                       wtrack_decomposed.sel_data[rotation][order1]['wtrack']['ind'])
        trans_mat[np.ix_(wtrack_decomposed.sel_data[rotation][order2]['wtrack']['ind'],
                        wtrack_decomposed.sel_data[rotation][order1]['wtrack']['ind'])] = decomp_trans_mat[trans_main_sel[1], trans_main_sel[0]]
        
        return trans_mat
    
    @staticmethod
    def _wtrack_recompose_observ(observ_cw, observ_ccw, wtrack_decompose, encode_settings):

        observ = pd.DataFrame(np.zeros((observ_cw.shape[0], 
                                        len(wtrack_decompose.sel_data['cw']['main']['wtrack']['ind']))),
                              columns=encode_settings.pos_col_names, index=observ_cw.index)

        observ.iloc[:, wtrack_decompose.sel_data['cw']['main']['wtrack']['ind']] = \
                observ_cw.loc[:, wtrack_decompose.sel_data['cw']['main']['decomposed']['col']].values

        for rot_k in WtrackRecomposer.rotations:
            for ord_k in WtrackRecomposer.orders:
                observ.iloc[:, wtrack_decompose.sel_data[rot_k][ord_k]['wtrack']['ind']] += \
                        eval('observ_'+rot_k).loc[:, wtrack_decompose.sel_data[rot_k][ord_k]['decomposed']['col']].values

        observ = observ.join(observ_cw.get_other_view())
        observ = SpikeObservation.create_default(observ, enc_settings=wtrack_decompose.encode_settings)

        observ['position'] = decomposed_remap_to_wtrack(observ['position'], 
                                                        wtrack_decompose.armcoord_cw, 
                                                        wtrack_decompose.wtrack_armcoord)

        return observ


In [153]:
wtrack_recomposer = WtrackRecomposer(encoder_cw, encoder_ccw, wtrack_decompose, encode_settings)

In [172]:
hv.output(backend='bokeh')
hv.Image(np.flip(wtrack_recomposer.trans_mat, axis=0), bounds=(0, 0, 800, 800)).opts(invert_yaxis=True)


In [175]:
hv.Image(np.flip(np.triu(wtrack_recomposer.trans_mat), axis=0), bounds=(0, 0, 800, 800)).opts(invert_yaxis=True)


In [44]:
hv.output(backend='matplotlib')

# Setup plot to visualize estimated position distribution
sel_distrib = wtrack_recomposer.observ.loc[:, pos_col_format(0,encode_settings.pos_num_bins):         
                                           pos_col_format(encode_settings.pos_num_bins-1,
                                                          encode_settings.pos_num_bins)]
    
sel_pos = wtrack_recomposer.observ.loc[:, 'position']
max_prob = sel_distrib.max().max()

def plot_observ(ind):
    plot_list = []
    for ii in range(5):
        plot_list.append(hv.Curve(sel_distrib.iloc[ind+ii], extents=(0, 0, 800, max_prob)))
        plot_list.append(hv.Points((sel_pos.iloc[ind+ii], [0.005])))
    return hv.Overlay(plot_list).opts(fig_size=400, aspect=3)
        
dmap = hv.DynamicMap(plot_observ, kdims=['ind'])
dmap.redim.values(ind=list(range(0, len(sel_distrib), 5)))

In [156]:
%%time
# Run PP decoding algorithm
time_bin_size = 10

decoder = OfflinePPDecoder(observ_obj=observ, trans_mat=wtrack_recomposer.trans_mat, 
                           prob_no_spike=wtrack_recomposer.prob_no_spike,
                           encode_settings=encode_settings, decode_settings=decode_settings, 
                           time_bin_size=time_bin_size)

posteriors = decoder.run_decoder()

In [166]:
#%%output backend='bokeh' size=300 holomap='scrubber'
#%%opts RGB { +framewise} [height=100 width=350]
#%%opts Points (marker='o' color='#AAAAFF' size=1 alpha=0.05)
#%%opts Polygons (color='grey', alpha=0.3 fill_color='grey' fill_alpha=0.3)
#%%opts Image {+framewise}
hv.output(backend='bokeh')

dec_viz = DecodeVisualizer(posteriors.fillna(0), linpos=linpos_flat, enc_settings=encode_settings, heatmap_max=0.15)

dmap = dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY())

dmap.opts(aspect=3, frame_width=600)

dmap