In [None]:
# Imports
import sys
sys.path.append('../')
import tensorflow as tf
from tensorflow import keras
from helper import fast_benchmark, set_model_config
from helper import plot_loss
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import keras_cv
from keras_cv import visualization
from keras_cv import bounding_box

In [None]:
config  = set_model_config(model_name='pascal_yolo')
config

In [None]:
# Load the Pascal-Voc dataset

# Visualize the Keras-CV compatible dataset
def visualize_object_detection_samples(inputs, value_range, rows, cols, bounding_box_format, class_mapping):
    inputs = next(iter(inputs.take(1)))
    images, bounding_boxes = inputs["images"], inputs["bounding_boxes"]
    visualization.plot_bounding_box_gallery(
        images,
        value_range=value_range,
        rows=rows,
        cols=cols,
        y_true=bounding_boxes,
        scale=5,
        font_scale=0.7,
        bounding_box_format=bounding_box_format,
        class_mapping=class_mapping,
    )

# Get the class mapping dictionary
def get_class_mapping(dataset_info):
    class_mapping = {i: class_info for i, class_info in enumerate(dataset_info.features['objects']['label'].names)}
    return class_mapping


# Unpackage the raw tfdf formats into Keras-CV format
def unpackage_raw_tfds_inputs(inputs, bounding_box_format):
    image = inputs["image"]
    boxes = keras_cv.bounding_box.convert_format(
        inputs["objects"]["bbox"],
        images=image,
        source="rel_yxyx",
        target=bounding_box_format,
    )
    bounding_boxes = {
        "classes": tf.cast(inputs["objects"]["label"], dtype=tf.float32),
        "boxes": tf.cast(boxes, dtype=tf.float32),
    }
    return {"images": tf.cast(image, tf.float32), "bounding_boxes": bounding_boxes}

# Unpack batch from dataset to tuple format function
def unpack_batch_dicts(inputs):
    return inputs["images"], inputs["bounding_boxes"]

# Custom dataloader, compatible with Keras-CV, applies shuffling and batching
def load_pascal_voc(split, dataset, bounding_box_format):
    ds, ds_info  = tfds.load(dataset, split=split, with_info=True, shuffle_files=True)
    
    # Convert the images/bboxes to the Keras-CV API format
    ds = ds.map(
        lambda x: unpackage_raw_tfds_inputs(x, bounding_box_format=bounding_box_format),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    
    if split != 'test':
        ds = ds.shuffle(config['batch_size'] * 4, reshuffle_each_iteration=True)
    
    # Create ragged batches(with elems of different sizes)
    ds = ds.ragged_batch(config['batch_size'], drop_remainder=True)
    
    return ds,ds_info

# Define augmenter module using custom object detection friendly ops from Keras-CV
augmenter = keras.Sequential(
    layers=[
        keras_cv.layers.RandomFlip(mode="horizontal", bounding_box_format="xywh"),
        keras_cv.layers.RandomShear(
            x_factor=0.2, y_factor=0.2, bounding_box_format="xywh"
        ),
        keras_cv.layers.JitteredResize(
            target_size=(480, 480), scale_factor=(0.75, 1.3), bounding_box_format="xywh"
        ),
    ]
)

# # Inference inputs pre-processing for our test and validation sets
inf_preprocess = keras_cv.layers.JitteredResize(target_size=(480, 480),
                                              scale_factor=(0.75, 1.3),
                                              bounding_box_format="xywh",
                                            )

# Load the three different pre-processed splits of our dataset
ds_train, ds_info = load_pascal_voc(
    split="train", dataset="voc/2007", bounding_box_format="xywh"
)
ds_val, _ = load_pascal_voc(
    split="validation", dataset="voc/2007", bounding_box_format="xywh"
)
ds_test, _ = load_pascal_voc(
    split="test", dataset="voc/2007", bounding_box_format="xywh"
)

# Apply augmentations and set prefetch option on training set
ds_train = ds_train.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.map(unpack_batch_dicts, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

# Pre-process validation and test set
ds_val = ds_val.map(inf_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
ds_val = ds_val.map(unpack_batch_dicts, num_parallel_calls=tf.data.AUTOTUNE)
ds_val = ds_val.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(inf_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.shuffle(ds_info.splits['test'].num_examples)
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

In [None]:
ds_info

In [None]:
print("----Pascal-Voc dataset information-----:")
print(f"Number of training examples: {ds_info.splits['train'].num_examples}")
print(f"Number of validation examples: {ds_info.splits['validation'].num_examples}")
print(f"Number of test examples: {ds_info.splits['test'].num_examples}")
print(f"Dataset splits available: {list(ds_info.splits.keys())}")
print("Number of Classes:", len(ds_info.features["objects"]["label"].names))
print(f"Class names: {ds_info.features['objects']['label'].names}")

In [None]:
# Visualize some samples from the dataset
with plt.style.context('dark_background'):
    visualize_object_detection_samples(ds_train, value_range=(0, 255), rows=2, cols=4, bounding_box_format="xywh", class_mapping=get_class_mapping(ds_info))

In [None]:
# Load a pre-trained, on the CoCo dataset, YoloV8 model
backbone =  keras_cv.models.YOLOV8Backbone.from_preset("yolo_v8_xs_backbone")
backbone.summary()

In [None]:
# Create our custom YOLO model from the smallest available backbone
model = keras_cv.models.YOLOV8Detector(
    num_classes=len(get_class_mapping(ds_info)),
    bounding_box_format="xywh",
    backbone=backbone,
    fpn_depth = 1
)
model.summary()

In [None]:
# Compile and configure the model for training
if config['optimizer'].lower() == 'adam':
    optimizer = tf.keras.optimizers.Adam(learning_rate= config['learning_rate'],
                                      global_clipnorm= config['global_clipnorm'])

# Compile and train
model.compile(optimizer=optimizer,
              classification_loss= 'binary_crossentropy',
              box_loss="ciou")
history = model.fit(ds_train, validation_data= ds_val, epochs = config['training_epochs'])

# Plot with dark background
with plt.style.context('dark_background'):
    plot_loss(history, model_type = 'object_detection')

In [None]:
# Load a trained model and visualize predictions
from keras.models import load_model

test_iterator = iter(ds_test.take(1))
single_image = next(test_iterator)[0]

# Load a trained model and visualize predictions
# trained_model = load_model('computer_vision/trained_models/pascal_yolo_model')