diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 5e7cbb1c9d..371277465a 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -254,7 +254,7 @@ from keras_hub.src.models.sam.sam_backbone import SAMBackbone from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( - SAMImageSegmenterPreprocessor as SamImageSegmenterPreprocessor, + SAMImageSegmenterPreprocessor, ) from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor diff --git a/keras_hub/src/models/image_segmenter.py b/keras_hub/src/models/image_segmenter.py index fcd45db07a..edee64ba05 100644 --- a/keras_hub/src/models/image_segmenter.py +++ b/keras_hub/src/models/image_segmenter.py @@ -16,11 +16,6 @@ class ImageSegmenter(Task): be used to load a pre-trained config and weights. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Default compilation. - self.compile() - def compile( self, optimizer="auto", diff --git a/keras_hub/src/models/sam/__init__.py b/keras_hub/src/models/sam/__init__.py index e69de29bb2..81dbd8800e 100644 --- a/keras_hub/src/models/sam/__init__.py +++ b/keras_hub/src/models/sam/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.sam.sam_backbone import SAMBackbone +from keras_hub.src.models.sam.sam_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, SAMBackbone) diff --git a/keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py b/keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py index e47ce18614..a8dcdbe6a6 100644 --- a/keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +++ b/keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py @@ -1,12 +1,22 @@ +import keras + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_segmenter_preprocessor import ( ImageSegmenterPreprocessor, ) from keras_hub.src.models.sam.sam_backbone import SAMBackbone from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter +from keras_hub.src.utils.tensor_utils import preprocessing_function -@keras_hub_export("keras_hub.models.SamImageSegmenterPreprocessor") +@keras_hub_export("keras_hub.models.SAMImageSegmenterPreprocessor") class SAMImageSegmenterPreprocessor(ImageSegmenterPreprocessor): backbone_cls = SAMBackbone image_converter_cls = SAMImageConverter + + @preprocessing_function + def call(self, x, y=None, sample_weight=None): + images = x["images"] + if self.image_converter: + x["images"] = self.image_converter(images) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_hub/src/models/sam/sam_image_segmenter_test.py b/keras_hub/src/models/sam/sam_image_segmenter_test.py index 510c290e2c..0d36c31db2 100644 --- a/keras_hub/src/models/sam/sam_image_segmenter_test.py +++ b/keras_hub/src/models/sam/sam_image_segmenter_test.py @@ -2,7 +2,11 @@ import pytest from keras_hub.src.models.sam.sam_backbone import SAMBackbone +from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter +from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( + SAMImageSegmenterPreprocessor, +) from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone @@ -53,8 +57,13 @@ def setUp(self): prompt_encoder=self.prompt_encoder, mask_decoder=self.mask_decoder, ) + self.image_converter = SAMImageConverter( + height=self.image_size, width=self.image_size, scale=1 / 255.0 + ) + self.preprocessor = SAMImageSegmenterPreprocessor(self.image_converter) self.init_kwargs = { "backbone": self.backbone, + "preprocessor": self.preprocessor, } self.inputs = { "images": self.images, @@ -102,3 +111,16 @@ def test_end_to_end_model_predict(self): masks, iou_pred = outputs["masks"], outputs["iou_pred"] self.assertAllEqual(masks.shape, (2, 4, 32, 32)) self.assertAllEqual(iou_pred.shape, (2, 4)) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in SAMImageSegmenter.presets: + self.run_preset_test( + cls=SAMImageSegmenter, + preset=preset, + input_data=self.inputs, + expected_output_shape={ + "masks": [2, 2, 1], + "iou_pred": [2], + }, + ) diff --git a/keras_hub/src/models/sam/sam_presets.py b/keras_hub/src/models/sam/sam_presets.py index d666c0bde1..7b7986662c 100644 --- a/keras_hub/src/models/sam/sam_presets.py +++ b/keras_hub/src/models/sam/sam_presets.py @@ -9,7 +9,7 @@ "path": "sam", "model_card": "https://arxiv.org/abs/2304.02643", }, - "kaggle_handle": "kaggle://kerashub/sam/keras/sam_base_sa1b/1", + "kaggle_handle": "kaggle://kerashub/sam/keras/sam_base_sa1b/2", }, "sam_large_sa1b": { "metadata": { @@ -19,7 +19,7 @@ "path": "sam", "model_card": "https://arxiv.org/abs/2304.02643", }, - "kaggle_handle": "kaggle://kerashub/sam/keras/sam_large_sa1b/1", + "kaggle_handle": "kaggle://kerashub/sam/keras/sam_large_sa1b/2", }, "sam_huge_sa1b": { "metadata": { @@ -29,6 +29,6 @@ "path": "sam", "model_card": "https://arxiv.org/abs/2304.02643", }, - "kaggle_handle": "kaggle://kerashub/sam/keras/sam_huge_sa1b/1", + "kaggle_handle": "kaggle://kerashub/sam/keras/sam_huge_sa1b/2", }, } diff --git a/tools/checkpoint_conversion/convert_sam_checkpoints.py b/tools/checkpoint_conversion/convert_sam_checkpoints.py index 9466453cbc..08f4f4a504 100644 --- a/tools/checkpoint_conversion/convert_sam_checkpoints.py +++ b/tools/checkpoint_conversion/convert_sam_checkpoints.py @@ -5,10 +5,10 @@ from segment_anything import sam_model_registry from keras_hub.src.models.sam.sam_backbone import SAMBackbone -from keras_hub.src.models.sam.sam_image_converter import SamImageConverter +from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( - SamImageSegmenterPreprocessor, + SAMImageSegmenterPreprocessor, ) from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder @@ -57,10 +57,10 @@ def build_sam_base_model(): prompt_encoder=prompt_encoder, mask_decoder=mask_decoder, ) - sam_image_converter = SamImageConverter( + sam_image_converter = SAMImageConverter( height=1024, width=1024, scale=1.0 / 255 ) - sam_preprocessor = SamImageSegmenterPreprocessor( + sam_preprocessor = SAMImageSegmenterPreprocessor( image_converter=sam_image_converter ) sam_image_segmenter = SAMImageSegmenter( @@ -106,10 +106,10 @@ def build_sam_large_model(): prompt_encoder=prompt_encoder, mask_decoder=mask_decoder, ) - sam_image_converter = SamImageConverter( + sam_image_converter = SAMImageConverter( height=1024, width=1024, scale=1.0 / 255 ) - sam_preprocessor = SamImageSegmenterPreprocessor( + sam_preprocessor = SAMImageSegmenterPreprocessor( image_converter=sam_image_converter ) sam_image_segmenter = SAMImageSegmenter( @@ -155,10 +155,10 @@ def build_sam_huge_model(): prompt_encoder=prompt_encoder, mask_decoder=mask_decoder, ) - sam_image_converter = SamImageConverter( + sam_image_converter = SAMImageConverter( height=1024, width=1024, scale=1.0 / 255 ) - sam_preprocessor = SamImageSegmenterPreprocessor( + sam_preprocessor = SAMImageSegmenterPreprocessor( image_converter=sam_image_converter ) sam_image_segmenter = SAMImageSegmenter(