-
Notifications
You must be signed in to change notification settings - Fork 301
YOLOV8 port to keras-hub #1899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
oarriaga
wants to merge
111
commits into
keras-team:master
Choose a base branch
from
oarriaga:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
YOLOV8 port to keras-hub #1899
Changes from all commits
Commits
Show all changes
111 commits
Select commit
Hold shift + click to select a range
80e4589
Add regression loss for object detectors
oarriaga d155f28
Add missing mask function for invalid detections
oarriaga 5388655
Add multibackend non maximum supression layer
oarriaga c8ffdd1
Add abstract object detector task class
oarriaga 0c32ef7
Add YOLOV8 backbone and detector with keras-hub only imports
oarriaga f1be3d5
Add previous backbone and detector presets as template
oarriaga 6d2df25
Update API with new functions for YOLOV8
oarriaga 13ca589
Add new API modules for YOLOV8 following previous keras-cv structure
oarriaga 9756681
Move NMS layer to model directory
oarriaga 3bb7ec7
Remove backend gating for non-max supression layer and use keras ops …
oarriaga ac3ab62
Add missing args to docstrings
oarriaga 27b03c5
Remove unnecessary linter exception
oarriaga 524a280
Add better docstrings of internal nms functions arguments
oarriaga 96ea865
Rename single idx variables to more readable name
oarriaga 1f46bbd
Add generic layer test to not trainable NMS layer
oarriaga 0883fcf
Move CIOU loss inside YOLOV8 model directory
oarriaga 1d0ff43
Add changes from automatic refactorer
oarriaga 79e433f
Fix docstring for new loss location
oarriaga d6e264f
Rename YOLOV8 object detector model
oarriaga 3a16688
Change docstring to default argument type naming
oarriaga 36091a4
Add standard keras-hub model build separation
oarriaga aa8db7a
Add convertion script for YOLOV8 backbone models
oarriaga f083c83
Update conversion script to work for all YOLOV8 backbones
oarriaga 00f335a
Merge branch 'keras-team:master' into master
oarriaga 64a58be
Refactor layers and models
oarriaga 05af893
Change docstring to include only keras ops
oarriaga 67097b1
Remove tensorflow import to check for ragged tensors
oarriaga 3c3e50c
Revert back to original masking implementation
oarriaga 26d995d
Revert back to TOOD fix and fix typo
oarriaga 7b53837
Fix optimizer pass to default compile method
oarriaga f920a48
Add non-max suppression layer to API
oarriaga 18c6743
Add automatic linter changes
oarriaga 38fe6ff
Change preset names to point to the KerasHub Kaggle repository
oarriaga 0e6d443
Add non-max suppression layer to API
oarriaga 7be4e85
Start port for pascalvoc preset
oarriaga dc912bd
Merge branch 'master' into master
oarriaga 9729a98
Update public API with YOLOV8 preprocessors
oarriaga 832d850
Remove image rescaling from backbone
oarriaga 184211d
Add preprocessor field and change shape to use backticks
oarriaga db7e123
Fix import name for decode and encode functions
oarriaga 853ae2b
Add default YOLOV8 image preprocessor
oarriaga 4341d05
Add default YOLOV8 object detector preprocessor
oarriaga 7e5e28d
Extend checkpoint conversion to include backbones and object detector…
oarriaga ac4519d
Change base class to ImageObjectDetector class
oarriaga 5a9f327
Remove object detector base class
oarriaga 499a263
Remove unnecessary comments
oarriaga 843a05c
Remove from API previous base ObjectDetector Task class
oarriaga 3175b52
Add changes from automatic formatter
oarriaga 3138a46
Fix serialization to include backbone, label encoder, prediction deco…
oarriaga 6fa961c
Add preset reload for testing numerics
oarriaga 906889e
Update preset versions and register detector presets
oarriaga b3af94b
Fix unit tests
oarriaga da91ea1
Add changes from automatic formatter
oarriaga 7fc29d4
Rename block function names
oarriaga a88b246
Fix typo with shape loss docstring
oarriaga 925db4c
Fix suppression spelling typos
oarriaga 408f1f7
Fix docstring dividing typo
oarriaga 714733d
Change link to keras box formats
oarriaga 9430d33
Remove tensorflow import from test file
oarriaga b1c12c1
Add passing run task test
oarriaga 6bf964a
Fix docstring tensor shape
oarriaga 928a11b
Rename all gt to full variable names
oarriaga a2ca5d2
Add default label encoder in docstring
oarriaga 3396c92
Add check for bounding box format
oarriaga 65bcdce
Move mask invalid detection function to model directory
oarriaga 971801e
Add automatic format changes
oarriaga 65c82b5
Remove label encoder from exposed API
oarriaga 9266742
Remove public API for NonMaxSuppression YOLOV8 layer
oarriaga 9b2b20f
Remove outdated TODOs
oarriaga 73bcd55
Change epsilon value to core Keras default
oarriaga 103b734
Remove noqa from presets
oarriaga 61e84d2
Change docstring spacing to match KerasHub default
oarriaga 4c5c6a2
Fix docstring spacing
oarriaga 77c3464
Add empty line after function description
oarriaga 04f9470
Remove all noqa from YOLOV8 files
oarriaga 91a418e
Add missing markdown syntax to links and shapes
oarriaga 2ebac2c
Fix docstring to include one liner
oarriaga 5f6f105
Remove path and official name from presets
oarriaga 5ae005c
Change config update to standard format
oarriaga 08d4022
Remove bounding box init file and update import across YOLOV8
oarriaga cf343f9
Remove presets with no pre-trained weights
oarriaga 060506e
Order import order based on isort
oarriaga 9c3a996
Remove use of tf_keras_only for testing box masking function
oarriaga 355060a
Replace custom model save test for integrated run movel saving test
oarriaga a676b5f
Remove unused imports and apply automatic linter
oarriaga 56dbd82
Change model name to include full TaskName
oarriaga 0d3ca7e
Fix weight conversion to include new model name and current user API
oarriaga 1f5bf55
Merge branch 'keras-team:master' into master
oarriaga ba31cb8
Update CIOU loss to use keras bounding box utils
oarriaga 4728e1e
Remove custom keras-cv bounding box masking function
oarriaga d7fa73c
Remove custom non-maximum suppression layer
oarriaga d13c811
Update preprocessor with new abstract class module name
oarriaga b7ca66c
Update label encoder to use the keras bounding box module
oarriaga 62b49a2
Apply changes from automatic formatter
oarriaga 6611fc7
Update detector to use default NMS layer and keras bounding box module
oarriaga 1b20e19
Update tests to use the keras bounding box module
oarriaga e988804
Update presets to point to new Kaggle version
oarriaga 4d04999
Fix conversion script to use new NMS keys
oarriaga da047a7
Merge backbone and task presets into a single file
oarriaga 2e8253b
Delete comments
oarriaga 7ad98dc
Delete comment and TODO
oarriaga 80be872
Add smaller image size for forward pass test
oarriaga 170a11c
Add explicit default value for CIOU loss epsilon
oarriaga 157d097
Add comment to show how to run conversion file
oarriaga 07476e7
Add model.fit to docstring and add missing CIoULoss input
oarriaga 48d4953
Update conversion script to account for new backbone arguments
oarriaga cd13735
Update backbone arguments to default keras-hub API
oarriaga 6a36125
Update all presets with new backbone arguments
oarriaga 7d36dd8
Change parent class to pyramid backbone
oarriaga b058a6f
Add YOLOV8 backbone test
oarriaga 8b46ec3
Modify error message to be more detailed about what is expected, what…
oarriaga File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from keras_hub.src.models.yolo_v8.yolo_v8_backbone import YOLOV8Backbone | ||
from keras_hub.src.models.yolo_v8.yolo_v8_detector import ( | ||
YOLOV8ImageObjectDetector, | ||
) | ||
from keras_hub.src.models.yolo_v8.yolo_v8_presets import backbone_presets | ||
from keras_hub.src.models.yolo_v8.yolo_v8_presets import detector_presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(backbone_presets, YOLOV8Backbone) | ||
register_presets(detector_presets, YOLOV8ImageObjectDetector) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import keras | ||
from keras import ops | ||
from keras.utils.bounding_boxes import compute_ciou | ||
|
||
|
||
class CIoULoss(keras.losses.Loss): | ||
"""Implements the Complete IoU (CIoU) Loss | ||
|
||
CIoU loss is an extension of GIoU loss, which further improves the IoU | ||
optimization for object detection. CIoU loss not only penalizes the | ||
bounding box coordinates but also considers the aspect ratio and center | ||
distance of the boxes. The length of the last dimension should be 4 to | ||
represent the bounding boxes. | ||
|
||
Args: | ||
bounding_box_format: a case-insensitive string (for example, "xyxy"). | ||
Each bounding box is defined by these 4 values. For detailed | ||
information on the supported formats, see the [Keras bounding box | ||
documentation](https://github.com/keras-team/keras/blob/master/ | ||
keras/src/layers/preprocessing/image_preprocessing/ | ||
bounding_boxes/formats.py). | ||
epsilon: (optional) float, a small value added to avoid division by | ||
zero and stabilize calculations. Defaults 1e-07. | ||
|
||
References: | ||
- [CIoU paper](https://arxiv.org/pdf/2005.03572.pdf) | ||
|
||
Example: | ||
```python | ||
y_true = np.random.uniform( | ||
size=(5, 10, 4), | ||
low=0, | ||
high=10) | ||
y_pred = np.random.uniform( | ||
size=(5, 10, 4), | ||
low=0, | ||
high=10) | ||
loss = keras_hub.src.models.yolo_v8.ciou_loss.CIoULoss("xyxy") | ||
loss(y_true, y_pred).numpy() | ||
oarriaga marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
|
||
Usage with the `compile()` API: | ||
```python | ||
model.compile(optimizer="adam", loss=CIoULoss("xyxy")) | ||
model.fit(y_true, y_pred) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, bounding_box_format, epsilon=1e-07, image_shape=None, **kwargs | ||
): | ||
super().__init__(**kwargs) | ||
box_formats = [ | ||
"xywh", | ||
"center_xywh", | ||
"center_yxhw", | ||
"rel_xywh", | ||
"xyxy", | ||
"rel_xyxy", | ||
"yxyx", | ||
"rel_yxyx", | ||
] | ||
if bounding_box_format not in box_formats: | ||
raise ValueError( | ||
f"Invalid bounding box format: '{bounding_box_format}'. " | ||
f"Expected one of {box_formats}. " | ||
"Ensure that the string format is correctly spelled." | ||
) | ||
self.bounding_box_format = bounding_box_format | ||
self.epsilon = epsilon | ||
self.image_shape = image_shape | ||
|
||
def call(self, y_true, y_pred): | ||
y_pred = ops.convert_to_tensor(y_pred) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. checking if y_pred is a tensor and the dtype before converting could improve efficiency |
||
y_true = ops.cast(y_true, y_pred.dtype) | ||
|
||
if y_pred.shape[-1] != 4: | ||
raise ValueError( | ||
"CIoULoss expects y_pred.shape[-1] to be 4 to represent the " | ||
f"bounding boxes. Received y_pred.shape[-1]={y_pred.shape[-1]}." | ||
) | ||
|
||
if y_true.shape[-1] != 4: | ||
raise ValueError( | ||
"CIoULoss expects y_true.shape[-1] to be 4 to represent the " | ||
f"bounding boxes. Received y_true.shape[-1]={y_true.shape[-1]}." | ||
) | ||
|
||
if y_true.shape[-2] != y_pred.shape[-2]: | ||
raise ValueError( | ||
"CIoULoss expects number of boxes in y_pred to be equal to the " | ||
"number of boxes in y_true. Received number of boxes in " | ||
f"y_true={y_true.shape[-2]} and number of boxes in " | ||
f"y_pred={y_pred.shape[-2]}." | ||
) | ||
|
||
oarriaga marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ciou = compute_ciou( | ||
y_true, y_pred, self.bounding_box_format, self.image_shape | ||
) | ||
return 1 - ciou | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"epsilon": self.epsilon, | ||
} | ||
) | ||
return config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import numpy as np | ||
from absl.testing import parameterized | ||
|
||
from keras_hub.src.models.yolo_v8.ciou_loss import CIoULoss | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class CIoUTest(TestCase): | ||
def test_output_shape(self): | ||
y_true = np.random.uniform(size=(2, 2, 4), low=0, high=10) | ||
y_pred = np.random.uniform(size=(2, 2, 4), low=0, high=20) | ||
|
||
ciou_loss = CIoULoss(bounding_box_format="xywh") | ||
|
||
self.assertAllEqual(ciou_loss(y_true, y_pred).shape, ()) | ||
|
||
def test_output_shape_reduction_none(self): | ||
y_true = np.random.uniform(size=(2, 2, 4), low=0, high=10) | ||
y_pred = np.random.uniform(size=(2, 2, 4), low=0, high=20) | ||
|
||
ciou_loss = CIoULoss(bounding_box_format="xyxy", reduction="none") | ||
|
||
self.assertAllEqual( | ||
[2, 2], | ||
ciou_loss(y_true, y_pred).shape, | ||
) | ||
|
||
def test_output_shape_relative_formats(self): | ||
y_true = [ | ||
[0.0, 0.0, 0.1, 0.1], | ||
[0.0, 0.0, 0.2, 0.3], | ||
[0.4, 0.5, 0.5, 0.6], | ||
[0.2, 0.2, 0.3, 0.3], | ||
] | ||
|
||
y_pred = [ | ||
[0.0, 0.0, 0.5, 0.6], | ||
[0.0, 0.0, 0.7, 0.3], | ||
[0.4, 0.5, 0.5, 0.6], | ||
[0.2, 0.1, 0.3, 0.3], | ||
] | ||
|
||
ciou_loss = CIoULoss(bounding_box_format="xyxy") | ||
|
||
self.assertAllEqual(ciou_loss(y_true, y_pred).shape, ()) | ||
|
||
@parameterized.named_parameters( | ||
("xyxy", "xyxy"), | ||
("rel_xyxy", "rel_xyxy"), | ||
) | ||
def test_output_value(self, name): | ||
y_true = [ | ||
[0, 0, 1, 1], | ||
[0, 0, 2, 3], | ||
[4, 5, 3, 6], | ||
[2, 2, 3, 3], | ||
] | ||
|
||
y_pred = [ | ||
[0, 0, 5, 6], | ||
[0, 0, 7, 3], | ||
[4, 5, 5, 6], | ||
[2, 1, 3, 3], | ||
] | ||
expected_loss = 1.03202 | ||
divyashreepathihalli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ciou_loss = CIoULoss(bounding_box_format="xyxy") | ||
if name == "rel_xyxy": | ||
scale_factor = 1 / 640.0 | ||
y_true = np.array(y_true) * scale_factor | ||
y_pred = np.array(y_pred) * scale_factor | ||
|
||
self.assertAllClose( | ||
ciou_loss(y_true, y_pred), expected_loss, atol=0.005 | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from keras import ops | ||
oarriaga marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from keras.layers import Input | ||
from keras.layers import MaxPooling2D | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone | ||
from keras_hub.src.models.yolo_v8.yolo_v8_layers import apply_conv_bn | ||
from keras_hub.src.models.yolo_v8.yolo_v8_layers import apply_CSP | ||
|
||
|
||
def apply_stem(x, stem_width, activation): | ||
x = apply_conv_bn(x, stem_width // 2, 3, 2, activation, "stem_1") | ||
x = apply_conv_bn(x, stem_width, 3, 2, activation, "stem_2") | ||
return x | ||
|
||
|
||
def apply_fast_SPP(x, pool_size=5, activation="swish", name="spp_fast"): | ||
input_channels = x.shape[-1] | ||
hidden_channels = int(input_channels // 2) | ||
x = apply_conv_bn(x, hidden_channels, 1, 1, activation, f"{name}_pre") | ||
pool_kwargs = {"strides": 1, "padding": "same"} | ||
p1 = MaxPooling2D(pool_size, **pool_kwargs, name=f"{name}_pool1")(x) | ||
p2 = MaxPooling2D(pool_size, **pool_kwargs, name=f"{name}_pool2")(p1) | ||
p3 = MaxPooling2D(pool_size, **pool_kwargs, name=f"{name}_pool3")(p2) | ||
x = ops.concatenate([x, p1, p2, p3], axis=-1) | ||
x = apply_conv_bn(x, input_channels, 1, 1, activation, f"{name}_output") | ||
return x | ||
|
||
|
||
def apply_yolo_block(x, block_arg, channels, depth, block_depth, activation): | ||
name = f"stack{block_arg + 1}" | ||
if block_arg >= 1: | ||
x = apply_conv_bn(x, channels, 3, 2, activation, f"{name}_downsample") | ||
x = apply_CSP(x, -1, depth, True, 0.5, activation, f"{name}_c2f") | ||
if block_arg == len(block_depth) - 1: | ||
x = apply_fast_SPP(x, 5, activation, f"{name}_spp_fast") | ||
return x | ||
|
||
|
||
def stackwise_yolo_blocks(x, stackwise_depth, stackwise_channels, activation): | ||
pyramid_level_inputs = {"P1": get_tensor_input_name(x)} | ||
iterator = enumerate(zip(stackwise_channels, stackwise_depth)) | ||
block_args = (stackwise_depth, activation) | ||
for stack_arg, (channel, depth) in iterator: | ||
x = apply_yolo_block(x, stack_arg, channel, depth, *block_args) | ||
pyramid_level_inputs[f"P{stack_arg + 2}"] = get_tensor_input_name(x) | ||
return x, pyramid_level_inputs | ||
|
||
|
||
def get_tensor_input_name(tensor): | ||
return tensor._keras_history.operation.name | ||
|
||
|
||
def build_pyramid_outputs(model, level_to_layer_name): | ||
pyramid_outputs = {} | ||
for level_name, layer_name in level_to_layer_name.items(): | ||
pyramid_outputs[level_name] = model.get_layer(layer_name).output | ||
return pyramid_outputs | ||
|
||
|
||
@keras_hub_export("keras_hub.models.YOLOV8Backbone") | ||
class YOLOV8Backbone(FeaturePyramidBackbone): | ||
"""Implements the YOLOV8 backbone for object detection. | ||
|
||
This backbone is a variant of the `CSPDarkNetBackbone` architecture. | ||
|
||
For transfer learning use cases, make sure to read the | ||
[guide to transfer learning & fine-tuning](https://keras.io/guides/ | ||
transfer_learning/). | ||
|
||
Args: | ||
stackwise_channels: A list of int. The number of channels for each dark | ||
level in the model. | ||
stackwise_depth: A list of int. The depth for each dark level in the | ||
model. | ||
include_rescaling: bool. Rescale the inputs. If set to | ||
True, inputs will be passed through a `Rescaling(1/255.0)` layer. | ||
activation: str. The activation functions to use in the backbone to | ||
use in the CSPDarkNet blocks. Defaults to "swish". | ||
image_shape: optional shape tuple, defaults to `(None, None, 3)`. | ||
|
||
Returns: | ||
A `keras.Model` instance. | ||
|
||
Examples: | ||
```python | ||
input_data = tf.ones(shape=(8, 224, 224, 3)) | ||
|
||
# Pretrained backbone | ||
model = keras_hub.models.YOLOV8Backbone.from_preset( | ||
"yolo_v8_xs_backbone_coco" | ||
) | ||
output = model(input_data) | ||
|
||
# Randomly initialized backbone with a custom config | ||
model = keras_hub.models.YOLOV8Backbone( | ||
stackwise_channels=[128, 256, 512, 1024], | ||
stackwise_depth=[3, 9, 9, 3], | ||
) | ||
output = model(input_data) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
stackwise_channels, | ||
stackwise_depth, | ||
activation="swish", | ||
image_shape=(None, None, 3), | ||
**kwargs, | ||
): | ||
inputs = Input(shape=image_shape) | ||
stem_width = stackwise_channels[0] | ||
x = apply_stem(inputs, stem_width, activation) | ||
x, pyramid_level_inputs = stackwise_yolo_blocks( | ||
x, stackwise_depth, stackwise_channels, activation | ||
) | ||
super().__init__(inputs=inputs, outputs=x, **kwargs) | ||
self.pyramid_level_inputs = pyramid_level_inputs | ||
self.pyramid_outputs = build_pyramid_outputs(self, pyramid_level_inputs) | ||
self.stackwise_channels = stackwise_channels | ||
self.stackwise_depth = stackwise_depth | ||
self.activation = activation | ||
self.image_shape = image_shape | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"image_shape": self.image_shape, | ||
"stackwise_channels": self.stackwise_channels, | ||
"stackwise_depth": self.stackwise_depth, | ||
"activation": self.activation, | ||
} | ||
) | ||
return config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from keras_hub.src.models.yolo_v8.yolo_v8_backbone import YOLOV8Backbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class YOLOV8BackboneTest(TestCase): | ||
def setUp(self): | ||
self.init_kwargs = { | ||
"stackwise_channels": [64, 128, 256, 512], | ||
"stackwise_depth": [1, 2, 2, 1], | ||
"activation": "swish", | ||
"image_shape": (32, 32, 3), | ||
} | ||
self.input_size = 32 | ||
self.input_data = np.ones( | ||
(2, self.input_size, self.input_size, 3), dtype="float32" | ||
) | ||
|
||
def test_backbone_basics(self): | ||
self.run_vision_backbone_test( | ||
cls=YOLOV8Backbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape=(2, 1, 1, 512), | ||
expected_pyramid_output_keys=["P1", "P2", "P3", "P4", "P5"], | ||
expected_pyramid_image_sizes=[ | ||
(8, 8), | ||
(8, 8), | ||
(4, 4), | ||
(2, 2), | ||
(1, 1), | ||
], | ||
run_mixed_precision_check=False, | ||
run_data_format_check=False, | ||
) | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=YOLOV8Backbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.