# YOlO experimentation

This notebook provides the code to create training data, and train the ultralytics YOLO model to produce oriented bounding boxes (OBB) around pokemon cards in an image.

### notes

ONNX format - potential increased execution speed, can be used by many libraries https://docs.ultralytics.com/integrations/onnx/


### setup

In [80]:
!pip install ultralytics

Defaulting to user installation because normal site-packages is not writeable


In [81]:
# imports
import os
import random
import zipfile
import urllib
import math
from ultralytics import YOLO
from PIL import Image, ImageStat, ImageEnhance
import tqdm


# Dataset creation

In [82]:
# configure dataset settings here

DATASET_PATH = "../datasets/YOLO_training"
OBJECTS_PATH = "../datasets/pokemon/data/images" # the objects you want the model to detect
OBJECTS_ASPECT_RATIO = 825 / 600
IMAGE_DIMENSION = 640

# dataset generation settings
AMOUNT_TRAIN = 1000
AMOUNT_VAL = 25
MAX_CARDS_PER_IMAGE = 6
ALLOW_OVERLAP = False
BLACK_AND_WHITE = False
OBSTRUCTIONS = True

# training settings
EPOCHS = 100


### download images to use as backgrounds

In [83]:
url = "http://images.cocodataset.org/zips/val2017.zip"
output_path = DATASET_PATH + "/backgrounds/val2017.zip" 

# create directory
if not os.path.exists(os.path.dirname(output_path)):
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

if not os.path.exists(DATASET_PATH + "/backgrounds/val2017/"):
    # download and extract
    print("Downloading background images...")
    urllib.request.urlretrieve(url, output_path)
    print("Extracting")
    with zipfile.ZipFile(output_path, 'r') as zip_ref:
        zip_ref.extractall(DATASET_PATH + "/backgrounds/")

    # cleanup
    os.remove(output_path)

### Remove empty directories from objects path

this is done since we select random directories to find random images, and this breaks if empty directories exist 

In [84]:
# remove all empty directories in OBJECTS_PATH
for dirpath, dirnames, filenames in os.walk(OBJECTS_PATH, topdown=False):
    if not dirnames and not filenames:
        os.rmdir(dirpath)


### Utility functions

In [85]:
def load_random_background() -> str:
    bg_path = DATASET_PATH + "/backgrounds/val2017/"
    bg_image = random.choice(os.listdir(bg_path))
    return bg_path + bg_image

def load_random_card() -> str:
    # images may be in subdirectories
    path = random.choice(os.listdir(OBJECTS_PATH))
    while os.path.isdir(os.path.join(OBJECTS_PATH, path)):
        path = os.path.join(path, random.choice(os.listdir(os.path.join(OBJECTS_PATH, path))))
        # check if directory is empty
        if os.path.isdir(path) and len(os.listdir(path)) == 0:
            # return the first card that is downloaded in case all of the cards haven't been downloaded yet
            return "../datasets/pokemon/data/images/base1/base1-1.jpg"
    return os.path.join(OBJECTS_PATH, path)

# oriented bounding box class
# includes overlap detection, and YOLO OBB format export
class OBB:
    def __init__(self, cx: float, cy: float, w: float, h: float, angle: float):
        self.cx = cx  # center x
        self.cy = cy  # center y
        self.w = w    # width
        self.h = h    # height
        self.angle = angle  # rotation angle in radians

        self.half_w = w / 2
        self.half_h = h / 2
        self.cos_a = math.cos(angle)
        self.sin_a = math.sin(angle)

    # The YOLO OBB format designates bounding boxes by their four corner points with coordinates normalized between 0 and 1. It follows this format:
    # class_index x1 y1 x2 y2 x3 y3 x4 y4
    def get_YOLO_OBB_format(self, class_index: int) -> str:
        # Calculate corner points
        hs = self.half_h * self.sin_a
        hc = self.half_h * self.cos_a
        ws = self.half_w * self.sin_a
        wc = self.half_w * self.cos_a


        # Top-left corner
        x1 = (self.cx - wc + hs) / IMAGE_DIMENSION
        y1 = (self.cy - ws - hc) / IMAGE_DIMENSION

        # Top-right corner
        x2 = (self.cx + wc + hs) / IMAGE_DIMENSION
        y2 = (self.cy + ws - hc) / IMAGE_DIMENSION
        # Bottom-right corner
        x3 = (self.cx + wc - hs) / IMAGE_DIMENSION
        y3 = (self.cy + ws + hc) / IMAGE_DIMENSION

        # Bottom-left corner
        x4 = (self.cx - wc - hs) / IMAGE_DIMENSION
        y4 = (self.cy - ws + hc) / IMAGE_DIMENSION

        return f"{class_index} {x1} {y1} {x2} {y2} {x3} {y3} {x4} {y4}"
    
    def _project_onto_axis(self, ax: float, ay: float) -> tuple:
        # Calculate the projection of the OBB onto the given axis
        r = (self.half_w * abs(ax * self.cos_a + ay * self.sin_a) +
             self.half_h * abs(ax * -self.sin_a + ay * self.cos_a))
        return (0, r)
    
    # Check if this OBB overlaps with another OBB
    def overlaps_with(self, other: 'OBB') -> bool:
        # use SAT (Separating Axis Theorem) for OBB overlap detection
        
        tx = other.cx - self.cx
        ty = other.cy - self.cy

        axes = [
            (self.cos_a, self.sin_a),
            (-self.sin_a, self.cos_a),
            (other.cos_a, other.sin_a),
            (-other.sin_a, other.cos_a)
        ]

        for ax, ay in axes:
            # Project both OBBs onto the axis
            projection1 = self._project_onto_axis(ax, ay)
            projection2 = other._project_onto_axis(ax, ay)

            # Project the vector between centers onto the axis
            distance = tx * ax + ty * ay

            # Check for overlap
            if distance > (projection1[1] + projection2[1]) or distance < -(projection1[1] + projection2[1]):
                return False  # No overlap

        return True
    
def random_OBB(h: float) -> OBB:
    w = h / OBJECTS_ASPECT_RATIO
    half_diag = math.sqrt(w**2 + h**2 ) / 2
    
    cx = random.uniform(half_diag, IMAGE_DIMENSION - half_diag)
    cy = random.uniform(half_diag, IMAGE_DIMENSION - half_diag)
    angle = random.uniform(0, 2 * math.pi)
    return OBB(cx, cy, w, h, angle)

### Images creation


In [None]:
def create_test_item():
    
    # smaller cards if more cards per image
    card_amount = random.randint(1, MAX_CARDS_PER_IMAGE)
    max_card_size = IMAGE_DIMENSION / math.ceil(math.sqrt(card_amount))
    max_card_size = min(max_card_size, math.sqrt(IMAGE_DIMENSION*2 + (IMAGE_DIMENSION / OBJECTS_ASPECT_RATIO)**2))
    card_height = random.uniform(50, max_card_size)

    OBBs: list[OBB] = []
    cur_OBB: OBB = random_OBB(card_height)
    fails = 0

    # find positions for up to card_amount cards
    for _ in range(card_amount):

        if ALLOW_OVERLAP:
            OBBs.append(cur_OBB)
            cur_OBB = random_OBB(card_height)
            continue

        # try random positions until no overlap, or give up after 10 tries
        colliding = True
        while colliding and fails <= 10:
            colliding = False
            for obb in OBBs:
                if cur_OBB.overlaps_with(obb):
                    colliding = True
                    break
            if colliding:
                cur_OBB = random_OBB(card_height)
                fails += 1
        if fails < 10:
            OBBs.append(cur_OBB)    

        cur_OBB = random_OBB(card_height)

    # use python image library to paste cards onto background
    
    # load background 
    bg_image_path = load_random_background()
    bg_image = Image.open(bg_image_path).convert("RGB")
    if BLACK_AND_WHITE:
        bg_image = bg_image.convert("L").convert("RGB")
    bg_image = bg_image.resize((IMAGE_DIMENSION, IMAGE_DIMENSION))

    # calculate average brightness of background
    bg_image_gray = bg_image.convert("L")
    stat = ImageStat.Stat(bg_image_gray)
    bg_brightness = stat.mean[0]

    # calculate image white balance
    stat = ImageStat.Stat(bg_image)
    r_avg, g_avg, b_avg = stat.mean
    gray_avg = (r_avg + g_avg + b_avg) / 3
    r_ratio = gray_avg / (r_avg + 1e-5)
    g_ratio = gray_avg / (g_avg + 1e-5)
    b_ratio = gray_avg / (b_avg + 1e-5)

    # paste cards
    for obb in OBBs:
        # load card
        card_image_path = load_random_card()
        try:
            card_image = Image.open(card_image_path).convert("RGBA")
        except:
            continue
        if BLACK_AND_WHITE:
            card_image = card_image.convert("L").convert("RGBA")
        card_image = card_image.resize((int(obb.w), int(obb.h)))

        # adjust brightness to match background
        card_image_gray = card_image.convert("L")
        stat = ImageStat.Stat(card_image_gray)
        card_brightness = stat.mean[0]

        brightness_ratio = bg_brightness / (card_brightness + 1e-5)
        enhancer = ImageEnhance.Brightness(card_image)
        card_image = enhancer.enhance(brightness_ratio)

        # adjust white balance to match background
        r, g, b, a = card_image.split()
        r = r.point(lambda i: i * r_ratio)
        g = g.point(lambda i: i * g_ratio)
        b = b.point(lambda i: i * b_ratio)
        card_image = Image.merge('RGBA', (r, g, b, a))

        # apply obstruction if enabled
        if OBSTRUCTIONS and random.random() < 0.66:
            draw = Image.new('RGBA', card_image.size, (0, 0, 0, 0))
            obstruction_amount = random.randint(0, 10)

            # generate random shapes of varying size, shape, brightness, and alpha
            for _ in range(obstruction_amount):
                obs_w = random.randint(int(card_image.width * 0.1), int(card_image.width * 0.6))
                obs_h = random.randint(int(card_image.height * 0.1), int(card_image.height * 0.6))
                obs_x = random.randint(0, card_image.width - obs_w)
                obs_y = random.randint(0, card_image.height - obs_h)
                shade = random.randint(0, 4) * 255 // 4
                obstruction = Image.new('RGBA', (obs_w, obs_h), (shade, shade, shade, random.randint(150, 255)))
                # remove random shapes from obstruction to make it less blocky
                mask = Image.new('L', (obs_w, obs_h), 255)
                for _ in range(random.randint(5, 15)):
                    shape_w = random.randint(int(obs_w * 0.1), int(obs_w * 0.5))
                    shape_h = random.randint(int(obs_h * 0.1), int(obs_h * 0.5))
                    shape_x = random.randint(0, obs_w - shape_w)
                    shape_y = random.randint(0, obs_h - shape_h)
                    shape = Image.new('L', (shape_w, shape_h), 0)
                    mask.paste(shape, (shape_x, shape_y))
                obstruction.putalpha(mask)
                obstruction = obstruction.rotate(random.uniform(0, 360), expand=True)
                draw.paste(obstruction, (obs_x, obs_y), obstruction)
            card_image = Image.alpha_composite(card_image, draw)

        # rotate card
        angle_degrees = math.degrees(obb.angle)
        card_image = card_image.rotate(angle_degrees, expand=True)

        # calculate position to paste
        paste_x = int(obb.cx - card_image.width / 2)
        paste_y = int(obb.cy - card_image.height / 2)

        # paste card onto background
        bg_image.paste(card_image, (paste_x, paste_y), card_image)
    
    return bg_image, OBBs

def save_image_and_labels(image: Image.Image, obbs: list[OBB], index: int, train: bool = True):
    # create directories if they don't exist
    images_path = DATASET_PATH + ("/images/train" if train else "/images/val")
    labels_path = DATASET_PATH + ("/labels/train" if train else "/labels/val")
    os.makedirs(images_path, exist_ok=True)
    os.makedirs(labels_path, exist_ok=True)

    # validate OBBs are within bounds
    for obb in obbs:
        nums = obb.get_YOLO_OBB_format(0).split(" ")[1:]
        for num in nums:
            if float(num) < 0 or float(num) > 1:
                print("out of bounds card")
                return

    # save image
    image_save_path = os.path.join(images_path, f"image_{index:05d}.jpg")
    image.save(image_save_path)

    # save labels
    label_save_path = os.path.join(labels_path, f"image_{index:05d}.txt")
    with open(label_save_path, 'w') as f:
        for obb in obbs:
            f.write(obb.get_YOLO_OBB_format(0) + "\n")

for i in tqdm.tqdm(range(AMOUNT_TRAIN), desc="Generating training dataset"):
    img, OBBs = create_test_item()
    save_image_and_labels(img, OBBs, i, train=True)

for i in tqdm.tqdm(range(AMOUNT_VAL), desc="Generating validation dataset"):
    img, OBBs = create_test_item()
    save_image_and_labels(img, OBBs, i, train=False)

yaml_content = f"""
path: {DATASET_PATH}
train: images/train
val: images/val

names:
    0: card
"""

with open(os.path.join(DATASET_PATH, "card_yolo_dataset.yaml"), 'w') as f:
    f.write(yaml_content)

Generating training dataset: 100%|██████████| 1000/1000 [01:14<00:00, 13.39it/s]
Generating validation dataset: 100%|██████████| 25/25 [00:01<00:00, 12.97it/s]


# Training

### model download

In [None]:
model = YOLO("yolo26n-obb.pt")

### model training

In [None]:
results = model.train(data=os.path.join(DATASET_PATH, "card_yolo_dataset.yaml"), epochs=EPOCHS , imgsz=IMAGE_DIMENSION)

# save model
model.save("card_yolo.pt")

### model testing

In [None]:
model = YOLO(os.path.join(DATASET_PATH, "card_yolo.pt"))

metrics = model.val(data=os.path.join(DATASET_PATH, "card_yolo_dataset.yaml"))

metrics.box.map  # map50-95
metrics.box.map50  # map50
metrics.box.map75  # map75
metrics.box.maps  # a list containing mAP50-95 for each category