In [1]:
from ipycanvas import Canvas, hold_canvas
from ipywidgets import Image
from IPython.display import display
import ipywidgets as widgets
from dataclasses import dataclass
import dataclasses
import numpy as np
from os import path
from ipyevents import Event
import json

from juggle_tracker.labels import BallState, BallSequence, VideoLabels

In [2]:
_ = """
Labeler UI design.

Sidebar to select a ball or create a new ball.
Clicking on a ball selects it.
Creating a new ball creates a ball at the center of the current frame.

The selected ball is orange.
Other balls in the frame are black.

A/D keys go to prev/next frames.
If the selected ball is not present in the prev/next frame, it is extended.

Clicking repositions the selected ball.

There is a textbox for quickly jumping to a numbered frame. (not done yet)

S toggles the selected ball between freefall and held label.
"""

In [14]:
# The state of the labeling UI.
@dataclass
class LabelingState:
    # The current labels for the video.
    labels: VideoLabels
        
    # The current frame being displayed by the UI.
    current_frame: int
        
    # The id of the ball that will be affected by the next click/keypress.
    # `None` if there is no ball.
    current_ball: str
    
    # The directory with the images, e.g. 'data/cap3/imgs'.
    imgs_dir: str
        
    # The file where we load/save the labels, e.g. 'data/cap3/labels_1.txt'
    labels_file: str
        
    def img_path(self):
        return path.join(self.imgs_dir, '%03d.png' % self.current_frame)
    
    def current_ball_state(self):
        if self.current_ball is None:
            return None
        ball = self.labels.balls[self.current_ball]
        return ball.states[self.current_frame - ball.start_frame]

In [19]:
imgs_dir = 'data\cap3\img'
labels_file = 'data\cap3\labels_1.txt'

try:
    labels = VideoLabels.load(labels_file)
    print("loading existing labels")
except FileNotFoundError:
    print("new labels")
    labels = VideoLabels(balls={})
    
labeling_state = LabelingState(
    labels=labels,
    current_frame=0,
    current_ball=None,
    imgs_dir=imgs_dir,
    labels_file=labels_file,
)

loading existing labels


In [20]:
labeling_widget_canvas = Canvas(width=640, height=480)
labeling_widget_sidebar = widgets.VBox([])
labeling_widget = widgets.HBox([labeling_widget_canvas, labeling_widget_sidebar])
events = Event(source=labeling_widget, watched_events=['keydown'])
display(labeling_widget)

def handle_new_ball(x, y):
    next_ball_id = labeling_state.labels.next_ball_id()
    labeling_state.labels.balls[next_ball_id] = BallSequence(
        start_frame=labeling_state.current_frame,
        states=[
            BallState(
                position=np.array([x, y]),
                freefall=False,
            )
        ],
        color='orange',
    )
    labeling_state.current_ball = next_ball_id
    rerender()

def handle_reposition_ball(x, y):
    state = labeling_state.current_ball_state()
    state.position = np.array([x, y])
        
debug_view_event = widgets.Output()

@debug_view_event.capture(clear_output=True)
def handle_mouse_down(x, y):
    if labeling_state.current_ball is not None:
        handle_reposition_ball(x, y)
    rerender()
labeling_widget_canvas.on_mouse_down(handle_mouse_down)

def handle_frame_change(delta):
    old_frame = labeling_state.current_frame
    labeling_state.current_frame += delta
    
    # Extend the current ball into this frame
    if labeling_state.current_ball is not None:
        ball = labeling_state.labels.balls[labeling_state.current_ball]
        if ball.start_frame == old_frame and delta == -1:
            print("Extending backwards to " + str(ball.start_frame - 1))
            ball.start_frame -= 1
            ball.states.insert(0, dataclasses.replace(ball.states[0]))
        if ball.start_frame + len(ball.states) - 1 == old_frame and delta == 1:
            print("Extending forwards to " + str(ball.start_frame + len(ball.states)))
            ball.states.append(dataclasses.replace(ball.states[-1]))
    
    rerender()

def handle_toggle_freefall():
    if labeling_state.current_ball is None:
        return
    state = labeling_state.current_ball_state()
    state.freefall = not state.freefall
    rerender()
    
@debug_view_event.capture(clear_output=True)
def handle_dom_event(event):
    if event['event'] == 'keydown':
        if event['key'] == 'd':
            handle_frame_change(1)
        elif event['key'] == 'a':
            handle_frame_change(-1)
        elif event['key'] == 's':
            handle_toggle_freefall()
events.on_dom_event(handle_dom_event)

@debug_view_event.capture(clear_output=True)
def handle_button_click(ball_id):
    labeling_state.current_ball = ball_id
    rerender()
    
def rerender():
    
    # Redraw the canvas.
    
    with hold_canvas(labeling_widget_canvas):
        labeling_widget_canvas.clear()
        frame = Image.from_file(labeling_state.img_path())
        labeling_widget_canvas.draw_image(frame, 0, 0)

        # It doesn't display the first few strokes for some reason, so just draw some extra strokes to make it
        # display the important ones.
        for i in range(10):
            labeling_widget_canvas.stroke_circle(i * 10, 0, 5)

        for ball_id in labeling_state.labels.balls:
            ball = labeling_state.labels.balls[ball_id]
            sequence_index = labeling_state.current_frame - ball.start_frame
            if sequence_index >= 0 and sequence_index < len(ball.states):
                state = ball.states[sequence_index]
                if ball_id == labeling_state.current_ball:
                    labeling_widget_canvas.stroke_style = 'orange'
                else:
                    labeling_widget_canvas.stroke_style = 'gray'
                labeling_widget_canvas.stroke_circle(state.position[0], state.position[1], 10)
                
                label_text = str(ball_id)
                if not state.freefall:
                    label_text += ' H'
                labeling_widget_canvas.stroke_text(label_text, state.position[0], state.position[1])
            
    # Redraw the sidebar.
    
    sidebar_rows = []
    for ball_id in labeling_state.labels.balls:
        b = widgets.Button(description='Ball ' + str(ball_id))
        b.on_click(lambda _, ball_id=ball_id: handle_button_click(ball_id))
        sidebar_rows.append(b)
    new_ball_button = widgets.Button(description='Add ball')
    new_ball_button.on_click(lambda _: handle_new_ball(100, 100))
    sidebar_rows.append(new_ball_button)
    labeling_widget_sidebar.children = sidebar_rows
    
rerender()

HBox(children=(Canvas(height=480, width=640), VBox()))

In [22]:
display(debug_view_event)

Output()

In [23]:
labeling_state.labels.save(labeling_state.labels_file)