In [1]:
import os
import time
import pickle
import itertools
from collections import namedtuple
import numpy as np
import matplotlib.pyplot as plt

from genEM3.data.wkwdata import WkwData,DataSource
from genEM3.util.path import get_data_dir
import genEM3.data.annotation as annotation 

In [2]:
# Loaded the json file for the dataset
json_dir = os.path.join(get_data_dir(), 'debris_clean_added_bboxes2_wiggle_datasource.json') 
config = WkwData.config_wkwdata(json_dir)
dataset = WkwData.init_from_config(config)

# Get a set of data sources with the normal bounding boxes to create a patch wise detaset and a larger bounding box for annotation
margin = 35
roi_size = 140
source_dict = annotation.patch_source_list_from_dataset(dataset=dataset,
                                                        margin=margin,
                                                        roi_size=roi_size)
dataset_dict = dict.fromkeys(source_dict)

for key in source_dict:
    cur_source = source_dict[key]
    cur_patch_shape = tuple(cur_source[0].input_bbox[3:6])
    cur_config = WkwData.config_wkwdata(datasources_json_path=None,
                                        input_shape=cur_patch_shape,
                                        output_shape=cur_patch_shape)
    dataset_dict[key] = WkwData.init_from_config(cur_config, source_dict[key])
# assert larger and small datasets have the same length
dataset_lengths = [len(d) for d in dataset_dict.values()]
assert all(cur_L == dataset_lengths[0] for cur_L in dataset_lengths)
# break down the range into partitions of 1000
range_size = 1000
list_ranges = annotation.divide_range(total_size=len(dataset_dict['large']),
                                      chunk_size=range_size)

In [3]:
from functools import partial
import ipywidgets as widgets
from IPython.display import display,clear_output

In [4]:
# Parameters
# List of annotations to save. (index, dict)
annotation_list = [(index, {'Myelin': None, 'Debris': None}) for index in list_ranges[0]]
# Four buttons used for annotation
button_names = [['No', 'Yes'], ['No', 'Yes']]
target_classes = ['Myelin', 'Debris']
# Global variable for current index of training example
current_index = 0
# The output object to print python objects in the widget, 'margin':'0px 0px 0px 300px'
out = widgets.Output(layout={'border': '1px solid black', 'width': '50%', 
                             'align_self':'center'})

annotation_fun = lambda i: annotation.display_example(i, dataset=dataset_dict['large'], margin=margin, roi_size=roi_size)
with out:
    annotation_fun(current_index)

def set_annotation(button_value, target_class):
    """Set the annotation value"""
    annotation_list[current_index][1][target_class] = button_value
    with out:
        print(annotation_list[current_index][1])

def get_button(button_name, target_class):
    """Get the button given the button type [Yes or no] and target type [Debris, Myelin]"""
    assert button_name in ['Yes', 'No'], "Target type should be either 'Yes' or 'No'"
    button_value = 1.0 if (button_name == 'Yes') else 0.0
    button = widgets.Button(description=button_name,
                    disabled=False) 
    def on_click(b):
        set_annotation(button_value, target_class)
    button.on_click(on_click)
    return button

# Get the buttons for setting
target_buttons = []
for index, target in enumerate(target_classes):
    cur_target_buttons = [get_button(b_name, target) for b_name in button_names[index]]
    target_buttons.append(widgets.HBox([widgets.Label(target+': ')]+cur_target_buttons, layout=widgets.Layout(justify_content='center')))
# Previous/Next button
def show_next(btn,relative_pos=0):
    global current_index
    current_index += relative_pos
    progress.value = current_index
    with out:
        clear_output(wait=True)
        annotation_fun(current_index)

b1 = widgets.Button(description="Previous")
b2 = widgets.Button(description="Next")
b1.on_click(partial(show_next, relative_pos=-1))
b2.on_click(partial(show_next, relative_pos=1))
control_buttons = widgets.HBox([widgets.Label('Controlers: '),b1, b2], layout=widgets.Layout(justify_content='center'))

# Connect slider to the current_index
progress = widgets.IntSlider(value=None,
        min=min(list_ranges[0]),
        max=max(list_ranges[0]),
        step=1,
        description='Index:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        layout=widgets.Layout(width='100%',justify_content='center'))

progress.value = current_index
def on_change(v):
    global current_index
    current_index = v['new']
    show_next(0)
progress.observe(on_change, names='value')


selection = widgets.VBox([progress]+target_buttons+[control_buttons, out])
display(selection)


VBox(children=(IntSlider(value=0, continuous_update=False, description='Index:', layout=Layout(justify_content…

In [42]:
w = annotation.Widget(dataset=dataset_dict['large'], index_range=list_ranges[0])
w.show_widget()

VBox(children=(IntSlider(value=0, continuous_update=False, description='Index:', layout=Layout(justify_content…

In [36]:
w.current_index=100

In [None]:
button_descriptions = ['No', 'Yes', 'No', 'Yes']
widgets.VBox([widgets.Button(
    description=desc,
    disabled=False) for desc in button_descriptions]))

In [None]:
button_descriptions = [['No', 'Yes'], ['No', 'Yes']]
target_classes = ['Myelin', 'Debris']
{target:widgets.HBox([widgets.Button(description=desc, disabled=False) for desc in button_descriptions[index]] for index, target in enumerate(target_classes)}