In [1]:
%cd ../../

/Users/s01948/Documents/eval-detection


In [2]:
from jupyter_bbox_widget import BBoxWidget
import numpy as np
from src.extensions.metrics.ot_cost import get_ot_cost, get_cmap
import ipywidgets as widgets
%matplotlib inline

In [3]:
SAMPLE_IMG_URL = "http://farm8.staticflickr.com/7162/6767429191_69b495e08c_z.jpg"

In [4]:
from IPython.display import display


CLASS_LABELS = ["apple", "banana", "orange", "cup"]
n_class = len(CLASS_LABELS)


bbox_widget = BBoxWidget(
    image=SAMPLE_IMG_URL,
    classes= CLASS_LABELS + [" ".join(["GT", l]) for l in CLASS_LABELS],
    colors=["green"] * n_class + ["orange"] * n_class,
    hide_buttons=True
    )

w_conf = widgets.FloatSlider(value=0.5, min=0, max=1., description='Confidence')
bbox_widget.attach(w_conf, name="confidence")

def format_bboxes(bboxes, classes, return_orders=False):
    orders = []
    formatted_bboxes = []
    for label in classes:
        formatted_bboxes.append([])
        for i, bbox in enumerate(bboxes):
            if label in bbox["label"]:
                if bbox["label"].startswith("GT"):
                    conf = 1 
                else:
                    conf = bbox["confidence"]
                formatted_bboxes[-1].append([bbox["x"], bbox["y"], bbox["x"]+bbox["width"], bbox["y"]+bbox["height"], conf])
                orders.append(i)
        formatted_bboxes[-1] = np.asarray(formatted_bboxes[-1], dtype=np.float32).reshape(-1, 5)
    if return_orders:
        return formatted_bboxes, orders
    return formatted_bboxes
    
def evaluate_bboxes():
    gt_bboxes = [b for b in bbox_widget.bboxes if b["label"].startswith("GT")]
    gt_bboxes = format_bboxes(gt_bboxes, CLASS_LABELS)
    bboxes = [b for b in bbox_widget.bboxes if not b["label"].startswith("GT")]
    bboxes = format_bboxes(bboxes, CLASS_LABELS)
    cmap_func = lambda x, y: get_cmap(x, y, alpha=0.5, beta=0.6,)
    otc, log = get_ot_cost(gt_bboxes, bboxes, cmap_func, return_matrix=True)
    return otc, log

w_out = widgets.Output()

def update_label_conf():
    idx = bbox_widget.selected_index
    cur_label = bbox_widget.bboxes[idx]["label"]

    if cur_label.startswith("GT"):
        return
    
    for c_name in CLASS_LABELS:
        if c_name in cur_label:
            break

    # re-label bboxes of c_name class
    for idx, b in enumerate(bbox_widget.bboxes):
        if b["label"].startswith("GT"):
            continue
        if c_name in b["label"]:
            conf = b["confidence"]
            new_label = f"{c_name}|{conf}"
            bbox_widget._set_bbox_property(idx, "label", new_label)

def on_bbox_change(change):
    update_label_conf()
    w_out.clear_output(wait=True)
    otc, _ = evaluate_bboxes()
    with w_out:
        print(f"OC-cost: {otc:.3f}")
        
bbox_widget.observe(on_bbox_change, names=['bboxes'])

w_container = widgets.VBox([
    bbox_widget,
    w_conf,
    w_out,
])
display(w_container)

VBox(children=(BBoxWidget(classes=['apple', 'banana', 'orange', 'cup', 'GT apple', 'GT banana', 'GT orange', '…