diff --git a/keras_cv/bounding_box/__init__.py b/keras_cv/bounding_box/__init__.py index 3cea65744e..65fb23e7fe 100644 --- a/keras_cv/bounding_box/__init__.py +++ b/keras_cv/bounding_box/__init__.py @@ -22,6 +22,7 @@ from keras_cv.bounding_box.formats import XYWH from keras_cv.bounding_box.formats import XYXY from keras_cv.bounding_box.formats import YXYX +from keras_cv.bounding_box.iou import compute_ciou from keras_cv.bounding_box.iou import compute_iou from keras_cv.bounding_box.mask_invalid_detections import ( mask_invalid_detections, diff --git a/keras_cv/bounding_box/iou.py b/keras_cv/bounding_box/iou.py index c72810a029..b229db4fdc 100644 --- a/keras_cv/bounding_box/iou.py +++ b/keras_cv/bounding_box/iou.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains functions to compute ious of bounding boxes.""" +import math + import tensorflow as tf from keras_cv import bounding_box @@ -172,3 +174,75 @@ def compute_iou( ) iou_lookup_table = tf.where(background_mask, mask_val_t, res) return iou_lookup_table + + +def compute_ciou(box1, box2, bounding_box_format, eps=1e-7): + """ + Computes the Complete IoU (CIoU) between two bounding boxes or between + two batches of bounding boxes. + + CIoU loss is an extension of GIoU loss, which further improves the IoU + optimization for object detection. CIoU loss not only penalizes the + bounding box coordinates but also considers the aspect ratio and center + distance of the boxes. The length of the last dimension should be 4 to + represent the bounding boxes. + + Args: + box1 (tf.Tensor): Tensor representing the first bounding box with + shape (..., 4). + box2 (tf.Tensor): Tensor representing the second bounding box with + shape (..., 4). + bounding_box_format: a case-insensitive string (for example, "xyxy"). + Each bounding box is defined by these 4 values. For detailed + information on the supported formats, see the [KerasCV bounding box + documentation](https://keras.io/api/keras_cv/bounding_box/formats/). + eps (float, optional): A small value to avoid division by zero. Default + is 1e-7. + + Returns: + tf.Tensor: The CIoU distance between the two bounding boxes. + """ + target_format = "xyxy" + if bounding_box.is_relative(bounding_box_format): + target_format = bounding_box.as_relative(target_format) + + box1 = bounding_box.convert_format( + box1, source=bounding_box_format, target=target_format + ) + + box2 = bounding_box.convert_format( + box2, source=bounding_box_format, target=target_format + ) + b1_x1, b1_y1, b1_x2, b1_y2 = tf.split(box1, 4, axis=-1) + b2_x1, b2_y1, b2_x2, b2_y2 = tf.split(box2, 4, axis=-1) + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + + # Intersection area + inter = tf.math.maximum( + tf.math.minimum(b1_x2, b2_x2) - tf.math.maximum(b1_x1, b2_x1), 0 + ) * tf.math.maximum( + tf.math.minimum(b1_y2, b2_y2) - tf.math.maximum(b1_y1, b2_y1), 0 + ) + + # Union Area + union = w1 * h1 + w2 * h2 - inter + eps + + # IoU + iou = inter / union + + cw = tf.math.maximum(b1_x2, b2_x2) - tf.math.minimum( + b1_x1, b2_x1 + ) # convex (smallest enclosing box) width + ch = tf.math.maximum(b1_y2, b2_y2) - tf.math.minimum( + b1_y1, b2_y1 + ) # convex height + c2 = cw**2 + ch**2 + eps # convex diagonal squared + rho2 = ( + (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2 + ) / 4 # center dist ** 2 + v = tf.pow((4 / math.pi**2) * (tf.atan(w2 / h2) - tf.atan(w1 / h1)), 2) + alpha = v / (v - iou + (1 + eps)) + + return iou - (rho2 / c2 + v * alpha) diff --git a/keras_cv/losses/ciou_loss.py b/keras_cv/losses/ciou_loss.py index 96b1d454cc..5302cfb394 100644 --- a/keras_cv/losses/ciou_loss.py +++ b/keras_cv/losses/ciou_loss.py @@ -13,12 +13,10 @@ # limitations under the License. -import math - import tensorflow as tf from tensorflow import keras -from keras_cv import bounding_box +from keras_cv.bounding_box.iou import compute_ciou class CIoULoss(keras.losses.Loss): @@ -68,55 +66,6 @@ def __init__(self, bounding_box_format, eps=1e-7, **kwargs): self.eps = eps self.bounding_box_format = bounding_box_format - def compute_ciou(self, boxes1, boxes2): - target_format = "xyxy" - if bounding_box.is_relative(self.bounding_box_format): - target_format = bounding_box.as_relative(target_format) - - boxes1 = bounding_box.convert_format( - boxes1, source=self.bounding_box_format, target=target_format - ) - - boxes2 = bounding_box.convert_format( - boxes2, source=self.bounding_box_format, target=target_format - ) - - b1_x1, b1_y1, b1_x2, b1_y2 = tf.split(boxes1, 4, axis=-1) - b2_x1, b2_y1, b2_x2, b2_y2 = tf.split(boxes2, 4, axis=-1) - w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + self.eps - w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + self.eps - - # Intersection area - inter = tf.math.maximum( - tf.math.minimum(b1_x2, b2_x2) - tf.math.maximum(b1_x1, b2_x1), 0 - ) * tf.math.maximum( - tf.math.minimum(b1_y2, b2_y2) - tf.math.maximum(b1_y1, b2_y1), 0 - ) - - # Union Area - union = w1 * h1 + w2 * h2 - inter + self.eps - - # IoU - iou = inter / union - - cw = tf.math.maximum(b1_x2, b2_x2) - tf.math.minimum( - b1_x1, b2_x1 - ) # convex (smallest enclosing box) width - ch = tf.math.maximum(b1_y2, b2_y2) - tf.math.minimum( - b1_y1, b2_y1 - ) # convex height - c2 = cw**2 + ch**2 + self.eps # convex diagonal squared - rho2 = ( - (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 - + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2 - ) / 4 # center dist ** 2 - v = tf.pow( - (4 / math.pi**2) * (tf.atan(w2 / h2) - tf.atan(w1 / h1)), 2 - ) - alpha = v / (v - iou + (1 + self.eps)) - ciou = iou - (rho2 / c2 + v * alpha) - return ciou - def call(self, y_true, y_pred): y_pred = tf.convert_to_tensor(y_pred) y_true = tf.cast(y_true, y_pred.dtype) @@ -141,7 +90,9 @@ def call(self, y_true, y_pred): f"y_pred={y_pred.shape[-2]}." ) - ciou = tf.squeeze(self.compute_ciou(y_true, y_pred), axis=-1) + ciou = tf.squeeze( + compute_ciou(y_true, y_pred, self.bounding_box_format), axis=-1 + ) return 1 - ciou def get_config(self): diff --git a/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py b/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py index 4f544d778b..360d01c145 100644 --- a/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py +++ b/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import warnings import tensorflow as tf from keras import layers @@ -19,6 +20,7 @@ import keras_cv from keras_cv import bounding_box +from keras_cv.losses.ciou_loss import CIoULoss from keras_cv.models.backbones.backbone_presets import backbone_presets from keras_cv.models.backbones.backbone_presets import ( backbone_presets_with_weights, @@ -28,9 +30,6 @@ from keras_cv.models.object_detection.yolo_v8.yolo_v8_detector_presets import ( yolo_v8_detector_presets, ) -from keras_cv.models.object_detection.yolo_v8.yolo_v8_iou_loss import ( - YOLOV8IoULoss, -) from keras_cv.models.object_detection.yolo_v8.yolo_v8_label_encoder import ( YOLOV8LabelEncoder, ) @@ -378,7 +377,7 @@ class YOLOV8Detector(Task): # Train model model.compile( classification_loss='binary_crossentropy', - box_loss='iou', + box_loss='ciou', optimizer=tf.optimizers.SGD(global_clipnorm=10.0), jit_compile=False, ) @@ -459,7 +458,7 @@ def compile( Args: box_loss: a Keras loss to use for box offset regression. A - preconfigured loss is provided when the string "iou" is passed. + preconfigured loss is provided when the string "ciou" is passed. classification_loss: a Keras loss to use for box classification. A preconfigured loss is provided when the string "binary_crossentropy" is passed. @@ -474,12 +473,18 @@ def compile( raise ValueError("User metrics not yet supported for YOLOV8") if isinstance(box_loss, str): - if box_loss == "iou": - box_loss = YOLOV8IoULoss(reduction="sum") + if box_loss == "ciou": + box_loss = CIoULoss(bounding_box_format="xyxy", reduction="sum") + elif box_loss == "iou": + warnings.warn( + "YOLOV8 recommends using CIoU loss, but was configured to " + "use standard IoU. Consider using `box_loss='ciou'` " + "instead." + ) else: raise ValueError( f"Invalid box loss for YOLOV8Detector: {box_loss}. Box " - "loss should be a keras.Loss or the string 'iou'." + "loss should be a keras.Loss or the string 'ciou'." ) if isinstance(classification_loss, str): if classification_loss == "binary_crossentropy": diff --git a/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector_test.py b/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector_test.py index e6360e8670..f6992e57c6 100644 --- a/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector_test.py +++ b/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector_test.py @@ -45,7 +45,7 @@ def test_fit(self): yolo.compile( optimizer="adam", classification_loss="binary_crossentropy", - box_loss="iou", + box_loss="ciou", ) xs, ys = _create_bounding_box_dataset(bounding_box_format) @@ -66,7 +66,7 @@ def test_throws_with_ragged_tensors(self): yolo.compile( optimizer="adam", classification_loss="binary_crossentropy", - box_loss="iou", + box_loss="ciou", ) xs, ys = _create_bounding_box_dataset(bounding_box_format) ys = bounding_box.to_ragged(ys) @@ -108,7 +108,7 @@ def test_bad_loss(self): ValueError, "Invalid classification loss", ): - yolo.compile(box_loss="iou", classification_loss="bad_loss") + yolo.compile(box_loss="ciou", classification_loss="bad_loss") @parameterized.named_parameters( ("tf_format", "tf", "model"), diff --git a/keras_cv/models/object_detection/yolo_v8/yolo_v8_iou_loss.py b/keras_cv/models/object_detection/yolo_v8/yolo_v8_iou_loss.py deleted file mode 100644 index c953f55f89..0000000000 --- a/keras_cv/models/object_detection/yolo_v8/yolo_v8_iou_loss.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2023 The KerasCV Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import tensorflow as tf -from tensorflow import keras - - -def bbox_iou(box1, box2, eps=1e-7): - b1_x1, b1_y1, b1_x2, b1_y2 = tf.split(box1, 4, axis=-1) - b2_x1, b2_y1, b2_x2, b2_y2 = tf.split(box2, 4, axis=-1) - w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps - w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps - - # Intersection area - inter = tf.math.maximum( - tf.math.minimum(b1_x2, b2_x2) - tf.math.maximum(b1_x1, b2_x1), 0 - ) * tf.math.maximum( - tf.math.minimum(b1_y2, b2_y2) - tf.math.maximum(b1_y1, b2_y1), 0 - ) - - # Union Area - union = w1 * h1 + w2 * h2 - inter + eps - - # IoU - iou = inter / union - - cw = tf.math.maximum(b1_x2, b2_x2) - tf.math.minimum( - b1_x1, b2_x1 - ) # convex (smallest enclosing box) width - ch = tf.math.maximum(b1_y2, b2_y2) - tf.math.minimum( - b1_y1, b2_y1 - ) # convex height - c2 = cw**2 + ch**2 + eps # convex diagonal squared - rho2 = ( - (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 - + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2 - ) / 4 # center dist ** 2 - v = tf.pow((4 / math.pi**2) * (tf.atan(w2 / h2) - tf.atan(w1 / h1)), 2) - alpha = v / (v - iou + (1 + eps)) - - return iou - (rho2 / c2 + v * alpha) - - -# TODO(ianstenbit): Use keras_cv.losses.IoULoss instead -# (It needs to support CIoU as well as dynamic number of boxes) -class YOLOV8IoULoss(keras.losses.Loss): - def call(self, y_true, y_pred): - # IoU loss - iou = bbox_iou(y_pred, y_true) - - return 1.0 - iou diff --git a/keras_cv/models/object_detection/yolo_v8/yolo_v8_label_encoder.py b/keras_cv/models/object_detection/yolo_v8/yolo_v8_label_encoder.py index ec92ce8cff..c3cb6b8309 100644 --- a/keras_cv/models/object_detection/yolo_v8/yolo_v8_label_encoder.py +++ b/keras_cv/models/object_detection/yolo_v8/yolo_v8_label_encoder.py @@ -19,7 +19,7 @@ from tensorflow import keras from tensorflow.keras import layers -from keras_cv.models.object_detection.yolo_v8.yolo_v8_iou_loss import bbox_iou +from keras_cv.bounding_box.iou import compute_ciou def select_highest_overlaps(mask_pos, overlaps, max_num_boxes): @@ -297,7 +297,10 @@ def get_box_metrics( gt_boxes = tf.repeat(tf.expand_dims(gt_bboxes, axis=2), na, axis=2) - iou = tf.squeeze(bbox_iou(gt_boxes, pd_boxes), axis=-1) + iou = tf.squeeze( + compute_ciou(gt_boxes, pd_boxes, bounding_box_format="xyxy"), + axis=-1, + ) iou = tf.where(iou > 0, iou, 0) iou = tf.reshape(iou, (-1, max_num_boxes, na))