# Imports

In [None]:
import sys
sys.path.append('../')
import tensorflow as tf
from tensorflow import keras
from helper import set_model_config, plot_loss
from helper import visualize_object_predictions, visualize_object_detection_samples
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ModelCheckpoint
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import keras_cv
from keras_cv import bounding_box

# Load config

In [None]:
config  = set_model_config(model_name='pascal_yolo')
config

# Data preparation pipeline

In [None]:
# Load the Pascal-Voc dataset

# Get the class mapping dictionary
def get_class_mapping(dataset_info):
    class_mapping = {i: class_info for i, class_info in enumerate(dataset_info.features['objects']['label'].names)}
    return class_mapping


# Unpackage the raw tfdf formats into Keras-CV format
def unpackage_raw_tfds_inputs(inputs, bounding_box_format):
    image = inputs["image"]
    boxes = keras_cv.bounding_box.convert_format(
        inputs["objects"]["bbox"],
        images=image,
        source="rel_yxyx",
        target=bounding_box_format,
    )
    
    bounding_boxes = {
        "classes": tf.cast(inputs["objects"]["label"], dtype = tf.float32),
        "boxes": tf.cast(boxes, dtype=tf.int32)
    }
    return {"images": tf.cast(image, tf.float32), "bounding_boxes": bounding_boxes}

# Unpack batch from dataset to tuple format function
def unpack_batch_dicts(inputs):
    # Define the operation to cast values to tf.int32
    def cast_to_int32(value):
        return tf.cast(value, dtype=tf.int32)
    
    correct_gt = {"classes":  tf.ragged.map_flat_values(cast_to_int32, inputs['bounding_boxes']['classes']),
                    "boxes":  inputs['bounding_boxes']['boxes']}
    return inputs["images"], bounding_box.to_dense(correct_gt, max_boxes = 100)

# Custom dataloader, compatible with Keras-CV, applies shuffling and batching
def load_pascal_voc(split, dataset, bounding_box_format):
    ds, ds_info  = tfds.load(dataset, split=split, with_info=True, shuffle_files=True)
    
    # Convert the images/bboxes to the Keras-CV API format
    ds = ds.map(
        lambda x: unpackage_raw_tfds_inputs(x, bounding_box_format=bounding_box_format),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    
    # Create ragged batches(with elems of different sizes) #TODO: Do we need this?
    if split == 'test':
        ds = ds.ragged_batch(8, drop_remainder=True)
    else:
        ds = ds.shuffle(config['batch_size'] * 4, reshuffle_each_iteration=True)
        ds = ds.ragged_batch(config['batch_size'], drop_remainder=True)
    
    return ds,ds_info

# Define augmenter module using custom object detection friendly ops from Keras-CV
augmenter = keras.Sequential(
    layers=[
        keras_cv.layers.RandomFlip(mode="horizontal", bounding_box_format="xywh"),
        keras_cv.layers.RandomShear(
            x_factor=0.2, y_factor=0.2, bounding_box_format="xywh"
        ),
        keras_cv.layers.JitteredResize(
            target_size=(480, 480), scale_factor=(0.75, 1.3), bounding_box_format="xywh"
        ),
    ]
)

# # Inference inputs pre-processing for our test and validation sets
inf_preprocess = keras_cv.layers.JitteredResize(target_size=(480, 480),
                                              scale_factor=(0.75, 1.3),
                                              bounding_box_format="xywh",
                                            )

# Load the three different pre-processed splits of our dataset
ds_train, ds_info = load_pascal_voc(
    split="train", dataset="voc/2007", bounding_box_format="xywh"
)
ds_val, _ = load_pascal_voc(
    split="validation", dataset="voc/2007", bounding_box_format="xywh"
)
ds_test, _ = load_pascal_voc(
    split="test", dataset="voc/2007", bounding_box_format="xywh"
)

# Dataset information and samples visualization

In [None]:
print("----Pascal-Voc dataset information-----:")
print(f"Number of training examples: {ds_info.splits['train'].num_examples}")
print(f"Number of validation examples: {ds_info.splits['validation'].num_examples}")
print(f"Number of test examples: {ds_info.splits['test'].num_examples}")
print(f"Dataset splits available: {list(ds_info.splits.keys())}")
print("Number of Classes:", len(ds_info.features["objects"]["label"].names))
print(f"Class names: {ds_info.features['objects']['label'].names}")

In [None]:
# Visualize some samples from the dataset
with plt.style.context('dark_background'):
    visualize_object_detection_samples(ds_train, value_range=(0, 255), rows=2, cols=4, bounding_box_format="xywh", class_mapping=get_class_mapping(ds_info))

# Apply augmentations and pre-processing

In [None]:
# Apply augmentations and set prefetch option on training set
ds_train = ds_train.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.map(unpack_batch_dicts, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

# Pre-process validation and test set
ds_val = ds_val.map(inf_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
ds_val = ds_val.map(unpack_batch_dicts, num_parallel_calls=tf.data.AUTOTUNE)
ds_val = ds_val.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(inf_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.shuffle(ds_info.splits['test'].num_examples)
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

# Load a backbone architecture from Keras-CV and create our own YoloV8 model

In [None]:
# Load a pre-trained, on the CoCo dataset, YoloV8 model
backbone =  keras_cv.models.YOLOV8Backbone.from_preset("yolo_v8_xs_backbone")
backbone.summary()

In [None]:
# Create our custom YOLO model from the smallest available backbone

# Set the Non-Maximum Supression module for predictions decoding
# NOTE: Tune confidence threshold for predictions to pass NMS
# NOTE: Decrease the required threshold to make predictions get pruned out
prediction_decoder = keras_cv.layers.NonMaxSuppression(
    bounding_box_format="xywh",
    from_logits=True,
    iou_threshold=0.3,
    confidence_threshold=0.7,
)

model = keras_cv.models.YOLOV8Detector(
    num_classes=len(get_class_mapping(ds_info)),
    bounding_box_format="xywh",
    backbone=backbone,
    fpn_depth = 1,
    prediction_decoder = prediction_decoder
)
model.summary()

# Set training callbacks and train the model

In [None]:
# Set Early Stopping strategy after 5 epochs of no improvement in total loss for validation set
early_stopping = EarlyStopping(monitor='val_loss', patience=5)

# Always save the best model
saving_cb = ModelCheckpoint(
    filepath='../trained_models/pascal_yolo_model/best_weights.h5',
    save_weights_only=True,
    monitor='val_loss', 
    mode='min',  
    save_best_only=True,
    verbose=1
)

# Use the PyCOCO metrics callback to track the mAP across different box sizes for all classes
metrics_cb = keras_cv.callbacks.PyCOCOCallback(
    ds_val.take(20), bounding_box_format="xywh"
)

# Compile and train
model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate= config['learning_rate'],
                                         global_clipnorm= config['global_clipnorm']),
              classification_loss= 'binary_crossentropy',
              box_loss="ciou")
history = model.fit(ds_val.take(5), validation_data= ds_val.take(5), epochs = 1, callbacks = [early_stopping, 
                                                                                                      saving_cb, 
                                                                                                      metrics_cb])

# Save the model
model.save("../trained_models/pascal_yolo_model/saved_model")

# Plot losses

In [None]:
# Plot losses
with plt.style.context('dark_background'):
    plot_loss(history, model_type= 'object_detection')

# Load a trained model and predict on the test set

In [None]:
# Load a trained model and visualize predictions
from keras.models import load_model

trained_model = load_model('computer_vision/trained_models/pascal_yolo_model')

with plt.style.context('dark_background'):
    visualize_object_predictions(trained_model, dataset= ds_test, bounding_box_format='xywh', class_mapping= get_class_mapping(ds_info))