In [22]:
# general imports
import matplotlib.pyplot as plt
import cv2
import numpy as np
from IPython.display import display, Image
import ipywidgets as widgets
import threading

In [23]:
# ml imports 
import torch
import torchvision
import torchaudio

In [24]:
import os
from ultralytics import YOLO

In [29]:
def run_v5_model(model): 
    # Stop button
    # ================
    stopButton = widgets.ToggleButton(
        value=False,
        description="Stop",
        disabled=False,
        button_style="danger",  # 'success', 'info', 'warning', 'danger' or ''
        tooltip="Description",
        icon="square",  # (FontAwesome names without the `fa-` prefix)
    )

    # Display function
    # ================
    def view(button):
        cap = cv2.VideoCapture(0)
        display_handle = display(None, display_id=True)
        i = 0
        while True:
            _, frame = cap.read()
            frame = cv2.flip(frame, 1)  # if your camera reverses your image

            results = model(frame)
            frame = np.squeeze(results.render()[0])

            _, frame = cv2.imencode(".jpeg", frame)
            display_handle.update(Image(data=frame.tobytes()))
            if stopButton.value == True:
                cap.release()
                display_handle.update(None)

    # Run
    # ================
    display(stopButton)
    thread = threading.Thread(target=view, args=(stopButton,))
    thread.start()

# Reference: https://abauville.medium.com/display-your-live-webcam-feed-in-a-jupyter-notebook-using-opencv-d01eb75921d1

In [None]:
# test version 1 of the model trained on 75 focused images & 75 unfocused images

attention_model_v1 = torch.hub.load(
    "ultralytics/yolov5",
    "custom",
    path="model_iterations/v1/weights/best.pt",
    force_reload=True,
)

run_v5_model(attention_model_v1)

In [25]:
def run_v8_model(model): 
    threshold = 0.5

    # Stop button
    # ================
    stopButton = widgets.ToggleButton(
        value=False,
        description="Stop",
        disabled=False,
        button_style="danger",  # 'success', 'info', 'warning', 'danger' or ''
        tooltip="Description",
        icon="square",  # (FontAwesome names without the `fa-` prefix)
    )


    # Display function
    # ================
    def view(button):
        cap = cv2.VideoCapture(0)
        display_handle = display(None, display_id=True)
        i = 0
        while True:
            _, frame = cap.read()
            frame = cv2.flip(frame, 1)  # if your camera reverses your image

            results = model(frame)[0]

            for result in results.boxes.data.tolist():
                x1, y1, x2, y2, score, class_id = result

                if class_id == 16:
                    colour = (20, 200, 0)
                    # 16 = focused
                elif class_id == 15:
                    colour = (0, 0, 255)
                    # 15 = distracted
                else:
                    colour = (255, 255, 0)

                if score > threshold:
                    cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), colour, 4)
                    cv2.putText(
                        frame,
                        f"{class_id}{results.names[int(class_id)].upper()} {round(score, 4)}",
                        (int(x1), int(y1 - 10)),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        1.3,
                        colour,
                        3,
                        cv2.LINE_AA,
                    )

            _, frame = cv2.imencode(".jpeg", frame)
            display_handle.update(Image(data=frame.tobytes()))
            if stopButton.value == True:
                cap.release()
                display_handle.update(None)


    # Run
    # ================
    display(stopButton)
    thread = threading.Thread(target=view, args=(stopButton,))
    thread.start()

In [None]:
# testing run v8 model code
attention_model_v4 = YOLO("model_iterations/v4/weights/best.pt")
run_v8_model(attention_model_v4)