In [1]:
# This snippet of code properly adds the working source root path to python's path
# so you no longer have to install spykshrk through setuptools
import sys, os
root_depth = 2
notebook_dir = globals()['_dh'][0]
root_path = os.path.abspath(os.path.join(notebook_dir, '../'*root_depth))
# Add to python's path
try:
    while True:
        sys.path.remove(root_path)
except ValueError:
    # no more root paths
    pass
sys.path.append(root_path)
# Alternatively set root path as current working directory
#os.chdir(root_path)

In [2]:
import pandas as pd
import numpy as np
import scipy as sp
import json
import os
import holoviews as hv

from spykshrk.realtime.simulator import nspike_data

from spykshrk.franklab.data_containers import EncodeSettings, DecodeSettings, SpikeObservation, \
                                              LinearPosition, StimLockout, Posteriors, FlatLinearPosition
from spykshrk.franklab.pp_decoder.visualization import DecodeVisualizer, DecodeErrorVisualizer
from spykshrk.franklab.franklab_data import FrankAnimalInfo, FrankFilenameParser, FrankDataInfo

hv.extension('matplotlib')
hv.extension('bokeh')

In [3]:
# Load merged rec HDF store based on config

#config_file = '/opt/data36/daliu/realtime/spykshrk/dec_60uv_30samp/bond.config.json'
config_file = '/opt/data36/daliu/realtime/spykshrk/wang_sim_test/bond.config.json'

config = json.load(open(config_file, 'r'))

day = config['simulator']['nspike_animal_info']['days'][0]
epoch = config['simulator']['nspike_animal_info']['epochs'][0]
time_bin_size = config['pp_decoder']['bin_size']

# Main hdf5 data source file name
hdf_file = os.path.join(config['files']['output_dir'],
                        '{}.rec_merged.h5'.format(config['files']['prefix']))

# Extract just encode and decode settings from config
encode_settings = EncodeSettings(config)
decode_settings = DecodeSettings(config)

# Open data file
store = pd.HDFStore(hdf_file, mode='r')

# Encapsulate Spike Observation panda table in container
observ_obj = SpikeObservation.from_realtime(store['rec_3'], day=day, epoch=epoch, enc_settings=encode_settings)

realtime_posteriors = Posteriors.from_realtime(store['rec_4'], day=day, epoch=epoch, 
                                               enc_settings=encode_settings)

# Grab stimulation lockout times
stim_lockout = StimLockout.from_realtime(store['rec_11'], enc_settings=encode_settings)


# Grab animal linearized real position
nspike_anim = nspike_data.AnimalInfo(**config['simulator']['nspike_animal_info'])
pos = nspike_data.PosMatDataStream(nspike_anim)
pos_data = pos.data

# Encapsulate linear position
lin_obj = LinearPosition.from_nspike_posmat(pos_data, encode_settings)

ripcons = nspike_data.RipplesConsData(nspike_anim)
ripdata = ripcons.data_obj

In [4]:
anim = FrankAnimalInfo('/opt/data36/daliu/', 'pyBond')
decode_info = FrankDataInfo(anim, 'decode')
display(decode_info.entries)
offline_posterior = decode_info.load_single_dataset_ind(3)

In [5]:

dec_viz = DecodeVisualizer(realtime_posteriors, linpos=lin_obj, riptimes=ripdata.get_above_maxthresh(5),
                           enc_settings=encode_settings)

off_dec_viz = DecodeVisualizer(offline_posterior, linpos=lin_obj, riptimes=ripdata.get_above_maxthresh(5), 
                               enc_settings=encode_settings)


In [6]:
%%output backend='bokeh' size=400 holomap='scrubber'
%%opts RGB { +framewise} [height=100 width=250 aspect=2]
%%opts Points [height=100 width=250 aspect=2 ] (marker='o' color='#AAAAFF' size=2 alpha=0.7)
%%opts Polygons (color='grey', alpha=0.5 fill_color='grey' fill_alpha=0.5)

dec_viz.plot_all_dynamic(stream=hv.streams.RangeXY(), plt_range=1, slide=1, values=ripdata['starttime']-.5)


In [7]:
offline_posterior.memory_usage().sum()/(2**30)

In [8]:
%%opts Image {+axiswise} [height=300 width=300 aspect=1]
%%opts Curve.arm_bound {+axiswise} [aspect=1] (line_dash='dashed' color='#AAAAAA' linestyle='--' alpha=0.5)
%%opts Points {+axiswise} [aspect=1] (marker='*' size=18)
%%opts NdLayout {+axiswise}
%%output backend='matplotlib' size=200

online_rip_plots = dec_viz.plot_ripple_grid(1,1)
offline_rip_plots = off_dec_viz.plot_ripple_grid(1,1)

for ii, subplot in enumerate(online_rip_plots):
    display((subplot + offline_rip_plots[ii]).cols(2))
    


In [10]:
def ripple_posterior_map_error(first_posterior, second_posterior):

    rip_timestamps = np.intersect1d(first_posterior.index.get_level_values('timestamp'), second_posterior.index.get_level_values('timestamp'))

    first_map = []
    second_map = []
    for timestamp in rip_timestamps:
        rip_first_slice = first_posterior.query('timestamp == @timestamp').get_distribution_view()
        rip_second_slice = second_posterior.query('timestamp == @timestamp').get_distribution_view()
        first_slice_argmax = np.argmax(rip_first_slice.values)
        second_slice_argmax = np.argmax(rip_second_slice.values)
        first_slice_map = encode_settings.pos_bins[first_slice_argmax]
        second_slice_map = encode_settings.pos_bins[second_slice_argmax]
        first_map.append(first_slice_map)
        second_map.append(second_slice_map)

    map_error = np.abs(np.array(first_map) - np.array(second_map))
    map_error_mean = np.mean(map_error)
    map_error_std = np.std(map_error)
    
    return rip_timestamps, map_error, map_error_mean, map_error_std

ripple_ids = [2, 8, 18, 40, 123, 180, 199, 203, 208, 221, 235]

for ripple_id in ripple_ids:
    rt_post = realtime_posteriors.apply_time_event(ripdata.get_above_maxthresh(5))
    off_post = offline_posterior.apply_time_event(ripdata.get_above_maxthresh(5))

    rip_rt_post = rt_post.query('event_grp == @ripple_id')
    rip_off_post = off_post.query('event_grp == @ripple_id')
    
    map_timestamps, map_error, map_error_mean, map_error_std = ripple_posterior_map_error(rip_rt_post, rip_off_post)
    print('{}: mean: {:.02f} std: {:.02f}'.format(ripple_id, map_error_mean, map_error_std))
    

In [11]:
def ripple_posterior_wasserstein_distance(first_posterior, second_posterior):

    rip_timestamps = np.intersect1d(first_posterior.index.get_level_values('timestamp'), second_posterior.index.get_level_values('timestamp'))

    was_dists = []
    for timestamp in rip_timestamps:
        rip_first_slice = first_posterior.query('timestamp == @timestamp').get_distribution_view()
        rip_second_slice = second_posterior.query('timestamp == @timestamp').get_distribution_view()
        was_dist = sp.stats.wasserstein_distance(rip_first_slice.values.squeeze(), rip_second_slice.values.squeeze()) 
        was_dists.append(was_dist)
        
    was_dist_mean = np.mean(was_dists)
    was_dist_std = np.std(was_dists)
    
    return rip_timestamps, was_dists, was_dist_mean, was_dist_std

for ripple_id in ripple_ids:
    rt_post = realtime_posteriors.apply_time_event(ripdata.get_above_maxthresh(5))
    off_post = offline_posterior.apply_time_event(ripdata.get_above_maxthresh(5))

    rip_rt_post = rt_post.query('event_grp == @ripple_id')
    rip_off_post = off_post.query('event_grp == @ripple_id')
    
    was_timestamps, was_dist, was_dist_mean, was_dist_std = ripple_posterior_wasserstein_distance(rip_rt_post, rip_off_post)
    print('{}: mean: {:f} std: {:f}'.format(ripple_id, was_dist_mean, was_dist_std))
    

In [12]:
%%opts Image {+axiswise} [height=300 width=300 aspect=1]
%%opts Curve.arm_bound {+axiswise} [aspect=1] (line_dash='dashed' color='#AAAAAA' linestyle='--' alpha=0.5)
%%opts Curve.was {+framewise} [apply_ranges=False]
%%opts Points {+axiswise} [aspect=1] (marker='*' size=18)
%%opts NdLayout {+axiswise}
%%output backend='matplotlib' size=200

def overlay(first, plot, element):
    fig = hv.Store.renderers['matplotlib'].get_plot(first)
    ax = plot.handles['axis']
    ax2 = ax.twinx()
    ax2.set_yticks(fig.handles['axis'].get_yticks())
    #ax2.set_yticklabels([t.get_text() for t in fig.handles['axis'].get_yticklabels()])
    ax2.set_ylabel(fig.handles['axis'].get_ylabel())
    for line in fig.handles['axis'].lines:
        ax2.plot(*line.get_data())
        ax2.lines[-1].set_color('k')
        ax2.lines[-1].set_linestyle('--')

def tmp(plot, element):
    overlay(a, plot, element)

for ripple_id in ripple_ids:
    
    rt_post = realtime_posteriors.apply_time_event(ripdata.get_above_maxthresh(5))
    off_post = offline_posterior.apply_time_event(ripdata.get_above_maxthresh(5))

    rip_rt_post = rt_post.query('event_grp == @ripple_id')
    rip_off_post = off_post.query('event_grp == @ripple_id')
    
    map_timestamp, map_error, map_error_mean, map_error_std = ripple_posterior_map_error(rip_rt_post, rip_off_post)
    
    was_timestamp, was_dist, was_dist_mean, was_dist_std = ripple_posterior_wasserstein_distance(rip_rt_post, rip_off_post)
    
    error_plots = hv.Curve(map_error, group='map') + hv.Curve(was_dist, group='was')(norm=dict(framewise=True))
    
    display((dec_viz.plot_ripple_all(ripple_id) + off_dec_viz.plot_ripple_all(ripple_id) + error_plots).cols(2))

    print('MAP: id {}: mean: {:.02f} std: {:.02f}'.format(ripple_id, map_error_mean, map_error_std))
    print(' WASSERSTEIN: id {}: mean: {:.2e} std: {:.2e}'.format(ripple_id, was_dist_mean, was_dist_std))

    