Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
SAMImageSegmenterPreprocessor as SamImageSegmenterPreprocessor,
SAMImageSegmenterPreprocessor,
)
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
Expand Down
5 changes: 0 additions & 5 deletions keras_hub/src/models/image_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ class ImageSegmenter(Task):
be used to load a pre-trained config and weights.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default compilation.
self.compile()

def compile(
self,
optimizer="auto",
Expand Down
5 changes: 5 additions & 0 deletions keras_hub/src/models/sam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
from keras_hub.src.models.sam.sam_presets import backbone_presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, SAMBackbone)
12 changes: 11 additions & 1 deletion keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.image_segmenter_preprocessor import (
ImageSegmenterPreprocessor,
)
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
from keras_hub.src.utils.tensor_utils import preprocessing_function


@keras_hub_export("keras_hub.models.SamImageSegmenterPreprocessor")
@keras_hub_export("keras_hub.models.SAMImageSegmenterPreprocessor")
class SAMImageSegmenterPreprocessor(ImageSegmenterPreprocessor):
backbone_cls = SAMBackbone
image_converter_cls = SAMImageConverter

@preprocessing_function
def call(self, x, y=None, sample_weight=None):
images = x["images"]
if self.image_converter:
x["images"] = self.image_converter(images)
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
22 changes: 22 additions & 0 deletions keras_hub/src/models/sam/sam_image_segmenter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import pytest

from keras_hub.src.models.sam.sam_backbone import SAMBackbone
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
SAMImageSegmenterPreprocessor,
)
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
Expand Down Expand Up @@ -53,8 +57,13 @@ def setUp(self):
prompt_encoder=self.prompt_encoder,
mask_decoder=self.mask_decoder,
)
self.image_converter = SAMImageConverter(
height=self.image_size, width=self.image_size, scale=1 / 255.0
)
self.preprocessor = SAMImageSegmenterPreprocessor(self.image_converter)
self.init_kwargs = {
"backbone": self.backbone,
"preprocessor": self.preprocessor,
}
self.inputs = {
"images": self.images,
Expand Down Expand Up @@ -102,3 +111,16 @@ def test_end_to_end_model_predict(self):
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
self.assertAllEqual(masks.shape, (2, 4, 32, 32))
self.assertAllEqual(iou_pred.shape, (2, 4))

@pytest.mark.extra_large
def test_all_presets(self):
for preset in SAMImageSegmenter.presets:
self.run_preset_test(
cls=SAMImageSegmenter,
preset=preset,
input_data=self.inputs,
expected_output_shape={
"masks": [2, 2, 1],
"iou_pred": [2],
},
)
6 changes: 3 additions & 3 deletions keras_hub/src/models/sam/sam_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"path": "sam",
"model_card": "https://arxiv.org/abs/2304.02643",
},
"kaggle_handle": "kaggle://kerashub/sam/keras/sam_base_sa1b/1",
"kaggle_handle": "kaggle://kerashub/sam/keras/sam_base_sa1b/2",
},
"sam_large_sa1b": {
"metadata": {
Expand All @@ -19,7 +19,7 @@
"path": "sam",
"model_card": "https://arxiv.org/abs/2304.02643",
},
"kaggle_handle": "kaggle://kerashub/sam/keras/sam_large_sa1b/1",
"kaggle_handle": "kaggle://kerashub/sam/keras/sam_large_sa1b/2",
},
"sam_huge_sa1b": {
"metadata": {
Expand All @@ -29,6 +29,6 @@
"path": "sam",
"model_card": "https://arxiv.org/abs/2304.02643",
},
"kaggle_handle": "kaggle://kerashub/sam/keras/sam_huge_sa1b/1",
"kaggle_handle": "kaggle://kerashub/sam/keras/sam_huge_sa1b/2",
},
}
16 changes: 8 additions & 8 deletions tools/checkpoint_conversion/convert_sam_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from segment_anything import sam_model_registry

from keras_hub.src.models.sam.sam_backbone import SAMBackbone
from keras_hub.src.models.sam.sam_image_converter import SamImageConverter
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
SamImageSegmenterPreprocessor,
SAMImageSegmenterPreprocessor,
)
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
Expand Down Expand Up @@ -57,10 +57,10 @@ def build_sam_base_model():
prompt_encoder=prompt_encoder,
mask_decoder=mask_decoder,
)
sam_image_converter = SamImageConverter(
sam_image_converter = SAMImageConverter(
height=1024, width=1024, scale=1.0 / 255
)
sam_preprocessor = SamImageSegmenterPreprocessor(
sam_preprocessor = SAMImageSegmenterPreprocessor(
image_converter=sam_image_converter
)
sam_image_segmenter = SAMImageSegmenter(
Expand Down Expand Up @@ -106,10 +106,10 @@ def build_sam_large_model():
prompt_encoder=prompt_encoder,
mask_decoder=mask_decoder,
)
sam_image_converter = SamImageConverter(
sam_image_converter = SAMImageConverter(
height=1024, width=1024, scale=1.0 / 255
)
sam_preprocessor = SamImageSegmenterPreprocessor(
sam_preprocessor = SAMImageSegmenterPreprocessor(
image_converter=sam_image_converter
)
sam_image_segmenter = SAMImageSegmenter(
Expand Down Expand Up @@ -155,10 +155,10 @@ def build_sam_huge_model():
prompt_encoder=prompt_encoder,
mask_decoder=mask_decoder,
)
sam_image_converter = SamImageConverter(
sam_image_converter = SAMImageConverter(
height=1024, width=1024, scale=1.0 / 255
)
sam_preprocessor = SamImageSegmenterPreprocessor(
sam_preprocessor = SAMImageSegmenterPreprocessor(
image_converter=sam_image_converter
)
sam_image_segmenter = SAMImageSegmenter(
Expand Down