In [None]:
# Import Torch & Models requirements
import torch
import torchvision
import cv2
import time

#!pip install ultralytics
from ultralytics import YOLO

In [None]:
def setup_device():
    if torch.cuda.is_available():
        torch.cuda.set_device(0)  # Use first GPU
        # Enable cudnn benchmarking for better performance
        torch.backends.cudnn.benchmark = True
        return torch.device("cuda")
    return torch.device("cpu")

device = setup_device()
print("device: ", device)

In [None]:
# COCO 80 classes
class_names_80 = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon",
"bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
"cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
"clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"]

In [None]:
class YOLOConfig:
    def __init__(self, model_name, device):
        self.model_path = "../checkpoints/model_weights/{}.pt".format(model_name)
        self.conf_thres = 0.6
        self.iou_thres = 0.3
        self.model = YOLO(self.model_path).to(device)

yolo_handler = YOLOConfig("yolov8s", device)

In [None]:
# Allowed classes
class AllowedClasses:
    def __init__(self, model_name):
        if "yolo" in model_name:
            self.class_names = class_names_80
        else:
            print("Model not supported")
            exit(0)

    def get_allowed_classes(self):
        self.allowed_classes = [
            self.class_names.index("car"),
            self.class_names.index("truck")
        ]

        return self.allowed_classes
    
allowed_classes = AllowedClasses("yolo").get_allowed_classes()

In [None]:
cap = cv2.VideoCapture("../data/input/Video1.mp4")

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    # Inference with thresholds and class filter
    results = yolo_handler.model.track(frame, 
                                 conf=yolo_handler.conf_thres, 
                                 iou=yolo_handler.iou_thres, 
                                 classes=allowed_classes,
                                 persist=True,  # To keep the IDs consistent across 
    )  

    print(results[0].boxes)  # Boxes object for bbox outputs
    print(results[0].boxes.cls)  # Class of detected objects
    print(results[0].boxes.conf)  # Confidence of detected objects
    print(results[0].boxes.xyxy)  # Bounding boxes in (x1, y1, x2, y2) format   
    print(results[0].masks)  # Mask for segmentation model only
    print(results[0].keypoints)  # Keypoints for keypoint model only
    print(results[0].probs)  # Class probabilities for classification model only
    print(results[0].names)  # Class names
    print(results[0].orig_shape)  # Original image shape (height, width)
    print(results[0].speed)  # Inference speed (preprocess, inference,
    
    
    # Draw detections
    annotated_frame = results[0].plot(line_width=1, font_size=0.8)
    cv2.imshow("YOLO11s Detection", annotated_frame)

    if cv2.waitKey(1) & 0xFF == ord("q"):
        break

cap.release()
cv2.destroyAllWindows()
