In [1]:
# Run this notebook with python3
import pandas as pd
import h5py
import numpy as np
from bokeh.plotting import (figure,
                            show,
                            output_notebook,
                            output_file,
                           )
from bokeh.models import (ColumnDataSource,
                          CDSView,
                          GroupFilter,
                          BooleanFilter,
                          LabelSet,
                          HoverTool,
                         )
from bokeh.layouts import column
from bokeh.layouts import gridplot

from spikelib.utils import datasets_to_array

In [4]:
target_path = '../../data/sorting/target_units_MR-0092t2.result-1.json'
with open(target_path, 'r') as f:
    targets = json.loads(f.read())

target_units = []
for key in targets:
    target_units.extend(
        [*map(lambda k: ('temp_{}'.format(k), key), targets[key])]
    )
class_type = pd.DataFrame(target_units, columns=['name', 'type'])
class_type.set_index('name', inplace=True)


## Plot RF fitting

In [11]:
def get_dateset(fanalysis, wn_name, class_=None):
    """Retrive dataset to plot from hdf5 file.
    
    Parameter
    ---------
    fanalysis: str
        path to hdf5 file with sta information.
    wn_name: str
        name of white noise protocol in '/sta/' group in hdf5 file.

    """
    px_um = 50
    group_name = '/sta/{}/spatial/char/'.format(wn_name)
    with h5py.File(fanalysis) as panalysis: 
        col_name = panalysis[group_name].attrs['col_name'].split(',')
        array, key = datasets_to_array(panalysis[group_name])
    data = pd.DataFrame(data=array, index=key, columns=col_name)
    data.loc[:,'semia'] = data['semia']*px_um*2
    data.loc[:,'semib'] = data['semib']*px_um*2
    data.loc[:,'x0'] = data['x0']*px_um
    data.loc[:,'y0'] = data['y0']*px_um
    data['radius'] = np.sqrt(data['semia']*data['semib'])/2
    data['unit_name'] = data.index.values
    data['class'] = class_
    
    return data

In [9]:
fanalysis = '../../data/processed_protocols/MR-0092t2_modified_analysis_of_protocols_150um_merge.hdf5'

char_cds_nd2 = ColumnDataSource(get_dateset(fanalysis, 'nd2-255', class_type))
char_cds_nd3 = ColumnDataSource(get_dateset(fanalysis, 'nd3-255', class_type))
char_cds_nd4 = ColumnDataSource(get_dateset(fanalysis, 'nd4-255', class_type))
char_cds_nd5 = ColumnDataSource(get_dateset(fanalysis, 'nd5-255', class_type))

view_nd2 = CDSView(source=char_cds_nd2, filters=[GroupFilter(column_name='class', group='valid')])
view_nd3 = CDSView(source=char_cds_nd3, filters=[GroupFilter(column_name='class', group='valid')])
view_nd4 = CDSView(source=char_cds_nd4, filters=[GroupFilter(column_name='class', group='valid')])
view_nd5 = CDSView(source=char_cds_nd5, filters=[GroupFilter(column_name='class', group='valid')])

In [10]:
output_file('rfs.html')

tools = 'pan,wheel_zoom,xbox_select,reset'
plot_kwargs = dict(
    tools=tools, plot_width=500, plot_height=500, x_range=(-4*50, 40*42),
    y_range=(-4*50, 40*42), x_axis_label='x axis [um]', y_axis_label='y axis [um]',
)
ell_kwargs = dict(
    x='x0', y='y0', width='semia', height='semib', angle='angle',
    alpha=0.2, line_color='blue', angle_units='deg', fill_color='blue', legend='64',
)
label_kwargs = dict(level='glyph', x_offset=5, y_offset=5,render_mode='canvas',
                    text_alpha=0.4, )


fig_nd2 = figure(title='RF MR-0092 ND2', **plot_kwargs)
fig_nd3 = figure(title='RF MR-0092 ND3', **plot_kwargs)
fig_nd4 = figure(title='RF MR-0092 ND4', **plot_kwargs)
fig_nd5 = figure(title='RF MR-0092 ND5', **plot_kwargs)
fig_nd2.ellipse(**ell_kwargs, source=char_cds_nd2, view=view_nd2)
fig_nd3.ellipse(**ell_kwargs, source=char_cds_nd3, view=view_nd3)
fig_nd4.ellipse(**ell_kwargs, source=char_cds_nd4, view=view_nd4)
fig_nd5.ellipse(**ell_kwargs, source=char_cds_nd5, view=view_nd5)

fig_nd2.x_range = fig_nd3.x_range = fig_nd4.x_range = fig_nd5.x_range
fig_nd2.y_range = fig_nd3.y_range = fig_nd4.y_range = fig_nd5.y_range

# fig_nd2.add_layout(LabelSet(x='x0', y='y0', text='unit_name', source=char_cds_nd2))

show(gridplot([[fig_nd2, fig_nd3], [fig_nd4, fig_nd5]]))