Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
from keras_hub.src.models.resnet.resnet_image_converter import (
ResNetImageConverter,
)
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
from keras_hub.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
)
3 changes: 3 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
from keras_hub.src.models.image_classifier_preprocessor import (
ImageClassifierPreprocessor,
)
from keras_hub.src.models.image_segmenter import ImageSegmenter
from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM
from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import (
Expand Down Expand Up @@ -255,6 +256,8 @@
RobertaTextClassifierPreprocessor as RobertaPreprocessor,
)
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
from keras_hub.src.models.t5.t5_backbone import T5Backbone
Expand Down
86 changes: 86 additions & 0 deletions keras_hub/src/models/image_segmenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.task import Task


@keras_hub_export("keras_hub.models.ImageSegmenter")
class ImageSegmenter(Task):
"""Base class for all image segmentation tasks.

`ImageSegmenter` tasks wrap a `keras_hub.models.Task` and
a `keras_hub.models.Preprocessor` to create a model that can be used for
image segmentation.

All `ImageSegmenter` tasks include a `from_preset()` constructor which can
be used to load a pre-trained config and weights.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default compilation.
self.compile()

def compile(
self,
optimizer="auto",
loss="auto",
*,
metrics="auto",
**kwargs,
):
"""Configures the `ImageSegmenter` task for training.

The `ImageSegmenter` task extends the default compilation signature of
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
`metrics`. To override these defaults, pass any value
to these arguments during compilation.

Args:
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
instance. Defaults to `"auto"`, which uses the default optimizer
for the given model and task. See `keras.Model.compile` and
`keras.optimizers` for more info on possible `optimizer` values.
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
Defaults to `"auto"`, where a
`keras.losses.SparseCategoricalCrossentropy` loss will be
applied for the classification task. See
`keras.Model.compile` and `keras.losses` for more info on
possible `loss` values.
metrics: `"auto"`, or a list of metrics to be evaluated by
the model during training and testing. Defaults to `"auto"`,
where a `keras.metrics.SparseCategoricalAccuracy` will be
applied to track the accuracy of the model during training.
See `keras.Model.compile` and `keras.metrics` for
more info on possible `metrics` values.
**kwargs: See `keras.Model.compile` for a full list of arguments
supported by the compile method.
"""
if optimizer == "auto":
optimizer = keras.optimizers.Adam(5e-5)
if loss == "auto":
activation = getattr(self, "activation", None)
activation = keras.activations.get(activation)
from_logits = activation != keras.activations.softmax
loss = keras.losses.CategoricalCrossentropy(from_logits=from_logits)
if metrics == "auto":
metrics = [keras.metrics.CategoricalAccuracy()]
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
**kwargs,
)
13 changes: 13 additions & 0 deletions keras_hub/src/models/sam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
153 changes: 153 additions & 0 deletions keras_hub/src/models/sam/sam_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone


@keras_hub_export("keras_hub.models.SAMBackbone")
class SAMBackbone(Backbone):
"""A backbone for the Segment Anything Model (SAM).

Args:
image_encoder: `keras_hub.models.ViTDetBackbone`. A feature extractor for
the input images.
prompt_encoder: `keras_hub.layers.SAMPromptEncoder`. A Keras layer to
compute embeddings for points, box, and mask prompt.
mask_decoder: `keras_hub.layers.SAMMaskDecoder`. A Keras layer to
generate segmentation masks given the embeddings generated by the
backbone and the prompt encoder.
dtype: The dtype of the layer weights.

Example:
```python
image_size=128
batch_size=2
input_data = {
"images": np.ones(
(batch_size, image_size, image_size, 3),
dtype="float32",
),
"points": np.ones((batch_size, 1, 2), dtype="float32"),
"labels": np.ones((batch_size, 1), dtype="float32"),
"boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
"masks": np.zeros(
(batch_size, 0, image_size, image_size, 1)
),
}
image_encoder = keras_hub.models.ViTDetBackbone(
hidden_size=16,
num_layers=16,
intermediate_dim=16 * 4,
num_heads=16,
global_attention_layer_indices=[2, 5, 8, 11],
patch_size=16,
num_output_channels=8,
window_size=2,
image_shape=(image_size, image_size, 3),
)
prompt_encoder = keras_hub.layers.SAMPromptEncoder(
hidden_size=8,
image_embedding_size=(8, 8),
input_image_size=(
image_size,
image_size,
),
mask_in_channels=16,
)
mask_decoder = keras_hub.layers.SAMMaskDecoder(
num_layers=2,
hidden_size=8,
intermediate_dim=32,
num_heads=8,
embedding_dim=8,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=8,
)
backbone = keras_hub.models.SAMBackbone(
image_encoder=image_encoder,
prompt_encoder=prompt_encoder,
mask_decoder=mask_decoder,
image_shape=(image_size, image_size, 3),
)
backbone(input_data)
```
"""

def __init__(
self,
image_encoder,
prompt_encoder,
mask_decoder,
dtype=None,
**kwargs,
):
# === Layers ===
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
# === Functional model
image_input = self.image_encoder.input

inputs = {
"images": image_input,
"points": keras.Input(shape=[None, 2], name="points"),
"labels": keras.Input(shape=[None], name="labels"),
"boxes": keras.Input(shape=[None, 2, 2], name="boxes"),
"masks": keras.Input(shape=[None, None, None, 1], name="masks"),
}
image_embeddings = self.image_encoder.output
prompt_embeddings = self.prompt_encoder(**inputs)
outputs = {
"image_embeddings": image_embeddings,
}
outputs.update(prompt_embeddings)
super().__init__(
inputs=inputs,
outputs=outputs,
dtype=dtype,
**kwargs,
)

def get_config(self):
config = super().get_config()
config.update(
{
"image_encoder": keras.layers.serialize(self.image_encoder),
"prompt_encoder": keras.layers.serialize(self.prompt_encoder),
"mask_decoder": keras.layers.serialize(self.mask_decoder),
}
)
return config

@classmethod
def from_config(cls, config):
config.update(
{
"image_encoder": keras.layers.deserialize(
config["image_encoder"]
),
"prompt_encoder": keras.layers.deserialize(
config["prompt_encoder"]
),
"mask_decoder": keras.layers.deserialize(
config["mask_decoder"]
),
}
)

return super().from_config(config)
90 changes: 90 additions & 0 deletions keras_hub/src/models/sam/sam_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from keras_hub.src.models.sam.sam_backbone import SAMBackbone
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
from keras_hub.src.tests.test_case import TestCase


class SAMBackboneTest(TestCase):
def setUp(self):
self.batch_size = 2
self.image_size = 16
self.image_encoder = ViTDetBackbone(
hidden_size=16,
num_layers=16,
intermediate_dim=16 * 4,
num_heads=16,
global_attention_layer_indices=[2, 5, 8, 11],
patch_size=16,
num_output_channels=8,
window_size=2,
image_shape=(self.image_size, self.image_size, 3),
)
self.prompt_encoder = SAMPromptEncoder(
hidden_size=8,
image_embedding_size=(8, 8),
input_image_size=(
self.image_size,
self.image_size,
),
mask_in_channels=16,
)
self.mask_decoder = SAMMaskDecoder(
num_layers=2,
hidden_size=8,
intermediate_dim=32,
num_heads=8,
embedding_dim=8,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=8,
)
self.init_kwargs = {
"image_encoder": self.image_encoder,
"prompt_encoder": self.prompt_encoder,
"mask_decoder": self.mask_decoder,
}

self.input_data = {
"images": np.ones(
(self.batch_size, self.image_size, self.image_size, 3),
dtype="float32",
),
"points": np.ones((self.batch_size, 1, 2), dtype="float32"),
"labels": np.ones((self.batch_size, 1), dtype="float32"),
"boxes": np.ones((self.batch_size, 1, 2, 2), dtype="float32"),
"masks": np.zeros(
(self.batch_size, 0, self.image_size, self.image_size, 1)
),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=SAMBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape={
"image_embeddings": (2, 1, 1, 8),
"prompt_sparse_embeddings": (2, 3, 8),
"prompt_dense_embeddings": (2, 8, 8, 8),
"prompt_dense_positional_embeddings": (1, 8, 8, 8),
},
run_mixed_precision_check=False,
run_quantization_check=False,
)
Loading