In [11]:
# This snippet of code properly adds the working source root path to python's path
# so you no longer have to install spykshrk through setuptools
import sys, os
root_depth = 3
notebook_dir = globals()['_dh'][0]
root_path = os.path.abspath(os.path.join(notebook_dir, '../'*root_depth))
# Add to python's path
try:
    while True:
        sys.path.remove(root_path)
except ValueError:
    # no more root paths
    pass
sys.path.append(root_path)
# Alternatively set root path as current working directory
#os.chdir(root_path)

In [281]:
import numpy as np
import scipy as sp
import pandas as pd
import functools
import holoviews as hv
from holoviews import dim, opts

# Spykshrk modules for data analysis
from spykshrk.franklab.data_containers import FlatLinearPosition, SpikeFeatures, \
        EncodeSettings, pos_col_format, SpikeObservation
from spykshrk.franklab.generator.place.data_generator import UnitNormalGenerator, TetrodeUniformUnitNormalGenerator
from spykshrk.franklab.pp_decoder.util import normal_pdf_int_lookup, gaussian
from spykshrk.franklab.pp_decoder.pp_clusterless import OfflinePPEncoder, OfflinePPDecoder
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer
from spykshrk.franklab.visualization import LinPosVisualizer, TetrodeVisualizer
from spykshrk.franklab.pp_decoder.util import apply_no_anim_boundary
from spykshrk.util import AttrDict

hv.extension('bokeh')
hv.extension('matplotlib')

hv.renderer('bokeh').theme = "dark_minimal"
import matplotlib.pyplot as plt
plt.style.use('dark_background')
plt.rcParams['figure.facecolor'] = 'black'

In [338]:
# Encoding and decoding settings for both simulator and algorithm

encode_settings = AttrDict({'sampling_rate': 1000,
                            'pos_bins': np.arange(0,800,1),
                            'pos_bin_edges': np.arange(0,800.1,1),
                            'pos_bin_delta': 1,
                            'pos_kernel': sp.stats.norm.pdf(np.arange(0,800,1), 400, 1),
                            'pos_kernel_std': 1, 
                            'mark_kernel_std': int(20), 
                            'pos_num_bins': 800,
                            'pos_col_names': [pos_col_format(ii, 800) for ii in range(800)],
                            'arm_coordinates': [[0, 100], [100, 250], [250,400], [400, 500], [500, 650], [650, 800]],
                            'wtrack_arm_coordinates': AttrDict(center=AttrDict(forward=AttrDict(x1=0, x2=100, len=100), 
                                                                               reverse=AttrDict(x1=400, x2=500, len=100)),
                                                               left=AttrDict(forward=AttrDict(x1=100, x2=250, len=150),
                                                                             reverse=AttrDict(x1=250, x2=400, len=150)),
                                                               right=AttrDict(forward=AttrDict(x1=500, x2=650, len=150),
                                                                              reverse=AttrDict(x1=650, x2=800, len=150))),
                            'vel': 3,
                            'spk_amp': 60})

decode_settings = AttrDict({'trans_smooth_std': 5,
                            'trans_uniform_gain': 0.001,
                            'time_bin_size': 10})
                            

In [339]:
def wtrack_simulate_segment_const_speed(arm_range, trial_time, trial_len, sampling_rate):
    pos_seg = np.arange(arm_range.x1, arm_range.x2, 
                        arm_range.len/sampling_rate/(trial_time*arm_range.len/(trial_len)))
    return pos_seg

def wtrack_simulate_trial_const_speed(arm1_range, arm2_range, trial_time, sampling_rate):
    pos_trial = np.array([])

    pos_seg = wtrack_simulate_segment_const_speed(arm1_range, trial_time, 
                                                  arm1_range.len + arm2_range.len, sampling_rate)
    pos_trial = np.append(pos_trial, pos_seg)

    pos_seg = wtrack_simulate_segment_const_speed(arm2_range, trial_time, 
                                                  arm1_range.len + arm2_range.len, sampling_rate)
    pos_trial = np.append(pos_trial, pos_seg)
    
    return pos_trial

def wtrack_simulate_run_const_speed(arm1_dir, arm2_dir, trial_time, sampling_rate):
    pos_run = np.array([])

    pos_trial = wtrack_simulate_trial_const_speed(arm1_dir.forward, arm2_dir.forward, trial_time, sampling_rate)
    pos_run = np.append(pos_run, pos_trial)

    pos_trial = wtrack_simulate_trial_const_speed(arm2_dir.reverse, arm1_dir.reverse, trial_time, sampling_rate)
    pos_run = np.append(pos_run, pos_trial)

    return pos_run

def wtrack_simulate_series_const_speed(w_coor, trial_time, sampling_rate):
    pos_series = np.array([])
    
    pos_run = wtrack_simulate_run_const_speed(w_coor.center, w_coor.left, trial_time, sampling_rate)
    pos_series = np.append(pos_series, pos_run)

    pos_run = wtrack_simulate_run_const_speed(w_coor.center, w_coor.right, trial_time, sampling_rate)
    pos_series = np.append(pos_series, pos_run)
    
    return pos_series
    

In [346]:
trial_time = 10

pos_epoch = np.array([])

pos_series = wtrack_simulate_series_const_speed(encode_settings.wtrack_arm_coordinates, trial_time, encode_settings.sampling_rate)

pos_epoch = np.append(pos_epoch, pos_series)

pos_time = np.arange(0, pos_epoch.size/encode_settings.sampling_rate, 1/encode_settings.sampling_rate)


In [348]:
pos_time.size

40004

In [455]:
arm_colormap = dict(center='darkorange', left='pink', right='cyan')

direction_hatchmap = dict(forward='right_diagonal_line', reverse='left_diagonal_line')
mpl_direction_hatchmap = dict(forward='/', reverse='\\')

def wtrack_linear_plot_hook(plot, element, arm, direction):
    if hv.Store.current_backend == 'bokeh':
        plot.handles['glyph'].fill_color = arm_colormap[arm]
        plot.handles['glyph'].hatch_pattern = direction_hatchmap[direction]
        plot.handles['glyph'].line_color = None
        plot.handles['glyph'].hatch_color = 'grey'
        plot.handles['plot'].width = 600
    elif hv.Store.current_backend == 'matplotlib':
        plot.handles['fig'].set_size_inches(10, 10)
        plot.handles['fig'].get_axes()[0].set_aspect(0.5)
        #plot.handles['fig'].gca(aspect=2)
        #plot.handles['artist'].set_color(None)
        #plot.handles['artist'].set_hatch(mpl_direction_hatchmap[direction])
        element.opts(hatch=mpl_direction_hatchmap[direction], facecolor=arm_colormap[arm], clone=False)
        #plot.handles['artist'].set_color(arm_colormap[arm])

def wtrack_linear_plot_init_hook(plot, element, arm, direction):
    if hv.Store.current_backend == 'bokeh':
        pass
    elif hv.Store.current_backend == 'matplotlib':
        plot_kwargs = plot.style.kwargs
        plot_kwargs['facecolor'] = color=arm_colormap[arm]
        plot.style = opts.Polygons(**plot_kwargs, hatch=mpl_direction_hatchmap[direction])

def pos_hook(plot, element):
    #plot.handles['plot'].frame_width = 600
    print(plot.handles.keys())
    print(plot.handles['glyph'].marker)
    plot.handles['glyph'].line_color='royalblue'
    
    #plot.handles['glyph'].opts(color='royalblue')

        
def wtrack_linear_plot_polygons(arm, direction, pos_time, w_coor):
    time_range = (pos_time[0], pos_time[-1])
    y_range = (w_coor[arm][direction].x1, w_coor[arm][direction].x2)
    time_total = time_range[1] - time_range[0]
    time_center = time_total/2 + time_range[0]
    y_total = y_range[1] - y_range[0]
    y_center = y_total/2 + y_range[0]
    
    box = hv.Box(time_center, y_center, (time_total, y_total))    
    
    init_hooks = [functools.partial(wtrack_linear_plot_init_hook, arm=arm, direction=direction)]
    
    hooks = [functools.partial(wtrack_linear_plot_hook, arm=arm, direction=direction)]
    
    if hv.Store.current_backend == 'bokeh':
        poly = hv.Polygons(box).opts(hooks=hooks)
    elif hv.Store.current_backend == 'matplotlib':
        poly = hv.Polygons(box).opts(initial_hooks=init_hooks,
                                     hooks=hooks)
    return poly

In [456]:
hv.output(backend='bokeh')
#hv.output(backend='matplotlib', fig='pdf')

w_coor = encode_settings.wtrack_arm_coordinates
wtrack_linear_plot_polygons(arm='center', direction='forward', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons(arm='left', direction='forward', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons( arm='left', direction='reverse', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons(arm='center', direction='reverse', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons(arm='right', direction='forward', pos_time=pos_time, w_coor=w_coor) * \
    wtrack_linear_plot_polygons(arm='right', direction='reverse', pos_time=pos_time, w_coor=w_coor) * \
    hv.Scatter((pos_time, pos_epoch)).opts(hooks=[pos_hook])
#hv.Scatter((pos_time, pos_epoch)).opts(color='royalblue', fig_size=400, aspect=2)

dict_keys(['xaxis', 'x_range', 'yaxis', 'y_range', 'plot', 'previous_id', 'source', 'cds', 'selected', 'glyph', 'glyph_renderer'])
circle


In [507]:
encode_settings.wtrack_arm_coordinates['center']['forward']

{'x1': 0, 'x2': 100, 'len': 100}

In [531]:
class WtrackLinposDecomposed:
    def __init__(self, wtrack_arm_coord):
        self.segment_order_cw = [('center', 'forward'),
                                 ('left', 'forward'), 
                                 ('left', 'reverse'),
                                 ('right', 'forward'),
                                 ('right', 'reverse'),
                                 ('center', 'reverse')]
        self.segment_order_ccw = [('center', 'forward'),
                                  ('right', 'forward'), 
                                  ('right', 'reverse'),
                                  ('left', 'forward'),
                                  ('left', 'reverse'),
                                  ('center', 'reverse')]
        self.segment_decomp_armcord_cw = self._segment_decomposed_with_buffer(self.segment_order_cw, wtrack_arm_coord)
        self.segment_decomp_armcord_ccw = self._segment_decomposed_with_buffer(self.segment_order_ccw, wtrack_arm_coord)
        
    @staticmethod
    def _segment_decomposed_with_buffer(segment_order, wtrack_arm_coord):
        seg_decomposed = AttrDict()
        seg_offset = 0
        for ii, seg in enumerate(segment_order):
            prev_seg = segment_order[ii-1]
            next_seg = segment_order[(ii+1)%len(segment_order)]

            prev_seg_len = wtrack_arm_coord[prev_seg[0]][prev_seg[1]].len
            main_seg_len = wtrack_arm_coord[seg[0]][seg[1]].len
            next_seg_len = wtrack_arm_coord[next_seg[0]][next_seg[1]].len

            prev_seg_start = seg_offset
            main_seg_start = prev_seg_start + prev_seg_len
            next_seg_start = main_seg_start + main_seg_len
            next_seg_end = next_seg_start + next_seg_len
            seg_total_len = prev_seg_len + main_seg_len + next_seg_len
            seg_offset += seg_total_len

            arm_dict = seg_decomposed.setdefault(seg[0], AttrDict())
            dir_dict = arm_dict.setdefault(seg[1], AttrDict(prev=[prev_seg_start, main_seg_start], 
                                                            main=[main_seg_start, next_seg_start],
                                                            next=[next_seg_start, next_seg_end],
                                                            prev_seg=prev_seg,
                                                            next_seg=next_seg))
            
        return seg_decomposed
     

In [532]:
wtrack_decompose = WtrackLinposDecomposed(encode_settings.wtrack_arm_coordinates)
wtrack_decomp_armcoor_cw = wtrack_decompose.segment_decomp_armcord_cw
wtrack_decomp_armcoor_ccw = wtrack_decompose.segment_decomp_armcord_ccw


In [533]:
display(wtrack_decomp_armcoor_cw)
display(wtrack_decomp_armcoor_ccw)

{'center': {'forward': {'prev': [0, 100],
   'main': [100, 200],
   'next': [200, 350],
   'prev_seg': ('center', 'reverse'),
   'next_seg': ('left', 'forward')},
  'reverse': {'prev': [2050, 2200],
   'main': [2200, 2300],
   'next': [2300, 2400],
   'prev_seg': ('right', 'reverse'),
   'next_seg': ('center', 'forward')}},
 'left': {'forward': {'prev': [350, 450],
   'main': [450, 600],
   'next': [600, 750],
   'prev_seg': ('center', 'forward'),
   'next_seg': ('left', 'reverse')},
  'reverse': {'prev': [750, 900],
   'main': [900, 1050],
   'next': [1050, 1200],
   'prev_seg': ('left', 'forward'),
   'next_seg': ('right', 'forward')}},
 'right': {'forward': {'prev': [1200, 1350],
   'main': [1350, 1500],
   'next': [1500, 1650],
   'prev_seg': ('left', 'reverse'),
   'next_seg': ('right', 'reverse')},
  'reverse': {'prev': [1650, 1800],
   'main': [1800, 1950],
   'next': [1950, 2050],
   'prev_seg': ('right', 'forward'),
   'next_seg': ('center', 'reverse')}}}

{'center': {'forward': {'prev': [0, 100],
   'main': [100, 200],
   'next': [200, 350],
   'prev_seg': ('center', 'reverse'),
   'next_seg': ('right', 'forward')},
  'reverse': {'prev': [2050, 2200],
   'main': [2200, 2300],
   'next': [2300, 2400],
   'prev_seg': ('left', 'reverse'),
   'next_seg': ('center', 'forward')}},
 'right': {'forward': {'prev': [350, 450],
   'main': [450, 600],
   'next': [600, 750],
   'prev_seg': ('center', 'forward'),
   'next_seg': ('right', 'reverse')},
  'reverse': {'prev': [750, 900],
   'main': [900, 1050],
   'next': [1050, 1200],
   'prev_seg': ('right', 'forward'),
   'next_seg': ('left', 'forward')}},
 'left': {'forward': {'prev': [1200, 1350],
   'main': [1350, 1500],
   'next': [1500, 1650],
   'prev_seg': ('right', 'reverse'),
   'next_seg': ('left', 'reverse')},
  'reverse': {'prev': [1650, 1800],
   'main': [1800, 1950],
   'next': [1950, 2050],
   'prev_seg': ('left', 'forward'),
   'next_seg': ('center', 'reverse')}}}

In [491]:
encode_settings.wtrack_arm_coordinates

{'center': {'forward': {'x1': 0, 'x2': 100, 'len': 100},
  'reverse': {'x1': 400, 'x2': 500, 'len': 100}},
 'left': {'forward': {'x1': 100, 'x2': 250, 'len': 150},
  'reverse': {'x1': 250, 'x2': 400, 'len': 150}},
 'right': {'forward': {'x1': 500, 'x2': 650, 'len': 150},
  'reverse': {'x1': 650, 'x2': 800, 'len': 150}}}