diff --git a/keras_hub/src/models/vit_det/vit_det_backbone.py b/keras_hub/src/models/vit_det/vit_det_backbone.py index 94f7887c44..7d5883409e 100644 --- a/keras_hub/src/models/vit_det/vit_det_backbone.py +++ b/keras_hub/src/models/vit_det/vit_det_backbone.py @@ -31,7 +31,7 @@ class ViTDetBackbone(Backbone): global_attention_layer_indices (list): Indexes for blocks using global attention. image_shape (tuple[int], optional): The size of the input image in - `(H, W, C)` format. Defaults to `(1024, 1024, 3)`. + `(H, W, C)` format. Defaults to `(None, None, 3)`. patch_size (int, optional): the patch size to be supplied to the Patching layer to turn input images into a flattened sequence of patches. Defaults to `16`. @@ -79,7 +79,7 @@ def __init__( intermediate_dim, num_heads, global_attention_layer_indices, - image_shape=(1024, 1024, 3), + image_shape=(None, None, 3), patch_size=16, num_output_channels=256, use_bias=True,