diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 928becf3c0..fe7a5ea3a1 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -72,3 +72,6 @@ from keras_hub.src.models.whisper.whisper_audio_converter import ( WhisperAudioConverter, ) +from keras_hub.src.models.yolo_v8.yolo_v8_image_converter import ( + YOLOV8ImageConverter, +) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 248232312d..6d2b9c12ba 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -372,4 +372,11 @@ XLMRobertaTokenizer, ) from keras_hub.src.models.xlnet.xlnet_backbone import XLNetBackbone +from keras_hub.src.models.yolo_v8.yolo_v8_backbone import YOLOV8Backbone +from keras_hub.src.models.yolo_v8.yolo_v8_detector import ( + YOLOV8ImageObjectDetector, +) +from keras_hub.src.models.yolo_v8.yolo_v8_object_detector_preprocessor import ( + YOLOV8ImageObjectDetectorPreprocessor, +) from keras_hub.src.tokenizers.tokenizer import Tokenizer diff --git a/keras_hub/src/models/yolo_v8/__init__.py b/keras_hub/src/models/yolo_v8/__init__.py new file mode 100644 index 0000000000..4ca99f14a3 --- /dev/null +++ b/keras_hub/src/models/yolo_v8/__init__.py @@ -0,0 +1,10 @@ +from keras_hub.src.models.yolo_v8.yolo_v8_backbone import YOLOV8Backbone +from keras_hub.src.models.yolo_v8.yolo_v8_detector import ( + YOLOV8ImageObjectDetector, +) +from keras_hub.src.models.yolo_v8.yolo_v8_presets import backbone_presets +from keras_hub.src.models.yolo_v8.yolo_v8_presets import detector_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, YOLOV8Backbone) +register_presets(detector_presets, YOLOV8ImageObjectDetector) diff --git a/keras_hub/src/models/yolo_v8/ciou_loss.py b/keras_hub/src/models/yolo_v8/ciou_loss.py new file mode 100644 index 0000000000..e6d48f5b36 --- /dev/null +++ b/keras_hub/src/models/yolo_v8/ciou_loss.py @@ -0,0 +1,109 @@ +import keras +from keras import ops +from keras.utils.bounding_boxes import compute_ciou + + +class CIoULoss(keras.losses.Loss): + """Implements the Complete IoU (CIoU) Loss + + CIoU loss is an extension of GIoU loss, which further improves the IoU + optimization for object detection. CIoU loss not only penalizes the + bounding box coordinates but also considers the aspect ratio and center + distance of the boxes. The length of the last dimension should be 4 to + represent the bounding boxes. + + Args: + bounding_box_format: a case-insensitive string (for example, "xyxy"). + Each bounding box is defined by these 4 values. For detailed + information on the supported formats, see the [Keras bounding box + documentation](https://github.com/keras-team/keras/blob/master/ + keras/src/layers/preprocessing/image_preprocessing/ + bounding_boxes/formats.py). + epsilon: (optional) float, a small value added to avoid division by + zero and stabilize calculations. Defaults 1e-07. + + References: + - [CIoU paper](https://arxiv.org/pdf/2005.03572.pdf) + + Example: + ```python + y_true = np.random.uniform( + size=(5, 10, 4), + low=0, + high=10) + y_pred = np.random.uniform( + size=(5, 10, 4), + low=0, + high=10) + loss = keras_hub.src.models.yolo_v8.ciou_loss.CIoULoss("xyxy") + loss(y_true, y_pred).numpy() + ``` + + Usage with the `compile()` API: + ```python + model.compile(optimizer="adam", loss=CIoULoss("xyxy")) + model.fit(y_true, y_pred) + ``` + """ + + def __init__( + self, bounding_box_format, epsilon=1e-07, image_shape=None, **kwargs + ): + super().__init__(**kwargs) + box_formats = [ + "xywh", + "center_xywh", + "center_yxhw", + "rel_xywh", + "xyxy", + "rel_xyxy", + "yxyx", + "rel_yxyx", + ] + if bounding_box_format not in box_formats: + raise ValueError( + f"Invalid bounding box format: '{bounding_box_format}'. " + f"Expected one of {box_formats}. " + "Ensure that the string format is correctly spelled." + ) + self.bounding_box_format = bounding_box_format + self.epsilon = epsilon + self.image_shape = image_shape + + def call(self, y_true, y_pred): + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + + if y_pred.shape[-1] != 4: + raise ValueError( + "CIoULoss expects y_pred.shape[-1] to be 4 to represent the " + f"bounding boxes. Received y_pred.shape[-1]={y_pred.shape[-1]}." + ) + + if y_true.shape[-1] != 4: + raise ValueError( + "CIoULoss expects y_true.shape[-1] to be 4 to represent the " + f"bounding boxes. Received y_true.shape[-1]={y_true.shape[-1]}." + ) + + if y_true.shape[-2] != y_pred.shape[-2]: + raise ValueError( + "CIoULoss expects number of boxes in y_pred to be equal to the " + "number of boxes in y_true. Received number of boxes in " + f"y_true={y_true.shape[-2]} and number of boxes in " + f"y_pred={y_pred.shape[-2]}." + ) + + ciou = compute_ciou( + y_true, y_pred, self.bounding_box_format, self.image_shape + ) + return 1 - ciou + + def get_config(self): + config = super().get_config() + config.update( + { + "epsilon": self.epsilon, + } + ) + return config diff --git a/keras_hub/src/models/yolo_v8/ciou_loss_test.py b/keras_hub/src/models/yolo_v8/ciou_loss_test.py new file mode 100644 index 0000000000..3a0e4b22ca --- /dev/null +++ b/keras_hub/src/models/yolo_v8/ciou_loss_test.py @@ -0,0 +1,74 @@ +import numpy as np +from absl.testing import parameterized + +from keras_hub.src.models.yolo_v8.ciou_loss import CIoULoss +from keras_hub.src.tests.test_case import TestCase + + +class CIoUTest(TestCase): + def test_output_shape(self): + y_true = np.random.uniform(size=(2, 2, 4), low=0, high=10) + y_pred = np.random.uniform(size=(2, 2, 4), low=0, high=20) + + ciou_loss = CIoULoss(bounding_box_format="xywh") + + self.assertAllEqual(ciou_loss(y_true, y_pred).shape, ()) + + def test_output_shape_reduction_none(self): + y_true = np.random.uniform(size=(2, 2, 4), low=0, high=10) + y_pred = np.random.uniform(size=(2, 2, 4), low=0, high=20) + + ciou_loss = CIoULoss(bounding_box_format="xyxy", reduction="none") + + self.assertAllEqual( + [2, 2], + ciou_loss(y_true, y_pred).shape, + ) + + def test_output_shape_relative_formats(self): + y_true = [ + [0.0, 0.0, 0.1, 0.1], + [0.0, 0.0, 0.2, 0.3], + [0.4, 0.5, 0.5, 0.6], + [0.2, 0.2, 0.3, 0.3], + ] + + y_pred = [ + [0.0, 0.0, 0.5, 0.6], + [0.0, 0.0, 0.7, 0.3], + [0.4, 0.5, 0.5, 0.6], + [0.2, 0.1, 0.3, 0.3], + ] + + ciou_loss = CIoULoss(bounding_box_format="xyxy") + + self.assertAllEqual(ciou_loss(y_true, y_pred).shape, ()) + + @parameterized.named_parameters( + ("xyxy", "xyxy"), + ("rel_xyxy", "rel_xyxy"), + ) + def test_output_value(self, name): + y_true = [ + [0, 0, 1, 1], + [0, 0, 2, 3], + [4, 5, 3, 6], + [2, 2, 3, 3], + ] + + y_pred = [ + [0, 0, 5, 6], + [0, 0, 7, 3], + [4, 5, 5, 6], + [2, 1, 3, 3], + ] + expected_loss = 1.03202 + ciou_loss = CIoULoss(bounding_box_format="xyxy") + if name == "rel_xyxy": + scale_factor = 1 / 640.0 + y_true = np.array(y_true) * scale_factor + y_pred = np.array(y_pred) * scale_factor + + self.assertAllClose( + ciou_loss(y_true, y_pred), expected_loss, atol=0.005 + ) diff --git a/keras_hub/src/models/yolo_v8/yolo_v8_backbone.py b/keras_hub/src/models/yolo_v8/yolo_v8_backbone.py new file mode 100644 index 0000000000..b81cd50dd1 --- /dev/null +++ b/keras_hub/src/models/yolo_v8/yolo_v8_backbone.py @@ -0,0 +1,136 @@ +from keras import ops +from keras.layers import Input +from keras.layers import MaxPooling2D + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_hub.src.models.yolo_v8.yolo_v8_layers import apply_conv_bn +from keras_hub.src.models.yolo_v8.yolo_v8_layers import apply_CSP + + +def apply_stem(x, stem_width, activation): + x = apply_conv_bn(x, stem_width // 2, 3, 2, activation, "stem_1") + x = apply_conv_bn(x, stem_width, 3, 2, activation, "stem_2") + return x + + +def apply_fast_SPP(x, pool_size=5, activation="swish", name="spp_fast"): + input_channels = x.shape[-1] + hidden_channels = int(input_channels // 2) + x = apply_conv_bn(x, hidden_channels, 1, 1, activation, f"{name}_pre") + pool_kwargs = {"strides": 1, "padding": "same"} + p1 = MaxPooling2D(pool_size, **pool_kwargs, name=f"{name}_pool1")(x) + p2 = MaxPooling2D(pool_size, **pool_kwargs, name=f"{name}_pool2")(p1) + p3 = MaxPooling2D(pool_size, **pool_kwargs, name=f"{name}_pool3")(p2) + x = ops.concatenate([x, p1, p2, p3], axis=-1) + x = apply_conv_bn(x, input_channels, 1, 1, activation, f"{name}_output") + return x + + +def apply_yolo_block(x, block_arg, channels, depth, block_depth, activation): + name = f"stack{block_arg + 1}" + if block_arg >= 1: + x = apply_conv_bn(x, channels, 3, 2, activation, f"{name}_downsample") + x = apply_CSP(x, -1, depth, True, 0.5, activation, f"{name}_c2f") + if block_arg == len(block_depth) - 1: + x = apply_fast_SPP(x, 5, activation, f"{name}_spp_fast") + return x + + +def stackwise_yolo_blocks(x, stackwise_depth, stackwise_channels, activation): + pyramid_level_inputs = {"P1": get_tensor_input_name(x)} + iterator = enumerate(zip(stackwise_channels, stackwise_depth)) + block_args = (stackwise_depth, activation) + for stack_arg, (channel, depth) in iterator: + x = apply_yolo_block(x, stack_arg, channel, depth, *block_args) + pyramid_level_inputs[f"P{stack_arg + 2}"] = get_tensor_input_name(x) + return x, pyramid_level_inputs + + +def get_tensor_input_name(tensor): + return tensor._keras_history.operation.name + + +def build_pyramid_outputs(model, level_to_layer_name): + pyramid_outputs = {} + for level_name, layer_name in level_to_layer_name.items(): + pyramid_outputs[level_name] = model.get_layer(layer_name).output + return pyramid_outputs + + +@keras_hub_export("keras_hub.models.YOLOV8Backbone") +class YOLOV8Backbone(FeaturePyramidBackbone): + """Implements the YOLOV8 backbone for object detection. + + This backbone is a variant of the `CSPDarkNetBackbone` architecture. + + For transfer learning use cases, make sure to read the + [guide to transfer learning & fine-tuning](https://keras.io/guides/ + transfer_learning/). + + Args: + stackwise_channels: A list of int. The number of channels for each dark + level in the model. + stackwise_depth: A list of int. The depth for each dark level in the + model. + include_rescaling: bool. Rescale the inputs. If set to + True, inputs will be passed through a `Rescaling(1/255.0)` layer. + activation: str. The activation functions to use in the backbone to + use in the CSPDarkNet blocks. Defaults to "swish". + image_shape: optional shape tuple, defaults to `(None, None, 3)`. + + Returns: + A `keras.Model` instance. + + Examples: + ```python + input_data = tf.ones(shape=(8, 224, 224, 3)) + + # Pretrained backbone + model = keras_hub.models.YOLOV8Backbone.from_preset( + "yolo_v8_xs_backbone_coco" + ) + output = model(input_data) + + # Randomly initialized backbone with a custom config + model = keras_hub.models.YOLOV8Backbone( + stackwise_channels=[128, 256, 512, 1024], + stackwise_depth=[3, 9, 9, 3], + ) + output = model(input_data) + ``` + """ + + def __init__( + self, + stackwise_channels, + stackwise_depth, + activation="swish", + image_shape=(None, None, 3), + **kwargs, + ): + inputs = Input(shape=image_shape) + stem_width = stackwise_channels[0] + x = apply_stem(inputs, stem_width, activation) + x, pyramid_level_inputs = stackwise_yolo_blocks( + x, stackwise_depth, stackwise_channels, activation + ) + super().__init__(inputs=inputs, outputs=x, **kwargs) + self.pyramid_level_inputs = pyramid_level_inputs + self.pyramid_outputs = build_pyramid_outputs(self, pyramid_level_inputs) + self.stackwise_channels = stackwise_channels + self.stackwise_depth = stackwise_depth + self.activation = activation + self.image_shape = image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "image_shape": self.image_shape, + "stackwise_channels": self.stackwise_channels, + "stackwise_depth": self.stackwise_depth, + "activation": self.activation, + } + ) + return config diff --git a/keras_hub/src/models/yolo_v8/yolo_v8_backbone_test.py b/keras_hub/src/models/yolo_v8/yolo_v8_backbone_test.py new file mode 100644 index 0000000000..19eb3f1836 --- /dev/null +++ b/keras_hub/src/models/yolo_v8/yolo_v8_backbone_test.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest + +from keras_hub.src.models.yolo_v8.yolo_v8_backbone import YOLOV8Backbone +from keras_hub.src.tests.test_case import TestCase + + +class YOLOV8BackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_channels": [64, 128, 256, 512], + "stackwise_depth": [1, 2, 2, 1], + "activation": "swish", + "image_shape": (32, 32, 3), + } + self.input_size = 32 + self.input_data = np.ones( + (2, self.input_size, self.input_size, 3), dtype="float32" + ) + + def test_backbone_basics(self): + self.run_vision_backbone_test( + cls=YOLOV8Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 1, 1, 512), + expected_pyramid_output_keys=["P1", "P2", "P3", "P4", "P5"], + expected_pyramid_image_sizes=[ + (8, 8), + (8, 8), + (4, 4), + (2, 2), + (1, 1), + ], + run_mixed_precision_check=False, + run_data_format_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=YOLOV8Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/yolo_v8/yolo_v8_detector.py b/keras_hub/src/models/yolo_v8/yolo_v8_detector.py new file mode 100644 index 0000000000..abba79c0f8 --- /dev/null +++ b/keras_hub/src/models/yolo_v8/yolo_v8_detector.py @@ -0,0 +1,617 @@ +import keras +from keras import Model +from keras import ops +from keras.layers import Activation +from keras.layers import Concatenate +from keras.layers import Conv2D +from keras.layers import Input +from keras.layers import Reshape +from keras.losses import BinaryCrossentropy +from keras.optimizers import Adam +from keras.saving import deserialize_keras_object +from keras.saving import serialize_keras_object +from keras.utils.bounding_boxes import convert_format + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression +from keras_hub.src.models.object_detector import ( + ObjectDetector as ImageObjectDetector, +) +from keras_hub.src.models.yolo_v8.ciou_loss import CIoULoss +from keras_hub.src.models.yolo_v8.yolo_v8_backbone import YOLOV8Backbone +from keras_hub.src.models.yolo_v8.yolo_v8_label_encoder import ( + YOLOV8LabelEncoder, +) +from keras_hub.src.models.yolo_v8.yolo_v8_layers import apply_conv_bn +from keras_hub.src.models.yolo_v8.yolo_v8_layers import apply_CSP +from keras_hub.src.models.yolo_v8.yolo_v8_object_detector_preprocessor import ( + YOLOV8ImageObjectDetectorPreprocessor, +) + + +def unwrap_data(data): + if type(data) is dict: + return data["images"], data["bounding_boxes"] + else: + return data + + +def get_anchors(image_shape, strides=[8, 16, 32], base_anchors=[0.5, 0.5]): + """Gets anchor points for YOLOV8. + + YOLOV8 uses anchor points representing the center of proposed boxes, and + matches ground truth boxes to anchors based on center points. + + Args: + image_shape: tuple or list of two integers representing the height and + width of input images, respectively. + strides: tuple of list of integers, the size of the strides across the + image size that should be used to create anchors. + base_anchors: tuple or list of two integers representing offset from + `(0,0)` to start creating the center of anchor boxes, relative to + the stride. For example, using the default `(0.5, 0.5)` creates the + first anchor box for each stride such that its center is half of a + stride from the edge of the image. + + Returns: + A tuple of anchor centerpoints and anchor strides. Multiplying the + two together will yield the centerpoints in absolute x,y format. + + """ + base_anchors = ops.array(base_anchors, dtype="float32") + + all_anchors = [] + all_strides = [] + for stride in strides: + hh_centers = ops.arange(0, image_shape[0], stride) + ww_centers = ops.arange(0, image_shape[1], stride) + ww_grid, hh_grid = ops.meshgrid(ww_centers, hh_centers) + grid = ops.cast( + ops.reshape(ops.stack([hh_grid, ww_grid], 2), [-1, 1, 2]), + "float32", + ) + anchors = ( + ops.expand_dims( + base_anchors * ops.array([stride, stride], "float32"), 0 + ) + + grid + ) + anchors = ops.reshape(anchors, [-1, 2]) + all_anchors.append(anchors) + all_strides.append(ops.repeat(stride, anchors.shape[0])) + + all_anchors = ops.cast(ops.concatenate(all_anchors, axis=0), "float32") + all_strides = ops.cast(ops.concatenate(all_strides, axis=0), "float32") + + all_anchors = all_anchors / all_strides[:, None] + + # Swap the x and y coordinates of the anchors. + all_anchors = ops.concatenate( + [all_anchors[:, 1, None], all_anchors[:, 0, None]], axis=-1 + ) + return all_anchors, all_strides + + +def upsample(x, size=2): + return ops.repeat(ops.repeat(x, size, axis=1), size, axis=2) + + +def merge_upper_level(lower_level, upper_level, depth, name): + x = upsample(upper_level) + x = ops.concatenate([x, lower_level], axis=-1) + channels = lower_level.shape[-1] + return apply_CSP(x, channels, depth, False, 0.5, "swish", name) + + +def merge_lower_level(lower_level, upper_level, name): + x = apply_conv_bn(lower_level, lower_level.shape[-1], 3, 2, "swish", name) + x = ops.concatenate([x, upper_level], axis=-1) + channels = upper_level.shape[-1] + x = apply_CSP(x, channels, 2, False, 0.5, "swish", f"{name}_block") + return x + + +def apply_path_aggregation_FPN(features, depth=3, name="fpn"): + p3, p4, p5 = features + p4p5 = merge_upper_level(p4, p5, depth, f"{name}_p4p5") + p3p4p5 = merge_upper_level(p3, p4p5, depth, f"{name}_p3p4p5") + p3p4p5_d1 = merge_lower_level(p3p4p5, p4p5, f"{name}_p3p4p5_downsample1") + p3p4p5_d2 = merge_lower_level(p3p4p5_d1, p5, f"{name}_p3p4p5_downsample2") + return [p3p4p5, p3p4p5_d1, p3p4p5_d2] + + +def get_boxes_channels(x, default=64): + # If the input has a large resolution e.g. P3 has >256 channels, + # additional channels are used in intermediate conv layers. + return max(default, x.shape[-1] // 4) + + +def get_class_channels(x, num_classes): + # We use at least num_classes channels for intermediate conv layer for + # class predictions. In most cases, the P3 input has more channels than the + # number of classes, so we preserve those channels until the final layer. + return max(num_classes, x.shape[-1]) + + +def apply_boxes_block(x, boxes_channels, name): + x = apply_conv_bn(x, boxes_channels, 3, 1, "swish", f"{name}_box_1") + x = apply_conv_bn(x, boxes_channels, 3, 1, "swish", f"{name}_box_2") + BOX_REGRESSION_CHANNELS = 64 # 16 values per corner offset from center. + x = Conv2D(BOX_REGRESSION_CHANNELS, 1, name=f"{name}_box_3_conv")(x) + return x + + +def apply_class_block(x, class_channels, num_classes, name): + x = apply_conv_bn(x, class_channels, 3, 1, "swish", f"{name}_class_1") + x = apply_conv_bn(x, class_channels, 3, 1, "swish", f"{name}_class_2") + x = Conv2D(num_classes, 1, name=f"{name}_class_3_conv")(x) + x = Activation("sigmoid", name=f"{name}_classifier")(x) + return x + + +def apply_branch_head(x, boxes_channels, class_channels, num_classes, name): + boxes_predictions = apply_boxes_block(x, boxes_channels, name) + class_predictions = apply_class_block(x, class_channels, num_classes, name) + branch = ops.concatenate([boxes_predictions, class_predictions], axis=-1) + branch_shape = [-1, branch.shape[-1]] + branch = Reshape(branch_shape, name=f"{name}_output_reshape")(branch) + return branch + + +def apply_detection_head(inputs, num_classes, name="yolo_v8_head"): + boxes_channels = get_boxes_channels(inputs[0]) + class_channels = get_class_channels(inputs[0], num_classes) + branch_args = (boxes_channels, class_channels, num_classes) + outputs = [] + for feature_arg, feature in enumerate(inputs): + feature_name = f"{name}_{feature_arg + 1}" + outputs.append(apply_branch_head(feature, *branch_args, feature_name)) + + x = ops.concatenate(outputs, axis=1) + x = Activation("linear", dtype="float32", name="box_outputs")(x) + BOX_REGRESSION_CHANNELS = 64 + boxes_tensor = x[:, :, :BOX_REGRESSION_CHANNELS] + class_tensor = x[:, :, BOX_REGRESSION_CHANNELS:] + return {"boxes": boxes_tensor, "classes": class_tensor} + + +def add_no_op_for_pretty_print(x, name): + return Concatenate(axis=1, name=name)([x]) + + +def get_feature_extractor(model, layer_names, output_keys=None): + if not output_keys: + output_keys = layer_names + items = zip(output_keys, layer_names) + outputs = {key: model.get_layer(name).output for key, name in items} + return Model(inputs=model.inputs, outputs=outputs) + + +def get_backbone_pyramid_layer_names(backbone, level_names): + """Gets layer names from the provided pyramid levels inside backbone. + + Args: + backbone: Keras backbone model with the field "pyramid_level_inputs". + level_names: list of strings indicating the level names. + + Returns: + List of layer strings indicating the layer names of each level. + """ + layer_names = [] + for level_name in level_names: + layer_names.append(backbone.pyramid_level_inputs[level_name]) + return layer_names + + +def build_feature_extractor(backbone, level_names): + """Builds feature extractor directly from the level names + + Args: + backbone: Keras backbone model with the field "pyramid_level_inputs". + level_names: list of strings indicating the level names. + + Returns: + Keras Model with level names as outputs. + """ + layer_names = get_backbone_pyramid_layer_names(backbone, level_names) + items = zip(level_names, layer_names) + outputs = {key: backbone.get_layer(name).output for key, name in items} + return Model(inputs=backbone.inputs, outputs=outputs) + + +def extend_branches(inputs, extractor, FPN_depth): + """Extends extractor model with a feature pyramid network. + + Args: + inputs: tensor, with image input. + extractor: Keras Model with level names as outputs. + FPN_depth: integer representing the feature pyramid depth. + + Returns: + List of extended branch tensors. + """ + features = list(extractor(inputs).values()) + branches = apply_path_aggregation_FPN(features, FPN_depth, name="pa_fpn") + return branches + + +def extend_backbone(backbone, level_names, FPN_depth): + """Extends backbone levels with a feature pyramid network. + + Args: + backbone: Keras backbone model with the field "pyramid_level_inputs". + level_names: list of strings indicating the level names. + trainable: boolean indicating if backbone should be optimized. + FPN_depth: integer representing the feature pyramid depth. + + Return: + Tuple with input image tensor, and list of extended branch tensors. + """ + feature_extractor = build_feature_extractor(backbone, level_names) + image = Input(feature_extractor.input_shape[1:]) + branches = extend_branches(image, feature_extractor, FPN_depth) + return image, branches + + +def decode_regression_to_boxes(preds): + """Decodes the results of the YOLOV8Detector forward-pass into boxes. + + Returns left / top / right / bottom predictions with respect to anchor + points. + + Each coordinate is encoded with 16 predicted values. Those predictions are + softmaxed and multiplied by [0..15] to make predictions. The resulting + predictions are relative to the stride of an anchor box + (and correspondingly relative to the scale of the feature map from which + the predictions came). + """ + BOX_REGRESSION_CHANNELS = 64 + preds_bbox = Reshape((-1, 4, BOX_REGRESSION_CHANNELS // 4))(preds) + preds_bbox = ops.nn.softmax(preds_bbox, axis=-1) * ops.arange( + BOX_REGRESSION_CHANNELS // 4, dtype="float32" + ) + return ops.sum(preds_bbox, axis=-1) + + +def dist2bbox(distance, anchor_points): + """Decodes distance predictions into xyxy boxes. + + Input left / top / right / bottom predictions are transformed into xyxy box + predictions based on anchor points. + + The resulting xyxy predictions must be scaled by the stride of their + corresponding anchor points to yield an absolute xyxy box. + """ + left_top, right_bottom = ops.split(distance, 2, axis=-1) + x1y1 = anchor_points - left_top + x2y2 = anchor_points + right_bottom + return ops.concatenate((x1y1, x2y2), axis=-1) # xyxy bbox + + +@keras_hub_export(["keras_hub.models.YOLOV8ImageObjectDetector"]) +class YOLOV8ImageObjectDetector(ImageObjectDetector): + """Implements the YOLOV8 architecture for object detection. + + Args: + backbone: `keras.Model`, must implement the `pyramid_level_inputs` + property with keys "P3", "P4", and "P5" and layer names as values. + A sensible backbone to use is the `keras_hub.models.YOLOV8Backbone`. + num_classes: integer, the number of classes in your dataset excluding + the background class. Classes should be represented by integers in + the range `[0, num_classes)`. + bounding_box_format: string, the format of bounding boxes of input + dataset. Refer [to the keras.io docs]( + https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box formats. + fpn_depth: integer, a specification of the depth of the CSP blocks in + the Feature Pyramid Network. This is usually 1, 2, or 3, depending + on the size of your YOLOV8Detector model. We recommend using 3 for + "yolo_v8_l_backbone" and "yolo_v8_xl_backbone". Defaults to 2. + preprocessor: Optional. An instance of + `YOLOV8ImageObjectDetectorPreprocessor`or a custom preprocessor. + Handles image preprocessing before feeding into the backbone. + label_encoder: Optional. A `YOLOV8LabelEncoder` that is + responsible for transforming input boxes into trainable labels for + YOLOV8Detector. If not provided, the default + [Task-aligned sample scheme](https://arxiv.org/abs/2108.07755) is + used. + prediction_decoder: Optional. A `keras.layers.Layer` that is + responsible for transforming YOLOV8 predictions into usable + bounding boxes. If not provided, a default is provided. The + default `prediction_decoder` layer is a + `keras_hub.layers.NonMaxSuppression` layer, which uses + a Non-Max Suppression for box pruning. + + Example: + ```python + images = ops.ones(shape=(1, 512, 512, 3)) + labels = { + "boxes": ops.array([ + [ + [0, 0, 100, 100], + [100, 100, 200, 200], + [300, 300, 100, 100], + ] + ], dtype="float32"), + "classes": ops.array([[1, 1, 1]], dtype="int64"), + } + + model = keras_hub.models.YOLOV8ImageObjectDetector( + num_classes=20, + bounding_box_format="xywh", + backbone=keras_hub.models.YOLOV8Backbone.from_preset( + "yolo_v8_m_backbone_coco" + ), + fpn_depth=2 + ) + + # Evaluate model without box decoding and NMS + model(images) + + # Prediction with box decoding and NMS + model.predict(images) + + # Train model + model.compile( + optimizer=keras.optimizers.SGD(global_clipnorm=10.0), + jit_compile=False, + ) + model.fit(images, labels) + ``` + """ + + backbone_cls = YOLOV8Backbone + preprocessor_cls = YOLOV8ImageObjectDetectorPreprocessor + + def __init__( + self, + backbone, + num_classes, + bounding_box_format, + fpn_depth=2, + preprocessor=None, + label_encoder=None, + prediction_decoder=None, + **kwargs, + ): + level_names = ["P3", "P4", "P5"] + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + image, branches = extend_backbone(backbone, level_names, fpn_depth) + head = apply_detection_head(branches, num_classes) + boxes_tensor = add_no_op_for_pretty_print(head["boxes"], "box") + class_tensor = add_no_op_for_pretty_print(head["classes"], "class") + outputs = {"boxes": boxes_tensor, "classes": class_tensor} + super().__init__(inputs=image, outputs=outputs, **kwargs) + # === Config === + self.bounding_box_format = bounding_box_format + self._prediction_decoder = prediction_decoder or NonMaxSuppression( + bounding_box_format=bounding_box_format, + from_logits=False, + confidence_threshold=0.2, + iou_threshold=0.7, + ) + self.fpn_depth = fpn_depth + self.num_classes = num_classes + self.label_encoder = label_encoder or YOLOV8LabelEncoder( + num_classes=num_classes + ) + + def compile( + self, + optimizer="auto", + box_loss="auto", + classification_loss="auto", + box_loss_weight=7.5, + classification_loss_weight=0.5, + metrics=None, + **kwargs, + ): + """Configures the `ObjectDetector` task for training. + + The `ObjectDetector` task extends the default compilation signature of + `keras.Model.compile` with a default `optimizer`, default loss + functions `box_loss`, and `classification_loss` and default loss weights + "box_loss_weight" and "classification_loss_weight". + + `compile()` mirrors the standard Keras `compile()` method, but has one + key distinction -- two losses must be provided: `box_loss` and + `classification_loss`. + + Args: + box_loss: a Keras loss to use for box offset regression. A + preconfigured loss is provided when the string `"ciou"` is + passed. + classification_loss: a Keras loss to use for box classification. A + preconfigured loss is provided when the string + "binary_crossentropy" is passed. + box_loss_weight: (optional) float, a scaling factor for the box + loss. Defaults to 7.5. + classification_loss_weight: (optional) float, a scaling factor for + the classification loss. Defaults to 0.5. + kwargs: most other `keras.Model.compile()` arguments are supported + and propagated to the `keras.Model` class. + """ + if optimizer == "auto": + optimizer = Adam(0.001) + if box_loss in ["auto", "ciou"]: + box_loss = CIoULoss(bounding_box_format="xyxy", reduction="sum") + else: + raise ValueError("Invalid box loss. Use `auto` or `ciou`") + if classification_loss in ["auto", "binary_crossentropy"]: + classification_loss = BinaryCrossentropy(reduction="sum") + else: + raise ValueError( + "Invalid classification loss." + "Use `auto` or `binary_crossentropy`" + ) + if metrics is not None: + raise ValueError("User metrics not yet supported") + self.box_loss = box_loss + self.classification_loss = classification_loss + self.box_loss_weight = box_loss_weight + self.classification_loss_weight = classification_loss_weight + losses = { + "box": self.box_loss, + "class": self.classification_loss, + } + super(ImageObjectDetector, self).compile( + optimizer=optimizer, loss=losses, **kwargs + ) + + def train_step(self, *args): + # This is done for tf.data pipelines that don't unwrap dictionaries. + data = args[-1] + args = args[:-1] + x, y = unwrap_data(data) + return super().train_step(*args, (x, y)) + + def test_step(self, *args): + # This is done for tf.data pipelines that don't unwrap dictionaries. + data = args[-1] + args = args[:-1] + x, y = unwrap_data(data) + return super().test_step(*args, (x, y)) + + def compute_loss(self, x, y, y_pred, sample_weight=None, **kwargs): + box_pred, cls_pred = y_pred["boxes"], y_pred["classes"] + + pred_boxes = decode_regression_to_boxes(box_pred) + pred_scores = cls_pred + + anchor_points, stride_tensor = get_anchors(image_shape=x.shape[1:]) + stride_tensor = ops.expand_dims(stride_tensor, axis=-1) + + ground_truth_labels = y["classes"] + mask_ground_truth = ops.all(y["boxes"] > -1.0, axis=-1, keepdims=True) + ground_truth_bboxes = convert_format( + y["boxes"], + source=self.bounding_box_format, + target="xyxy", + # images=x, + height=ops.shape(x)[1], + width=ops.shape(x)[2], + ) + + pred_bboxes = dist2bbox(pred_boxes, anchor_points) + + target_bboxes, target_scores, fg_mask = self.label_encoder( + pred_scores, + ops.cast(pred_bboxes * stride_tensor, ground_truth_bboxes.dtype), + anchor_points * stride_tensor, + ground_truth_labels, + ground_truth_bboxes, + mask_ground_truth, + ) + + target_bboxes /= stride_tensor + target_scores_sum = ops.maximum(ops.sum(target_scores), 1) + box_weight = ops.expand_dims( + ops.sum(target_scores, axis=-1) * fg_mask, + axis=-1, + ) + + y_true = { + "box": target_bboxes * fg_mask[..., None], + "class": target_scores, + } + y_pred = { + "box": pred_bboxes * fg_mask[..., None], + "class": pred_scores, + } + sample_weights = { + "box": self.box_loss_weight * box_weight / target_scores_sum, + "class": self.classification_loss_weight / target_scores_sum, + } + + return super().compute_loss( + x=x, y=y_true, y_pred=y_pred, sample_weight=sample_weights, **kwargs + ) + + def decode_predictions(self, pred, data): + boxes = pred["boxes"] + scores = pred["classes"] + if isinstance(data, list) or isinstance(data, tuple): + images, _ = data + else: + images = data + boxes = decode_regression_to_boxes(boxes) + anchor_points, stride_tensor = get_anchors(image_shape=images.shape[1:]) + stride_tensor = ops.expand_dims(stride_tensor, axis=-1) + + box_preds = dist2bbox(boxes, anchor_points) * stride_tensor + box_preds = convert_format( + box_preds, + source="xyxy", + target=self.bounding_box_format, + # images=images, + height=ops.shape(images)[1], + width=ops.shape(images)[2], + ) + return self.prediction_decoder(box_preds, scores) + + def predict_step(self, *args): + outputs = super().predict_step(*args) + if isinstance(outputs, tuple): + return self.decode_predictions(outputs[0], args[-1]), outputs[1] + else: + return self.decode_predictions(outputs, args[-1]) + + @property + def prediction_decoder(self): + return self._prediction_decoder + + @prediction_decoder.setter + def prediction_decoder(self, prediction_decoder): + if prediction_decoder.bounding_box_format != self.bounding_box_format: + raise ValueError( + "Expected `prediction_decoder` and YOLOV8Detector to " + "use the same `bounding_box_format`, but got " + "`prediction_decoder.bounding_box_format=" + f"{prediction_decoder.bounding_box_format}`, and " + "`self.bounding_box_format=" + f"{self.bounding_box_format}`." + ) + self._prediction_decoder = prediction_decoder + self.make_predict_function(force=True) + self.make_train_function(force=True) + self.make_test_function(force=True) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "bounding_box_format": self.bounding_box_format, + "fpn_depth": self.fpn_depth, + "backbone": serialize_keras_object(self.backbone), + "label_encoder": serialize_keras_object(self.label_encoder), + "prediction_decoder": serialize_keras_object( + self._prediction_decoder + ), + } + ) + return config + + @classmethod + def from_config(cls, config): + config["backbone"] = deserialize_keras_object(config["backbone"]) + label_encoder = config.get("label_encoder") + if label_encoder is not None and isinstance(label_encoder, dict): + config["label_encoder"] = deserialize_keras_object(label_encoder) + prediction_decoder = config.get("prediction_decoder") + if prediction_decoder is not None and isinstance( + prediction_decoder, dict + ): + config["prediction_decoder"] = deserialize_keras_object( + prediction_decoder + ) + if "preprocessor" in config and isinstance( + config["preprocessor"], dict + ): + config["preprocessor"] = keras.layers.deserialize( + config["preprocessor"] + ) + return cls(**config) + # return super().from_config(**config) diff --git a/keras_hub/src/models/yolo_v8/yolo_v8_detector_test.py b/keras_hub/src/models/yolo_v8/yolo_v8_detector_test.py new file mode 100644 index 0000000000..c1a11c9c43 --- /dev/null +++ b/keras_hub/src/models/yolo_v8/yolo_v8_detector_test.py @@ -0,0 +1,323 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from keras import ops +from keras.utils.bounding_boxes import convert_format + +import keras_hub +from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression +from keras_hub.src.models.yolo_v8.yolo_v8_presets import detector_presets +from keras_hub.src.tests.test_case import TestCase + +test_backbone_presets = [ + "yolo_v8_xs_backbone_coco", + "yolo_v8_s_backbone_coco", + "yolo_v8_m_backbone_coco", + "yolo_v8_l_backbone_coco", + "yolo_v8_xl_backbone_coco", +] + + +def _create_bounding_box_dataset(bounding_box_format, image_size=512): + # Just about the easiest dataset you can have, all classes are 0, all boxes + # are exactly the same. [1, 1, 2, 2] are the coordinates in xyxy. + xs = np.random.normal(size=(1, image_size, image_size, 3)) + xs = np.tile(xs, [5, 1, 1, 1]) + + y_classes = np.zeros((5, 3), "float32") + + ys = np.array( + [ + [0.1, 0.1, 0.23, 0.23], + [0.67, 0.75, 0.23, 0.23], + [0.25, 0.25, 0.23, 0.23], + ], + "float32", + ) + ys = np.expand_dims(ys, axis=0) + ys = np.tile(ys, [5, 1, 1]) + ys = ops.convert_to_numpy( + convert_format( + ys, + source="rel_xywh", + target=bounding_box_format, + # images=xs, + height=1, + width=1, + dtype="float32", + ) + ) + return xs, {"boxes": ys, "classes": y_classes} + + +class YOLOV8DetectorTest(TestCase): + @pytest.mark.large # Fit is slow, so mark these large. + def test_fit(self): + bounding_box_format = "xywh" + yolo = keras_hub.models.YOLOV8ImageObjectDetector( + num_classes=2, + fpn_depth=1, + bounding_box_format=bounding_box_format, + backbone=keras_hub.models.YOLOV8Backbone.from_preset( + "yolo_v8_xs_backbone_coco" + ), + ) + + yolo.compile( + optimizer="adam", + classification_loss="auto", + box_loss="auto", + ) + xs, ys = _create_bounding_box_dataset(bounding_box_format) + + yolo.fit(x=xs, y=ys, epochs=1) + + @pytest.mark.skip( + reason="to_ragged has been deleted and requires tensorflow" + ) + def test_fit_with_ragged_tensors(self): + bounding_box_format = "xywh" + yolo = keras_hub.models.YOLOV8ImageObjectDetector( + num_classes=2, + fpn_depth=1, + bounding_box_format=bounding_box_format, + backbone=keras_hub.models.YOLOV8Backbone.from_preset( + "yolo_v8_xs_backbone_coco" + ), + ) + + yolo.compile( + optimizer="adam", + classification_loss="auto", + box_loss="auto", + ) + xs, ys = _create_bounding_box_dataset(bounding_box_format) + # ys = to_ragged(ys) + yolo.fit(x=xs, y=ys, epochs=1) + + @pytest.mark.large # Fit is slow, so mark these large. + def test_fit_with_no_valid_gt_bbox(self): + bounding_box_format = "xywh" + yolo = keras_hub.models.YOLOV8ImageObjectDetector( + num_classes=1, + fpn_depth=1, + bounding_box_format=bounding_box_format, + backbone=keras_hub.models.YOLOV8Backbone.from_preset( + "yolo_v8_xs_backbone_coco" + ), + ) + + yolo.compile( + optimizer="adam", + classification_loss="auto", + box_loss="auto", + ) + xs, ys = _create_bounding_box_dataset(bounding_box_format) + # Make all bounding_boxes invalid and filter out them + ys["classes"] = -np.ones_like(ys["classes"]) + + yolo.fit(x=xs, y=ys, epochs=1) + + def test_trainable_weight_count(self): + yolo = keras_hub.models.YOLOV8ImageObjectDetector( + num_classes=2, + fpn_depth=1, + bounding_box_format="xywh", + backbone=keras_hub.models.YOLOV8Backbone.from_preset( + "yolo_v8_s_backbone_coco" + ), + ) + + self.assertEqual(len(yolo.trainable_weights), 195) + + def test_bad_loss(self): + yolo = keras_hub.models.YOLOV8ImageObjectDetector( + num_classes=2, + fpn_depth=1, + bounding_box_format="xywh", + backbone=keras_hub.models.YOLOV8Backbone.from_preset( + "yolo_v8_xs_backbone_coco" + ), + ) + + with self.assertRaisesRegex( + ValueError, + "Invalid box loss", + ): + yolo.compile(box_loss="bad_loss", classification_loss="auto") + + with self.assertRaisesRegex( + ValueError, + "Invalid classification loss", + ): + yolo.compile(box_loss="auto", classification_loss="bad_loss") + + @pytest.mark.large # Saving is slow, so mark these large. + def test_saved_model(self): + init_kwargs = { + "num_classes": 20, + "bounding_box_format": "xywh", + "fpn_depth": 1, + "backbone": keras_hub.models.YOLOV8Backbone.from_preset( + "yolo_v8_xs_backbone_coco" + ), + } + + xs, _ = _create_bounding_box_dataset("xywh") + self.run_model_saving_test( + cls=keras_hub.models.YOLOV8ImageObjectDetector, + init_kwargs=init_kwargs, + input_data=xs, + ) + + def test_update_prediction_decoder(self): + yolo = keras_hub.models.YOLOV8ImageObjectDetector( + num_classes=2, + fpn_depth=1, + bounding_box_format="xywh", + backbone=keras_hub.models.YOLOV8Backbone.from_preset( + "yolo_v8_s_backbone_coco" + ), + prediction_decoder=NonMaxSuppression( + bounding_box_format="xywh", + from_logits=False, + confidence_threshold=0.0, + iou_threshold=1.0, + ), + ) + + image = np.ones((1, 512, 512, 3)) + + outputs = yolo.predict(image) + # We predicted at least 1 box with confidence_threshold 0 + self.assertGreater(outputs["boxes"].shape[0], 0) + + yolo.prediction_decoder = NonMaxSuppression( + bounding_box_format="xywh", + from_logits=False, + confidence_threshold=1.0, + iou_threshold=1.0, + ) + + outputs = yolo.predict(image) + # We predicted no boxes with confidence threshold 1 + self.assertAllEqual(outputs["boxes"], -np.ones_like(outputs["boxes"])) + self.assertAllEqual( + outputs["confidence"], -np.ones_like(outputs["confidence"]) + ) + self.assertAllEqual( + # outputs["classes"], -np.ones_like(outputs["classes"]) + outputs["labels"], + -np.ones_like(outputs["labels"]), + ) + + def test_yolov8_basics(self): + box_format = "xyxy" + xs, ys = _create_bounding_box_dataset(box_format) + backbone = keras_hub.models.YOLOV8Backbone.from_preset( + "yolo_v8_m_backbone_coco" + ) + scale = np.array(1.0 / 255).astype("float32") + xs = xs.astype("float32") + image_converter = keras_hub.layers.YOLOV8ImageConverter(scale=scale) + preprocessor = keras_hub.models.YOLOV8ImageObjectDetectorPreprocessor( + image_converter=image_converter + ) + + init_kwargs = { + "backbone": backbone, + "num_classes": 3, + "bounding_box_format": box_format, + "preprocessor": preprocessor, + } + self.run_task_test( + cls=keras_hub.models.YOLOV8ImageObjectDetector, + init_kwargs=init_kwargs, + train_data=(xs, ys), + batch_size=len(xs), + ) + + +@pytest.mark.large +class YOLOV8ImageObjectDetectorSmokeTest(TestCase): + # @pytest.mark.skip(reason="Missing non YOLOV8 presets in KerasHub") + @parameterized.named_parameters( + *[(preset, preset) for preset in test_backbone_presets] + ) + @pytest.mark.extra_large + def test_backbone_preset(self, preset): + backbone = keras_hub.models.YOLOV8Backbone.from_preset(preset) + model = keras_hub.models.YOLOV8ImageObjectDetector( + backbone=backbone, + num_classes=20, + bounding_box_format="xywh", + ) + xs, _ = _create_bounding_box_dataset(bounding_box_format="xywh") + output = model(xs) + + # 64 represents number of parameters in a box + # 5376 is the number of anchors for a 512x512 image + self.assertEqual(output["boxes"].shape, (xs.shape[0], 5376, 64)) + + @parameterized.named_parameters( + ("256x256", (256, 256)), + ("512x512", (512, 512)), + ) + def test_preset_with_forward_pass(self, image_size): + model = keras_hub.models.YOLOV8ImageObjectDetector.from_preset( + "yolo_v8_m_pascalvoc", + bounding_box_format="xywh", + ) + + image = np.ones((1, image_size[0], image_size[1], 3)) + encoded_predictions = model(image / 255.0) + + if image_size == (512, 512): + self.assertAllClose( + ops.convert_to_numpy(encoded_predictions["boxes"][0, 0:5, 0]), + [-0.8303556, 0.75213313, 1.809204, 1.6576759, 1.4134747], + ) + self.assertAllClose( + ops.convert_to_numpy(encoded_predictions["classes"][0, 0:5, 0]), + [ + 7.6146556e-08, + 8.0103280e-07, + 9.7873999e-07, + 2.2314548e-06, + 2.5051115e-06, + ], + ) + if image_size == (256, 256): + self.assertAllClose( + ops.convert_to_numpy(encoded_predictions["boxes"][0, 0:5, 0]), + [-0.6900742, 0.62832844, 1.6327355, 1.539787, 1.311696], + ) + self.assertAllClose( + ops.convert_to_numpy(encoded_predictions["classes"][0, 0:5, 0]), + [ + 4.8766704e-08, + 2.8596392e-07, + 2.3858618e-07, + 3.8180931e-07, + 3.8900879e-07, + ], + ) + + +@pytest.mark.extra_large +class YOLOV8DetectorPresetFullTest(TestCase): + """ + Test the full enumeration of our presets. + This every presets for YOLOV8Detector and is only run manually. + Run with: + `pytest keras_hub/models/object_detection/yolo_v8/ + yolo_v8_detector_test.py --run_extra_large` + """ + + def test_load_yolo_v8_detector(self): + input_data = np.ones(shape=(2, 224, 224, 3)) + for preset in detector_presets: + model = keras_hub.models.YOLOV8ImageObjectDetector.from_preset( + preset, bounding_box_format="xywh" + ) + model(input_data) diff --git a/keras_hub/src/models/yolo_v8/yolo_v8_image_converter.py b/keras_hub/src/models/yolo_v8/yolo_v8_image_converter.py new file mode 100644 index 0000000000..b5f6cf5729 --- /dev/null +++ b/keras_hub/src/models/yolo_v8/yolo_v8_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.yolo_v8.yolo_v8_backbone import YOLOV8Backbone + + +@keras_hub_export("keras_hub.layers.YOLOV8ImageConverter") +class YOLOV8ImageConverter(ImageConverter): + backbone_cls = YOLOV8Backbone diff --git a/keras_hub/src/models/yolo_v8/yolo_v8_label_encoder.py b/keras_hub/src/models/yolo_v8/yolo_v8_label_encoder.py new file mode 100644 index 0000000000..c63ec0c09f --- /dev/null +++ b/keras_hub/src/models/yolo_v8/yolo_v8_label_encoder.py @@ -0,0 +1,295 @@ +import keras +from keras import ops +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.validation import ( # noqa: E501 + densify_bounding_boxes as to_dense, +) +from keras.utils.bounding_boxes import compute_ciou + + +def is_anchor_center_within_box(anchors, ground_truth_bboxes): + return ops.all( + ops.logical_and( + ground_truth_bboxes[:, :, None, :2] < anchors, + ground_truth_bboxes[:, :, None, 2:] > anchors, + ), + axis=-1, + ) + + +def is_tensorflow_ragged(value): + if hasattr(value, "__class__"): + return ( + value.__class__.__name__ == "RaggedTensor" + and "tensorflow.python." in str(value.__class__.__module__) + ) + return False + + +class YOLOV8LabelEncoder(keras.layers.Layer): + """Encodes ground truth boxes to target boxes for training YOLOV8. + + Args: + num_classes: integer, the number of classes in the training dataset + max_anchor_matches: optional integer, the maximum number of anchors to + match with any given ground truth box. For example, when the default + 10 is used, the 10 candidate anchor points with the highest + alignment score are matched with a ground truth box. If less than 10 + candidate anchors exist, all candidates will be matched to the box. + alpha: float, a parameter to control the influence of class predictions + on the alignment score of an anchor box. This is the alpha parameter + in equation 9 of [TOOD](https://arxiv.org/pdf/2108.07755.pdf). + beta: float, a parameter to control the influence of box IOUs on the + alignment score of an anchor box. This is the beta parameter in + equation 9 of [TOOD](https://arxiv.org/pdf/2108.07755.pdf). + epsilon: float, a small number used for numerical stability in division + (to avoid dividing by zero), and used as a threshold to eliminate + very small matches based on alignment scores of approximately zero. + + References: + - [Task-aligned sample assignment](https://arxiv.org/abs/2108.07755). + """ + + def __init__( + self, + num_classes, + max_anchor_matches=10, + alpha=0.5, + beta=6.0, + epsilon=1e-9, + **kwargs, + ): + super().__init__(**kwargs) + self.max_anchor_matches = max_anchor_matches + self.num_classes = num_classes + self.alpha = alpha + self.beta = beta + self.epsilon = epsilon + + def assign( + self, + scores, + decode_bboxes, + anchors, + ground_truth_labels, + ground_truth_bboxes, + ground_truth_mask, + ): + """Assigns ground-truth boxes to anchors. + + Uses the task-aligned assignment strategy for matching ground truth + and anchor boxes based on prediction scores and IoU. + """ + num_anchors = anchors.shape[0] + + # Box scores are the predicted scores for each anchor, ground truth box + # pair. Only the predicted score for the class of the ground_truth box + # is included # Shape: (B, num_ground_truth_boxes, num_anchors) + # (after transpose) + bbox_scores = ops.take_along_axis( + scores, + ops.cast(ops.maximum(ground_truth_labels[:, None, :], 0), "int32"), + axis=-1, + ) + bbox_scores = ops.transpose(bbox_scores, (0, 2, 1)) + + # Overlaps are the IoUs of each predicted box and each ground_truth box. + # Shape: (B, num_ground_truth_boxes, num_anchors) + overlaps = compute_ciou( + ops.expand_dims(ground_truth_bboxes, axis=2), + ops.expand_dims(decode_bboxes, axis=1), + bounding_box_format="xyxy", + ) + + # Alignment metrics are a combination of box scores and overlaps, per + # the task-aligned-assignment formula. + # Metrics are forced to 0 for boxes which have been masked in the + # ground_truth input (e.g. due to padding) + alignment_metrics = ops.power(bbox_scores, self.alpha) * ops.power( + overlaps, self.beta + ) + alignment_metrics = ops.where(ground_truth_mask, alignment_metrics, 0) + + # Only anchors which are inside of relevant ground_truth boxes are + # considered for assignment. + # This is a boolean tensor of shape + # (B, num_ground_truth_boxes, num_anchors) + matching_anchors_in_ground_truth_boxes = is_anchor_center_within_box( + anchors, ground_truth_bboxes + ) + alignment_metrics = ops.where( + matching_anchors_in_ground_truth_boxes, alignment_metrics, 0 + ) + + # The top-k highest alignment metrics are used to select K candidate + # anchors for each ground_truth box. + candidate_metrics, candidate_idxs = ops.top_k( + alignment_metrics, self.max_anchor_matches + ) + candidate_idxs = ops.where(candidate_metrics > 0, candidate_idxs, -1) + + # We now compute a dense grid of anchors and ground_truth boxes. + # This is useful for picking a ground_truth box when an anchor matches + # to 2, as well as returning to a dense format for a mask of which + # anchors have been matched. + anchors_matched_ground_truth_box = ops.zeros_like(overlaps) + for k in range(self.max_anchor_matches): + anchors_matched_ground_truth_box += ops.one_hot( + candidate_idxs[:, :, k], num_anchors + ) + + # We zero-out the overlap for anchor, ground_truth box pairs which + # don't match. + overlaps *= anchors_matched_ground_truth_box + # In cases where one anchor matches to 2 ground_truth boxes, we pick + # the ground_truth box with the highest overlap as a max. + ground_truth_box_matches_per_anchor = ops.argmax(overlaps, axis=1) + ground_truth_box_matches_per_anchor_mask = ops.max(overlaps, axis=1) > 0 + ground_truth_box_matches_per_anchor = ops.cast( + ground_truth_box_matches_per_anchor, "int32" + ) + + # We select the ground_truth boxes and labels that correspond to + # anchor matches. + bbox_labels = ops.take_along_axis( + ground_truth_bboxes, + ground_truth_box_matches_per_anchor[:, :, None], + axis=1, + ) + bbox_labels = ops.where( + ground_truth_box_matches_per_anchor_mask[:, :, None], + bbox_labels, + -1, + ) + class_labels = ops.take_along_axis( + ground_truth_labels, ground_truth_box_matches_per_anchor, axis=1 + ) + class_labels = ops.where( + ground_truth_box_matches_per_anchor_mask, class_labels, -1 + ) + + class_labels = ops.one_hot( + ops.cast(class_labels, "int32"), self.num_classes + ) + + # Finally, we normalize an anchor's class labels based on the relative + # strenground_truthh of the anchors match with the corresponding + # ground_truth box. + alignment_metrics *= anchors_matched_ground_truth_box + max_alignment_per_ground_truth_box = ops.max( + alignment_metrics, axis=-1, keepdims=True + ) + max_overlap_per_ground_truth_box = ops.max( + overlaps, axis=-1, keepdims=True + ) + + normalized_alignment_metrics = ops.max( + alignment_metrics + * max_overlap_per_ground_truth_box + / (max_alignment_per_ground_truth_box + self.epsilon), + axis=-2, + ) + class_labels *= normalized_alignment_metrics[:, :, None] + + # On TF backend, the final "4" becomes a dynamic shape so we include + # this to force it to a static shape of 4. This does not actually + # reshape the Tensor. + bbox_labels = ops.reshape(bbox_labels, (-1, num_anchors, 4)) + return ( + ops.stop_gradient(bbox_labels), + ops.stop_gradient(class_labels), + ops.stop_gradient( + # ops.cast(ground_truth_box_matches_per_anchor > -1, "float32") + ops.cast(ground_truth_box_matches_per_anchor > 0, "float32") + ), + ) + + def call( + self, + scores, + decode_bboxes, + anchors, + ground_truth_labels, + ground_truth_bboxes, + ground_truth_mask, + ): + """Computes target boxes and classes for anchors. + + Args: + scores: a Float Tensor of shape `(batch_size, num_anchors, + num_classes)` representing predicted class scores for each + anchor. + decode_bboxes: a Float Tensor of shape `(batch_size, num_anchors, + 4)` representing predicted boxes for each anchor. + anchors: a Float Tensor of shape `(batch_size, num_anchors, 2)` + representing the xy coordinates of the center of each anchor. + ground_truth_labels: a Float Tensor of shape + `(batch_size, num_ground_truth_boxes)` + representing the classes of ground truth boxes. + ground_truth_bboxes: a Float Tensor of shape + `(batch_size, num_ground_truth_boxes, 4)` + representing the ground truth bounding boxes in xyxy format. + ground_truth_mask: A Boolean Tensor of shape + `(batch_size, num_ground_truth_boxes)` + representing whether a box in `ground_truth_bboxes` is a real + box or a non-box that exists due to padding. + + Returns: + A tuple of the following: + - A Float Tensor of shape `(batch_size, num_anchors, 4)` + representing box targets for the model. + - A Float Tensor of shape `(batch_size, num_anchors, + num_classes)` + representing class targets for the model. + - A Boolean Tensor of shape `(batch_size, num_anchors)` + representing whether each anchor was a match with a ground + truth box. Anchors that didn't match with a ground truth + box should be excluded from both class and box losses. + """ + if is_tensorflow_ragged(ground_truth_bboxes): + dense_bounding_boxes = to_dense( + {"boxes": ground_truth_bboxes, "classes": ground_truth_labels}, + ) + ground_truth_bboxes = dense_bounding_boxes["boxes"] + ground_truth_labels = dense_bounding_boxes["classes"] + + if is_tensorflow_ragged(ground_truth_mask): + ground_truth_mask = ground_truth_mask.to_tensor() + + max_num_boxes = ops.shape(ground_truth_bboxes)[1] + + # If there are no ground_truth boxes in the batch, we short-circuit and + # return empty targets to avoid NaNs. + return ops.cond( + ops.array(max_num_boxes > 0), + lambda: self.assign( + scores, + decode_bboxes, + anchors, + ground_truth_labels, + ground_truth_bboxes, + ground_truth_mask, + ), + lambda: ( + ops.zeros_like(decode_bboxes), + ops.zeros_like(scores), + ops.zeros_like(scores[..., 0]), + ), + ) + + def count_params(self): + # The label encoder has no weights, so we short-circuit the weight + # counting to avoid having to `build` this layer unnecessarily. + return 0 + + def get_config(self): + config = super().get_config() + config.update( + { + "max_anchor_matches": self.max_anchor_matches, + "num_classes": self.num_classes, + "alpha": self.alpha, + "beta": self.beta, + "epsilon": self.epsilon, + } + ) + return config diff --git a/keras_hub/src/models/yolo_v8/yolo_v8_layers.py b/keras_hub/src/models/yolo_v8/yolo_v8_layers.py new file mode 100644 index 0000000000..0b32f01e71 --- /dev/null +++ b/keras_hub/src/models/yolo_v8/yolo_v8_layers.py @@ -0,0 +1,65 @@ +from keras import ops +from keras.layers import Activation +from keras.layers import BatchNormalization +from keras.layers import Conv2D +from keras.layers import ZeroPadding2D + + +def apply_conv_bn( + x, + num_channels, + kernel_size=1, + strides=1, + activation="swish", + name="conv_bn", +): + if kernel_size > 1: + x = ZeroPadding2D(kernel_size // 2, name=f"{name}_pad")(x) + conv_kwargs = {"use_bias": False, "name": f"{name}_conv"} + x = Conv2D(num_channels, kernel_size, strides, "valid", **conv_kwargs)(x) + x = BatchNormalization(momentum=0.97, epsilon=1e-3, name=f"{name}_bn")(x) + x = Activation(activation, name=name)(x) + return x + + +def compute_hidden_channels(channels, expansion): + return int(channels * expansion) + + +def get_default_channels(channels, x): + return channels if channels > 0 else x.shape[-1] + + +def compute_short_and_deep(x, hidden_channels, activation, name): + x = apply_conv_bn(x, 2 * hidden_channels, 1, 1, activation, f"{name}_pre") + short, deep = ops.split(x, 2, axis=-1) + return short, deep + + +def apply_conv_block(y, channels, activation, shortcut, name): + x = apply_conv_bn(y, channels, 3, 1, activation, f"{name}_1") + x = apply_conv_bn(x, channels, 3, 1, activation, f"{name}_2") + x = (y + x) if shortcut else x + return x + + +def apply_CSP( + x, + channels=-1, + depth=2, + shortcut=True, + expansion=0.5, + activation="swish", + name="csp_block", +): + channels = get_default_channels(channels, x) + hidden_channels = compute_hidden_channels(channels, expansion) + short, deep = compute_short_and_deep(x, hidden_channels, activation, name) + out = [short, deep] + conv_args = (hidden_channels, activation, shortcut) + for depth_arg in range(depth): + deep = apply_conv_block(deep, *conv_args, f"{name}_pre_{depth_arg}") + out.append(deep) + out = ops.concatenate(out, axis=-1) + out = apply_conv_bn(out, channels, 1, 1, activation, f"{name}_output") + return out diff --git a/keras_hub/src/models/yolo_v8/yolo_v8_object_detector_preprocessor.py b/keras_hub/src/models/yolo_v8/yolo_v8_object_detector_preprocessor.py new file mode 100644 index 0000000000..236aa6b009 --- /dev/null +++ b/keras_hub/src/models/yolo_v8/yolo_v8_object_detector_preprocessor.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.object_detector_preprocessor import ( + ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor, +) +from keras_hub.src.models.yolo_v8.yolo_v8_backbone import YOLOV8Backbone +from keras_hub.src.models.yolo_v8.yolo_v8_image_converter import ( + YOLOV8ImageConverter, +) + + +@keras_hub_export("keras_hub.models.YOLOV8ImageObjectDetectorPreprocessor") +class YOLOV8ImageObjectDetectorPreprocessor(ImageObjectDetectorPreprocessor): + backbone_cls = YOLOV8Backbone + image_converter_cls = YOLOV8ImageConverter diff --git a/keras_hub/src/models/yolo_v8/yolo_v8_presets.py b/keras_hub/src/models/yolo_v8/yolo_v8_presets.py new file mode 100644 index 0000000000..a9c0ba23ae --- /dev/null +++ b/keras_hub/src/models/yolo_v8/yolo_v8_presets.py @@ -0,0 +1,56 @@ +backbone_presets = { + "yolo_v8_xs_backbone_coco": { + "metadata": { + "description": ( + "An extra small YOLOV8 backbone pretrained on COCO" + ), + "params": 1277680, + }, + "kaggle_handle": "kaggle://kerashub/yolov8/keras/yolo_v8_xs_backbone_coco/5", + }, + "yolo_v8_s_backbone_coco": { + "metadata": { + "description": ("A small YOLOV8 backbone pretrained on COCO"), + "params": 5089760, + }, + "kaggle_handle": "kaggle://kerashub/yolov8/keras/yolo_v8_s_backbone_coco/5", + }, + "yolo_v8_m_backbone_coco": { + "metadata": { + "description": ("A medium YOLOV8 backbone pretrained on COCO"), + "params": 11872464, + }, + "kaggle_handle": "kaggle://kerashub/yolov8/keras/yolo_v8_m_backbone_coco/5", + }, + "yolo_v8_l_backbone_coco": { + "metadata": { + "description": ("A large YOLOV8 backbone pretrained on COCO"), + "params": 19831744, + }, + "kaggle_handle": "kaggle://kerashub/yolov8/keras/yolo_v8_l_backbone_coco/5", + }, + "yolo_v8_xl_backbone_coco": { + "metadata": { + "description": ( + "An extra large YOLOV8 backbone pretrained on COCO" + ), + "params": 30972080, + }, + "kaggle_handle": "kaggle://kerashub/yolov8/keras/yolo_v8_xl_backbone_coco/5", + }, +} + + +detector_presets = { + "yolo_v8_m_pascalvoc": { + "metadata": { + "description": ( + "YOLOV8-M pretrained on PascalVOC 2012 object detection task, " + "which consists of 20 classes. This model achieves a final MaP " + "of 0.45 on the evaluation set." + ), + "params": 25901004, + }, + "kaggle_handle": "kaggle://kerashub/yolov8/keras/yolo_v8_m_pascalvoc/5", + }, +} diff --git a/tools/checkpoint_conversion/convert_yolov8_checkpoints.py b/tools/checkpoint_conversion/convert_yolov8_checkpoints.py new file mode 100644 index 0000000000..b4f67212fc --- /dev/null +++ b/tools/checkpoint_conversion/convert_yolov8_checkpoints.py @@ -0,0 +1,149 @@ +# Install keras_cv: ``pip install keras-cv`` +# Call: ``python3 checkpoint_conversion.py`` +from pathlib import Path + +import numpy as np +from keras import ops +from keras_cv.models import YOLOV8Backbone as KerasCVYOLOV8Backbone +from keras_cv.models import YOLOV8Detector + +from keras_hub.layers import NonMaxSuppression +from keras_hub.layers import YOLOV8ImageConverter +from keras_hub.models import YOLOV8Backbone +from keras_hub.models import YOLOV8ImageObjectDetector +from keras_hub.models import YOLOV8ImageObjectDetectorPreprocessor +from keras_hub.src.models.yolo_v8.yolo_v8_label_encoder import ( + YOLOV8LabelEncoder, +) + + +def get_max_abs_error(output_A, output_B): + return ops.max(ops.abs(output_A - output_B)) + + +def validate_numerics(rng, preset_name, model_A, model_B, preprocessor_B): + random_data = rng.random((2, 224, 224, 3)).astype("float32") + output_A = model_A(random_data) + if preprocessor_B is not None: + output_B = model_B.predict(preprocessor_B(random_data)) + else: + output_B = model_B.predict(random_data) + output_A = ops.convert_to_numpy(output_A) + output_B = ops.convert_to_numpy(output_B) + is_valid = np.allclose(output_A, output_B, atol=1e-5) + print(f"Port {preset_name} with valid numerics: {is_valid}") + print("Max abs error", get_max_abs_error(output_A, output_B)) + assert is_valid + + +def validate_detector_numerics(rng, preset_name, model_A, model_B): + random_data = rng.random((2, 224, 224, 3)).astype("float32") + output_A = model_A.predict(random_data) + output_B = model_B.predict(random_data) + A_to_B_keys = { + "boxes": "boxes", + "confidence": "confidence", + "classes": "labels", + "num_detections": "num_detections", + } + for key_A, key_B in A_to_B_keys.items(): + x_A = output_A[key_A] + x_B = output_B[key_B] + x_A = ops.convert_to_numpy(x_A) + x_B = ops.convert_to_numpy(x_B) + is_valid = np.allclose(x_A, x_B, atol=1e-5) + print(f"Port '{preset_name}' '{key_A}' with valid numerics: {is_valid}") + print("Max abs error", get_max_abs_error(x_A, x_B)) + assert is_valid + + +def make_directory(root, preset): + preset_path_name = f"{root}/{preset}" + Path(preset_path_name).mkdir(parents=True, exist_ok=True) + return preset_path_name + + +def pass_weights_A_to_B(model_A, model_B, root_path): + model_B.set_weights(model_A.get_weights()) + + +def convert_backbone(ModelA, ModelB, weights_path, preset_name): + preset_path = make_directory(weights_path, preset_name) + model_A = ModelA.from_preset(preset_name) + config = model_A.get_config() + config.pop("include_rescaling") + config["image_shape"] = config.pop("input_shape") + model_B = ModelB(**config) + pass_weights_A_to_B(model_A, model_B, preset_path) + model_B.save_to_preset(preset_path) + return model_A, model_B + + +def build_detector_parts(config): + backbone_config = config["backbone"]["config"] + backbone_config.pop("include_rescaling") + backbone_config["image_shape"] = backbone_config.pop("input_shape") + backbone = YOLOV8Backbone(**backbone_config) + config["backbone"] = backbone + label_encoder = YOLOV8LabelEncoder(**config["label_encoder"]["config"]) + config["label_encoder"] = label_encoder + prediction_decoder = NonMaxSuppression( + **config["prediction_decoder"]["config"] + ) + config["prediction_decoder"] = prediction_decoder + return config + + +def build_preprocessor(): + image_converter = YOLOV8ImageConverter(scale=1.0 / 255) + preprocessor = YOLOV8ImageObjectDetectorPreprocessor( + image_converter=image_converter + ) + return preprocessor + + +def convert_detector(ModelA, ModelB, weights_path, preset_name): + model_A = ModelA.from_preset(preset_name) + config = model_A.get_config() + config = build_detector_parts(config) + config["preprocessor"] = build_preprocessor() + print(config) + model_B = ModelB(**config) + preset_path = make_directory(weights_path, preset_name) + pass_weights_A_to_B(model_A, model_B, preset_path) + model_B.save_to_preset(preset_path) + model_B = ModelB.from_preset(preset_path) + return model_A, model_B + + +if __name__ == "__main__": + import argparse + from functools import partial + + description = "Convert YOLOV8 keras-cv to keras-hub weights." + parser = argparse.ArgumentParser(description=description) + parser.add_argument("--seed", type=int, default=777) + parser.add_argument("--weights_path", type=str, default="YOLOV8") + args = parser.parse_args() + rng = np.random.default_rng(args.seed) + backbone_presets = [ + "yolo_v8_xs_backbone", + "yolo_v8_s_backbone", + "yolo_v8_m_backbone", + "yolo_v8_l_backbone", + "yolo_v8_xl_backbone", + "yolo_v8_xs_backbone_coco", + "yolo_v8_s_backbone_coco", + "yolo_v8_m_backbone_coco", + "yolo_v8_l_backbone_coco", + "yolo_v8_xl_backbone_coco", + ] + convert = partial(convert_backbone, KerasCVYOLOV8Backbone, YOLOV8Backbone) + for preset in backbone_presets: + model_A, model_B = convert(args.weights_path, preset) + validate_numerics(rng, preset, model_A, model_B, lambda x: x / 255.0) + preset = "yolo_v8_m_pascalvoc" + model_A, model_B = convert_detector( + YOLOV8Detector, YOLOV8ImageObjectDetector, args.weights_path, preset + ) + validate_detector_numerics(rng, preset, model_A, model_B)