In [1]:
import os
import time
import threading

import torch
import numpy as np
import cv2 as cv
from ultralytics import YOLO
import ipywidgets as widgets
from IPython.display import display
print(torch.backends.mps.is_available())

True


In [2]:
video_dir = 'data/license/videos'
# model = YOLO('yolov8m.pt')
model = YOLO('/Users/eric/Desktop/2-Career/Projects/ObjectDetectionLL/runs/detect/train/weights/best.pt')
video_files = sorted([f for f in os.listdir(video_dir) if f.endswith('.mp4')])
video_file = f'{video_dir}/{video_files[0]}'

In [3]:
capture = cv.VideoCapture(video_file)
RED = (0, 0, 255)
GREEN = (0, 255, 0)
BLUE = (255, 0, 0)
fps = int(capture.get(cv.CAP_PROP_FPS))

# Global variables
keep_playing = True
is_paused = False

# Create widgets
image_widget = widgets.Image(format='jpeg', height=800, width=800)
play_pause_button = widgets.ToggleButton(
    value=False, description='Play/Pause', icon='play')
stop_button = widgets.Button(description="Exit (rerun code)", icon='stop')
restart_button = widgets.Button(description="Restart", icon='refresh')
rewind_3s_button = widgets.Button(description="3s", icon='fast-backward')
rewind_1s_button = widgets.Button(description="1s", icon='backward')
fastforward_1s_button = widgets.Button(description="1s", icon='forward')
fastforward_3s_button = widgets.Button(description="3s", icon='fast-forward')


def display_frame(frame):
    results = model(frame, device='mps', verbose=False)
    result = results[0]
    bboxes = np.array(result.boxes.xyxy.cpu(), dtype="int")
    classes = np.array(result.boxes.cls.cpu(), dtype="int")

    # create a dict of colors for each class
    colors = {
        '0': (0, 0, 255),
        '1': (0, 255, 0),
        '2': (255, 0, 0),
        '3': (0, 255, 255),
        '4': (255, 0, 255),
        '5': (255, 255, 0),
        '6': (255, 255, 255),
        '7': (0, 0, 128),
        '8': (0, 128, 0),
        '9': (128, 0, 0),
        '10': (0, 128, 128),
        '11': (128, 0, 128),
        '12': (128, 128, 0),
        '13': (128, 128, 128),
        '14': (0, 0, 64),
        '15': (0, 64, 0),
        '16': (64, 0, 0),
        '17': (0, 64, 64),
        '18': (64, 0, 64),
        '19': (64, 64, 0),
        '20': (64, 64, 64),
    } 
    for classification, bbox in zip(classes, bboxes):
        x1, y1, x2, y2 = bbox
        cv.rectangle(frame, (x1, y1), (x2, y2), colors[classification], 2)
        cv.putText(frame, str(classification), (x1, y1 - 5),
                   cv.FONT_HERSHEY_SIMPLEX, 0.5, colors[classification], 2)

    _, frame_data = cv.imencode('.jpeg', frame)
    image_widget.value = frame_data.tobytes()
    time.sleep(0.1)


def play():
    global keep_playing, is_paused
    keep_playing = True
    while keep_playing:
        if not is_paused:
            ret, frame = capture.read()
            if not ret:
                break
            display_frame(frame)
        else:
            time.sleep(0.1)


def pause(b):
    global is_paused
    is_paused = True


def resume(b):
    global is_paused
    is_paused = False


def stop(b):
    global keep_playing
    keep_playing = False
    capture.release()


def restart(b):
    global is_paused
    is_paused = True
    capture.set(cv.CAP_PROP_POS_FRAMES, 0)
    is_paused = False


def rewind_1s(b):
    global is_paused
    current_state = is_paused  # Save the current state
    is_paused = True           # Pause the video temporarily
    current_time = capture.get(cv.CAP_PROP_POS_MSEC)
    new_time = max(0, current_time - 1000)
    capture.set(cv.CAP_PROP_POS_MSEC, new_time)
    if current_state:
        ret, frame = capture.read()
        if ret:
            display_frame(frame)
    is_paused = current_state  # Restore the original state


def rewind_3s(b):
    global is_paused
    current_state = is_paused  # Save the current state
    is_paused = True           # Pause the video temporarily
    current_time = capture.get(cv.CAP_PROP_POS_MSEC)
    new_time = max(0, current_time - 3000)
    capture.set(cv.CAP_PROP_POS_MSEC, new_time)
    if current_state:
        ret, frame = capture.read()
        if ret:
            display_frame(frame)
    is_paused = current_state  # Restore the original state


def fastforward_1s(b):
    global is_paused
    current_state = is_paused  # Save the current state
    is_paused = True           # Pause the video temporarily
    current_time = capture.get(cv.CAP_PROP_POS_MSEC)
    total_time = capture.get(cv.CAP_PROP_FRAME_COUNT) * 1000 / fps
    new_time = min(total_time, current_time + 3000) 
    capture.set(cv.CAP_PROP_POS_MSEC, new_time)
    if current_state:
        ret, frame = capture.read()
        if ret:
            display_frame(frame)
    is_paused = current_state  # Restore the original state

def fastforward_3s(b):
    global is_paused
    current_state = is_paused  # Save the current state
    is_paused = True           # Pause the video temporarily
    current_time = capture.get(cv.CAP_PROP_POS_MSEC)
    total_time = capture.get(cv.CAP_PROP_FRAME_COUNT) * 1000 / fps
    new_time = min(total_time, current_time + 3000) 
    capture.set(cv.CAP_PROP_POS_MSEC, new_time)
    if current_state:
        ret, frame = capture.read()
        if ret:
            display_frame(frame)
    is_paused = current_state  # Restore the original state


# Event Handlers
play_pause_button.observe(lambda change: resume(
    None) if change.new else pause(None), 'value')
stop_button.on_click(stop)
restart_button.on_click(restart)
rewind_3s_button.on_click(rewind_3s)
rewind_1s_button.on_click(rewind_1s)
fastforward_1s_button.on_click(fastforward_1s)
fastforward_3s_button.on_click(fastforward_3s)

In [4]:
# Start video in a new thread
threading.Thread(target=play).start()
# threading.Thread(target=play)
# Display widgets
display(image_widget)
display(widgets.HBox([play_pause_button, rewind_3s_button, rewind_1s_button, fastforward_1s_button, fastforward_3s_button, restart_button, stop_button]))

Image(value=b'', format='jpeg', height='800', width='800')

HBox(children=(ToggleButton(value=False, description='Play/Pause', icon='play'), Button(description='3s', icon…

  elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available() and TORCH_2_X:
