From efd1042665b875457467ce1c855d95116f733ede Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Tue, 12 Nov 2024 14:20:10 -0800 Subject: [PATCH 01/15] adding initial basnet files --- keras_hub/src/models/basnet/__init__.py | 5 + keras_hub/src/models/basnet/basnet.py | 371 ++++++++++++++++++ .../models/basnet/basnet_image_converter.py | 8 + .../src/models/basnet/basnet_preprocessor.py | 14 + keras_hub/src/models/basnet/basnet_presets.py | 22 ++ keras_hub/src/models/basnet/basnet_test.py | 41 ++ 6 files changed, 461 insertions(+) create mode 100644 keras_hub/src/models/basnet/__init__.py create mode 100644 keras_hub/src/models/basnet/basnet.py create mode 100644 keras_hub/src/models/basnet/basnet_image_converter.py create mode 100644 keras_hub/src/models/basnet/basnet_preprocessor.py create mode 100644 keras_hub/src/models/basnet/basnet_presets.py create mode 100644 keras_hub/src/models/basnet/basnet_test.py diff --git a/keras_hub/src/models/basnet/__init__.py b/keras_hub/src/models/basnet/__init__.py new file mode 100644 index 0000000000..9a49eacce0 --- /dev/null +++ b/keras_hub/src/models/basnet/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.basnet.basnet import BASNet +from keras_hub.src.models.basnet.basnet_presets import basnet_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(basnet_presets, BASNet) diff --git a/keras_hub/src/models/basnet/basnet.py b/keras_hub/src/models/basnet/basnet.py new file mode 100644 index 0000000000..d7c44c17bc --- /dev/null +++ b/keras_hub/src/models/basnet/basnet.py @@ -0,0 +1,371 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor +from keras_hub.src.models.image_segmenter import ImageSegmenter +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.models.resnet.resnet_backbone import ( + apply_basic_block as resnet_basic_block, +) + + +@keras_hub_export( + [ + "keras_hub.models.BASNet", + ] +) +class BASNet(ImageSegmenter): + """ + A Keras model implementing the BASNet architecture for semantic + segmentation. + + References: + - [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) + + Args: + backbone: `keras.Model`. The backbone network for the model that is + used as a feature extractor for BASNet prediction encoder. Currently + supported backbones are ResNet18 and ResNet34. Default backbone is + `keras_cv.models.ResNet34Backbone()`. + (Note: Do not specify `image_shape` within the backbone. + Please provide these while initializing the 'BASNet' model.) + num_classes: int, the number of classes for the segmentation model. + image_shape: optional shape tuple, defaults to (None, None, 3). + projection_filters: int, number of filters in the convolution layer + projecting low-level features from the `backbone`. + prediction_heads: (Optional) List of `keras.layers.Layer` defining + the prediction module head for the model. If not provided, a + default head is created with a Conv2D layer followed by resizing. + refinement_head: (Optional) a `keras.layers.Layer` defining the + refinement module head for the model. If not provided, a default + head is created with a Conv2D layer. + + Example: + ```python + + import keras_hub + + images = np.ones(shape=(1, 288, 288, 3)) + labels = np.zeros(shape=(1, 288, 288, 1)) + + # Note: Do not specify `image_shape` within the backbone. + backbone = keras_hub.models.ResNetBackbone.from_preset( + "resnet_18_imagenet", + load_weights=False + ) + model = keras_hub.models.BASNet( + backbone=backbone, + num_classes=1, + image_shape=[288, 288, 3], + ) + + # Evaluate model + output = model(images) + pred_labels = output[0] + + # Train model + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(from_logits=False), + metrics=["accuracy"], + ) + model.fit(images, labels, epochs=3) + ``` + """ # noqa: E501 + + backbone_cls = ResNetBackbone + preprocessor_cls = BASNetPreprocessor + + def __init__( + self, + backbone, + num_classes, + image_shape=(None, None, 3), + projection_filters=64, + prediction_heads=None, + refinement_head=None, + preprocessor=None, + **kwargs, + ): + if not isinstance(backbone, keras.layers.Layer) or not isinstance( + backbone, keras.Model + ): + raise ValueError( + "Argument `backbone` must be a `keras.layers.Layer` instance" + f" or `keras.Model`. Received instead" + f" backbone={backbone} (of type {type(backbone)})." + ) + + if tuple(backbone.image_shape) != (None, None, 3): + raise ValueError( + "Do not specify `image_shape` within the" + " 'BASNet' backbone. \nPlease provide `image_shape`" + " while initializing the 'BASNet' model." + ) + + # === Functional Model === + inputs = keras.layers.Input(shape=image_shape) + x = inputs + + if prediction_heads is None: + prediction_heads = [] + for size in (1, 2, 4, 8, 16, 32, 32): + head_layers = [ + keras.layers.Conv2D( + num_classes, kernel_size=(3, 3), padding="same" + ) + ] + if size != 1: + head_layers.append( + keras.layers.UpSampling2D( + size=size, interpolation="bilinear" + ) + ) + prediction_heads.append(keras.Sequential(head_layers)) + + if refinement_head is None: + refinement_head = keras.Sequential( + [ + keras.layers.Conv2D( + num_classes, kernel_size=(3, 3), padding="same" + ), + ] + ) + + # Prediction model. + predict_model = basnet_predict( + x, backbone, projection_filters, prediction_heads + ) + + # Refinement model. + refine_model = basnet_rrm( + predict_model, projection_filters, refinement_head + ) + + outputs = refine_model.outputs # Combine outputs. + outputs.extend(predict_model.outputs) + + outputs = [ + keras.layers.Activation("sigmoid", dtype="float32")(_) + for _ in outputs + ] + + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + + # === Config === + self.backbone = backbone + self.num_classes = num_classes + self.image_shape = image_shape + self.projection_filters = projection_filters + self.prediction_heads = prediction_heads + self.refinement_head = refinement_head + self.preprocessor = preprocessor + + def get_config(self): + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "image_shape": self.image_shape, + "projection_filters": self.projection_filters, + "prediction_heads": [ + keras.saving.serialize_keras_object(prediction_head) + for prediction_head in self.prediction_heads + ], + "refinement_head": keras.saving.serialize_keras_object( + self.refinement_head + ), + } + ) + return config + + @classmethod + def from_config(cls, config): + if "prediction_heads" in config and isinstance( + config["prediction_heads"], list + ): + for i in range(len(config["prediction_heads"])): + if isinstance(config["prediction_heads"][i], dict): + config["prediction_heads"][i] = keras.layers.deserialize( + config["prediction_heads"][i] + ) + + if "refinement_head" in config and isinstance( + config["refinement_head"], dict + ): + config["refinement_head"] = keras.layers.deserialize( + config["refinement_head"] + ) + return super().from_config(config) + + +def convolution_block(x_input, filters, dilation=1): + """ + Apply convolution + batch normalization + ReLU activation. + + Args: + x_input: Input keras tensor. + filters: int, number of output filters in the convolution. + dilation: int, dilation rate for the convolution operation. + Defaults to 1. + + Returns: + A tensor with convolution, batch normalization, and ReLU + activation applied. + """ + x = keras.layers.Conv2D( + filters, (3, 3), padding="same", dilation_rate=dilation + )(x_input) + x = keras.layers.BatchNormalization()(x) + return keras.layers.Activation("relu")(x) + + +def get_resnet_block(_resnet, block_num): + """ + Extract and return a specific ResNet block. + + Args: + _resnet: `keras.Model`. ResNet model instance. + block_num: int, block number to extract. + + Returns: + A Keras Model representing the specified ResNet block. + """ + + extractor_levels = ["P2", "P3", "P4", "P5"] + num_blocks = _resnet.stackwise_num_blocks + if block_num == 0: + x = _resnet.get_layer("pool1_pool").output + else: + x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]] + y = _resnet.get_layer( + f"stack{block_num}_block{num_blocks[block_num]-1}_add" + ).output + return keras.models.Model( + inputs=x, + outputs=y, + name=f"resnet_block{block_num + 1}", + ) + + +def basnet_predict(x_input, backbone, filters, segmentation_heads): + """ + BASNet Prediction Module. + + This module outputs a coarse label map by integrating heavy + encoder, bridge, and decoder blocks. + + Args: + x_input: Input keras tensor. + backbone: `keras.Model`. The backbone network used as a feature + extractor for BASNet prediction encoder. + filters: int, the number of filters. + segmentation_heads: List of `keras.layers.Layer`, A list of Keras + layers serving as the segmentation head for prediction module. + + + Returns: + A Keras Model that integrates the encoder, bridge, and decoder + blocks for coarse label map prediction. + """ + num_stages = 6 + + x = x_input + + # -------------Encoder-------------- + x = keras.layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x) + + encoder_blocks = [] + for i in range(num_stages): + if i < 4: # First four stages are adopted from ResNet backbone. + x = get_resnet_block(backbone, i)(x) + encoder_blocks.append(x) + else: # Last 2 stages consist of three basic resnet blocks. + x = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x) + for j in range(3): + x = resnet_basic_block( + x, + filters=x.shape[3], + conv_shortcut=False, + name=f"v1_basic_block_{i + 1}_{j + 1}", + ) + encoder_blocks.append(x) + + # -------------Bridge------------- + x = convolution_block(x, filters=filters * 8, dilation=2) + x = convolution_block(x, filters=filters * 8, dilation=2) + x = convolution_block(x, filters=filters * 8, dilation=2) + encoder_blocks.append(x) + + # -------------Decoder------------- + decoder_blocks = [] + for i in reversed(range(num_stages)): + if i != (num_stages - 1): # Except first, scale other decoder stages. + x = keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x) + + x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) + x = convolution_block(x, filters=filters * 8) + x = convolution_block(x, filters=filters * 8) + x = convolution_block(x, filters=filters * 8) + decoder_blocks.append(x) + + decoder_blocks.reverse() # Change order from last to first decoder stage. + decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder. + + # -------------Side Outputs-------------- + decoder_blocks = [ + segmentation_head(decoder_block) # Prediction segmentation head. + for segmentation_head, decoder_block in zip( + segmentation_heads, decoder_blocks + ) + ] + + return keras.models.Model(inputs=[x_input], outputs=decoder_blocks) + + +def basnet_rrm(base_model, filters, segmentation_head): + """ + BASNet Residual Refinement Module (RRM). + + This module outputs a fine label map by integrating light encoder, + bridge, and decoder blocks. + + Args: + base_model: Keras model used as the base or coarse label map. + filters: int, the number of filters. + segmentation_head: a `keras.layers.Layer`, A Keras layer serving + as the segmentation head for refinement module. + + Returns: + A Keras Model that constructs the Residual Refinement Module (RRM). + """ + num_stages = 4 + + x_input = base_model.output[0] + + # -------------Encoder-------------- + x = keras.layers.Conv2D(filters, kernel_size=(3, 3), padding="same")( + x_input + ) + + encoder_blocks = [] + for _ in range(num_stages): + x = convolution_block(x, filters=filters) + encoder_blocks.append(x) + x = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x) + + # -------------Bridge-------------- + x = convolution_block(x, filters=filters) + + # -------------Decoder-------------- + for i in reversed(range(num_stages)): + x = keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x) + x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) + x = convolution_block(x, filters=filters) + + x = segmentation_head(x) # Refinement segmentation head. + + # ------------- refined = coarse + residual + x = keras.layers.Add()([x_input, x]) # Add prediction + refinement output + + return keras.models.Model(inputs=base_model.input, outputs=[x]) diff --git a/keras_hub/src/models/basnet/basnet_image_converter.py b/keras_hub/src/models/basnet/basnet_image_converter.py new file mode 100644 index 0000000000..45f364fd4e --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone + + +@keras_hub_export("keras_hub.layers.BASNetImageConverter") +class BASNetImageConverter(ImageConverter): + backbone_cls = ResNetBackbone diff --git a/keras_hub/src/models/basnet/basnet_preprocessor.py b/keras_hub/src/models/basnet/basnet_preprocessor.py new file mode 100644 index 0000000000..7faa64cdd9 --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_preprocessor.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.basnet.basnet_image_converter import ( + BASNetImageConverter, +) +from keras_hub.src.models.image_segmenter_preprocessor import ( + ImageSegmenterPreprocessor, +) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone + + +@keras_hub_export("keras_hub.models.BASNetPreprocessor") +class BASNetPreprocessor(ImageSegmenterPreprocessor): + backbone_cls = ResNetBackbone + image_converter_cls = BASNetImageConverter diff --git a/keras_hub/src/models/basnet/basnet_presets.py b/keras_hub/src/models/basnet/basnet_presets.py new file mode 100644 index 0000000000..1ed728643b --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_presets.py @@ -0,0 +1,22 @@ +"""BASNet model preset configurations.""" + +basnet_presets = { + "basnet_resnet18": { + "metadata": { + "description": "BASNet with a ResNet18 v1 backbone.", + "params": 98780872, + "official_name": "BASNet", + "path": "basnet", + }, + "kaggle_handle": "", # TODO + }, + "basnet_resnet34": { + "metadata": { + "description": "BASNet with a ResNet34 v1 backbone.", + "params": 108896456, + "official_name": "BASNet", + "path": "basnet", + }, + "kaggle_handle": "", # TODO + }, +} diff --git a/keras_hub/src/models/basnet/basnet_test.py b/keras_hub/src/models/basnet/basnet_test.py new file mode 100644 index 0000000000..7e05f5c6d7 --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_test.py @@ -0,0 +1,41 @@ +import pytest +from keras import ops + +from keras_hub.src.models.basnet.basnet import BASNet +from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class BASNetTest(TestCase): + def setUp(self): + self.images = ops.ones((2, 64, 64, 3)) + self.backbone = ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], + stackwise_num_filters=[64, 128, 256, 512], + stackwise_num_blocks=[2, 2, 2, 2], + stackwise_num_strides=[1, 2, 2, 2], + block_type="basic_block", + ) + self.preprocessor = BASNetPreprocessor() + self.init_kwargs = { + "backbone": self.backbone, + "preprocessor": self.preprocessor, + "num_classes": 1, + } + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=BASNet, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) + + def test_end_to_end_model_predict(self): + model = BASNet(**self.init_kwargs) + outputs = model.predict(self.images) + self.assertAllEqual( + [output.shape for output in outputs], [(2, 64, 64, 1)] * 8 + ) From 9c3d8d855f03fccfe0f4875add2d1ca1a4ff709b Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Tue, 12 Nov 2024 14:28:42 -0800 Subject: [PATCH 02/15] run api_gen.sh --- keras_hub/api/layers/__init__.py | 3 +++ keras_hub/api/models/__init__.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 2a29cdb64e..80d42c1e43 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -34,6 +34,9 @@ from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion from keras_hub.src.layers.preprocessing.random_swap import RandomSwap from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_hub.src.models.basnet.basnet_image_converter import ( + BASNetImageConverter, +) from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( DeepLabV3ImageConverter, diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index dd85a97a45..53890d8105 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -29,6 +29,8 @@ BartSeq2SeqLMPreprocessor, ) from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer +from keras_hub.src.models.basnet.basnet import BASNet +from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.bert.bert_backbone import BertBackbone from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM from keras_hub.src.models.bert.bert_masked_lm_preprocessor import ( From e51a89140e140949e9e2f37cb6744a01eb42b718 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Thu, 5 Dec 2024 11:57:41 -0800 Subject: [PATCH 03/15] Addressing Matt minor changes except backbone class changes --- keras_hub/src/models/basnet/__init__.py | 4 +- keras_hub/src/models/basnet/basnet.py | 73 +++++++++++++++---- .../models/basnet/basnet_image_converter.py | 2 +- .../src/models/basnet/basnet_preprocessor.py | 2 +- keras_hub/src/models/basnet/basnet_presets.py | 2 +- keras_hub/src/models/basnet/basnet_test.py | 29 +++++++- 6 files changed, 89 insertions(+), 23 deletions(-) diff --git a/keras_hub/src/models/basnet/__init__.py b/keras_hub/src/models/basnet/__init__.py index 9a49eacce0..02d96b59cc 100644 --- a/keras_hub/src/models/basnet/__init__.py +++ b/keras_hub/src/models/basnet/__init__.py @@ -1,5 +1,5 @@ -from keras_hub.src.models.basnet.basnet import BASNet +from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter from keras_hub.src.models.basnet.basnet_presets import basnet_presets from keras_hub.src.utils.preset_utils import register_presets -register_presets(basnet_presets, BASNet) +register_presets(basnet_presets, BASNetImageSegmenter) \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet.py b/keras_hub/src/models/basnet/basnet.py index d7c44c17bc..9ce56771fe 100644 --- a/keras_hub/src/models/basnet/basnet.py +++ b/keras_hub/src/models/basnet/basnet.py @@ -9,12 +9,8 @@ ) -@keras_hub_export( - [ - "keras_hub.models.BASNet", - ] -) -class BASNet(ImageSegmenter): +@keras_hub_export("keras_hub.models.BASNetImageSegmenter") +class BASNetImageSegmenter(ImageSegmenter): """ A Keras model implementing the BASNet architecture for semantic segmentation. @@ -25,8 +21,7 @@ class BASNet(ImageSegmenter): Args: backbone: `keras.Model`. The backbone network for the model that is used as a feature extractor for BASNet prediction encoder. Currently - supported backbones are ResNet18 and ResNet34. Default backbone is - `keras_cv.models.ResNet34Backbone()`. + supported backbones are ResNet18 and ResNet34. (Note: Do not specify `image_shape` within the backbone. Please provide these while initializing the 'BASNet' model.) num_classes: int, the number of classes for the segmentation model. @@ -42,7 +37,6 @@ class BASNet(ImageSegmenter): Example: ```python - import keras_hub images = np.ones(shape=(1, 288, 288, 3)) @@ -53,7 +47,7 @@ class BASNet(ImageSegmenter): "resnet_18_imagenet", load_weights=False ) - model = keras_hub.models.BASNet( + model = keras_hub.models.BASNetImageSegmenter( backbone=backbone, num_classes=1, image_shape=[288, 288, 3], @@ -70,8 +64,8 @@ class BASNet(ImageSegmenter): metrics=["accuracy"], ) model.fit(images, labels, epochs=3) - ``` - """ # noqa: E501 + ``` + """ backbone_cls = ResNetBackbone preprocessor_cls = BASNetPreprocessor @@ -145,9 +139,13 @@ def __init__( outputs = refine_model.outputs # Combine outputs. outputs.extend(predict_model.outputs) + output_names = ["refine_out"] + [ + f"predict_out_{i}" for i in range(1, len(outputs)) + ] + outputs = [ - keras.layers.Activation("sigmoid", dtype="float32")(_) - for _ in outputs + keras.layers.Activation("sigmoid", name=output_name)(output) + for output, output_name in zip(outputs, output_names) ] super().__init__(inputs=inputs, outputs=outputs, **kwargs) @@ -161,6 +159,51 @@ def __init__( self.refinement_head = refinement_head self.preprocessor = preprocessor + def compile( + self, + optimizer="auto", + loss="auto", + metrics="auto", + **kwargs, + ): + """Configures the `BASNet` task for training. + + `BASNet` extends the default compilation signature + of `keras.Model.compile` with defaults for `optimizer` and `loss`. To + override these defaults, pass any value to these arguments during + compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default + optimizer for `BASNet`. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` + values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, in which case the default loss + computation of `BASNet` will be applied. + See `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.Accuracy` will be applied to track the + accuracy of the model during training. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if loss == "auto": + loss = keras.losses.BinaryCrossentropy() + if metrics == "auto": + metrics = {"refine_out": keras.metrics.Accuracy()} + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) + def get_config(self): config = super().get_config() config.update( @@ -368,4 +411,4 @@ def basnet_rrm(base_model, filters, segmentation_head): # ------------- refined = coarse + residual x = keras.layers.Add()([x_input, x]) # Add prediction + refinement output - return keras.models.Model(inputs=base_model.input, outputs=[x]) + return keras.models.Model(inputs=base_model.input, outputs=[x]) \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_image_converter.py b/keras_hub/src/models/basnet/basnet_image_converter.py index 45f364fd4e..467367f75f 100644 --- a/keras_hub/src/models/basnet/basnet_image_converter.py +++ b/keras_hub/src/models/basnet/basnet_image_converter.py @@ -5,4 +5,4 @@ @keras_hub_export("keras_hub.layers.BASNetImageConverter") class BASNetImageConverter(ImageConverter): - backbone_cls = ResNetBackbone + backbone_cls = ResNetBackbone \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_preprocessor.py b/keras_hub/src/models/basnet/basnet_preprocessor.py index 7faa64cdd9..e93fba5ed9 100644 --- a/keras_hub/src/models/basnet/basnet_preprocessor.py +++ b/keras_hub/src/models/basnet/basnet_preprocessor.py @@ -11,4 +11,4 @@ @keras_hub_export("keras_hub.models.BASNetPreprocessor") class BASNetPreprocessor(ImageSegmenterPreprocessor): backbone_cls = ResNetBackbone - image_converter_cls = BASNetImageConverter + image_converter_cls = BASNetImageConverter \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_presets.py b/keras_hub/src/models/basnet/basnet_presets.py index 1ed728643b..7a4b0a27d3 100644 --- a/keras_hub/src/models/basnet/basnet_presets.py +++ b/keras_hub/src/models/basnet/basnet_presets.py @@ -19,4 +19,4 @@ }, "kaggle_handle": "", # TODO }, -} +} \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_test.py b/keras_hub/src/models/basnet/basnet_test.py index 7e05f5c6d7..ca5baf5685 100644 --- a/keras_hub/src/models/basnet/basnet_test.py +++ b/keras_hub/src/models/basnet/basnet_test.py @@ -1,7 +1,7 @@ import pytest from keras import ops -from keras_hub.src.models.basnet.basnet import BASNet +from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.tests.test_case import TestCase @@ -10,6 +10,9 @@ class BASNetTest(TestCase): def setUp(self): self.images = ops.ones((2, 64, 64, 3)) + self.labels = ops.concatenate( + (ops.zeros((2, 32, 64, 1)), ops.ones((2, 32, 64, 1))), axis=1 + ) self.backbone = ResNetBackbone( input_conv_filters=[64], input_conv_kernel_sizes=[7], @@ -24,18 +27,38 @@ def setUp(self): "preprocessor": self.preprocessor, "num_classes": 1, } + self.train_data = (self.images, self.labels) + + def test_basics(self): + self.run_task_test( + cls=BASNetImageSegmenter, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=[(2, 64, 64, 1)] * 8, + ) @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( - cls=BASNet, + cls=BASNetImageSegmenter, init_kwargs=self.init_kwargs, input_data=self.images, ) def test_end_to_end_model_predict(self): - model = BASNet(**self.init_kwargs) + model = BASNetImageSegmenter(**self.init_kwargs) outputs = model.predict(self.images) self.assertAllEqual( [output.shape for output in outputs], [(2, 64, 64, 1)] * 8 ) + + @pytest.mark.skip(reason="disabled until preset's been uploaded to Kaggle") + @pytest.mark.extra_large + def test_all_presets(self): + for preset in BASNetImageSegmenter.presets: + self.run_preset_test( + cls=BASNetImageSegmenter, + preset=preset, + input_data=self.images, + expected_output_shape=[(2, 64, 64, 1)] * 8, + ) \ No newline at end of file From 4d5dc6ccff4543f704bfecffb06261eedfda9c1e Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Thu, 5 Dec 2024 13:15:22 -0800 Subject: [PATCH 04/15] Fixing lint errors and run api_gen --- keras_hub/api/models/__init__.py | 2 +- keras_hub/src/models/basnet/__init__.py | 2 +- keras_hub/src/models/basnet/basnet.py | 2 +- keras_hub/src/models/basnet/basnet_image_converter.py | 2 +- keras_hub/src/models/basnet/basnet_preprocessor.py | 2 +- keras_hub/src/models/basnet/basnet_presets.py | 2 +- keras_hub/src/models/basnet/basnet_test.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 9d28a4df63..44478a0db4 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -29,7 +29,7 @@ BartSeq2SeqLMPreprocessor, ) from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer -from keras_hub.src.models.basnet.basnet import BASNet +from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.bert.bert_backbone import BertBackbone from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM diff --git a/keras_hub/src/models/basnet/__init__.py b/keras_hub/src/models/basnet/__init__.py index 02d96b59cc..340f40d09a 100644 --- a/keras_hub/src/models/basnet/__init__.py +++ b/keras_hub/src/models/basnet/__init__.py @@ -2,4 +2,4 @@ from keras_hub.src.models.basnet.basnet_presets import basnet_presets from keras_hub.src.utils.preset_utils import register_presets -register_presets(basnet_presets, BASNetImageSegmenter) \ No newline at end of file +register_presets(basnet_presets, BASNetImageSegmenter) diff --git a/keras_hub/src/models/basnet/basnet.py b/keras_hub/src/models/basnet/basnet.py index 9ce56771fe..420eee6185 100644 --- a/keras_hub/src/models/basnet/basnet.py +++ b/keras_hub/src/models/basnet/basnet.py @@ -411,4 +411,4 @@ def basnet_rrm(base_model, filters, segmentation_head): # ------------- refined = coarse + residual x = keras.layers.Add()([x_input, x]) # Add prediction + refinement output - return keras.models.Model(inputs=base_model.input, outputs=[x]) \ No newline at end of file + return keras.models.Model(inputs=base_model.input, outputs=[x]) diff --git a/keras_hub/src/models/basnet/basnet_image_converter.py b/keras_hub/src/models/basnet/basnet_image_converter.py index 467367f75f..45f364fd4e 100644 --- a/keras_hub/src/models/basnet/basnet_image_converter.py +++ b/keras_hub/src/models/basnet/basnet_image_converter.py @@ -5,4 +5,4 @@ @keras_hub_export("keras_hub.layers.BASNetImageConverter") class BASNetImageConverter(ImageConverter): - backbone_cls = ResNetBackbone \ No newline at end of file + backbone_cls = ResNetBackbone diff --git a/keras_hub/src/models/basnet/basnet_preprocessor.py b/keras_hub/src/models/basnet/basnet_preprocessor.py index e93fba5ed9..7faa64cdd9 100644 --- a/keras_hub/src/models/basnet/basnet_preprocessor.py +++ b/keras_hub/src/models/basnet/basnet_preprocessor.py @@ -11,4 +11,4 @@ @keras_hub_export("keras_hub.models.BASNetPreprocessor") class BASNetPreprocessor(ImageSegmenterPreprocessor): backbone_cls = ResNetBackbone - image_converter_cls = BASNetImageConverter \ No newline at end of file + image_converter_cls = BASNetImageConverter diff --git a/keras_hub/src/models/basnet/basnet_presets.py b/keras_hub/src/models/basnet/basnet_presets.py index 7a4b0a27d3..1ed728643b 100644 --- a/keras_hub/src/models/basnet/basnet_presets.py +++ b/keras_hub/src/models/basnet/basnet_presets.py @@ -19,4 +19,4 @@ }, "kaggle_handle": "", # TODO }, -} \ No newline at end of file +} diff --git a/keras_hub/src/models/basnet/basnet_test.py b/keras_hub/src/models/basnet/basnet_test.py index ca5baf5685..fa7c1c9b64 100644 --- a/keras_hub/src/models/basnet/basnet_test.py +++ b/keras_hub/src/models/basnet/basnet_test.py @@ -61,4 +61,4 @@ def test_all_presets(self): preset=preset, input_data=self.images, expected_output_shape=[(2, 64, 64, 1)] * 8, - ) \ No newline at end of file + ) From f018bbfdf205509ccc04ad347c81b6c41345be8e Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Fri, 13 Dec 2024 14:24:28 -0800 Subject: [PATCH 05/15] separate backbone, add compute_loss --- keras_hub/api/models/__init__.py | 354 +++++------------ keras_hub/src/models/basnet/__init__.py | 4 +- keras_hub/src/models/basnet/basnet.py | 348 ++-------------- .../src/models/basnet/basnet_backbone.py | 370 ++++++++++++++++++ .../src/models/basnet/basnet_backbone_test.py | 36 ++ .../models/basnet/basnet_image_converter.py | 4 +- .../src/models/basnet/basnet_preprocessor.py | 6 +- keras_hub/src/models/basnet/basnet_presets.py | 2 +- keras_hub/src/models/basnet/basnet_test.py | 20 +- 9 files changed, 546 insertions(+), 598 deletions(-) create mode 100644 keras_hub/src/models/basnet/basnet_backbone.py create mode 100644 keras_hub/src/models/basnet/basnet_backbone_test.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 9748bc40b6..660263dff3 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -4,54 +4,34 @@ since your modifications would be overwritten. """ + from keras_hub.src.models.albert.albert_backbone import AlbertBackbone from keras_hub.src.models.albert.albert_masked_lm import AlbertMaskedLM -from keras_hub.src.models.albert.albert_masked_lm_preprocessor import ( - AlbertMaskedLMPreprocessor, -) -from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier, -) -from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier as AlbertClassifier, -) -from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor, -) -from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor as AlbertPreprocessor, -) +from keras_hub.src.models.albert.albert_masked_lm_preprocessor import AlbertMaskedLMPreprocessor +from keras_hub.src.models.albert.albert_text_classifier import AlbertTextClassifier +from keras_hub.src.models.albert.albert_text_classifier import AlbertTextClassifier as AlbertClassifier +from keras_hub.src.models.albert.albert_text_classifier_preprocessor import AlbertTextClassifierPreprocessor +from keras_hub.src.models.albert.albert_text_classifier_preprocessor import AlbertTextClassifierPreprocessor as AlbertPreprocessor from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.bart.bart_backbone import BartBackbone from keras_hub.src.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM -from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import ( - BartSeq2SeqLMPreprocessor, -) +from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import BartSeq2SeqLMPreprocessor from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.bert.bert_backbone import BertBackbone from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM -from keras_hub.src.models.bert.bert_masked_lm_preprocessor import ( - BertMaskedLMPreprocessor, -) +from keras_hub.src.models.bert.bert_masked_lm_preprocessor import BertMaskedLMPreprocessor from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier -from keras_hub.src.models.bert.bert_text_classifier import ( - BertTextClassifier as BertClassifier, -) -from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor, -) -from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor as BertPreprocessor, -) +from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier as BertClassifier +from keras_hub.src.models.bert.bert_text_classifier_preprocessor import BertTextClassifierPreprocessor +from keras_hub.src.models.bert.bert_text_classifier_preprocessor import BertTextClassifierPreprocessor as BertPreprocessor from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone from keras_hub.src.models.bloom.bloom_causal_lm import BloomCausalLM -from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import ( - BloomCausalLMPreprocessor, -) +from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import BloomCausalLMPreprocessor from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_hub.src.models.causal_lm import CausalLM from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor @@ -60,309 +40,155 @@ from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder -from keras_hub.src.models.csp_darknet.csp_darknet_backbone import ( - CSPDarkNetBackbone, -) -from keras_hub.src.models.csp_darknet.csp_darknet_image_classifier import ( - CSPDarkNetImageClassifier, -) -from keras_hub.src.models.deberta_v3.deberta_v3_backbone import ( - DebertaV3Backbone, -) -from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import ( - DebertaV3MaskedLM, -) -from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import ( - DebertaV3MaskedLMPreprocessor, -) -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier, -) -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier as DebertaV3Classifier, -) -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor, -) -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, -) -from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer, -) -from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( - DeepLabV3Backbone, -) -from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( - DeepLabV3ImageSegmenterPreprocessor, -) -from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( - DeepLabV3ImageSegmenter, -) +from keras_hub.src.models.csp_darknet.csp_darknet_backbone import CSPDarkNetBackbone +from keras_hub.src.models.csp_darknet.csp_darknet_image_classifier import CSPDarkNetImageClassifier +from keras_hub.src.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone +from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import DebertaV3MaskedLM +from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import DebertaV3MaskedLMPreprocessor +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import DebertaV3TextClassifier +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import DebertaV3TextClassifier as DebertaV3Classifier +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import DebertaV3TextClassifierPreprocessor +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor +from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import DebertaV3Tokenizer +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import DeepLabV3Backbone +from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import DeepLabV3ImageSegmenterPreprocessor +from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import DeepLabV3ImageSegmenter from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone -from keras_hub.src.models.densenet.densenet_image_classifier import ( - DenseNetImageClassifier, -) -from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( - DenseNetImageClassifierPreprocessor, -) -from keras_hub.src.models.distil_bert.distil_bert_backbone import ( - DistilBertBackbone, -) -from keras_hub.src.models.distil_bert.distil_bert_masked_lm import ( - DistilBertMaskedLM, -) -from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import ( - DistilBertMaskedLMPreprocessor, -) -from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier, -) -from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier as DistilBertClassifier, -) -from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor, -) -from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, -) -from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer, -) -from keras_hub.src.models.efficientnet.efficientnet_backbone import ( - EfficientNetBackbone, -) -from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( - EfficientNetImageClassifier, -) -from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( - EfficientNetImageClassifierPreprocessor, -) +from keras_hub.src.models.densenet.densenet_image_classifier import DenseNetImageClassifier +from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import DenseNetImageClassifierPreprocessor +from keras_hub.src.models.distil_bert.distil_bert_backbone import DistilBertBackbone +from keras_hub.src.models.distil_bert.distil_bert_masked_lm import DistilBertMaskedLM +from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import DistilBertMaskedLMPreprocessor +from keras_hub.src.models.distil_bert.distil_bert_text_classifier import DistilBertTextClassifier +from keras_hub.src.models.distil_bert.distil_bert_text_classifier import DistilBertTextClassifier as DistilBertClassifier +from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import DistilBertTextClassifierPreprocessor +from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import DistilBertTextClassifierPreprocessor as DistilBertPreprocessor +from keras_hub.src.models.distil_bert.distil_bert_tokenizer import DistilBertTokenizer +from keras_hub.src.models.efficientnet.efficientnet_backbone import EfficientNetBackbone +from keras_hub.src.models.efficientnet.efficientnet_image_classifier import EfficientNetImageClassifier +from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import EfficientNetImageClassifierPreprocessor from keras_hub.src.models.electra.electra_backbone import ElectraBackbone from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone from keras_hub.src.models.f_net.f_net_masked_lm import FNetMaskedLM -from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import ( - FNetMaskedLMPreprocessor, -) +from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import FNetMaskedLMPreprocessor from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier -from keras_hub.src.models.f_net.f_net_text_classifier import ( - FNetTextClassifier as FNetClassifier, -) -from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor, -) -from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor as FNetPreprocessor, -) +from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier as FNetClassifier +from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import FNetTextClassifierPreprocessor +from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import FNetTextClassifierPreprocessor as FNetPreprocessor from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM -from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import ( - FalconCausalLMPreprocessor, -) +from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import FalconCausalLMPreprocessor from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_hub.src.models.flux.flux_model import FluxBackbone from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage -from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( - FluxTextToImagePreprocessor, -) +from keras_hub.src.models.flux.flux_text_to_image_preprocessor import FluxTextToImagePreprocessor from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM -from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( - GemmaCausalLMPreprocessor, -) +from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import GemmaCausalLMPreprocessor from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM -from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import ( - GPT2CausalLMPreprocessor, -) +from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import GPT2CausalLMPreprocessor from keras_hub.src.models.gpt2.gpt2_preprocessor import GPT2Preprocessor from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import GPTNeoXCausalLM -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( - GPTNeoXCausalLMPreprocessor, -) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import GPTNeoXCausalLMPreprocessor from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_hub.src.models.image_classifier import ImageClassifier -from keras_hub.src.models.image_classifier_preprocessor import ( - ImageClassifierPreprocessor, -) +from keras_hub.src.models.image_classifier_preprocessor import ImageClassifierPreprocessor from keras_hub.src.models.image_object_detector import ImageObjectDetector -from keras_hub.src.models.image_object_detector_preprocessor import ( - ImageObjectDetectorPreprocessor, -) +from keras_hub.src.models.image_object_detector_preprocessor import ImageObjectDetectorPreprocessor from keras_hub.src.models.image_segmenter import ImageSegmenter -from keras_hub.src.models.image_segmenter_preprocessor import ( - ImageSegmenterPreprocessor, -) +from keras_hub.src.models.image_segmenter_preprocessor import ImageSegmenterPreprocessor from keras_hub.src.models.image_to_image import ImageToImage from keras_hub.src.models.inpaint import Inpaint from keras_hub.src.models.llama.llama_backbone import LlamaBackbone from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM -from keras_hub.src.models.llama.llama_causal_lm_preprocessor import ( - LlamaCausalLMPreprocessor, -) +from keras_hub.src.models.llama.llama_causal_lm_preprocessor import LlamaCausalLMPreprocessor from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM -from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( - Llama3CausalLMPreprocessor, -) +from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import Llama3CausalLMPreprocessor from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer from keras_hub.src.models.masked_lm import MaskedLM from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone from keras_hub.src.models.mistral.mistral_causal_lm import MistralCausalLM -from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import ( - MistralCausalLMPreprocessor, -) +from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import MistralCausalLMPreprocessor from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer from keras_hub.src.models.mit.mit_backbone import MiTBackbone from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier -from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( - MiTImageClassifierPreprocessor, -) +from keras_hub.src.models.mit.mit_image_classifier_preprocessor import MiTImageClassifierPreprocessor from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone -from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( - MobileNetImageClassifier, -) +from keras_hub.src.models.mobilenet.mobilenet_image_classifier import MobileNetImageClassifier from keras_hub.src.models.opt.opt_backbone import OPTBackbone from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM -from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( - OPTCausalLMPreprocessor, -) +from keras_hub.src.models.opt.opt_causal_lm_preprocessor import OPTCausalLMPreprocessor from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer -from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( - PaliGemmaBackbone, -) -from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import ( - PaliGemmaCausalLM, -) -from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( - PaliGemmaCausalLMPreprocessor, -) -from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer, -) +from keras_hub.src.models.pali_gemma.pali_gemma_backbone import PaliGemmaBackbone +from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import PaliGemmaCausalLM +from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import PaliGemmaCausalLMPreprocessor +from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import PaliGemmaTokenizer from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM -from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import ( - Phi3CausalLMPreprocessor, -) +from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import Phi3CausalLMPreprocessor from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_hub.src.models.preprocessor import Preprocessor from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone -from keras_hub.src.models.resnet.resnet_image_classifier import ( - ResNetImageClassifier, -) -from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ( - ResNetImageClassifierPreprocessor, -) +from keras_hub.src.models.resnet.resnet_image_classifier import ResNetImageClassifier +from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ResNetImageClassifierPreprocessor from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone -from keras_hub.src.models.retinanet.retinanet_object_detector import ( - RetinaNetObjectDetector, -) -from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( - RetinaNetObjectDetectorPreprocessor, -) +from keras_hub.src.models.retinanet.retinanet_object_detector import RetinaNetObjectDetector +from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import RetinaNetObjectDetectorPreprocessor from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM -from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import ( - RobertaMaskedLMPreprocessor, -) -from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier, -) -from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier as RobertaClassifier, -) -from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor, -) -from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor as RobertaPreprocessor, -) +from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import RobertaMaskedLMPreprocessor +from keras_hub.src.models.roberta.roberta_text_classifier import RobertaTextClassifier +from keras_hub.src.models.roberta.roberta_text_classifier import RobertaTextClassifier as RobertaClassifier +from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import RobertaTextClassifierPreprocessor +from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import RobertaTextClassifierPreprocessor as RobertaPreprocessor from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_hub.src.models.sam.sam_backbone import SAMBackbone from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter -from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( - SAMImageSegmenterPreprocessor, -) +from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import SAMImageSegmenterPreprocessor from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone -from keras_hub.src.models.segformer.segformer_image_segmenter import ( - SegFormerImageSegmenter, -) -from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( - SegFormerImageSegmenterPreprocessor, -) +from keras_hub.src.models.segformer.segformer_image_segmenter import SegFormerImageSegmenter +from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import SegFormerImageSegmenterPreprocessor from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( - StableDiffusion3Backbone, -) -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( - StableDiffusion3ImageToImage, -) -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( - StableDiffusion3Inpaint, -) -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( - StableDiffusion3TextToImage, -) -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( - StableDiffusion3TextToImagePreprocessor, -) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import StableDiffusion3Backbone +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import StableDiffusion3ImageToImage +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import StableDiffusion3Inpaint +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import StableDiffusion3TextToImage +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import StableDiffusion3TextToImagePreprocessor from keras_hub.src.models.t5.t5_backbone import T5Backbone from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer from keras_hub.src.models.task import Task from keras_hub.src.models.text_classifier import TextClassifier from keras_hub.src.models.text_classifier import TextClassifier as Classifier -from keras_hub.src.models.text_classifier_preprocessor import ( - TextClassifierPreprocessor, -) +from keras_hub.src.models.text_classifier_preprocessor import TextClassifierPreprocessor from keras_hub.src.models.text_to_image import TextToImage from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier -from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( - VGGImageClassifierPreprocessor, -) +from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import VGGImageClassifierPreprocessor from keras_hub.src.models.vit.vit_backbone import ViTBackbone from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier -from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( - ViTImageClassifierPreprocessor, -) +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ViTImageClassifierPreprocessor from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer -from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import ( - XLMRobertaBackbone, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import ( - XLMRobertaMaskedLM, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( - XLMRobertaMaskedLMPreprocessor, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier as XLMRobertaClassifier, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer, -) +from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone +from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import XLMRobertaMaskedLM +from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import XLMRobertaMaskedLMPreprocessor +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import XLMRobertaTextClassifier +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import XLMRobertaTextClassifier as XLMRobertaClassifier +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import XLMRobertaTextClassifierPreprocessor +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor +from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import XLMRobertaTokenizer from keras_hub.src.models.xlnet.xlnet_backbone import XLNetBackbone from keras_hub.src.tokenizers.tokenizer import Tokenizer diff --git a/keras_hub/src/models/basnet/__init__.py b/keras_hub/src/models/basnet/__init__.py index 340f40d09a..8df584cab4 100644 --- a/keras_hub/src/models/basnet/__init__.py +++ b/keras_hub/src/models/basnet/__init__.py @@ -1,5 +1,5 @@ -from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_presets import basnet_presets from keras_hub.src.utils.preset_utils import register_presets -register_presets(basnet_presets, BASNetImageSegmenter) +register_presets(basnet_presets, BASNetBackbone) \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet.py b/keras_hub/src/models/basnet/basnet.py index 420eee6185..8fc304d073 100644 --- a/keras_hub/src/models/basnet/basnet.py +++ b/keras_hub/src/models/basnet/basnet.py @@ -1,12 +1,9 @@ import keras from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.image_segmenter import ImageSegmenter -from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone -from keras_hub.src.models.resnet.resnet_backbone import ( - apply_basic_block as resnet_basic_block, -) @keras_hub_export("keras_hub.models.BASNetImageSegmenter") @@ -16,24 +13,13 @@ class BASNetImageSegmenter(ImageSegmenter): segmentation. References: - - [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) + [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) Args: - backbone: `keras.Model`. The backbone network for the model that is - used as a feature extractor for BASNet prediction encoder. Currently - supported backbones are ResNet18 and ResNet34. - (Note: Do not specify `image_shape` within the backbone. - Please provide these while initializing the 'BASNet' model.) - num_classes: int, the number of classes for the segmentation model. - image_shape: optional shape tuple, defaults to (None, None, 3). - projection_filters: int, number of filters in the convolution layer - projecting low-level features from the `backbone`. - prediction_heads: (Optional) List of `keras.layers.Layer` defining - the prediction module head for the model. If not provided, a - default head is created with a Conv2D layer followed by resizing. - refinement_head: (Optional) a `keras.layers.Layer` defining the - refinement module head for the model. If not provided, a default - head is created with a Conv2D layer. + backbone: A `keras_hub.models.BASNetBackbone` instance. + preprocessor: `None`, a `keras_hub.models.Preprocessor` instance, + a `keras.Layer` instance, or a callable. If `None` no preprocessing + will be applied to the inputs. Example: ```python @@ -42,22 +28,21 @@ class BASNetImageSegmenter(ImageSegmenter): images = np.ones(shape=(1, 288, 288, 3)) labels = np.zeros(shape=(1, 288, 288, 1)) - # Note: Do not specify `image_shape` within the backbone. - backbone = keras_hub.models.ResNetBackbone.from_preset( + image_encoder = keras_hub.models.ResNetBackbone.from_preset( "resnet_18_imagenet", load_weights=False ) - model = keras_hub.models.BASNetImageSegmenter( - backbone=backbone, + backbone = keras_hub.models.BASNetBackbone( + image_encoder, num_classes=1, - image_shape=[288, 288, 3], + image_shape=[288, 288, 3] ) + model = keras_hub.models.BASNetImageSegmenter(backbone) - # Evaluate model - output = model(images) - pred_labels = output[0] + # Evaluate the model + pred_labels = model(images) - # Train model + # Train the model model.compile( optimizer="adam", loss=keras.losses.BinaryCrossentropy(from_logits=False), @@ -67,98 +52,36 @@ class BASNetImageSegmenter(ImageSegmenter): ``` """ - backbone_cls = ResNetBackbone + backbone_cls = BASNetBackbone preprocessor_cls = BASNetPreprocessor def __init__( self, backbone, - num_classes, - image_shape=(None, None, 3), - projection_filters=64, - prediction_heads=None, - refinement_head=None, preprocessor=None, **kwargs, ): - if not isinstance(backbone, keras.layers.Layer) or not isinstance( - backbone, keras.Model - ): - raise ValueError( - "Argument `backbone` must be a `keras.layers.Layer` instance" - f" or `keras.Model`. Received instead" - f" backbone={backbone} (of type {type(backbone)})." - ) - - if tuple(backbone.image_shape) != (None, None, 3): - raise ValueError( - "Do not specify `image_shape` within the" - " 'BASNet' backbone. \nPlease provide `image_shape`" - " while initializing the 'BASNet' model." - ) # === Functional Model === - inputs = keras.layers.Input(shape=image_shape) - x = inputs - - if prediction_heads is None: - prediction_heads = [] - for size in (1, 2, 4, 8, 16, 32, 32): - head_layers = [ - keras.layers.Conv2D( - num_classes, kernel_size=(3, 3), padding="same" - ) - ] - if size != 1: - head_layers.append( - keras.layers.UpSampling2D( - size=size, interpolation="bilinear" - ) - ) - prediction_heads.append(keras.Sequential(head_layers)) - - if refinement_head is None: - refinement_head = keras.Sequential( - [ - keras.layers.Conv2D( - num_classes, kernel_size=(3, 3), padding="same" - ), - ] - ) - - # Prediction model. - predict_model = basnet_predict( - x, backbone, projection_filters, prediction_heads - ) - - # Refinement model. - refine_model = basnet_rrm( - predict_model, projection_filters, refinement_head - ) - - outputs = refine_model.outputs # Combine outputs. - outputs.extend(predict_model.outputs) - - output_names = ["refine_out"] + [ - f"predict_out_{i}" for i in range(1, len(outputs)) - ] - - outputs = [ - keras.layers.Activation("sigmoid", name=output_name)(output) - for output, output_name in zip(outputs, output_names) - ] - - super().__init__(inputs=inputs, outputs=outputs, **kwargs) + x = backbone.input + outputs = backbone(x) + # only return the refinement module's output as final prediction + outputs = outputs["refine_out"] + super().__init__(inputs=x, outputs=outputs, **kwargs) # === Config === self.backbone = backbone - self.num_classes = num_classes - self.image_shape = image_shape - self.projection_filters = projection_filters - self.prediction_heads = prediction_heads - self.refinement_head = refinement_head self.preprocessor = preprocessor + def compute_loss(self, x, y, y_pred, *args, **kwargs): + # train BASNet's prediction and refinement module outputs against the same ground truth data + outputs = self.backbone(x) + losses = [] + for output in outputs.values(): + losses.append(super().compute_loss(x, y, output, *args, **kwargs)) + loss = keras.ops.sum(losses, axis=0) + return loss + def compile( self, optimizer="auto", @@ -196,219 +119,10 @@ def compile( if loss == "auto": loss = keras.losses.BinaryCrossentropy() if metrics == "auto": - metrics = {"refine_out": keras.metrics.Accuracy()} + metrics = [keras.metrics.Accuracy()] super().compile( optimizer=optimizer, loss=loss, metrics=metrics, **kwargs, - ) - - def get_config(self): - config = super().get_config() - config.update( - { - "num_classes": self.num_classes, - "image_shape": self.image_shape, - "projection_filters": self.projection_filters, - "prediction_heads": [ - keras.saving.serialize_keras_object(prediction_head) - for prediction_head in self.prediction_heads - ], - "refinement_head": keras.saving.serialize_keras_object( - self.refinement_head - ), - } - ) - return config - - @classmethod - def from_config(cls, config): - if "prediction_heads" in config and isinstance( - config["prediction_heads"], list - ): - for i in range(len(config["prediction_heads"])): - if isinstance(config["prediction_heads"][i], dict): - config["prediction_heads"][i] = keras.layers.deserialize( - config["prediction_heads"][i] - ) - - if "refinement_head" in config and isinstance( - config["refinement_head"], dict - ): - config["refinement_head"] = keras.layers.deserialize( - config["refinement_head"] - ) - return super().from_config(config) - - -def convolution_block(x_input, filters, dilation=1): - """ - Apply convolution + batch normalization + ReLU activation. - - Args: - x_input: Input keras tensor. - filters: int, number of output filters in the convolution. - dilation: int, dilation rate for the convolution operation. - Defaults to 1. - - Returns: - A tensor with convolution, batch normalization, and ReLU - activation applied. - """ - x = keras.layers.Conv2D( - filters, (3, 3), padding="same", dilation_rate=dilation - )(x_input) - x = keras.layers.BatchNormalization()(x) - return keras.layers.Activation("relu")(x) - - -def get_resnet_block(_resnet, block_num): - """ - Extract and return a specific ResNet block. - - Args: - _resnet: `keras.Model`. ResNet model instance. - block_num: int, block number to extract. - - Returns: - A Keras Model representing the specified ResNet block. - """ - - extractor_levels = ["P2", "P3", "P4", "P5"] - num_blocks = _resnet.stackwise_num_blocks - if block_num == 0: - x = _resnet.get_layer("pool1_pool").output - else: - x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]] - y = _resnet.get_layer( - f"stack{block_num}_block{num_blocks[block_num]-1}_add" - ).output - return keras.models.Model( - inputs=x, - outputs=y, - name=f"resnet_block{block_num + 1}", - ) - - -def basnet_predict(x_input, backbone, filters, segmentation_heads): - """ - BASNet Prediction Module. - - This module outputs a coarse label map by integrating heavy - encoder, bridge, and decoder blocks. - - Args: - x_input: Input keras tensor. - backbone: `keras.Model`. The backbone network used as a feature - extractor for BASNet prediction encoder. - filters: int, the number of filters. - segmentation_heads: List of `keras.layers.Layer`, A list of Keras - layers serving as the segmentation head for prediction module. - - - Returns: - A Keras Model that integrates the encoder, bridge, and decoder - blocks for coarse label map prediction. - """ - num_stages = 6 - - x = x_input - - # -------------Encoder-------------- - x = keras.layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x) - - encoder_blocks = [] - for i in range(num_stages): - if i < 4: # First four stages are adopted from ResNet backbone. - x = get_resnet_block(backbone, i)(x) - encoder_blocks.append(x) - else: # Last 2 stages consist of three basic resnet blocks. - x = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x) - for j in range(3): - x = resnet_basic_block( - x, - filters=x.shape[3], - conv_shortcut=False, - name=f"v1_basic_block_{i + 1}_{j + 1}", - ) - encoder_blocks.append(x) - - # -------------Bridge------------- - x = convolution_block(x, filters=filters * 8, dilation=2) - x = convolution_block(x, filters=filters * 8, dilation=2) - x = convolution_block(x, filters=filters * 8, dilation=2) - encoder_blocks.append(x) - - # -------------Decoder------------- - decoder_blocks = [] - for i in reversed(range(num_stages)): - if i != (num_stages - 1): # Except first, scale other decoder stages. - x = keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x) - - x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) - x = convolution_block(x, filters=filters * 8) - x = convolution_block(x, filters=filters * 8) - x = convolution_block(x, filters=filters * 8) - decoder_blocks.append(x) - - decoder_blocks.reverse() # Change order from last to first decoder stage. - decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder. - - # -------------Side Outputs-------------- - decoder_blocks = [ - segmentation_head(decoder_block) # Prediction segmentation head. - for segmentation_head, decoder_block in zip( - segmentation_heads, decoder_blocks - ) - ] - - return keras.models.Model(inputs=[x_input], outputs=decoder_blocks) - - -def basnet_rrm(base_model, filters, segmentation_head): - """ - BASNet Residual Refinement Module (RRM). - - This module outputs a fine label map by integrating light encoder, - bridge, and decoder blocks. - - Args: - base_model: Keras model used as the base or coarse label map. - filters: int, the number of filters. - segmentation_head: a `keras.layers.Layer`, A Keras layer serving - as the segmentation head for refinement module. - - Returns: - A Keras Model that constructs the Residual Refinement Module (RRM). - """ - num_stages = 4 - - x_input = base_model.output[0] - - # -------------Encoder-------------- - x = keras.layers.Conv2D(filters, kernel_size=(3, 3), padding="same")( - x_input - ) - - encoder_blocks = [] - for _ in range(num_stages): - x = convolution_block(x, filters=filters) - encoder_blocks.append(x) - x = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x) - - # -------------Bridge-------------- - x = convolution_block(x, filters=filters) - - # -------------Decoder-------------- - for i in reversed(range(num_stages)): - x = keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x) - x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) - x = convolution_block(x, filters=filters) - - x = segmentation_head(x) # Refinement segmentation head. - - # ------------- refined = coarse + residual - x = keras.layers.Add()([x_input, x]) # Add prediction + refinement output - - return keras.models.Model(inputs=base_model.input, outputs=[x]) + ) \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_backbone.py b/keras_hub/src/models/basnet/basnet_backbone.py new file mode 100644 index 0000000000..29599d1e30 --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_backbone.py @@ -0,0 +1,370 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.resnet.resnet_backbone import ( + apply_basic_block as resnet_basic_block, +) + + +@keras_hub_export("keras_hub.models.BASNetBackbone") +class BASNetBackbone(Backbone): + """ + A Keras model implementing the BASNet architecture for semantic + segmentation. + + References: + - [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) + + Args: + image_encoder: A `keras_hub.models.ResNetBackbone` instance. The + backbone network for the model that is used as a feature extractor + for BASNet prediction encoder. Currently supported backbones are + ResNet18 and ResNet34. + (Note: Do not specify `image_shape` within the backbone. + Please provide these while initializing the 'BASNetBackbone' model) + num_classes: int, the number of classes for the segmentation model. + image_shape: optional shape tuple, defaults to (None, None, 3). + projection_filters: int, number of filters in the convolution layer + projecting low-level features from the `backbone`. + prediction_heads: (Optional) List of `keras.layers.Layer` defining + the prediction module head for the model. If not provided, a + default head is created with a Conv2D layer followed by resizing. + refinement_head: (Optional) a `keras.layers.Layer` defining the + refinement module head for the model. If not provided, a default + head is created with a Conv2D layer. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + + """ + + def __init__( + self, + image_encoder, + num_classes, + image_shape=(None, None, 3), + projection_filters=64, + prediction_heads=None, + refinement_head=None, + dtype=None, + **kwargs, + ): + if not isinstance(image_encoder, keras.layers.Layer) or not isinstance( + image_encoder, keras.Model + ): + raise ValueError( + "Argument `image_encoder` must be a `keras.layers.Layer` instance" + f" or `keras.Model`. Received instead" + f" image_encoder={image_encoder} (of type {type(image_encoder)})." + ) + + if tuple(image_encoder.image_shape) != (None, None, 3): + raise ValueError( + "Do not specify `image_shape` within the" + " `BASNetBackbone`'s image_encoder. \nPlease provide `image_shape`" + " while initializing the 'BASNetBackbone' model." + ) + + # === Functional Model === + inputs = keras.layers.Input(shape=image_shape) + x = inputs + + if prediction_heads is None: + prediction_heads = [] + for size in (1, 2, 4, 8, 16, 32, 32): + head_layers = [ + keras.layers.Conv2D( + num_classes, + kernel_size=(3, 3), + padding="same", + dtype=dtype, + ) + ] + if size != 1: + head_layers.append( + keras.layers.UpSampling2D( + size=size, interpolation="bilinear", dtype=dtype + ) + ) + prediction_heads.append(keras.Sequential(head_layers)) + + if refinement_head is None: + refinement_head = keras.Sequential( + [ + keras.layers.Conv2D( + num_classes, + kernel_size=(3, 3), + padding="same", + dtype=dtype, + ), + ] + ) + + # Prediction model. + predict_model = basnet_predict( + x, image_encoder, projection_filters, prediction_heads, dtype=dtype + ) + + # Refinement model. + refine_model = basnet_rrm( + predict_model, projection_filters, refinement_head, dtype=dtype + ) + + outputs = refine_model.outputs # Combine outputs. + outputs.extend(predict_model.outputs) + + output_names = ["refine_out"] + [ + f"predict_out_{i}" for i in range(1, len(outputs)) + ] + + outputs = { + output_name: keras.layers.Activation( + "sigmoid", name=output_name, dtype=dtype + )(output) + for output, output_name in zip(outputs, output_names) + } + + super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs) + + # === Config === + self.image_encoder = image_encoder + self.num_classes = num_classes + self.image_shape = image_shape + self.projection_filters = projection_filters + self.prediction_heads = prediction_heads + self.refinement_head = refinement_head + + def get_config(self): + config = super().get_config() + config.update( + { + "image_encoder": keras.saving.serialize_keras_object( + self.image_encoder + ), + "num_classes": self.num_classes, + "image_shape": self.image_shape, + "projection_filters": self.projection_filters, + "prediction_heads": [ + keras.saving.serialize_keras_object(prediction_head) + for prediction_head in self.prediction_heads + ], + "refinement_head": keras.saving.serialize_keras_object( + self.refinement_head + ), + } + ) + return config + + @classmethod + def from_config(cls, config): + if "image_encoder" in config: + config["image_encoder"] = keras.layers.deserialize( + config["image_encoder"] + ) + if "prediction_heads" in config and isinstance( + config["prediction_heads"], list + ): + for i in range(len(config["prediction_heads"])): + if isinstance(config["prediction_heads"][i], dict): + config["prediction_heads"][i] = keras.layers.deserialize( + config["prediction_heads"][i] + ) + + if "refinement_head" in config and isinstance( + config["refinement_head"], dict + ): + config["refinement_head"] = keras.layers.deserialize( + config["refinement_head"] + ) + return super().from_config(config) + + +def convolution_block(x_input, filters, dilation=1, dtype=None): + """ + Apply convolution + batch normalization + ReLU activation. + + Args: + x_input: Input keras tensor. + filters: int, number of output filters in the convolution. + dilation: int, dilation rate for the convolution operation. + Defaults to 1. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + + Returns: + A tensor with convolution, batch normalization, and ReLU + activation applied. + """ + x = keras.layers.Conv2D( + filters, (3, 3), padding="same", dilation_rate=dilation, dtype=dtype + )(x_input) + x = keras.layers.BatchNormalization(dtype=dtype)(x) + return keras.layers.Activation("relu", dtype=dtype)(x) + + +def get_resnet_block(_resnet, block_num): + """ + Extract and return a specific ResNet block. + + Args: + _resnet: `keras.Model`. ResNet model instance. + block_num: int, block number to extract. + + Returns: + A Keras Model representing the specified ResNet block. + """ + + extractor_levels = ["P2", "P3", "P4", "P5"] + num_blocks = _resnet.stackwise_num_blocks + if block_num == 0: + x = _resnet.get_layer("pool1_pool").output + else: + x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]] + y = _resnet.get_layer( + f"stack{block_num}_block{num_blocks[block_num]-1}_add" + ).output + return keras.models.Model( + inputs=x, + outputs=y, + name=f"resnet_block{block_num + 1}", + ) + + +def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None): + """ + BASNet Prediction Module. + + This module outputs a coarse label map by integrating heavy + encoder, bridge, and decoder blocks. + + Args: + x_input: Input keras tensor. + backbone: `keras.Model`. The backbone network used as a feature + extractor for BASNet prediction encoder. + filters: int, the number of filters. + segmentation_heads: List of `keras.layers.Layer`, A list of Keras + layers serving as the segmentation head for prediction module. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + + + Returns: + A Keras Model that integrates the encoder, bridge, and decoder + blocks for coarse label map prediction. + """ + num_stages = 6 + + x = x_input + + # -------------Encoder-------------- + x = keras.layers.Conv2D( + filters, kernel_size=(3, 3), padding="same", dtype=dtype + )(x) + + encoder_blocks = [] + for i in range(num_stages): + if i < 4: # First four stages are adopted from ResNet backbone. + x = get_resnet_block(backbone, i)(x) + encoder_blocks.append(x) + else: # Last 2 stages consist of three basic resnet blocks. + x = keras.layers.MaxPool2D( + pool_size=(2, 2), strides=(2, 2), dtype=dtype + )(x) + for j in range(3): + x = resnet_basic_block( + x, + filters=x.shape[3], + conv_shortcut=False, + name=f"v1_basic_block_{i + 1}_{j + 1}", + dtype=dtype, + ) + encoder_blocks.append(x) + + # -------------Bridge------------- + x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype) + x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype) + x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype) + encoder_blocks.append(x) + + # -------------Decoder------------- + decoder_blocks = [] + for i in reversed(range(num_stages)): + if i != (num_stages - 1): # Except first, scale other decoder stages. + x = keras.layers.UpSampling2D( + size=2, interpolation="bilinear", dtype=dtype + )(x) + + x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) + x = convolution_block(x, filters=filters * 8, dtype=dtype) + x = convolution_block(x, filters=filters * 8, dtype=dtype) + x = convolution_block(x, filters=filters * 8, dtype=dtype) + decoder_blocks.append(x) + + decoder_blocks.reverse() # Change order from last to first decoder stage. + decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder. + + # -------------Side Outputs-------------- + decoder_blocks = [ + segmentation_head(decoder_block) # Prediction segmentation head. + for segmentation_head, decoder_block in zip( + segmentation_heads, decoder_blocks + ) + ] + + return keras.models.Model(inputs=[x_input], outputs=decoder_blocks) + + +def basnet_rrm(base_model, filters, segmentation_head, dtype=None): + """ + BASNet Residual Refinement Module (RRM). + + This module outputs a fine label map by integrating light encoder, + bridge, and decoder blocks. + + Args: + base_model: Keras model used as the base or coarse label map. + filters: int, the number of filters. + segmentation_head: a `keras.layers.Layer`, A Keras layer serving + as the segmentation head for refinement module. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + + Returns: + A Keras Model that constructs the Residual Refinement Module (RRM). + """ + num_stages = 4 + + x_input = base_model.output[0] + + # -------------Encoder-------------- + x = keras.layers.Conv2D( + filters, kernel_size=(3, 3), padding="same", dtype=dtype + )(x_input) + + encoder_blocks = [] + for _ in range(num_stages): + x = convolution_block(x, filters=filters) + encoder_blocks.append(x) + x = keras.layers.MaxPool2D( + pool_size=(2, 2), strides=(2, 2), dtype=dtype + )(x) + + # -------------Bridge-------------- + x = convolution_block(x, filters=filters, dtype=dtype) + + # -------------Decoder-------------- + for i in reversed(range(num_stages)): + x = keras.layers.UpSampling2D( + size=2, interpolation="bilinear", dtype=dtype + )(x) + x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) + x = convolution_block(x, filters=filters) + + x = segmentation_head(x) # Refinement segmentation head. + + # ------------- refined = coarse + residual + x = keras.layers.Add(dtype=dtype)( + [x_input, x] + ) # Add prediction + refinement output + + return keras.models.Model(inputs=base_model.input, outputs=[x]) \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_backbone_test.py b/keras_hub/src/models/basnet/basnet_backbone_test.py new file mode 100644 index 0000000000..ca7efe3a3b --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_backbone_test.py @@ -0,0 +1,36 @@ +from keras import ops + +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class BASNetBackboneTest(TestCase): + def setUp(self): + self.images = ops.ones((2, 64, 64, 3)) + self.image_encoder = ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], + stackwise_num_filters=[64, 128, 256, 512], + stackwise_num_blocks=[2, 2, 2, 2], + stackwise_num_strides=[1, 2, 2, 2], + block_type="basic_block", + ) + self.init_kwargs = { + "image_encoder": self.image_encoder, + "num_classes": 1, + } + + def test_backbone_basics(self): + output_names = ["refine_out"] + [ + f"predict_out_{i}" for i in range(1, 8) + ] + expected_output_shape = {name: (2, 64, 64, 1) for name in output_names} + self.run_backbone_test( + cls=BASNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.images, + expected_output_shape=expected_output_shape, + run_mixed_precision_check=False, + run_quantization_check=False, + ) \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_image_converter.py b/keras_hub/src/models/basnet/basnet_image_converter.py index 45f364fd4e..347447526e 100644 --- a/keras_hub/src/models/basnet/basnet_image_converter.py +++ b/keras_hub/src/models/basnet/basnet_image_converter.py @@ -1,8 +1,8 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.image_converter import ImageConverter -from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone @keras_hub_export("keras_hub.layers.BASNetImageConverter") class BASNetImageConverter(ImageConverter): - backbone_cls = ResNetBackbone + backbone_cls = BASNetBackbone \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_preprocessor.py b/keras_hub/src/models/basnet/basnet_preprocessor.py index 7faa64cdd9..9fdb855e73 100644 --- a/keras_hub/src/models/basnet/basnet_preprocessor.py +++ b/keras_hub/src/models/basnet/basnet_preprocessor.py @@ -1,14 +1,14 @@ from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_image_converter import ( BASNetImageConverter, ) from keras_hub.src.models.image_segmenter_preprocessor import ( ImageSegmenterPreprocessor, ) -from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone @keras_hub_export("keras_hub.models.BASNetPreprocessor") class BASNetPreprocessor(ImageSegmenterPreprocessor): - backbone_cls = ResNetBackbone - image_converter_cls = BASNetImageConverter + backbone_cls = BASNetBackbone + image_converter_cls = BASNetImageConverter \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_presets.py b/keras_hub/src/models/basnet/basnet_presets.py index 1ed728643b..7a4b0a27d3 100644 --- a/keras_hub/src/models/basnet/basnet_presets.py +++ b/keras_hub/src/models/basnet/basnet_presets.py @@ -19,4 +19,4 @@ }, "kaggle_handle": "", # TODO }, -} +} \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_test.py b/keras_hub/src/models/basnet/basnet_test.py index fa7c1c9b64..a0af8e8e4e 100644 --- a/keras_hub/src/models/basnet/basnet_test.py +++ b/keras_hub/src/models/basnet/basnet_test.py @@ -2,6 +2,7 @@ from keras import ops from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.tests.test_case import TestCase @@ -13,7 +14,7 @@ def setUp(self): self.labels = ops.concatenate( (ops.zeros((2, 32, 64, 1)), ops.ones((2, 32, 64, 1))), axis=1 ) - self.backbone = ResNetBackbone( + self.image_encoder = ResNetBackbone( input_conv_filters=[64], input_conv_kernel_sizes=[7], stackwise_num_filters=[64, 128, 256, 512], @@ -21,11 +22,14 @@ def setUp(self): stackwise_num_strides=[1, 2, 2, 2], block_type="basic_block", ) + self.backbone = BASNetBackbone( + image_encoder=self.image_encoder, + num_classes=1, + ) self.preprocessor = BASNetPreprocessor() self.init_kwargs = { "backbone": self.backbone, "preprocessor": self.preprocessor, - "num_classes": 1, } self.train_data = (self.images, self.labels) @@ -34,7 +38,7 @@ def test_basics(self): cls=BASNetImageSegmenter, init_kwargs=self.init_kwargs, train_data=self.train_data, - expected_output_shape=[(2, 64, 64, 1)] * 8, + expected_output_shape=(2, 64, 64, 1), ) @pytest.mark.large @@ -47,10 +51,8 @@ def test_saved_model(self): def test_end_to_end_model_predict(self): model = BASNetImageSegmenter(**self.init_kwargs) - outputs = model.predict(self.images) - self.assertAllEqual( - [output.shape for output in outputs], [(2, 64, 64, 1)] * 8 - ) + output = model.predict(self.images) + self.assertAllEqual(output.shape, (2, 64, 64, 1)) @pytest.mark.skip(reason="disabled until preset's been uploaded to Kaggle") @pytest.mark.extra_large @@ -60,5 +62,5 @@ def test_all_presets(self): cls=BASNetImageSegmenter, preset=preset, input_data=self.images, - expected_output_shape=[(2, 64, 64, 1)] * 8, - ) + expected_output_shape=(2, 64, 64, 1), + ) \ No newline at end of file From 756e8a111be209bdb56c5e6e9347cc20f0718c6b Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Fri, 13 Dec 2024 15:00:59 -0800 Subject: [PATCH 06/15] reverting unwanted changes to branch --- keras_hub/api/models/__init__.py | 354 ++++++++++++----- keras_hub/src/models/basnet/__init__.py | 4 +- keras_hub/src/models/basnet/basnet.py | 348 ++++++++++++++-- .../src/models/basnet/basnet_backbone.py | 370 ------------------ .../src/models/basnet/basnet_backbone_test.py | 36 -- .../models/basnet/basnet_image_converter.py | 4 +- .../src/models/basnet/basnet_preprocessor.py | 6 +- keras_hub/src/models/basnet/basnet_presets.py | 2 +- keras_hub/src/models/basnet/basnet_test.py | 20 +- 9 files changed, 598 insertions(+), 546 deletions(-) delete mode 100644 keras_hub/src/models/basnet/basnet_backbone.py delete mode 100644 keras_hub/src/models/basnet/basnet_backbone_test.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 660263dff3..9748bc40b6 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -4,34 +4,54 @@ since your modifications would be overwritten. """ - from keras_hub.src.models.albert.albert_backbone import AlbertBackbone from keras_hub.src.models.albert.albert_masked_lm import AlbertMaskedLM -from keras_hub.src.models.albert.albert_masked_lm_preprocessor import AlbertMaskedLMPreprocessor -from keras_hub.src.models.albert.albert_text_classifier import AlbertTextClassifier -from keras_hub.src.models.albert.albert_text_classifier import AlbertTextClassifier as AlbertClassifier -from keras_hub.src.models.albert.albert_text_classifier_preprocessor import AlbertTextClassifierPreprocessor -from keras_hub.src.models.albert.albert_text_classifier_preprocessor import AlbertTextClassifierPreprocessor as AlbertPreprocessor +from keras_hub.src.models.albert.albert_masked_lm_preprocessor import ( + AlbertMaskedLMPreprocessor, +) +from keras_hub.src.models.albert.albert_text_classifier import ( + AlbertTextClassifier, +) +from keras_hub.src.models.albert.albert_text_classifier import ( + AlbertTextClassifier as AlbertClassifier, +) +from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( + AlbertTextClassifierPreprocessor, +) +from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( + AlbertTextClassifierPreprocessor as AlbertPreprocessor, +) from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.bart.bart_backbone import BartBackbone from keras_hub.src.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM -from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import BartSeq2SeqLMPreprocessor +from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import ( + BartSeq2SeqLMPreprocessor, +) from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter -from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.bert.bert_backbone import BertBackbone from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM -from keras_hub.src.models.bert.bert_masked_lm_preprocessor import BertMaskedLMPreprocessor +from keras_hub.src.models.bert.bert_masked_lm_preprocessor import ( + BertMaskedLMPreprocessor, +) from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier -from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier as BertClassifier -from keras_hub.src.models.bert.bert_text_classifier_preprocessor import BertTextClassifierPreprocessor -from keras_hub.src.models.bert.bert_text_classifier_preprocessor import BertTextClassifierPreprocessor as BertPreprocessor +from keras_hub.src.models.bert.bert_text_classifier import ( + BertTextClassifier as BertClassifier, +) +from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( + BertTextClassifierPreprocessor, +) +from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( + BertTextClassifierPreprocessor as BertPreprocessor, +) from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone from keras_hub.src.models.bloom.bloom_causal_lm import BloomCausalLM -from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import BloomCausalLMPreprocessor +from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import ( + BloomCausalLMPreprocessor, +) from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_hub.src.models.causal_lm import CausalLM from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor @@ -40,155 +60,309 @@ from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder -from keras_hub.src.models.csp_darknet.csp_darknet_backbone import CSPDarkNetBackbone -from keras_hub.src.models.csp_darknet.csp_darknet_image_classifier import CSPDarkNetImageClassifier -from keras_hub.src.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone -from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import DebertaV3MaskedLM -from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import DebertaV3MaskedLMPreprocessor -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import DebertaV3TextClassifier -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import DebertaV3TextClassifier as DebertaV3Classifier -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import DebertaV3TextClassifierPreprocessor -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor -from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import DebertaV3Tokenizer -from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import DeepLabV3Backbone -from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import DeepLabV3ImageSegmenterPreprocessor -from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import DeepLabV3ImageSegmenter +from keras_hub.src.models.csp_darknet.csp_darknet_backbone import ( + CSPDarkNetBackbone, +) +from keras_hub.src.models.csp_darknet.csp_darknet_image_classifier import ( + CSPDarkNetImageClassifier, +) +from keras_hub.src.models.deberta_v3.deberta_v3_backbone import ( + DebertaV3Backbone, +) +from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import ( + DebertaV3MaskedLM, +) +from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import ( + DebertaV3MaskedLMPreprocessor, +) +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( + DebertaV3TextClassifier, +) +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( + DebertaV3TextClassifier as DebertaV3Classifier, +) +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( + DebertaV3TextClassifierPreprocessor, +) +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( + DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, +) +from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( + DebertaV3Tokenizer, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( + DeepLabV3ImageSegmenterPreprocessor, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( + DeepLabV3ImageSegmenter, +) from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone -from keras_hub.src.models.densenet.densenet_image_classifier import DenseNetImageClassifier -from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import DenseNetImageClassifierPreprocessor -from keras_hub.src.models.distil_bert.distil_bert_backbone import DistilBertBackbone -from keras_hub.src.models.distil_bert.distil_bert_masked_lm import DistilBertMaskedLM -from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import DistilBertMaskedLMPreprocessor -from keras_hub.src.models.distil_bert.distil_bert_text_classifier import DistilBertTextClassifier -from keras_hub.src.models.distil_bert.distil_bert_text_classifier import DistilBertTextClassifier as DistilBertClassifier -from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import DistilBertTextClassifierPreprocessor -from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import DistilBertTextClassifierPreprocessor as DistilBertPreprocessor -from keras_hub.src.models.distil_bert.distil_bert_tokenizer import DistilBertTokenizer -from keras_hub.src.models.efficientnet.efficientnet_backbone import EfficientNetBackbone -from keras_hub.src.models.efficientnet.efficientnet_image_classifier import EfficientNetImageClassifier -from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import EfficientNetImageClassifierPreprocessor +from keras_hub.src.models.densenet.densenet_image_classifier import ( + DenseNetImageClassifier, +) +from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( + DenseNetImageClassifierPreprocessor, +) +from keras_hub.src.models.distil_bert.distil_bert_backbone import ( + DistilBertBackbone, +) +from keras_hub.src.models.distil_bert.distil_bert_masked_lm import ( + DistilBertMaskedLM, +) +from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import ( + DistilBertMaskedLMPreprocessor, +) +from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( + DistilBertTextClassifier, +) +from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( + DistilBertTextClassifier as DistilBertClassifier, +) +from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( + DistilBertTextClassifierPreprocessor, +) +from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( + DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, +) +from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( + DistilBertTokenizer, +) +from keras_hub.src.models.efficientnet.efficientnet_backbone import ( + EfficientNetBackbone, +) +from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( + EfficientNetImageClassifier, +) +from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( + EfficientNetImageClassifierPreprocessor, +) from keras_hub.src.models.electra.electra_backbone import ElectraBackbone from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone from keras_hub.src.models.f_net.f_net_masked_lm import FNetMaskedLM -from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import FNetMaskedLMPreprocessor +from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import ( + FNetMaskedLMPreprocessor, +) from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier -from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier as FNetClassifier -from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import FNetTextClassifierPreprocessor -from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import FNetTextClassifierPreprocessor as FNetPreprocessor +from keras_hub.src.models.f_net.f_net_text_classifier import ( + FNetTextClassifier as FNetClassifier, +) +from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( + FNetTextClassifierPreprocessor, +) +from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( + FNetTextClassifierPreprocessor as FNetPreprocessor, +) from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM -from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import FalconCausalLMPreprocessor +from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import ( + FalconCausalLMPreprocessor, +) from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_hub.src.models.flux.flux_model import FluxBackbone from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage -from keras_hub.src.models.flux.flux_text_to_image_preprocessor import FluxTextToImagePreprocessor +from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( + FluxTextToImagePreprocessor, +) from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM -from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import GemmaCausalLMPreprocessor +from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM -from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import GPT2CausalLMPreprocessor +from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import ( + GPT2CausalLMPreprocessor, +) from keras_hub.src.models.gpt2.gpt2_preprocessor import GPT2Preprocessor from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import GPTNeoXCausalLM -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import GPTNeoXCausalLMPreprocessor +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( + GPTNeoXCausalLMPreprocessor, +) from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_hub.src.models.image_classifier import ImageClassifier -from keras_hub.src.models.image_classifier_preprocessor import ImageClassifierPreprocessor +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) from keras_hub.src.models.image_object_detector import ImageObjectDetector -from keras_hub.src.models.image_object_detector_preprocessor import ImageObjectDetectorPreprocessor +from keras_hub.src.models.image_object_detector_preprocessor import ( + ImageObjectDetectorPreprocessor, +) from keras_hub.src.models.image_segmenter import ImageSegmenter -from keras_hub.src.models.image_segmenter_preprocessor import ImageSegmenterPreprocessor +from keras_hub.src.models.image_segmenter_preprocessor import ( + ImageSegmenterPreprocessor, +) from keras_hub.src.models.image_to_image import ImageToImage from keras_hub.src.models.inpaint import Inpaint from keras_hub.src.models.llama.llama_backbone import LlamaBackbone from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM -from keras_hub.src.models.llama.llama_causal_lm_preprocessor import LlamaCausalLMPreprocessor +from keras_hub.src.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM -from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import Llama3CausalLMPreprocessor +from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( + Llama3CausalLMPreprocessor, +) from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer from keras_hub.src.models.masked_lm import MaskedLM from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone from keras_hub.src.models.mistral.mistral_causal_lm import MistralCausalLM -from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import MistralCausalLMPreprocessor +from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import ( + MistralCausalLMPreprocessor, +) from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer from keras_hub.src.models.mit.mit_backbone import MiTBackbone from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier -from keras_hub.src.models.mit.mit_image_classifier_preprocessor import MiTImageClassifierPreprocessor +from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( + MiTImageClassifierPreprocessor, +) from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone -from keras_hub.src.models.mobilenet.mobilenet_image_classifier import MobileNetImageClassifier +from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( + MobileNetImageClassifier, +) from keras_hub.src.models.opt.opt_backbone import OPTBackbone from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM -from keras_hub.src.models.opt.opt_causal_lm_preprocessor import OPTCausalLMPreprocessor +from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( + OPTCausalLMPreprocessor, +) from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer -from keras_hub.src.models.pali_gemma.pali_gemma_backbone import PaliGemmaBackbone -from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import PaliGemmaCausalLM -from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import PaliGemmaCausalLMPreprocessor -from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import PaliGemmaTokenizer +from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( + PaliGemmaBackbone, +) +from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import ( + PaliGemmaCausalLM, +) +from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( + PaliGemmaCausalLMPreprocessor, +) +from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( + PaliGemmaTokenizer, +) from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM -from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import Phi3CausalLMPreprocessor +from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import ( + Phi3CausalLMPreprocessor, +) from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_hub.src.models.preprocessor import Preprocessor from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone -from keras_hub.src.models.resnet.resnet_image_classifier import ResNetImageClassifier -from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ResNetImageClassifierPreprocessor +from keras_hub.src.models.resnet.resnet_image_classifier import ( + ResNetImageClassifier, +) +from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ( + ResNetImageClassifierPreprocessor, +) from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone -from keras_hub.src.models.retinanet.retinanet_object_detector import RetinaNetObjectDetector -from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import RetinaNetObjectDetectorPreprocessor +from keras_hub.src.models.retinanet.retinanet_object_detector import ( + RetinaNetObjectDetector, +) +from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( + RetinaNetObjectDetectorPreprocessor, +) from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM -from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import RobertaMaskedLMPreprocessor -from keras_hub.src.models.roberta.roberta_text_classifier import RobertaTextClassifier -from keras_hub.src.models.roberta.roberta_text_classifier import RobertaTextClassifier as RobertaClassifier -from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import RobertaTextClassifierPreprocessor -from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import RobertaTextClassifierPreprocessor as RobertaPreprocessor +from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import ( + RobertaMaskedLMPreprocessor, +) +from keras_hub.src.models.roberta.roberta_text_classifier import ( + RobertaTextClassifier, +) +from keras_hub.src.models.roberta.roberta_text_classifier import ( + RobertaTextClassifier as RobertaClassifier, +) +from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( + RobertaTextClassifierPreprocessor, +) +from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( + RobertaTextClassifierPreprocessor as RobertaPreprocessor, +) from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_hub.src.models.sam.sam_backbone import SAMBackbone from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter -from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import SAMImageSegmenterPreprocessor +from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( + SAMImageSegmenterPreprocessor, +) from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone -from keras_hub.src.models.segformer.segformer_image_segmenter import SegFormerImageSegmenter -from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import SegFormerImageSegmenterPreprocessor +from keras_hub.src.models.segformer.segformer_image_segmenter import ( + SegFormerImageSegmenter, +) +from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( + SegFormerImageSegmenterPreprocessor, +) from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import StableDiffusion3Backbone -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import StableDiffusion3ImageToImage -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import StableDiffusion3Inpaint -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import StableDiffusion3TextToImage -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import StableDiffusion3TextToImagePreprocessor +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( + StableDiffusion3ImageToImage, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( + StableDiffusion3Inpaint, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( + StableDiffusion3TextToImage, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) from keras_hub.src.models.t5.t5_backbone import T5Backbone from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer from keras_hub.src.models.task import Task from keras_hub.src.models.text_classifier import TextClassifier from keras_hub.src.models.text_classifier import TextClassifier as Classifier -from keras_hub.src.models.text_classifier_preprocessor import TextClassifierPreprocessor +from keras_hub.src.models.text_classifier_preprocessor import ( + TextClassifierPreprocessor, +) from keras_hub.src.models.text_to_image import TextToImage from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier -from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import VGGImageClassifierPreprocessor +from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( + VGGImageClassifierPreprocessor, +) from keras_hub.src.models.vit.vit_backbone import ViTBackbone from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier -from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ViTImageClassifierPreprocessor +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer -from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone -from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import XLMRobertaMaskedLM -from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import XLMRobertaMaskedLMPreprocessor -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import XLMRobertaTextClassifier -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import XLMRobertaTextClassifier as XLMRobertaClassifier -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import XLMRobertaTextClassifierPreprocessor -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor -from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import XLMRobertaTokenizer +from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import ( + XLMRobertaBackbone, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import ( + XLMRobertaMaskedLM, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( + XLMRobertaMaskedLMPreprocessor, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( + XLMRobertaTextClassifier, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( + XLMRobertaTextClassifier as XLMRobertaClassifier, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( + XLMRobertaTextClassifierPreprocessor, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( + XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( + XLMRobertaTokenizer, +) from keras_hub.src.models.xlnet.xlnet_backbone import XLNetBackbone from keras_hub.src.tokenizers.tokenizer import Tokenizer diff --git a/keras_hub/src/models/basnet/__init__.py b/keras_hub/src/models/basnet/__init__.py index 8df584cab4..340f40d09a 100644 --- a/keras_hub/src/models/basnet/__init__.py +++ b/keras_hub/src/models/basnet/__init__.py @@ -1,5 +1,5 @@ -from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter from keras_hub.src.models.basnet.basnet_presets import basnet_presets from keras_hub.src.utils.preset_utils import register_presets -register_presets(basnet_presets, BASNetBackbone) \ No newline at end of file +register_presets(basnet_presets, BASNetImageSegmenter) diff --git a/keras_hub/src/models/basnet/basnet.py b/keras_hub/src/models/basnet/basnet.py index 8fc304d073..420eee6185 100644 --- a/keras_hub/src/models/basnet/basnet.py +++ b/keras_hub/src/models/basnet/basnet.py @@ -1,9 +1,12 @@ import keras from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.image_segmenter import ImageSegmenter +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.models.resnet.resnet_backbone import ( + apply_basic_block as resnet_basic_block, +) @keras_hub_export("keras_hub.models.BASNetImageSegmenter") @@ -13,13 +16,24 @@ class BASNetImageSegmenter(ImageSegmenter): segmentation. References: - [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) + - [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) Args: - backbone: A `keras_hub.models.BASNetBackbone` instance. - preprocessor: `None`, a `keras_hub.models.Preprocessor` instance, - a `keras.Layer` instance, or a callable. If `None` no preprocessing - will be applied to the inputs. + backbone: `keras.Model`. The backbone network for the model that is + used as a feature extractor for BASNet prediction encoder. Currently + supported backbones are ResNet18 and ResNet34. + (Note: Do not specify `image_shape` within the backbone. + Please provide these while initializing the 'BASNet' model.) + num_classes: int, the number of classes for the segmentation model. + image_shape: optional shape tuple, defaults to (None, None, 3). + projection_filters: int, number of filters in the convolution layer + projecting low-level features from the `backbone`. + prediction_heads: (Optional) List of `keras.layers.Layer` defining + the prediction module head for the model. If not provided, a + default head is created with a Conv2D layer followed by resizing. + refinement_head: (Optional) a `keras.layers.Layer` defining the + refinement module head for the model. If not provided, a default + head is created with a Conv2D layer. Example: ```python @@ -28,21 +42,22 @@ class BASNetImageSegmenter(ImageSegmenter): images = np.ones(shape=(1, 288, 288, 3)) labels = np.zeros(shape=(1, 288, 288, 1)) - image_encoder = keras_hub.models.ResNetBackbone.from_preset( + # Note: Do not specify `image_shape` within the backbone. + backbone = keras_hub.models.ResNetBackbone.from_preset( "resnet_18_imagenet", load_weights=False ) - backbone = keras_hub.models.BASNetBackbone( - image_encoder, + model = keras_hub.models.BASNetImageSegmenter( + backbone=backbone, num_classes=1, - image_shape=[288, 288, 3] + image_shape=[288, 288, 3], ) - model = keras_hub.models.BASNetImageSegmenter(backbone) - # Evaluate the model - pred_labels = model(images) + # Evaluate model + output = model(images) + pred_labels = output[0] - # Train the model + # Train model model.compile( optimizer="adam", loss=keras.losses.BinaryCrossentropy(from_logits=False), @@ -52,36 +67,98 @@ class BASNetImageSegmenter(ImageSegmenter): ``` """ - backbone_cls = BASNetBackbone + backbone_cls = ResNetBackbone preprocessor_cls = BASNetPreprocessor def __init__( self, backbone, + num_classes, + image_shape=(None, None, 3), + projection_filters=64, + prediction_heads=None, + refinement_head=None, preprocessor=None, **kwargs, ): + if not isinstance(backbone, keras.layers.Layer) or not isinstance( + backbone, keras.Model + ): + raise ValueError( + "Argument `backbone` must be a `keras.layers.Layer` instance" + f" or `keras.Model`. Received instead" + f" backbone={backbone} (of type {type(backbone)})." + ) + + if tuple(backbone.image_shape) != (None, None, 3): + raise ValueError( + "Do not specify `image_shape` within the" + " 'BASNet' backbone. \nPlease provide `image_shape`" + " while initializing the 'BASNet' model." + ) # === Functional Model === - x = backbone.input - outputs = backbone(x) - # only return the refinement module's output as final prediction - outputs = outputs["refine_out"] - super().__init__(inputs=x, outputs=outputs, **kwargs) + inputs = keras.layers.Input(shape=image_shape) + x = inputs + + if prediction_heads is None: + prediction_heads = [] + for size in (1, 2, 4, 8, 16, 32, 32): + head_layers = [ + keras.layers.Conv2D( + num_classes, kernel_size=(3, 3), padding="same" + ) + ] + if size != 1: + head_layers.append( + keras.layers.UpSampling2D( + size=size, interpolation="bilinear" + ) + ) + prediction_heads.append(keras.Sequential(head_layers)) + + if refinement_head is None: + refinement_head = keras.Sequential( + [ + keras.layers.Conv2D( + num_classes, kernel_size=(3, 3), padding="same" + ), + ] + ) + + # Prediction model. + predict_model = basnet_predict( + x, backbone, projection_filters, prediction_heads + ) + + # Refinement model. + refine_model = basnet_rrm( + predict_model, projection_filters, refinement_head + ) + + outputs = refine_model.outputs # Combine outputs. + outputs.extend(predict_model.outputs) + + output_names = ["refine_out"] + [ + f"predict_out_{i}" for i in range(1, len(outputs)) + ] + + outputs = [ + keras.layers.Activation("sigmoid", name=output_name)(output) + for output, output_name in zip(outputs, output_names) + ] + + super().__init__(inputs=inputs, outputs=outputs, **kwargs) # === Config === self.backbone = backbone + self.num_classes = num_classes + self.image_shape = image_shape + self.projection_filters = projection_filters + self.prediction_heads = prediction_heads + self.refinement_head = refinement_head self.preprocessor = preprocessor - def compute_loss(self, x, y, y_pred, *args, **kwargs): - # train BASNet's prediction and refinement module outputs against the same ground truth data - outputs = self.backbone(x) - losses = [] - for output in outputs.values(): - losses.append(super().compute_loss(x, y, output, *args, **kwargs)) - loss = keras.ops.sum(losses, axis=0) - return loss - def compile( self, optimizer="auto", @@ -119,10 +196,219 @@ def compile( if loss == "auto": loss = keras.losses.BinaryCrossentropy() if metrics == "auto": - metrics = [keras.metrics.Accuracy()] + metrics = {"refine_out": keras.metrics.Accuracy()} super().compile( optimizer=optimizer, loss=loss, metrics=metrics, **kwargs, - ) \ No newline at end of file + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "image_shape": self.image_shape, + "projection_filters": self.projection_filters, + "prediction_heads": [ + keras.saving.serialize_keras_object(prediction_head) + for prediction_head in self.prediction_heads + ], + "refinement_head": keras.saving.serialize_keras_object( + self.refinement_head + ), + } + ) + return config + + @classmethod + def from_config(cls, config): + if "prediction_heads" in config and isinstance( + config["prediction_heads"], list + ): + for i in range(len(config["prediction_heads"])): + if isinstance(config["prediction_heads"][i], dict): + config["prediction_heads"][i] = keras.layers.deserialize( + config["prediction_heads"][i] + ) + + if "refinement_head" in config and isinstance( + config["refinement_head"], dict + ): + config["refinement_head"] = keras.layers.deserialize( + config["refinement_head"] + ) + return super().from_config(config) + + +def convolution_block(x_input, filters, dilation=1): + """ + Apply convolution + batch normalization + ReLU activation. + + Args: + x_input: Input keras tensor. + filters: int, number of output filters in the convolution. + dilation: int, dilation rate for the convolution operation. + Defaults to 1. + + Returns: + A tensor with convolution, batch normalization, and ReLU + activation applied. + """ + x = keras.layers.Conv2D( + filters, (3, 3), padding="same", dilation_rate=dilation + )(x_input) + x = keras.layers.BatchNormalization()(x) + return keras.layers.Activation("relu")(x) + + +def get_resnet_block(_resnet, block_num): + """ + Extract and return a specific ResNet block. + + Args: + _resnet: `keras.Model`. ResNet model instance. + block_num: int, block number to extract. + + Returns: + A Keras Model representing the specified ResNet block. + """ + + extractor_levels = ["P2", "P3", "P4", "P5"] + num_blocks = _resnet.stackwise_num_blocks + if block_num == 0: + x = _resnet.get_layer("pool1_pool").output + else: + x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]] + y = _resnet.get_layer( + f"stack{block_num}_block{num_blocks[block_num]-1}_add" + ).output + return keras.models.Model( + inputs=x, + outputs=y, + name=f"resnet_block{block_num + 1}", + ) + + +def basnet_predict(x_input, backbone, filters, segmentation_heads): + """ + BASNet Prediction Module. + + This module outputs a coarse label map by integrating heavy + encoder, bridge, and decoder blocks. + + Args: + x_input: Input keras tensor. + backbone: `keras.Model`. The backbone network used as a feature + extractor for BASNet prediction encoder. + filters: int, the number of filters. + segmentation_heads: List of `keras.layers.Layer`, A list of Keras + layers serving as the segmentation head for prediction module. + + + Returns: + A Keras Model that integrates the encoder, bridge, and decoder + blocks for coarse label map prediction. + """ + num_stages = 6 + + x = x_input + + # -------------Encoder-------------- + x = keras.layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x) + + encoder_blocks = [] + for i in range(num_stages): + if i < 4: # First four stages are adopted from ResNet backbone. + x = get_resnet_block(backbone, i)(x) + encoder_blocks.append(x) + else: # Last 2 stages consist of three basic resnet blocks. + x = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x) + for j in range(3): + x = resnet_basic_block( + x, + filters=x.shape[3], + conv_shortcut=False, + name=f"v1_basic_block_{i + 1}_{j + 1}", + ) + encoder_blocks.append(x) + + # -------------Bridge------------- + x = convolution_block(x, filters=filters * 8, dilation=2) + x = convolution_block(x, filters=filters * 8, dilation=2) + x = convolution_block(x, filters=filters * 8, dilation=2) + encoder_blocks.append(x) + + # -------------Decoder------------- + decoder_blocks = [] + for i in reversed(range(num_stages)): + if i != (num_stages - 1): # Except first, scale other decoder stages. + x = keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x) + + x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) + x = convolution_block(x, filters=filters * 8) + x = convolution_block(x, filters=filters * 8) + x = convolution_block(x, filters=filters * 8) + decoder_blocks.append(x) + + decoder_blocks.reverse() # Change order from last to first decoder stage. + decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder. + + # -------------Side Outputs-------------- + decoder_blocks = [ + segmentation_head(decoder_block) # Prediction segmentation head. + for segmentation_head, decoder_block in zip( + segmentation_heads, decoder_blocks + ) + ] + + return keras.models.Model(inputs=[x_input], outputs=decoder_blocks) + + +def basnet_rrm(base_model, filters, segmentation_head): + """ + BASNet Residual Refinement Module (RRM). + + This module outputs a fine label map by integrating light encoder, + bridge, and decoder blocks. + + Args: + base_model: Keras model used as the base or coarse label map. + filters: int, the number of filters. + segmentation_head: a `keras.layers.Layer`, A Keras layer serving + as the segmentation head for refinement module. + + Returns: + A Keras Model that constructs the Residual Refinement Module (RRM). + """ + num_stages = 4 + + x_input = base_model.output[0] + + # -------------Encoder-------------- + x = keras.layers.Conv2D(filters, kernel_size=(3, 3), padding="same")( + x_input + ) + + encoder_blocks = [] + for _ in range(num_stages): + x = convolution_block(x, filters=filters) + encoder_blocks.append(x) + x = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x) + + # -------------Bridge-------------- + x = convolution_block(x, filters=filters) + + # -------------Decoder-------------- + for i in reversed(range(num_stages)): + x = keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x) + x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) + x = convolution_block(x, filters=filters) + + x = segmentation_head(x) # Refinement segmentation head. + + # ------------- refined = coarse + residual + x = keras.layers.Add()([x_input, x]) # Add prediction + refinement output + + return keras.models.Model(inputs=base_model.input, outputs=[x]) diff --git a/keras_hub/src/models/basnet/basnet_backbone.py b/keras_hub/src/models/basnet/basnet_backbone.py deleted file mode 100644 index 29599d1e30..0000000000 --- a/keras_hub/src/models/basnet/basnet_backbone.py +++ /dev/null @@ -1,370 +0,0 @@ -import keras - -from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.resnet.resnet_backbone import ( - apply_basic_block as resnet_basic_block, -) - - -@keras_hub_export("keras_hub.models.BASNetBackbone") -class BASNetBackbone(Backbone): - """ - A Keras model implementing the BASNet architecture for semantic - segmentation. - - References: - - [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) - - Args: - image_encoder: A `keras_hub.models.ResNetBackbone` instance. The - backbone network for the model that is used as a feature extractor - for BASNet prediction encoder. Currently supported backbones are - ResNet18 and ResNet34. - (Note: Do not specify `image_shape` within the backbone. - Please provide these while initializing the 'BASNetBackbone' model) - num_classes: int, the number of classes for the segmentation model. - image_shape: optional shape tuple, defaults to (None, None, 3). - projection_filters: int, number of filters in the convolution layer - projecting low-level features from the `backbone`. - prediction_heads: (Optional) List of `keras.layers.Layer` defining - the prediction module head for the model. If not provided, a - default head is created with a Conv2D layer followed by resizing. - refinement_head: (Optional) a `keras.layers.Layer` defining the - refinement module head for the model. If not provided, a default - head is created with a Conv2D layer. - dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype - to use for the model's computations and weights. - - """ - - def __init__( - self, - image_encoder, - num_classes, - image_shape=(None, None, 3), - projection_filters=64, - prediction_heads=None, - refinement_head=None, - dtype=None, - **kwargs, - ): - if not isinstance(image_encoder, keras.layers.Layer) or not isinstance( - image_encoder, keras.Model - ): - raise ValueError( - "Argument `image_encoder` must be a `keras.layers.Layer` instance" - f" or `keras.Model`. Received instead" - f" image_encoder={image_encoder} (of type {type(image_encoder)})." - ) - - if tuple(image_encoder.image_shape) != (None, None, 3): - raise ValueError( - "Do not specify `image_shape` within the" - " `BASNetBackbone`'s image_encoder. \nPlease provide `image_shape`" - " while initializing the 'BASNetBackbone' model." - ) - - # === Functional Model === - inputs = keras.layers.Input(shape=image_shape) - x = inputs - - if prediction_heads is None: - prediction_heads = [] - for size in (1, 2, 4, 8, 16, 32, 32): - head_layers = [ - keras.layers.Conv2D( - num_classes, - kernel_size=(3, 3), - padding="same", - dtype=dtype, - ) - ] - if size != 1: - head_layers.append( - keras.layers.UpSampling2D( - size=size, interpolation="bilinear", dtype=dtype - ) - ) - prediction_heads.append(keras.Sequential(head_layers)) - - if refinement_head is None: - refinement_head = keras.Sequential( - [ - keras.layers.Conv2D( - num_classes, - kernel_size=(3, 3), - padding="same", - dtype=dtype, - ), - ] - ) - - # Prediction model. - predict_model = basnet_predict( - x, image_encoder, projection_filters, prediction_heads, dtype=dtype - ) - - # Refinement model. - refine_model = basnet_rrm( - predict_model, projection_filters, refinement_head, dtype=dtype - ) - - outputs = refine_model.outputs # Combine outputs. - outputs.extend(predict_model.outputs) - - output_names = ["refine_out"] + [ - f"predict_out_{i}" for i in range(1, len(outputs)) - ] - - outputs = { - output_name: keras.layers.Activation( - "sigmoid", name=output_name, dtype=dtype - )(output) - for output, output_name in zip(outputs, output_names) - } - - super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs) - - # === Config === - self.image_encoder = image_encoder - self.num_classes = num_classes - self.image_shape = image_shape - self.projection_filters = projection_filters - self.prediction_heads = prediction_heads - self.refinement_head = refinement_head - - def get_config(self): - config = super().get_config() - config.update( - { - "image_encoder": keras.saving.serialize_keras_object( - self.image_encoder - ), - "num_classes": self.num_classes, - "image_shape": self.image_shape, - "projection_filters": self.projection_filters, - "prediction_heads": [ - keras.saving.serialize_keras_object(prediction_head) - for prediction_head in self.prediction_heads - ], - "refinement_head": keras.saving.serialize_keras_object( - self.refinement_head - ), - } - ) - return config - - @classmethod - def from_config(cls, config): - if "image_encoder" in config: - config["image_encoder"] = keras.layers.deserialize( - config["image_encoder"] - ) - if "prediction_heads" in config and isinstance( - config["prediction_heads"], list - ): - for i in range(len(config["prediction_heads"])): - if isinstance(config["prediction_heads"][i], dict): - config["prediction_heads"][i] = keras.layers.deserialize( - config["prediction_heads"][i] - ) - - if "refinement_head" in config and isinstance( - config["refinement_head"], dict - ): - config["refinement_head"] = keras.layers.deserialize( - config["refinement_head"] - ) - return super().from_config(config) - - -def convolution_block(x_input, filters, dilation=1, dtype=None): - """ - Apply convolution + batch normalization + ReLU activation. - - Args: - x_input: Input keras tensor. - filters: int, number of output filters in the convolution. - dilation: int, dilation rate for the convolution operation. - Defaults to 1. - dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype - to use for the model's computations and weights. - - Returns: - A tensor with convolution, batch normalization, and ReLU - activation applied. - """ - x = keras.layers.Conv2D( - filters, (3, 3), padding="same", dilation_rate=dilation, dtype=dtype - )(x_input) - x = keras.layers.BatchNormalization(dtype=dtype)(x) - return keras.layers.Activation("relu", dtype=dtype)(x) - - -def get_resnet_block(_resnet, block_num): - """ - Extract and return a specific ResNet block. - - Args: - _resnet: `keras.Model`. ResNet model instance. - block_num: int, block number to extract. - - Returns: - A Keras Model representing the specified ResNet block. - """ - - extractor_levels = ["P2", "P3", "P4", "P5"] - num_blocks = _resnet.stackwise_num_blocks - if block_num == 0: - x = _resnet.get_layer("pool1_pool").output - else: - x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]] - y = _resnet.get_layer( - f"stack{block_num}_block{num_blocks[block_num]-1}_add" - ).output - return keras.models.Model( - inputs=x, - outputs=y, - name=f"resnet_block{block_num + 1}", - ) - - -def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None): - """ - BASNet Prediction Module. - - This module outputs a coarse label map by integrating heavy - encoder, bridge, and decoder blocks. - - Args: - x_input: Input keras tensor. - backbone: `keras.Model`. The backbone network used as a feature - extractor for BASNet prediction encoder. - filters: int, the number of filters. - segmentation_heads: List of `keras.layers.Layer`, A list of Keras - layers serving as the segmentation head for prediction module. - dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype - to use for the model's computations and weights. - - - Returns: - A Keras Model that integrates the encoder, bridge, and decoder - blocks for coarse label map prediction. - """ - num_stages = 6 - - x = x_input - - # -------------Encoder-------------- - x = keras.layers.Conv2D( - filters, kernel_size=(3, 3), padding="same", dtype=dtype - )(x) - - encoder_blocks = [] - for i in range(num_stages): - if i < 4: # First four stages are adopted from ResNet backbone. - x = get_resnet_block(backbone, i)(x) - encoder_blocks.append(x) - else: # Last 2 stages consist of three basic resnet blocks. - x = keras.layers.MaxPool2D( - pool_size=(2, 2), strides=(2, 2), dtype=dtype - )(x) - for j in range(3): - x = resnet_basic_block( - x, - filters=x.shape[3], - conv_shortcut=False, - name=f"v1_basic_block_{i + 1}_{j + 1}", - dtype=dtype, - ) - encoder_blocks.append(x) - - # -------------Bridge------------- - x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype) - x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype) - x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype) - encoder_blocks.append(x) - - # -------------Decoder------------- - decoder_blocks = [] - for i in reversed(range(num_stages)): - if i != (num_stages - 1): # Except first, scale other decoder stages. - x = keras.layers.UpSampling2D( - size=2, interpolation="bilinear", dtype=dtype - )(x) - - x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) - x = convolution_block(x, filters=filters * 8, dtype=dtype) - x = convolution_block(x, filters=filters * 8, dtype=dtype) - x = convolution_block(x, filters=filters * 8, dtype=dtype) - decoder_blocks.append(x) - - decoder_blocks.reverse() # Change order from last to first decoder stage. - decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder. - - # -------------Side Outputs-------------- - decoder_blocks = [ - segmentation_head(decoder_block) # Prediction segmentation head. - for segmentation_head, decoder_block in zip( - segmentation_heads, decoder_blocks - ) - ] - - return keras.models.Model(inputs=[x_input], outputs=decoder_blocks) - - -def basnet_rrm(base_model, filters, segmentation_head, dtype=None): - """ - BASNet Residual Refinement Module (RRM). - - This module outputs a fine label map by integrating light encoder, - bridge, and decoder blocks. - - Args: - base_model: Keras model used as the base or coarse label map. - filters: int, the number of filters. - segmentation_head: a `keras.layers.Layer`, A Keras layer serving - as the segmentation head for refinement module. - dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype - to use for the model's computations and weights. - - Returns: - A Keras Model that constructs the Residual Refinement Module (RRM). - """ - num_stages = 4 - - x_input = base_model.output[0] - - # -------------Encoder-------------- - x = keras.layers.Conv2D( - filters, kernel_size=(3, 3), padding="same", dtype=dtype - )(x_input) - - encoder_blocks = [] - for _ in range(num_stages): - x = convolution_block(x, filters=filters) - encoder_blocks.append(x) - x = keras.layers.MaxPool2D( - pool_size=(2, 2), strides=(2, 2), dtype=dtype - )(x) - - # -------------Bridge-------------- - x = convolution_block(x, filters=filters, dtype=dtype) - - # -------------Decoder-------------- - for i in reversed(range(num_stages)): - x = keras.layers.UpSampling2D( - size=2, interpolation="bilinear", dtype=dtype - )(x) - x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) - x = convolution_block(x, filters=filters) - - x = segmentation_head(x) # Refinement segmentation head. - - # ------------- refined = coarse + residual - x = keras.layers.Add(dtype=dtype)( - [x_input, x] - ) # Add prediction + refinement output - - return keras.models.Model(inputs=base_model.input, outputs=[x]) \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_backbone_test.py b/keras_hub/src/models/basnet/basnet_backbone_test.py deleted file mode 100644 index ca7efe3a3b..0000000000 --- a/keras_hub/src/models/basnet/basnet_backbone_test.py +++ /dev/null @@ -1,36 +0,0 @@ -from keras import ops - -from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone -from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone -from keras_hub.src.tests.test_case import TestCase - - -class BASNetBackboneTest(TestCase): - def setUp(self): - self.images = ops.ones((2, 64, 64, 3)) - self.image_encoder = ResNetBackbone( - input_conv_filters=[64], - input_conv_kernel_sizes=[7], - stackwise_num_filters=[64, 128, 256, 512], - stackwise_num_blocks=[2, 2, 2, 2], - stackwise_num_strides=[1, 2, 2, 2], - block_type="basic_block", - ) - self.init_kwargs = { - "image_encoder": self.image_encoder, - "num_classes": 1, - } - - def test_backbone_basics(self): - output_names = ["refine_out"] + [ - f"predict_out_{i}" for i in range(1, 8) - ] - expected_output_shape = {name: (2, 64, 64, 1) for name in output_names} - self.run_backbone_test( - cls=BASNetBackbone, - init_kwargs=self.init_kwargs, - input_data=self.images, - expected_output_shape=expected_output_shape, - run_mixed_precision_check=False, - run_quantization_check=False, - ) \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_image_converter.py b/keras_hub/src/models/basnet/basnet_image_converter.py index 347447526e..45f364fd4e 100644 --- a/keras_hub/src/models/basnet/basnet_image_converter.py +++ b/keras_hub/src/models/basnet/basnet_image_converter.py @@ -1,8 +1,8 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.image_converter import ImageConverter -from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone @keras_hub_export("keras_hub.layers.BASNetImageConverter") class BASNetImageConverter(ImageConverter): - backbone_cls = BASNetBackbone \ No newline at end of file + backbone_cls = ResNetBackbone diff --git a/keras_hub/src/models/basnet/basnet_preprocessor.py b/keras_hub/src/models/basnet/basnet_preprocessor.py index 9fdb855e73..7faa64cdd9 100644 --- a/keras_hub/src/models/basnet/basnet_preprocessor.py +++ b/keras_hub/src/models/basnet/basnet_preprocessor.py @@ -1,14 +1,14 @@ from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_image_converter import ( BASNetImageConverter, ) from keras_hub.src.models.image_segmenter_preprocessor import ( ImageSegmenterPreprocessor, ) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone @keras_hub_export("keras_hub.models.BASNetPreprocessor") class BASNetPreprocessor(ImageSegmenterPreprocessor): - backbone_cls = BASNetBackbone - image_converter_cls = BASNetImageConverter \ No newline at end of file + backbone_cls = ResNetBackbone + image_converter_cls = BASNetImageConverter diff --git a/keras_hub/src/models/basnet/basnet_presets.py b/keras_hub/src/models/basnet/basnet_presets.py index 7a4b0a27d3..1ed728643b 100644 --- a/keras_hub/src/models/basnet/basnet_presets.py +++ b/keras_hub/src/models/basnet/basnet_presets.py @@ -19,4 +19,4 @@ }, "kaggle_handle": "", # TODO }, -} \ No newline at end of file +} diff --git a/keras_hub/src/models/basnet/basnet_test.py b/keras_hub/src/models/basnet/basnet_test.py index a0af8e8e4e..fa7c1c9b64 100644 --- a/keras_hub/src/models/basnet/basnet_test.py +++ b/keras_hub/src/models/basnet/basnet_test.py @@ -2,7 +2,6 @@ from keras import ops from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter -from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.tests.test_case import TestCase @@ -14,7 +13,7 @@ def setUp(self): self.labels = ops.concatenate( (ops.zeros((2, 32, 64, 1)), ops.ones((2, 32, 64, 1))), axis=1 ) - self.image_encoder = ResNetBackbone( + self.backbone = ResNetBackbone( input_conv_filters=[64], input_conv_kernel_sizes=[7], stackwise_num_filters=[64, 128, 256, 512], @@ -22,14 +21,11 @@ def setUp(self): stackwise_num_strides=[1, 2, 2, 2], block_type="basic_block", ) - self.backbone = BASNetBackbone( - image_encoder=self.image_encoder, - num_classes=1, - ) self.preprocessor = BASNetPreprocessor() self.init_kwargs = { "backbone": self.backbone, "preprocessor": self.preprocessor, + "num_classes": 1, } self.train_data = (self.images, self.labels) @@ -38,7 +34,7 @@ def test_basics(self): cls=BASNetImageSegmenter, init_kwargs=self.init_kwargs, train_data=self.train_data, - expected_output_shape=(2, 64, 64, 1), + expected_output_shape=[(2, 64, 64, 1)] * 8, ) @pytest.mark.large @@ -51,8 +47,10 @@ def test_saved_model(self): def test_end_to_end_model_predict(self): model = BASNetImageSegmenter(**self.init_kwargs) - output = model.predict(self.images) - self.assertAllEqual(output.shape, (2, 64, 64, 1)) + outputs = model.predict(self.images) + self.assertAllEqual( + [output.shape for output in outputs], [(2, 64, 64, 1)] * 8 + ) @pytest.mark.skip(reason="disabled until preset's been uploaded to Kaggle") @pytest.mark.extra_large @@ -62,5 +60,5 @@ def test_all_presets(self): cls=BASNetImageSegmenter, preset=preset, input_data=self.images, - expected_output_shape=(2, 64, 64, 1), - ) \ No newline at end of file + expected_output_shape=[(2, 64, 64, 1)] * 8, + ) From 8e4163af53c57ff8762af4f9f55ca8c2d4009717 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Fri, 13 Dec 2024 15:09:58 -0800 Subject: [PATCH 07/15] separate backbone, compute_loss, fix tests --- keras_hub/src/models/basnet/__init__.py | 4 +- keras_hub/src/models/basnet/basnet.py | 348 ++-------------- .../src/models/basnet/basnet_backbone.py | 370 ++++++++++++++++++ .../src/models/basnet/basnet_backbone_test.py | 36 ++ .../models/basnet/basnet_image_converter.py | 4 +- .../src/models/basnet/basnet_preprocessor.py | 6 +- keras_hub/src/models/basnet/basnet_presets.py | 2 +- keras_hub/src/models/basnet/basnet_test.py | 20 +- 8 files changed, 456 insertions(+), 334 deletions(-) create mode 100644 keras_hub/src/models/basnet/basnet_backbone.py create mode 100644 keras_hub/src/models/basnet/basnet_backbone_test.py diff --git a/keras_hub/src/models/basnet/__init__.py b/keras_hub/src/models/basnet/__init__.py index 340f40d09a..8df584cab4 100644 --- a/keras_hub/src/models/basnet/__init__.py +++ b/keras_hub/src/models/basnet/__init__.py @@ -1,5 +1,5 @@ -from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_presets import basnet_presets from keras_hub.src.utils.preset_utils import register_presets -register_presets(basnet_presets, BASNetImageSegmenter) +register_presets(basnet_presets, BASNetBackbone) \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet.py b/keras_hub/src/models/basnet/basnet.py index 420eee6185..8fc304d073 100644 --- a/keras_hub/src/models/basnet/basnet.py +++ b/keras_hub/src/models/basnet/basnet.py @@ -1,12 +1,9 @@ import keras from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.image_segmenter import ImageSegmenter -from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone -from keras_hub.src.models.resnet.resnet_backbone import ( - apply_basic_block as resnet_basic_block, -) @keras_hub_export("keras_hub.models.BASNetImageSegmenter") @@ -16,24 +13,13 @@ class BASNetImageSegmenter(ImageSegmenter): segmentation. References: - - [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) + [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) Args: - backbone: `keras.Model`. The backbone network for the model that is - used as a feature extractor for BASNet prediction encoder. Currently - supported backbones are ResNet18 and ResNet34. - (Note: Do not specify `image_shape` within the backbone. - Please provide these while initializing the 'BASNet' model.) - num_classes: int, the number of classes for the segmentation model. - image_shape: optional shape tuple, defaults to (None, None, 3). - projection_filters: int, number of filters in the convolution layer - projecting low-level features from the `backbone`. - prediction_heads: (Optional) List of `keras.layers.Layer` defining - the prediction module head for the model. If not provided, a - default head is created with a Conv2D layer followed by resizing. - refinement_head: (Optional) a `keras.layers.Layer` defining the - refinement module head for the model. If not provided, a default - head is created with a Conv2D layer. + backbone: A `keras_hub.models.BASNetBackbone` instance. + preprocessor: `None`, a `keras_hub.models.Preprocessor` instance, + a `keras.Layer` instance, or a callable. If `None` no preprocessing + will be applied to the inputs. Example: ```python @@ -42,22 +28,21 @@ class BASNetImageSegmenter(ImageSegmenter): images = np.ones(shape=(1, 288, 288, 3)) labels = np.zeros(shape=(1, 288, 288, 1)) - # Note: Do not specify `image_shape` within the backbone. - backbone = keras_hub.models.ResNetBackbone.from_preset( + image_encoder = keras_hub.models.ResNetBackbone.from_preset( "resnet_18_imagenet", load_weights=False ) - model = keras_hub.models.BASNetImageSegmenter( - backbone=backbone, + backbone = keras_hub.models.BASNetBackbone( + image_encoder, num_classes=1, - image_shape=[288, 288, 3], + image_shape=[288, 288, 3] ) + model = keras_hub.models.BASNetImageSegmenter(backbone) - # Evaluate model - output = model(images) - pred_labels = output[0] + # Evaluate the model + pred_labels = model(images) - # Train model + # Train the model model.compile( optimizer="adam", loss=keras.losses.BinaryCrossentropy(from_logits=False), @@ -67,98 +52,36 @@ class BASNetImageSegmenter(ImageSegmenter): ``` """ - backbone_cls = ResNetBackbone + backbone_cls = BASNetBackbone preprocessor_cls = BASNetPreprocessor def __init__( self, backbone, - num_classes, - image_shape=(None, None, 3), - projection_filters=64, - prediction_heads=None, - refinement_head=None, preprocessor=None, **kwargs, ): - if not isinstance(backbone, keras.layers.Layer) or not isinstance( - backbone, keras.Model - ): - raise ValueError( - "Argument `backbone` must be a `keras.layers.Layer` instance" - f" or `keras.Model`. Received instead" - f" backbone={backbone} (of type {type(backbone)})." - ) - - if tuple(backbone.image_shape) != (None, None, 3): - raise ValueError( - "Do not specify `image_shape` within the" - " 'BASNet' backbone. \nPlease provide `image_shape`" - " while initializing the 'BASNet' model." - ) # === Functional Model === - inputs = keras.layers.Input(shape=image_shape) - x = inputs - - if prediction_heads is None: - prediction_heads = [] - for size in (1, 2, 4, 8, 16, 32, 32): - head_layers = [ - keras.layers.Conv2D( - num_classes, kernel_size=(3, 3), padding="same" - ) - ] - if size != 1: - head_layers.append( - keras.layers.UpSampling2D( - size=size, interpolation="bilinear" - ) - ) - prediction_heads.append(keras.Sequential(head_layers)) - - if refinement_head is None: - refinement_head = keras.Sequential( - [ - keras.layers.Conv2D( - num_classes, kernel_size=(3, 3), padding="same" - ), - ] - ) - - # Prediction model. - predict_model = basnet_predict( - x, backbone, projection_filters, prediction_heads - ) - - # Refinement model. - refine_model = basnet_rrm( - predict_model, projection_filters, refinement_head - ) - - outputs = refine_model.outputs # Combine outputs. - outputs.extend(predict_model.outputs) - - output_names = ["refine_out"] + [ - f"predict_out_{i}" for i in range(1, len(outputs)) - ] - - outputs = [ - keras.layers.Activation("sigmoid", name=output_name)(output) - for output, output_name in zip(outputs, output_names) - ] - - super().__init__(inputs=inputs, outputs=outputs, **kwargs) + x = backbone.input + outputs = backbone(x) + # only return the refinement module's output as final prediction + outputs = outputs["refine_out"] + super().__init__(inputs=x, outputs=outputs, **kwargs) # === Config === self.backbone = backbone - self.num_classes = num_classes - self.image_shape = image_shape - self.projection_filters = projection_filters - self.prediction_heads = prediction_heads - self.refinement_head = refinement_head self.preprocessor = preprocessor + def compute_loss(self, x, y, y_pred, *args, **kwargs): + # train BASNet's prediction and refinement module outputs against the same ground truth data + outputs = self.backbone(x) + losses = [] + for output in outputs.values(): + losses.append(super().compute_loss(x, y, output, *args, **kwargs)) + loss = keras.ops.sum(losses, axis=0) + return loss + def compile( self, optimizer="auto", @@ -196,219 +119,10 @@ def compile( if loss == "auto": loss = keras.losses.BinaryCrossentropy() if metrics == "auto": - metrics = {"refine_out": keras.metrics.Accuracy()} + metrics = [keras.metrics.Accuracy()] super().compile( optimizer=optimizer, loss=loss, metrics=metrics, **kwargs, - ) - - def get_config(self): - config = super().get_config() - config.update( - { - "num_classes": self.num_classes, - "image_shape": self.image_shape, - "projection_filters": self.projection_filters, - "prediction_heads": [ - keras.saving.serialize_keras_object(prediction_head) - for prediction_head in self.prediction_heads - ], - "refinement_head": keras.saving.serialize_keras_object( - self.refinement_head - ), - } - ) - return config - - @classmethod - def from_config(cls, config): - if "prediction_heads" in config and isinstance( - config["prediction_heads"], list - ): - for i in range(len(config["prediction_heads"])): - if isinstance(config["prediction_heads"][i], dict): - config["prediction_heads"][i] = keras.layers.deserialize( - config["prediction_heads"][i] - ) - - if "refinement_head" in config and isinstance( - config["refinement_head"], dict - ): - config["refinement_head"] = keras.layers.deserialize( - config["refinement_head"] - ) - return super().from_config(config) - - -def convolution_block(x_input, filters, dilation=1): - """ - Apply convolution + batch normalization + ReLU activation. - - Args: - x_input: Input keras tensor. - filters: int, number of output filters in the convolution. - dilation: int, dilation rate for the convolution operation. - Defaults to 1. - - Returns: - A tensor with convolution, batch normalization, and ReLU - activation applied. - """ - x = keras.layers.Conv2D( - filters, (3, 3), padding="same", dilation_rate=dilation - )(x_input) - x = keras.layers.BatchNormalization()(x) - return keras.layers.Activation("relu")(x) - - -def get_resnet_block(_resnet, block_num): - """ - Extract and return a specific ResNet block. - - Args: - _resnet: `keras.Model`. ResNet model instance. - block_num: int, block number to extract. - - Returns: - A Keras Model representing the specified ResNet block. - """ - - extractor_levels = ["P2", "P3", "P4", "P5"] - num_blocks = _resnet.stackwise_num_blocks - if block_num == 0: - x = _resnet.get_layer("pool1_pool").output - else: - x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]] - y = _resnet.get_layer( - f"stack{block_num}_block{num_blocks[block_num]-1}_add" - ).output - return keras.models.Model( - inputs=x, - outputs=y, - name=f"resnet_block{block_num + 1}", - ) - - -def basnet_predict(x_input, backbone, filters, segmentation_heads): - """ - BASNet Prediction Module. - - This module outputs a coarse label map by integrating heavy - encoder, bridge, and decoder blocks. - - Args: - x_input: Input keras tensor. - backbone: `keras.Model`. The backbone network used as a feature - extractor for BASNet prediction encoder. - filters: int, the number of filters. - segmentation_heads: List of `keras.layers.Layer`, A list of Keras - layers serving as the segmentation head for prediction module. - - - Returns: - A Keras Model that integrates the encoder, bridge, and decoder - blocks for coarse label map prediction. - """ - num_stages = 6 - - x = x_input - - # -------------Encoder-------------- - x = keras.layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x) - - encoder_blocks = [] - for i in range(num_stages): - if i < 4: # First four stages are adopted from ResNet backbone. - x = get_resnet_block(backbone, i)(x) - encoder_blocks.append(x) - else: # Last 2 stages consist of three basic resnet blocks. - x = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x) - for j in range(3): - x = resnet_basic_block( - x, - filters=x.shape[3], - conv_shortcut=False, - name=f"v1_basic_block_{i + 1}_{j + 1}", - ) - encoder_blocks.append(x) - - # -------------Bridge------------- - x = convolution_block(x, filters=filters * 8, dilation=2) - x = convolution_block(x, filters=filters * 8, dilation=2) - x = convolution_block(x, filters=filters * 8, dilation=2) - encoder_blocks.append(x) - - # -------------Decoder------------- - decoder_blocks = [] - for i in reversed(range(num_stages)): - if i != (num_stages - 1): # Except first, scale other decoder stages. - x = keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x) - - x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) - x = convolution_block(x, filters=filters * 8) - x = convolution_block(x, filters=filters * 8) - x = convolution_block(x, filters=filters * 8) - decoder_blocks.append(x) - - decoder_blocks.reverse() # Change order from last to first decoder stage. - decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder. - - # -------------Side Outputs-------------- - decoder_blocks = [ - segmentation_head(decoder_block) # Prediction segmentation head. - for segmentation_head, decoder_block in zip( - segmentation_heads, decoder_blocks - ) - ] - - return keras.models.Model(inputs=[x_input], outputs=decoder_blocks) - - -def basnet_rrm(base_model, filters, segmentation_head): - """ - BASNet Residual Refinement Module (RRM). - - This module outputs a fine label map by integrating light encoder, - bridge, and decoder blocks. - - Args: - base_model: Keras model used as the base or coarse label map. - filters: int, the number of filters. - segmentation_head: a `keras.layers.Layer`, A Keras layer serving - as the segmentation head for refinement module. - - Returns: - A Keras Model that constructs the Residual Refinement Module (RRM). - """ - num_stages = 4 - - x_input = base_model.output[0] - - # -------------Encoder-------------- - x = keras.layers.Conv2D(filters, kernel_size=(3, 3), padding="same")( - x_input - ) - - encoder_blocks = [] - for _ in range(num_stages): - x = convolution_block(x, filters=filters) - encoder_blocks.append(x) - x = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x) - - # -------------Bridge-------------- - x = convolution_block(x, filters=filters) - - # -------------Decoder-------------- - for i in reversed(range(num_stages)): - x = keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x) - x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) - x = convolution_block(x, filters=filters) - - x = segmentation_head(x) # Refinement segmentation head. - - # ------------- refined = coarse + residual - x = keras.layers.Add()([x_input, x]) # Add prediction + refinement output - - return keras.models.Model(inputs=base_model.input, outputs=[x]) + ) \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_backbone.py b/keras_hub/src/models/basnet/basnet_backbone.py new file mode 100644 index 0000000000..29599d1e30 --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_backbone.py @@ -0,0 +1,370 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.resnet.resnet_backbone import ( + apply_basic_block as resnet_basic_block, +) + + +@keras_hub_export("keras_hub.models.BASNetBackbone") +class BASNetBackbone(Backbone): + """ + A Keras model implementing the BASNet architecture for semantic + segmentation. + + References: + - [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) + + Args: + image_encoder: A `keras_hub.models.ResNetBackbone` instance. The + backbone network for the model that is used as a feature extractor + for BASNet prediction encoder. Currently supported backbones are + ResNet18 and ResNet34. + (Note: Do not specify `image_shape` within the backbone. + Please provide these while initializing the 'BASNetBackbone' model) + num_classes: int, the number of classes for the segmentation model. + image_shape: optional shape tuple, defaults to (None, None, 3). + projection_filters: int, number of filters in the convolution layer + projecting low-level features from the `backbone`. + prediction_heads: (Optional) List of `keras.layers.Layer` defining + the prediction module head for the model. If not provided, a + default head is created with a Conv2D layer followed by resizing. + refinement_head: (Optional) a `keras.layers.Layer` defining the + refinement module head for the model. If not provided, a default + head is created with a Conv2D layer. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + + """ + + def __init__( + self, + image_encoder, + num_classes, + image_shape=(None, None, 3), + projection_filters=64, + prediction_heads=None, + refinement_head=None, + dtype=None, + **kwargs, + ): + if not isinstance(image_encoder, keras.layers.Layer) or not isinstance( + image_encoder, keras.Model + ): + raise ValueError( + "Argument `image_encoder` must be a `keras.layers.Layer` instance" + f" or `keras.Model`. Received instead" + f" image_encoder={image_encoder} (of type {type(image_encoder)})." + ) + + if tuple(image_encoder.image_shape) != (None, None, 3): + raise ValueError( + "Do not specify `image_shape` within the" + " `BASNetBackbone`'s image_encoder. \nPlease provide `image_shape`" + " while initializing the 'BASNetBackbone' model." + ) + + # === Functional Model === + inputs = keras.layers.Input(shape=image_shape) + x = inputs + + if prediction_heads is None: + prediction_heads = [] + for size in (1, 2, 4, 8, 16, 32, 32): + head_layers = [ + keras.layers.Conv2D( + num_classes, + kernel_size=(3, 3), + padding="same", + dtype=dtype, + ) + ] + if size != 1: + head_layers.append( + keras.layers.UpSampling2D( + size=size, interpolation="bilinear", dtype=dtype + ) + ) + prediction_heads.append(keras.Sequential(head_layers)) + + if refinement_head is None: + refinement_head = keras.Sequential( + [ + keras.layers.Conv2D( + num_classes, + kernel_size=(3, 3), + padding="same", + dtype=dtype, + ), + ] + ) + + # Prediction model. + predict_model = basnet_predict( + x, image_encoder, projection_filters, prediction_heads, dtype=dtype + ) + + # Refinement model. + refine_model = basnet_rrm( + predict_model, projection_filters, refinement_head, dtype=dtype + ) + + outputs = refine_model.outputs # Combine outputs. + outputs.extend(predict_model.outputs) + + output_names = ["refine_out"] + [ + f"predict_out_{i}" for i in range(1, len(outputs)) + ] + + outputs = { + output_name: keras.layers.Activation( + "sigmoid", name=output_name, dtype=dtype + )(output) + for output, output_name in zip(outputs, output_names) + } + + super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs) + + # === Config === + self.image_encoder = image_encoder + self.num_classes = num_classes + self.image_shape = image_shape + self.projection_filters = projection_filters + self.prediction_heads = prediction_heads + self.refinement_head = refinement_head + + def get_config(self): + config = super().get_config() + config.update( + { + "image_encoder": keras.saving.serialize_keras_object( + self.image_encoder + ), + "num_classes": self.num_classes, + "image_shape": self.image_shape, + "projection_filters": self.projection_filters, + "prediction_heads": [ + keras.saving.serialize_keras_object(prediction_head) + for prediction_head in self.prediction_heads + ], + "refinement_head": keras.saving.serialize_keras_object( + self.refinement_head + ), + } + ) + return config + + @classmethod + def from_config(cls, config): + if "image_encoder" in config: + config["image_encoder"] = keras.layers.deserialize( + config["image_encoder"] + ) + if "prediction_heads" in config and isinstance( + config["prediction_heads"], list + ): + for i in range(len(config["prediction_heads"])): + if isinstance(config["prediction_heads"][i], dict): + config["prediction_heads"][i] = keras.layers.deserialize( + config["prediction_heads"][i] + ) + + if "refinement_head" in config and isinstance( + config["refinement_head"], dict + ): + config["refinement_head"] = keras.layers.deserialize( + config["refinement_head"] + ) + return super().from_config(config) + + +def convolution_block(x_input, filters, dilation=1, dtype=None): + """ + Apply convolution + batch normalization + ReLU activation. + + Args: + x_input: Input keras tensor. + filters: int, number of output filters in the convolution. + dilation: int, dilation rate for the convolution operation. + Defaults to 1. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + + Returns: + A tensor with convolution, batch normalization, and ReLU + activation applied. + """ + x = keras.layers.Conv2D( + filters, (3, 3), padding="same", dilation_rate=dilation, dtype=dtype + )(x_input) + x = keras.layers.BatchNormalization(dtype=dtype)(x) + return keras.layers.Activation("relu", dtype=dtype)(x) + + +def get_resnet_block(_resnet, block_num): + """ + Extract and return a specific ResNet block. + + Args: + _resnet: `keras.Model`. ResNet model instance. + block_num: int, block number to extract. + + Returns: + A Keras Model representing the specified ResNet block. + """ + + extractor_levels = ["P2", "P3", "P4", "P5"] + num_blocks = _resnet.stackwise_num_blocks + if block_num == 0: + x = _resnet.get_layer("pool1_pool").output + else: + x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]] + y = _resnet.get_layer( + f"stack{block_num}_block{num_blocks[block_num]-1}_add" + ).output + return keras.models.Model( + inputs=x, + outputs=y, + name=f"resnet_block{block_num + 1}", + ) + + +def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None): + """ + BASNet Prediction Module. + + This module outputs a coarse label map by integrating heavy + encoder, bridge, and decoder blocks. + + Args: + x_input: Input keras tensor. + backbone: `keras.Model`. The backbone network used as a feature + extractor for BASNet prediction encoder. + filters: int, the number of filters. + segmentation_heads: List of `keras.layers.Layer`, A list of Keras + layers serving as the segmentation head for prediction module. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + + + Returns: + A Keras Model that integrates the encoder, bridge, and decoder + blocks for coarse label map prediction. + """ + num_stages = 6 + + x = x_input + + # -------------Encoder-------------- + x = keras.layers.Conv2D( + filters, kernel_size=(3, 3), padding="same", dtype=dtype + )(x) + + encoder_blocks = [] + for i in range(num_stages): + if i < 4: # First four stages are adopted from ResNet backbone. + x = get_resnet_block(backbone, i)(x) + encoder_blocks.append(x) + else: # Last 2 stages consist of three basic resnet blocks. + x = keras.layers.MaxPool2D( + pool_size=(2, 2), strides=(2, 2), dtype=dtype + )(x) + for j in range(3): + x = resnet_basic_block( + x, + filters=x.shape[3], + conv_shortcut=False, + name=f"v1_basic_block_{i + 1}_{j + 1}", + dtype=dtype, + ) + encoder_blocks.append(x) + + # -------------Bridge------------- + x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype) + x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype) + x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype) + encoder_blocks.append(x) + + # -------------Decoder------------- + decoder_blocks = [] + for i in reversed(range(num_stages)): + if i != (num_stages - 1): # Except first, scale other decoder stages. + x = keras.layers.UpSampling2D( + size=2, interpolation="bilinear", dtype=dtype + )(x) + + x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) + x = convolution_block(x, filters=filters * 8, dtype=dtype) + x = convolution_block(x, filters=filters * 8, dtype=dtype) + x = convolution_block(x, filters=filters * 8, dtype=dtype) + decoder_blocks.append(x) + + decoder_blocks.reverse() # Change order from last to first decoder stage. + decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder. + + # -------------Side Outputs-------------- + decoder_blocks = [ + segmentation_head(decoder_block) # Prediction segmentation head. + for segmentation_head, decoder_block in zip( + segmentation_heads, decoder_blocks + ) + ] + + return keras.models.Model(inputs=[x_input], outputs=decoder_blocks) + + +def basnet_rrm(base_model, filters, segmentation_head, dtype=None): + """ + BASNet Residual Refinement Module (RRM). + + This module outputs a fine label map by integrating light encoder, + bridge, and decoder blocks. + + Args: + base_model: Keras model used as the base or coarse label map. + filters: int, the number of filters. + segmentation_head: a `keras.layers.Layer`, A Keras layer serving + as the segmentation head for refinement module. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + + Returns: + A Keras Model that constructs the Residual Refinement Module (RRM). + """ + num_stages = 4 + + x_input = base_model.output[0] + + # -------------Encoder-------------- + x = keras.layers.Conv2D( + filters, kernel_size=(3, 3), padding="same", dtype=dtype + )(x_input) + + encoder_blocks = [] + for _ in range(num_stages): + x = convolution_block(x, filters=filters) + encoder_blocks.append(x) + x = keras.layers.MaxPool2D( + pool_size=(2, 2), strides=(2, 2), dtype=dtype + )(x) + + # -------------Bridge-------------- + x = convolution_block(x, filters=filters, dtype=dtype) + + # -------------Decoder-------------- + for i in reversed(range(num_stages)): + x = keras.layers.UpSampling2D( + size=2, interpolation="bilinear", dtype=dtype + )(x) + x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) + x = convolution_block(x, filters=filters) + + x = segmentation_head(x) # Refinement segmentation head. + + # ------------- refined = coarse + residual + x = keras.layers.Add(dtype=dtype)( + [x_input, x] + ) # Add prediction + refinement output + + return keras.models.Model(inputs=base_model.input, outputs=[x]) \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_backbone_test.py b/keras_hub/src/models/basnet/basnet_backbone_test.py new file mode 100644 index 0000000000..ca7efe3a3b --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_backbone_test.py @@ -0,0 +1,36 @@ +from keras import ops + +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class BASNetBackboneTest(TestCase): + def setUp(self): + self.images = ops.ones((2, 64, 64, 3)) + self.image_encoder = ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], + stackwise_num_filters=[64, 128, 256, 512], + stackwise_num_blocks=[2, 2, 2, 2], + stackwise_num_strides=[1, 2, 2, 2], + block_type="basic_block", + ) + self.init_kwargs = { + "image_encoder": self.image_encoder, + "num_classes": 1, + } + + def test_backbone_basics(self): + output_names = ["refine_out"] + [ + f"predict_out_{i}" for i in range(1, 8) + ] + expected_output_shape = {name: (2, 64, 64, 1) for name in output_names} + self.run_backbone_test( + cls=BASNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.images, + expected_output_shape=expected_output_shape, + run_mixed_precision_check=False, + run_quantization_check=False, + ) \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_image_converter.py b/keras_hub/src/models/basnet/basnet_image_converter.py index 45f364fd4e..347447526e 100644 --- a/keras_hub/src/models/basnet/basnet_image_converter.py +++ b/keras_hub/src/models/basnet/basnet_image_converter.py @@ -1,8 +1,8 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.image_converter import ImageConverter -from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone @keras_hub_export("keras_hub.layers.BASNetImageConverter") class BASNetImageConverter(ImageConverter): - backbone_cls = ResNetBackbone + backbone_cls = BASNetBackbone \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_preprocessor.py b/keras_hub/src/models/basnet/basnet_preprocessor.py index 7faa64cdd9..9fdb855e73 100644 --- a/keras_hub/src/models/basnet/basnet_preprocessor.py +++ b/keras_hub/src/models/basnet/basnet_preprocessor.py @@ -1,14 +1,14 @@ from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_image_converter import ( BASNetImageConverter, ) from keras_hub.src.models.image_segmenter_preprocessor import ( ImageSegmenterPreprocessor, ) -from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone @keras_hub_export("keras_hub.models.BASNetPreprocessor") class BASNetPreprocessor(ImageSegmenterPreprocessor): - backbone_cls = ResNetBackbone - image_converter_cls = BASNetImageConverter + backbone_cls = BASNetBackbone + image_converter_cls = BASNetImageConverter \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_presets.py b/keras_hub/src/models/basnet/basnet_presets.py index 1ed728643b..7a4b0a27d3 100644 --- a/keras_hub/src/models/basnet/basnet_presets.py +++ b/keras_hub/src/models/basnet/basnet_presets.py @@ -19,4 +19,4 @@ }, "kaggle_handle": "", # TODO }, -} +} \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_test.py b/keras_hub/src/models/basnet/basnet_test.py index fa7c1c9b64..a0af8e8e4e 100644 --- a/keras_hub/src/models/basnet/basnet_test.py +++ b/keras_hub/src/models/basnet/basnet_test.py @@ -2,6 +2,7 @@ from keras import ops from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.tests.test_case import TestCase @@ -13,7 +14,7 @@ def setUp(self): self.labels = ops.concatenate( (ops.zeros((2, 32, 64, 1)), ops.ones((2, 32, 64, 1))), axis=1 ) - self.backbone = ResNetBackbone( + self.image_encoder = ResNetBackbone( input_conv_filters=[64], input_conv_kernel_sizes=[7], stackwise_num_filters=[64, 128, 256, 512], @@ -21,11 +22,14 @@ def setUp(self): stackwise_num_strides=[1, 2, 2, 2], block_type="basic_block", ) + self.backbone = BASNetBackbone( + image_encoder=self.image_encoder, + num_classes=1, + ) self.preprocessor = BASNetPreprocessor() self.init_kwargs = { "backbone": self.backbone, "preprocessor": self.preprocessor, - "num_classes": 1, } self.train_data = (self.images, self.labels) @@ -34,7 +38,7 @@ def test_basics(self): cls=BASNetImageSegmenter, init_kwargs=self.init_kwargs, train_data=self.train_data, - expected_output_shape=[(2, 64, 64, 1)] * 8, + expected_output_shape=(2, 64, 64, 1), ) @pytest.mark.large @@ -47,10 +51,8 @@ def test_saved_model(self): def test_end_to_end_model_predict(self): model = BASNetImageSegmenter(**self.init_kwargs) - outputs = model.predict(self.images) - self.assertAllEqual( - [output.shape for output in outputs], [(2, 64, 64, 1)] * 8 - ) + output = model.predict(self.images) + self.assertAllEqual(output.shape, (2, 64, 64, 1)) @pytest.mark.skip(reason="disabled until preset's been uploaded to Kaggle") @pytest.mark.extra_large @@ -60,5 +62,5 @@ def test_all_presets(self): cls=BASNetImageSegmenter, preset=preset, input_data=self.images, - expected_output_shape=[(2, 64, 64, 1)] * 8, - ) + expected_output_shape=(2, 64, 64, 1), + ) \ No newline at end of file From ac961c5fc8570f55b3d53200872826d5ae8fd664 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Mon, 16 Dec 2024 13:20:13 -0800 Subject: [PATCH 08/15] Fix format issues, removed presets file --- keras_hub/api/models/__init__.py | 1 + keras_hub/src/models/basnet/__init__.py | 2 +- keras_hub/src/models/basnet/basnet.py | 9 +++----- .../src/models/basnet/basnet_backbone.py | 18 +++++---------- .../src/models/basnet/basnet_backbone_test.py | 2 +- .../models/basnet/basnet_image_converter.py | 2 +- .../src/models/basnet/basnet_preprocessor.py | 2 +- keras_hub/src/models/basnet/basnet_presets.py | 22 ------------------- keras_hub/src/models/basnet/basnet_test.py | 2 +- 9 files changed, 15 insertions(+), 45 deletions(-) delete mode 100644 keras_hub/src/models/basnet/basnet_presets.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 9748bc40b6..7c7adbf97c 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -30,6 +30,7 @@ ) from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.bert.bert_backbone import BertBackbone from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM diff --git a/keras_hub/src/models/basnet/__init__.py b/keras_hub/src/models/basnet/__init__.py index 8df584cab4..953164a973 100644 --- a/keras_hub/src/models/basnet/__init__.py +++ b/keras_hub/src/models/basnet/__init__.py @@ -2,4 +2,4 @@ from keras_hub.src.models.basnet.basnet_presets import basnet_presets from keras_hub.src.utils.preset_utils import register_presets -register_presets(basnet_presets, BASNetBackbone) \ No newline at end of file +register_presets(basnet_presets, BASNetBackbone) diff --git a/keras_hub/src/models/basnet/basnet.py b/keras_hub/src/models/basnet/basnet.py index 8fc304d073..6e0b4840d5 100644 --- a/keras_hub/src/models/basnet/basnet.py +++ b/keras_hub/src/models/basnet/basnet.py @@ -8,9 +8,7 @@ @keras_hub_export("keras_hub.models.BASNetImageSegmenter") class BASNetImageSegmenter(ImageSegmenter): - """ - A Keras model implementing the BASNet architecture for semantic - segmentation. + """A Keras model implementing the BASNet architecture for semantic segmentation. References: [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) @@ -61,7 +59,6 @@ def __init__( preprocessor=None, **kwargs, ): - # === Functional Model === x = backbone.input outputs = backbone(x) @@ -74,7 +71,7 @@ def __init__( self.preprocessor = preprocessor def compute_loss(self, x, y, y_pred, *args, **kwargs): - # train BASNet's prediction and refinement module outputs against the same ground truth data + # train BASNet's Prediction and RRM module outputs against the same gt data. outputs = self.backbone(x) losses = [] for output in outputs.values(): @@ -125,4 +122,4 @@ def compile( loss=loss, metrics=metrics, **kwargs, - ) \ No newline at end of file + ) diff --git a/keras_hub/src/models/basnet/basnet_backbone.py b/keras_hub/src/models/basnet/basnet_backbone.py index 29599d1e30..917709be38 100644 --- a/keras_hub/src/models/basnet/basnet_backbone.py +++ b/keras_hub/src/models/basnet/basnet_backbone.py @@ -9,9 +9,7 @@ @keras_hub_export("keras_hub.models.BASNetBackbone") class BASNetBackbone(Backbone): - """ - A Keras model implementing the BASNet architecture for semantic - segmentation. + """A Keras model implementing the BASNet architecture for semantic segmentation. References: - [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) @@ -180,8 +178,7 @@ def from_config(cls, config): def convolution_block(x_input, filters, dilation=1, dtype=None): - """ - Apply convolution + batch normalization + ReLU activation. + """Apply convolution + batch normalization + ReLU activation. Args: x_input: Input keras tensor. @@ -203,8 +200,7 @@ def convolution_block(x_input, filters, dilation=1, dtype=None): def get_resnet_block(_resnet, block_num): - """ - Extract and return a specific ResNet block. + """Extract and return a specific ResNet block. Args: _resnet: `keras.Model`. ResNet model instance. @@ -231,8 +227,7 @@ def get_resnet_block(_resnet, block_num): def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None): - """ - BASNet Prediction Module. + """BASNet Prediction Module. This module outputs a coarse label map by integrating heavy encoder, bridge, and decoder blocks. @@ -315,8 +310,7 @@ def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None): def basnet_rrm(base_model, filters, segmentation_head, dtype=None): - """ - BASNet Residual Refinement Module (RRM). + """BASNet Residual Refinement Module (RRM). This module outputs a fine label map by integrating light encoder, bridge, and decoder blocks. @@ -367,4 +361,4 @@ def basnet_rrm(base_model, filters, segmentation_head, dtype=None): [x_input, x] ) # Add prediction + refinement output - return keras.models.Model(inputs=base_model.input, outputs=[x]) \ No newline at end of file + return keras.models.Model(inputs=base_model.input, outputs=[x]) diff --git a/keras_hub/src/models/basnet/basnet_backbone_test.py b/keras_hub/src/models/basnet/basnet_backbone_test.py index ca7efe3a3b..41d56d40fc 100644 --- a/keras_hub/src/models/basnet/basnet_backbone_test.py +++ b/keras_hub/src/models/basnet/basnet_backbone_test.py @@ -33,4 +33,4 @@ def test_backbone_basics(self): expected_output_shape=expected_output_shape, run_mixed_precision_check=False, run_quantization_check=False, - ) \ No newline at end of file + ) diff --git a/keras_hub/src/models/basnet/basnet_image_converter.py b/keras_hub/src/models/basnet/basnet_image_converter.py index 347447526e..399858c174 100644 --- a/keras_hub/src/models/basnet/basnet_image_converter.py +++ b/keras_hub/src/models/basnet/basnet_image_converter.py @@ -5,4 +5,4 @@ @keras_hub_export("keras_hub.layers.BASNetImageConverter") class BASNetImageConverter(ImageConverter): - backbone_cls = BASNetBackbone \ No newline at end of file + backbone_cls = BASNetBackbone diff --git a/keras_hub/src/models/basnet/basnet_preprocessor.py b/keras_hub/src/models/basnet/basnet_preprocessor.py index 9fdb855e73..cfe580aa4b 100644 --- a/keras_hub/src/models/basnet/basnet_preprocessor.py +++ b/keras_hub/src/models/basnet/basnet_preprocessor.py @@ -11,4 +11,4 @@ @keras_hub_export("keras_hub.models.BASNetPreprocessor") class BASNetPreprocessor(ImageSegmenterPreprocessor): backbone_cls = BASNetBackbone - image_converter_cls = BASNetImageConverter \ No newline at end of file + image_converter_cls = BASNetImageConverter diff --git a/keras_hub/src/models/basnet/basnet_presets.py b/keras_hub/src/models/basnet/basnet_presets.py deleted file mode 100644 index 7a4b0a27d3..0000000000 --- a/keras_hub/src/models/basnet/basnet_presets.py +++ /dev/null @@ -1,22 +0,0 @@ -"""BASNet model preset configurations.""" - -basnet_presets = { - "basnet_resnet18": { - "metadata": { - "description": "BASNet with a ResNet18 v1 backbone.", - "params": 98780872, - "official_name": "BASNet", - "path": "basnet", - }, - "kaggle_handle": "", # TODO - }, - "basnet_resnet34": { - "metadata": { - "description": "BASNet with a ResNet34 v1 backbone.", - "params": 108896456, - "official_name": "BASNet", - "path": "basnet", - }, - "kaggle_handle": "", # TODO - }, -} \ No newline at end of file diff --git a/keras_hub/src/models/basnet/basnet_test.py b/keras_hub/src/models/basnet/basnet_test.py index a0af8e8e4e..720195046e 100644 --- a/keras_hub/src/models/basnet/basnet_test.py +++ b/keras_hub/src/models/basnet/basnet_test.py @@ -63,4 +63,4 @@ def test_all_presets(self): preset=preset, input_data=self.images, expected_output_shape=(2, 64, 64, 1), - ) \ No newline at end of file + ) From 2ccfde9f9285aaa89dcb873332245cf0cbd9ca2a Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Mon, 16 Dec 2024 13:36:36 -0800 Subject: [PATCH 09/15] adding deleted presets file in previous commit --- keras_hub/src/models/basnet/basnet_presets.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 keras_hub/src/models/basnet/basnet_presets.py diff --git a/keras_hub/src/models/basnet/basnet_presets.py b/keras_hub/src/models/basnet/basnet_presets.py new file mode 100644 index 0000000000..1ed728643b --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_presets.py @@ -0,0 +1,22 @@ +"""BASNet model preset configurations.""" + +basnet_presets = { + "basnet_resnet18": { + "metadata": { + "description": "BASNet with a ResNet18 v1 backbone.", + "params": 98780872, + "official_name": "BASNet", + "path": "basnet", + }, + "kaggle_handle": "", # TODO + }, + "basnet_resnet34": { + "metadata": { + "description": "BASNet with a ResNet34 v1 backbone.", + "params": 108896456, + "official_name": "BASNet", + "path": "basnet", + }, + "kaggle_handle": "", # TODO + }, +} From 7665914c3fd0310d48aea7070fa78a9199aa4d63 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Tue, 17 Dec 2024 10:02:32 -0800 Subject: [PATCH 10/15] Fix format issues, removed presets, newline issues --- keras_hub/src/models/basnet/basnet.py | 11 +++---- .../src/models/basnet/basnet_backbone.py | 32 +++++++++++-------- keras_hub/src/models/basnet/basnet_presets.py | 21 +----------- 3 files changed, 24 insertions(+), 40 deletions(-) diff --git a/keras_hub/src/models/basnet/basnet.py b/keras_hub/src/models/basnet/basnet.py index 6e0b4840d5..6ca7bd6851 100644 --- a/keras_hub/src/models/basnet/basnet.py +++ b/keras_hub/src/models/basnet/basnet.py @@ -8,10 +8,7 @@ @keras_hub_export("keras_hub.models.BASNetImageSegmenter") class BASNetImageSegmenter(ImageSegmenter): - """A Keras model implementing the BASNet architecture for semantic segmentation. - - References: - [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) + """BASNet image segmentation task. Args: backbone: A `keras_hub.models.BASNetBackbone` instance. @@ -71,13 +68,13 @@ def __init__( self.preprocessor = preprocessor def compute_loss(self, x, y, y_pred, *args, **kwargs): - # train BASNet's Prediction and RRM module outputs against the same gt data. + # train BASNet's prediction and refinement module outputs against the + # same ground truth data outputs = self.backbone(x) losses = [] for output in outputs.values(): losses.append(super().compute_loss(x, y, output, *args, **kwargs)) - loss = keras.ops.sum(losses, axis=0) - return loss + return keras.ops.sum(losses, axis=0) def compile( self, diff --git a/keras_hub/src/models/basnet/basnet_backbone.py b/keras_hub/src/models/basnet/basnet_backbone.py index 917709be38..617086a9f9 100644 --- a/keras_hub/src/models/basnet/basnet_backbone.py +++ b/keras_hub/src/models/basnet/basnet_backbone.py @@ -9,10 +9,12 @@ @keras_hub_export("keras_hub.models.BASNetBackbone") class BASNetBackbone(Backbone): - """A Keras model implementing the BASNet architecture for semantic segmentation. + """BASNet architecture for semantic segmentation. - References: - - [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) + A Keras model implementing the BASNet architecture described in [BASNet: + Boundary-Aware Segmentation Network for Mobile and Web Applications]( + https://arxiv.org/abs/2101.04704). BASNet uses a predict-refine + architecture for highly accurate image segmentation. Args: image_encoder: A `keras_hub.models.ResNetBackbone` instance. The @@ -33,7 +35,6 @@ class BASNetBackbone(Backbone): head is created with a Conv2D layer. dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype to use for the model's computations and weights. - """ def __init__( @@ -51,16 +52,17 @@ def __init__( image_encoder, keras.Model ): raise ValueError( - "Argument `image_encoder` must be a `keras.layers.Layer` instance" - f" or `keras.Model`. Received instead" - f" image_encoder={image_encoder} (of type {type(image_encoder)})." + "Argument `image_encoder` must be a `keras.layers.Layer`" + f" instance or `keras.Model`. Received instead" + f" image_encoder={image_encoder} (of type" + f" {type(image_encoder)})." ) if tuple(image_encoder.image_shape) != (None, None, 3): raise ValueError( "Do not specify `image_shape` within the" - " `BASNetBackbone`'s image_encoder. \nPlease provide `image_shape`" - " while initializing the 'BASNetBackbone' model." + " `BASNetBackbone`'s image_encoder. \nPlease provide" + " `image_shape` while initializing the 'BASNetBackbone' model." ) # === Functional Model === @@ -178,7 +180,8 @@ def from_config(cls, config): def convolution_block(x_input, filters, dilation=1, dtype=None): - """Apply convolution + batch normalization + ReLU activation. + """ + Apply convolution + batch normalization + ReLU activation. Args: x_input: Input keras tensor. @@ -200,7 +203,8 @@ def convolution_block(x_input, filters, dilation=1, dtype=None): def get_resnet_block(_resnet, block_num): - """Extract and return a specific ResNet block. + """ + Extract and return a specific ResNet block. Args: _resnet: `keras.Model`. ResNet model instance. @@ -227,7 +231,8 @@ def get_resnet_block(_resnet, block_num): def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None): - """BASNet Prediction Module. + """ + BASNet Prediction Module. This module outputs a coarse label map by integrating heavy encoder, bridge, and decoder blocks. @@ -310,7 +315,8 @@ def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None): def basnet_rrm(base_model, filters, segmentation_head, dtype=None): - """BASNet Residual Refinement Module (RRM). + """ + BASNet Residual Refinement Module (RRM). This module outputs a fine label map by integrating light encoder, bridge, and decoder blocks. diff --git a/keras_hub/src/models/basnet/basnet_presets.py b/keras_hub/src/models/basnet/basnet_presets.py index 1ed728643b..3d96ab7885 100644 --- a/keras_hub/src/models/basnet/basnet_presets.py +++ b/keras_hub/src/models/basnet/basnet_presets.py @@ -1,22 +1,3 @@ """BASNet model preset configurations.""" -basnet_presets = { - "basnet_resnet18": { - "metadata": { - "description": "BASNet with a ResNet18 v1 backbone.", - "params": 98780872, - "official_name": "BASNet", - "path": "basnet", - }, - "kaggle_handle": "", # TODO - }, - "basnet_resnet34": { - "metadata": { - "description": "BASNet with a ResNet34 v1 backbone.", - "params": 108896456, - "official_name": "BASNet", - "path": "basnet", - }, - "kaggle_handle": "", # TODO - }, -} +basnet_presets = {} From a32e2f3d810ccd42b16d24a34a135591a43207d6 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Tue, 17 Dec 2024 10:07:18 -0800 Subject: [PATCH 11/15] Fix for docstrings --- keras_hub/src/models/basnet/basnet_backbone.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/keras_hub/src/models/basnet/basnet_backbone.py b/keras_hub/src/models/basnet/basnet_backbone.py index 617086a9f9..3ae15fc9bb 100644 --- a/keras_hub/src/models/basnet/basnet_backbone.py +++ b/keras_hub/src/models/basnet/basnet_backbone.py @@ -180,8 +180,7 @@ def from_config(cls, config): def convolution_block(x_input, filters, dilation=1, dtype=None): - """ - Apply convolution + batch normalization + ReLU activation. + """Apply convolution + batch normalization + ReLU activation. Args: x_input: Input keras tensor. @@ -203,8 +202,7 @@ def convolution_block(x_input, filters, dilation=1, dtype=None): def get_resnet_block(_resnet, block_num): - """ - Extract and return a specific ResNet block. + """Extract and return a specific ResNet block. Args: _resnet: `keras.Model`. ResNet model instance. @@ -231,8 +229,7 @@ def get_resnet_block(_resnet, block_num): def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None): - """ - BASNet Prediction Module. + """BASNet Prediction Module. This module outputs a coarse label map by integrating heavy encoder, bridge, and decoder blocks. @@ -315,8 +312,7 @@ def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None): def basnet_rrm(base_model, filters, segmentation_head, dtype=None): - """ - BASNet Residual Refinement Module (RRM). + """BASNet Residual Refinement Module (RRM). This module outputs a fine label map by integrating light encoder, bridge, and decoder blocks. From 7963b4761cce954a2f094e3eed205d979a871cfb Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Tue, 17 Dec 2024 10:17:39 -0800 Subject: [PATCH 12/15] Add basnet conversion script --- .../convert_basnet_checkpoints.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 tools/checkpoint_conversion/convert_basnet_checkpoints.py diff --git a/tools/checkpoint_conversion/convert_basnet_checkpoints.py b/tools/checkpoint_conversion/convert_basnet_checkpoints.py new file mode 100644 index 0000000000..89e1c3ba53 --- /dev/null +++ b/tools/checkpoint_conversion/convert_basnet_checkpoints.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +""" +Converts BASNet model weights to KerasHub. + +Usage: python3 convert_basnet_checkpoint.py + +Downloads BASNet modelweights with ResNet34 backbone and converts them to a +KerasHub model. Credits for model training go to Hamid Ali +(https://github.com/hamidriasat/BASNet). + +Requirements: +pip3 install -q git+https://github.com/keras-team/keras-hub.git +pip3 install -q gdown +""" + +import gdown + +from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone + +# download weights +gdown.download( + "https://drive.google.com/uc?id=1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg", + "basnet_weights.h5", +) + +# instantiate ResNet34 +image_encoder = ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], + stackwise_num_filters=[64, 128, 256, 512], + stackwise_num_blocks=[3, 4, 6, 3], + stackwise_num_strides=[1, 2, 2, 2], + block_type="basic_block", +) + +# instantiate BASNet and load pretrained weights +preprocessor = BASNetPreprocessor() +backbone = BASNetBackbone(image_encoder=image_encoder, num_classes=1) +basnet = BASNetImageSegmenter(backbone=backbone, preprocessor=preprocessor) +backbone.load_weights("basnet_weights.h5") + +# save the preset +basnet.save_to_preset("basnet") \ No newline at end of file From ffc846917f7e276cc5033822ee4d02b72aaca40b Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Tue, 17 Dec 2024 10:24:55 -0800 Subject: [PATCH 13/15] Fix format issue in conversion script --- tools/checkpoint_conversion/convert_basnet_checkpoints.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tools/checkpoint_conversion/convert_basnet_checkpoints.py b/tools/checkpoint_conversion/convert_basnet_checkpoints.py index 89e1c3ba53..6ba6248fbe 100644 --- a/tools/checkpoint_conversion/convert_basnet_checkpoints.py +++ b/tools/checkpoint_conversion/convert_basnet_checkpoints.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 -""" -Converts BASNet model weights to KerasHub. +"""Converts BASNet model weights to KerasHub. Usage: python3 convert_basnet_checkpoint.py From c64d589747a1f6a8f3a58d14ec7f1bbc2fdeeeca Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Tue, 17 Dec 2024 10:30:33 -0800 Subject: [PATCH 14/15] Fix format issue --- tools/checkpoint_conversion/convert_basnet_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/checkpoint_conversion/convert_basnet_checkpoints.py b/tools/checkpoint_conversion/convert_basnet_checkpoints.py index 6ba6248fbe..b9db04c163 100644 --- a/tools/checkpoint_conversion/convert_basnet_checkpoints.py +++ b/tools/checkpoint_conversion/convert_basnet_checkpoints.py @@ -43,4 +43,4 @@ backbone.load_weights("basnet_weights.h5") # save the preset -basnet.save_to_preset("basnet") \ No newline at end of file +basnet.save_to_preset("basnet") From 32335863aa28f348b2d7b73fb17e29e04d5bae30 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Tue, 17 Dec 2024 15:31:12 -0800 Subject: [PATCH 15/15] Fix pytorch GPU test,removed weight conversion script --- .../src/models/basnet/basnet_backbone_test.py | 4 +- keras_hub/src/models/basnet/basnet_test.py | 8 ++-- .../convert_basnet_checkpoints.py | 46 ------------------- 3 files changed, 6 insertions(+), 52 deletions(-) delete mode 100644 tools/checkpoint_conversion/convert_basnet_checkpoints.py diff --git a/keras_hub/src/models/basnet/basnet_backbone_test.py b/keras_hub/src/models/basnet/basnet_backbone_test.py index 41d56d40fc..dfb665d2c3 100644 --- a/keras_hub/src/models/basnet/basnet_backbone_test.py +++ b/keras_hub/src/models/basnet/basnet_backbone_test.py @@ -1,4 +1,4 @@ -from keras import ops +import numpy as np from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone @@ -7,7 +7,7 @@ class BASNetBackboneTest(TestCase): def setUp(self): - self.images = ops.ones((2, 64, 64, 3)) + self.images = np.ones((2, 64, 64, 3)) self.image_encoder = ResNetBackbone( input_conv_filters=[64], input_conv_kernel_sizes=[7], diff --git a/keras_hub/src/models/basnet/basnet_test.py b/keras_hub/src/models/basnet/basnet_test.py index 720195046e..4147d43c7c 100644 --- a/keras_hub/src/models/basnet/basnet_test.py +++ b/keras_hub/src/models/basnet/basnet_test.py @@ -1,5 +1,5 @@ +import numpy as np import pytest -from keras import ops from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone @@ -10,9 +10,9 @@ class BASNetTest(TestCase): def setUp(self): - self.images = ops.ones((2, 64, 64, 3)) - self.labels = ops.concatenate( - (ops.zeros((2, 32, 64, 1)), ops.ones((2, 32, 64, 1))), axis=1 + self.images = np.ones((2, 64, 64, 3)) + self.labels = np.concatenate( + (np.zeros((2, 32, 64, 1)), np.ones((2, 32, 64, 1))), axis=1 ) self.image_encoder = ResNetBackbone( input_conv_filters=[64], diff --git a/tools/checkpoint_conversion/convert_basnet_checkpoints.py b/tools/checkpoint_conversion/convert_basnet_checkpoints.py deleted file mode 100644 index b9db04c163..0000000000 --- a/tools/checkpoint_conversion/convert_basnet_checkpoints.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python3 - -"""Converts BASNet model weights to KerasHub. - -Usage: python3 convert_basnet_checkpoint.py - -Downloads BASNet modelweights with ResNet34 backbone and converts them to a -KerasHub model. Credits for model training go to Hamid Ali -(https://github.com/hamidriasat/BASNet). - -Requirements: -pip3 install -q git+https://github.com/keras-team/keras-hub.git -pip3 install -q gdown -""" - -import gdown - -from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter -from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone -from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor -from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone - -# download weights -gdown.download( - "https://drive.google.com/uc?id=1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg", - "basnet_weights.h5", -) - -# instantiate ResNet34 -image_encoder = ResNetBackbone( - input_conv_filters=[64], - input_conv_kernel_sizes=[7], - stackwise_num_filters=[64, 128, 256, 512], - stackwise_num_blocks=[3, 4, 6, 3], - stackwise_num_strides=[1, 2, 2, 2], - block_type="basic_block", -) - -# instantiate BASNet and load pretrained weights -preprocessor = BASNetPreprocessor() -backbone = BASNetBackbone(image_encoder=image_encoder, num_classes=1) -basnet = BASNetImageSegmenter(backbone=backbone, preprocessor=preprocessor) -backbone.load_weights("basnet_weights.h5") - -# save the preset -basnet.save_to_preset("basnet")