
# Train KerasCV-based YOLOv8 model on a specified dataset

This code is heavily based on the following tutorials:
  - [Object Detection with KerasCV](https://keras.io/guides/keras_cv/object_detection_keras_cv/#train-a-custom-object-detection-model)
  - [Efficient Object Detection with YOLOV8 and KerasCV](https://keras.io/examples/vision/yolov8/)

## Initialization

Imports:

In [None]:
from dataclasses import dataclass
import os
import sys
import csv
from pprint import pprint
import argparse

import tensorflow as tf
from tensorflow import keras
import keras_cv
from keras_cv import visualization
from keras_cv import bounding_box

from matplotlib import pyplot as plt

Visualization function:

In [None]:
def visualize_dataset(inputs, class_names, value_range=(0, 255), rows=2, cols=2, bounding_box_format=BOUNDING_BOX_FORMAT):
    inputs = next(iter(inputs.take(1)))
    images, bounding_boxes = inputs["images"], inputs["bounding_boxes"]
    visualization.plot_bounding_box_gallery(
        images,
        value_range=value_range,
        rows=rows,
        cols=cols,
        y_true=bounding_boxes,
        scale=5,
        font_scale=0.7,
        bounding_box_format=bounding_box_format,
        class_mapping=class_names,
    )
    plt.show()

Image loader:

In [None]:
def load_image(image_path):
    """
    Loads an image

    Based on https://keras.io/examples/vision/yolov8/
    """

    image = tf.io.read_file(image_path)
    return tf.image.decode_jpeg(image, channels=3)

YOLO annotation file parser:

In [None]:
def parse_YOLO_annotations(annot_path: str):
    """
    Parses a YOLO annotation file

    Based on https://keras.io/examples/vision/yolov8/
    """

    boxes: list[tuple] = []
    classes: list[int] = []

    with open(annot_path, newline="", encoding="utf-8") as file:
        reader = csv.reader(file, delimiter=" ")
        for row in reader:

            # Convert bounding box coords from center to top-left
            box_center_x, box_center_y, box_w, box_h = tuple(float(val) for val in row[1:5])
            box_x = box_center_x - box_w/2
            box_y = box_center_y - box_h/2
            boxes.append((box_x, box_y, box_w, box_h))

            classes.append(int(row[0]))

    return boxes, classes

Evaluation class:

In [None]:
class EvaluateCOCOMetricsCallback(keras.callbacks.Callback):
    def __init__(self, data, save_path):
        super().__init__()
        self.data = data
        self.metrics = keras_cv.metrics.BoxCOCOMetrics(
            bounding_box_format=BOUNDING_BOX_FORMAT,
            evaluate_freq=1e9,
        )

        self.save_path = save_path
        self.best_map = -1.0

    def on_epoch_end(self, epoch, logs):
        self.metrics.reset_state()
        for batch in self.data:
            images, y_true = batch[0], batch[1]
            y_pred = self.model.predict(images, verbose=0)
            self.metrics.update_state(y_true, y_pred)

        metrics = self.metrics.result(force=True)
        logs.update(metrics)

        current_map = metrics["MaP"]
        if current_map > self.best_map:
            self.best_map = current_map
            self.model.save(self.save_path)  # Save the model when mAP improves

        return logs

## Defining constants

In [None]:
MODEL_NAME = "birdsyolo-v0.1"
NUM_CLASSES = 1

IMAGES_PATH = f"images/{MODEL_NAME}/images"
ANNOTATIONS_PATH = f"images/{MODEL_NAME}/annotations"
CLASS_NAMES_PATH = f"models/{MODEL_NAME}/annotations"
DATASET_BOUNDING_BOX_FORMAT = "rel_xywh"

SAVE_PATH = f"models/{MODEL_NAME}/model.keras"

SPLIT_RATIO = 0.2
EPOCHS = 20
BATCH_SIZE = 4
LEARNING_RATE = 0.001
GLOBAL_CLIPNORM = 10.0
TARGET_SIZE = (640, 640)
MODEL_BACKBONE = "resnet50_imagenet"
BOUNDING_BOX_FORMAT = "xywh"
TAKE = 20 # Set to -1 to train on full dataset

## Prepare dataset

Load annotations and image paths:

In [None]:
image_paths = []
boxes = []
classes = []
for annot_file in filter(lambda p: p.endswith(".txt"), os.listdir(ANNOTATIONS_PATH)):
    basename = os.path.splitext(annot_file)[0]
    found_img = False
    for ext in [".jpg", ".jpeg"]:
        img_path = os.path.join(IMAGES_PATH, basename + ext)
        if os.path.exists(img_path):
            found_img = True
            break

    if not found_img:
        continue

    annot_boxes, annot_classes = parse_YOLO_annotations(os.path.join(ANNOTATIONS_PATH, annot_file))

    image_paths.append(img_path)
    boxes.append(annot_boxes)
    classes.append(annot_classes)

n_images = len(image_paths)

Load class names:

In [None]:
with open(CLASS_NAMES_PATH, 'r', newline="") as class_names_f:
    class_names = class_names_f.readlines()

Convert data into dataset from ragged tensors:

In [None]:
image_paths = tf.ragged.constant(image_paths)
boxes = tf.ragged.constant(boxes)
classes = tf.ragged.constant(classes)
data = tf.data.Dataset.from_tensor_slices((image_paths, classes, boxes))

Split data into validation and training:

In [None]:
num_val = int(n_images * SPLIT_RATIO)

val_data = data.take(num_val)
train_data = data.skip(num_val)

Map datasets to loading and transforming function:

In [None]:
def load_dataset(image_path, classes, boxes):
    image = load_image(image_path)
    boxes = keras_cv.bounding_box.convert_format(
        boxes.to_tensor(),
        images=image,
        source=DATASET_BOUNDING_BOX_FORMAT,
        target=BOUNDING_BOX_FORMAT
    )
    bounding_boxes = {
        "classes": tf.cast(classes, dtype=tf.float32),
        "boxes": boxes,
    }
    return {"images": tf.cast(image, tf.float32), "bounding_boxes": bounding_boxes}

train_ds = train_data.map(load_dataset, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_data.map(load_dataset, num_parallel_calls=tf.data.AUTOTUNE)

Shuffle train data and convert to ragged_batch and visualize:

In [None]:
train_ds = train_ds.shuffle(BATCH_SIZE * 4)
train_ds = train_ds.ragged_batch(BATCH_SIZE, drop_remainder=True)
visualize_dataset(
    train_ds,
    value_range=(0, 255),
    rows=2,
    cols=2,
    class_names=class_names
)

Convert validation data to ragged_batch and visualize:

In [None]:
val_ds = val_ds.ragged_batch(BATCH_SIZE, drop_remainder=True)
visualize_dataset(val_ds, class_names=class_names)

Apply image augmentation to train data and visualize:

In [None]:
augmentation = keras.Sequential(
    layers=[
        keras_cv.layers.RandomFlip(mode="horizontal", bounding_box_format=BOUNDING_BOX_FORMAT),
        keras_cv.layers.RandomFlip(mode="vertical", bounding_box_format=BOUNDING_BOX_FORMAT),
        #keras_cv.layers.RandomShear( # corrupts bounding box locations
        #    x_factor=0.2, y_factor=0.2, bounding_box_format=BOUNDING_BOX_FORMAT
        #),
        keras_cv.layers.JitteredResize(
            target_size=TARGET_SIZE, scale_factor=(0.75, 1.3), bounding_box_format=BOUNDING_BOX_FORMAT
        ),
    ],
    name="augmentation"
)

train_ds = train_ds.map(augmentation, num_parallel_calls=tf.data.AUTOTUNE)
visualize_dataset(train_ds, class_names=class_names)

Apply resize to validation data and visualize:

In [None]:
resize = keras_cv.layers.Resizing(
    *TARGET_SIZE, bounding_box_format=BOUNDING_BOX_FORMAT, pad_to_aspect_ratio=True
)

val_ds = val_ds.map(resize, num_parallel_calls=tf.data.AUTOTUNE)
visualize_dataset(train_ds, class_names=class_names)

Map train and validation dataset to the correct format for training:

In [None]:
def input_tuple(inputs):
    return inputs["images"], bounding_box.to_dense(
        inputs["bounding_boxes"], max_boxes=32
    )

train_ds = train_ds.map(input_tuple, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.map(input_tuple, num_parallel_calls=tf.data.AUTOTUNE)

Prefetch:

In [None]:
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.prefetch(tf.data.AUTOTUNE)

Take a subset of data:

In [None]:
train_ds = train_ds.take(TAKE)
val_ds = val_ds.take(TAKE)

## Create model

In [None]:
model = keras_cv.models.YOLOV8Detector.from_preset(
    MODEL_BACKBONE,
    bounding_box_format=BOUNDING_BOX_FORMAT,
    num_classes=NUM_CLASSES,
)

## Train model

Create optimizer:

In [None]:
optimizer = keras.optimizers.legacy.SGD(
    learning_rate=LEARNING_RATE, momentum=0.9, global_clipnorm=GLOBAL_CLIPNORM
)

Compile the model for training:

In [None]:
model.compile(
    classification_loss="binary_crossentropy",
    box_loss="ciou",
    optimizer=optimizer,
)

Train the model:

In [None]:
hist = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[EvaluateCOCOMetricsCallback(val_ds, "model.h5")],
)

Plot the history: