Skip to content

Commit

Permalink
Update YOLOv8 Loss (#1858)
Browse files Browse the repository at this point in the history
* Update YOLOv8 Loss

* Fix Tests

* Update

* Update

* compute_ciou

* Update

* Docs

* Update Loss

* format

* format

* format

* Fix

* Update

* Format

* Fix

* Format

* Format

* Fix
  • Loading branch information
IMvision12 committed Jun 8, 2023
1 parent fb0a28d commit 078ec30
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 129 deletions.
1 change: 1 addition & 0 deletions keras_cv/bounding_box/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from keras_cv.bounding_box.formats import XYWH
from keras_cv.bounding_box.formats import XYXY
from keras_cv.bounding_box.formats import YXYX
from keras_cv.bounding_box.iou import compute_ciou
from keras_cv.bounding_box.iou import compute_iou
from keras_cv.bounding_box.mask_invalid_detections import (
mask_invalid_detections,
Expand Down
74 changes: 74 additions & 0 deletions keras_cv/bounding_box/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains functions to compute ious of bounding boxes."""
import math

import tensorflow as tf

from keras_cv import bounding_box
Expand Down Expand Up @@ -172,3 +174,75 @@ def compute_iou(
)
iou_lookup_table = tf.where(background_mask, mask_val_t, res)
return iou_lookup_table


def compute_ciou(box1, box2, bounding_box_format, eps=1e-7):
"""
Computes the Complete IoU (CIoU) between two bounding boxes or between
two batches of bounding boxes.
CIoU loss is an extension of GIoU loss, which further improves the IoU
optimization for object detection. CIoU loss not only penalizes the
bounding box coordinates but also considers the aspect ratio and center
distance of the boxes. The length of the last dimension should be 4 to
represent the bounding boxes.
Args:
box1 (tf.Tensor): Tensor representing the first bounding box with
shape (..., 4).
box2 (tf.Tensor): Tensor representing the second bounding box with
shape (..., 4).
bounding_box_format: a case-insensitive string (for example, "xyxy").
Each bounding box is defined by these 4 values. For detailed
information on the supported formats, see the [KerasCV bounding box
documentation](https://keras.io/api/keras_cv/bounding_box/formats/).
eps (float, optional): A small value to avoid division by zero. Default
is 1e-7.
Returns:
tf.Tensor: The CIoU distance between the two bounding boxes.
"""
target_format = "xyxy"
if bounding_box.is_relative(bounding_box_format):
target_format = bounding_box.as_relative(target_format)

box1 = bounding_box.convert_format(
box1, source=bounding_box_format, target=target_format
)

box2 = bounding_box.convert_format(
box2, source=bounding_box_format, target=target_format
)
b1_x1, b1_y1, b1_x2, b1_y2 = tf.split(box1, 4, axis=-1)
b2_x1, b2_y1, b2_x2, b2_y2 = tf.split(box2, 4, axis=-1)
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps

# Intersection area
inter = tf.math.maximum(
tf.math.minimum(b1_x2, b2_x2) - tf.math.maximum(b1_x1, b2_x1), 0
) * tf.math.maximum(
tf.math.minimum(b1_y2, b2_y2) - tf.math.maximum(b1_y1, b2_y1), 0
)

# Union Area
union = w1 * h1 + w2 * h2 - inter + eps

# IoU
iou = inter / union

cw = tf.math.maximum(b1_x2, b2_x2) - tf.math.minimum(
b1_x1, b2_x1
) # convex (smallest enclosing box) width
ch = tf.math.maximum(b1_y2, b2_y2) - tf.math.minimum(
b1_y1, b2_y1
) # convex height
c2 = cw**2 + ch**2 + eps # convex diagonal squared
rho2 = (
(b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
) / 4 # center dist ** 2
v = tf.pow((4 / math.pi**2) * (tf.atan(w2 / h2) - tf.atan(w1 / h1)), 2)
alpha = v / (v - iou + (1 + eps))

return iou - (rho2 / c2 + v * alpha)
57 changes: 4 additions & 53 deletions keras_cv/losses/ciou_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
# limitations under the License.


import math

import tensorflow as tf
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.bounding_box.iou import compute_ciou


class CIoULoss(keras.losses.Loss):
Expand Down Expand Up @@ -68,55 +66,6 @@ def __init__(self, bounding_box_format, eps=1e-7, **kwargs):
self.eps = eps
self.bounding_box_format = bounding_box_format

def compute_ciou(self, boxes1, boxes2):
target_format = "xyxy"
if bounding_box.is_relative(self.bounding_box_format):
target_format = bounding_box.as_relative(target_format)

boxes1 = bounding_box.convert_format(
boxes1, source=self.bounding_box_format, target=target_format
)

boxes2 = bounding_box.convert_format(
boxes2, source=self.bounding_box_format, target=target_format
)

b1_x1, b1_y1, b1_x2, b1_y2 = tf.split(boxes1, 4, axis=-1)
b2_x1, b2_y1, b2_x2, b2_y2 = tf.split(boxes2, 4, axis=-1)
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + self.eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + self.eps

# Intersection area
inter = tf.math.maximum(
tf.math.minimum(b1_x2, b2_x2) - tf.math.maximum(b1_x1, b2_x1), 0
) * tf.math.maximum(
tf.math.minimum(b1_y2, b2_y2) - tf.math.maximum(b1_y1, b2_y1), 0
)

# Union Area
union = w1 * h1 + w2 * h2 - inter + self.eps

# IoU
iou = inter / union

cw = tf.math.maximum(b1_x2, b2_x2) - tf.math.minimum(
b1_x1, b2_x1
) # convex (smallest enclosing box) width
ch = tf.math.maximum(b1_y2, b2_y2) - tf.math.minimum(
b1_y1, b2_y1
) # convex height
c2 = cw**2 + ch**2 + self.eps # convex diagonal squared
rho2 = (
(b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
) / 4 # center dist ** 2
v = tf.pow(
(4 / math.pi**2) * (tf.atan(w2 / h2) - tf.atan(w1 / h1)), 2
)
alpha = v / (v - iou + (1 + self.eps))
ciou = iou - (rho2 / c2 + v * alpha)
return ciou

def call(self, y_true, y_pred):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
Expand All @@ -141,7 +90,9 @@ def call(self, y_true, y_pred):
f"y_pred={y_pred.shape[-2]}."
)

ciou = tf.squeeze(self.compute_ciou(y_true, y_pred), axis=-1)
ciou = tf.squeeze(
compute_ciou(y_true, y_pred, self.bounding_box_format), axis=-1
)
return 1 - ciou

def get_config(self):
Expand Down
21 changes: 13 additions & 8 deletions keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import warnings

import tensorflow as tf
from keras import layers
from tensorflow import keras

import keras_cv
from keras_cv import bounding_box
from keras_cv.losses.ciou_loss import CIoULoss
from keras_cv.models.backbones.backbone_presets import backbone_presets
from keras_cv.models.backbones.backbone_presets import (
backbone_presets_with_weights,
Expand All @@ -28,9 +30,6 @@
from keras_cv.models.object_detection.yolo_v8.yolo_v8_detector_presets import (
yolo_v8_detector_presets,
)
from keras_cv.models.object_detection.yolo_v8.yolo_v8_iou_loss import (
YOLOV8IoULoss,
)
from keras_cv.models.object_detection.yolo_v8.yolo_v8_label_encoder import (
YOLOV8LabelEncoder,
)
Expand Down Expand Up @@ -378,7 +377,7 @@ class YOLOV8Detector(Task):
# Train model
model.compile(
classification_loss='binary_crossentropy',
box_loss='iou',
box_loss='ciou',
optimizer=tf.optimizers.SGD(global_clipnorm=10.0),
jit_compile=False,
)
Expand Down Expand Up @@ -459,7 +458,7 @@ def compile(
Args:
box_loss: a Keras loss to use for box offset regression. A
preconfigured loss is provided when the string "iou" is passed.
preconfigured loss is provided when the string "ciou" is passed.
classification_loss: a Keras loss to use for box classification. A
preconfigured loss is provided when the string
"binary_crossentropy" is passed.
Expand All @@ -474,12 +473,18 @@ def compile(
raise ValueError("User metrics not yet supported for YOLOV8")

if isinstance(box_loss, str):
if box_loss == "iou":
box_loss = YOLOV8IoULoss(reduction="sum")
if box_loss == "ciou":
box_loss = CIoULoss(bounding_box_format="xyxy", reduction="sum")
elif box_loss == "iou":
warnings.warn(
"YOLOV8 recommends using CIoU loss, but was configured to "
"use standard IoU. Consider using `box_loss='ciou'` "
"instead."
)
else:
raise ValueError(
f"Invalid box loss for YOLOV8Detector: {box_loss}. Box "
"loss should be a keras.Loss or the string 'iou'."
"loss should be a keras.Loss or the string 'ciou'."
)
if isinstance(classification_loss, str):
if classification_loss == "binary_crossentropy":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_fit(self):
yolo.compile(
optimizer="adam",
classification_loss="binary_crossentropy",
box_loss="iou",
box_loss="ciou",
)
xs, ys = _create_bounding_box_dataset(bounding_box_format)

Expand All @@ -66,7 +66,7 @@ def test_throws_with_ragged_tensors(self):
yolo.compile(
optimizer="adam",
classification_loss="binary_crossentropy",
box_loss="iou",
box_loss="ciou",
)
xs, ys = _create_bounding_box_dataset(bounding_box_format)
ys = bounding_box.to_ragged(ys)
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_bad_loss(self):
ValueError,
"Invalid classification loss",
):
yolo.compile(box_loss="iou", classification_loss="bad_loss")
yolo.compile(box_loss="ciou", classification_loss="bad_loss")

@parameterized.named_parameters(
("tf_format", "tf", "model"),
Expand Down
63 changes: 0 additions & 63 deletions keras_cv/models/object_detection/yolo_v8/yolo_v8_iou_loss.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tensorflow import keras
from tensorflow.keras import layers

from keras_cv.models.object_detection.yolo_v8.yolo_v8_iou_loss import bbox_iou
from keras_cv.bounding_box.iou import compute_ciou


def select_highest_overlaps(mask_pos, overlaps, max_num_boxes):
Expand Down Expand Up @@ -297,7 +297,10 @@ def get_box_metrics(

gt_boxes = tf.repeat(tf.expand_dims(gt_bboxes, axis=2), na, axis=2)

iou = tf.squeeze(bbox_iou(gt_boxes, pd_boxes), axis=-1)
iou = tf.squeeze(
compute_ciou(gt_boxes, pd_boxes, bounding_box_format="xyxy"),
axis=-1,
)
iou = tf.where(iou > 0, iou, 0)

iou = tf.reshape(iou, (-1, max_num_boxes, na))
Expand Down

0 comments on commit 078ec30

Please sign in to comment.