In [1]:
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO

model = YOLO("../MODELs/yolov8n")

In [3]:
font = ImageFont.truetype("../FONTs/STHeiti Light.ttc", 16) # Load the font for the labels
bg_color_ranges = {"#FF0000_#181818": [0,1], "#FF9900_#181818": [1,14], "#341A36_#FFFFFF": [14,24], "#00C036_#181818": [24,80]}
color_labels = []

for i in range(0,80):
    for color in bg_color_ranges:
        if i >= bg_color_ranges[color][0] and i < bg_color_ranges[color][1]:
            color_labels.append(color)
            break

In [21]:
img_path = "test_image.jpg"
img = Image.open(img_path) # Load the image

results = model(img, verbose=False) # Perform inference and get the results

result = results[0] # Get the first result

boxes = result.boxes.xyxy # Get the bounding boxes
cls = result.boxes.cls.tolist() # Get the class IDs
conf = result.boxes.conf.tolist() # Get the confidence values

names = result.names

draw = ImageDraw.Draw(img)
for index in range(len(boxes)):
    if round(conf[index], 2) < 0.32:
        continue
    
    box_data = boxes[index].tolist()  # Convert tensor to list
    filling_color = color_labels[int(cls[index])].split("_")

    # Check if the length of the box_data matches expected number of elements for just coordinates
    if len(box_data) == 4:
        x1, y1, x2, y2 = box_data  # Unpack the coordinates

        cls_label = names[cls[index]]  # Get the class name using class ID or default to "Unknown"
        conf_label = int(round(conf[index], 2)*100)  # Get the confidence and convert to percentage
        label = f"{cls_label} {conf_label}%"  # Create label with class name and confidence

        draw.rectangle([x1, y1, x2, y2], outline=filling_color[0], width=3)  # Draw the rectangle
        text_bg = [x1, max(y1 - 16,0), x1 + (len(cls_label)+5) * 9, y1] # Create background rectangle for text
        draw.rectangle(text_bg, fill=filling_color[0])
        draw.text((x1+2, max(y1 - 16,0)), label, fill=filling_color[1], font=font)  # Draw the label    
    else:
        print("Unexpected box data format:", box_data)  # Add an error message

img.show()  # Display the image