diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index f389052f8e..0e0d31d3a7 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -35,6 +35,9 @@ from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion from keras_hub.src.layers.preprocessing.random_swap import RandomSwap from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_hub.src.models.basnet.basnet_image_converter import ( + BASNetImageConverter, +) from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( DeepLabV3ImageConverter, diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index a37bad2e38..7c7adbf97c 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -29,6 +29,9 @@ BartSeq2SeqLMPreprocessor, ) from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer +from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.bert.bert_backbone import BertBackbone from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM from keras_hub.src.models.bert.bert_masked_lm_preprocessor import ( diff --git a/keras_hub/src/models/basnet/__init__.py b/keras_hub/src/models/basnet/__init__.py new file mode 100644 index 0000000000..953164a973 --- /dev/null +++ b/keras_hub/src/models/basnet/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.basnet.basnet_presets import basnet_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(basnet_presets, BASNetBackbone) diff --git a/keras_hub/src/models/basnet/basnet.py b/keras_hub/src/models/basnet/basnet.py new file mode 100644 index 0000000000..6ca7bd6851 --- /dev/null +++ b/keras_hub/src/models/basnet/basnet.py @@ -0,0 +1,122 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor +from keras_hub.src.models.image_segmenter import ImageSegmenter + + +@keras_hub_export("keras_hub.models.BASNetImageSegmenter") +class BASNetImageSegmenter(ImageSegmenter): + """BASNet image segmentation task. + + Args: + backbone: A `keras_hub.models.BASNetBackbone` instance. + preprocessor: `None`, a `keras_hub.models.Preprocessor` instance, + a `keras.Layer` instance, or a callable. If `None` no preprocessing + will be applied to the inputs. + + Example: + ```python + import keras_hub + + images = np.ones(shape=(1, 288, 288, 3)) + labels = np.zeros(shape=(1, 288, 288, 1)) + + image_encoder = keras_hub.models.ResNetBackbone.from_preset( + "resnet_18_imagenet", + load_weights=False + ) + backbone = keras_hub.models.BASNetBackbone( + image_encoder, + num_classes=1, + image_shape=[288, 288, 3] + ) + model = keras_hub.models.BASNetImageSegmenter(backbone) + + # Evaluate the model + pred_labels = model(images) + + # Train the model + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(from_logits=False), + metrics=["accuracy"], + ) + model.fit(images, labels, epochs=3) + ``` + """ + + backbone_cls = BASNetBackbone + preprocessor_cls = BASNetPreprocessor + + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): + # === Functional Model === + x = backbone.input + outputs = backbone(x) + # only return the refinement module's output as final prediction + outputs = outputs["refine_out"] + super().__init__(inputs=x, outputs=outputs, **kwargs) + + # === Config === + self.backbone = backbone + self.preprocessor = preprocessor + + def compute_loss(self, x, y, y_pred, *args, **kwargs): + # train BASNet's prediction and refinement module outputs against the + # same ground truth data + outputs = self.backbone(x) + losses = [] + for output in outputs.values(): + losses.append(super().compute_loss(x, y, output, *args, **kwargs)) + return keras.ops.sum(losses, axis=0) + + def compile( + self, + optimizer="auto", + loss="auto", + metrics="auto", + **kwargs, + ): + """Configures the `BASNet` task for training. + + `BASNet` extends the default compilation signature + of `keras.Model.compile` with defaults for `optimizer` and `loss`. To + override these defaults, pass any value to these arguments during + compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default + optimizer for `BASNet`. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` + values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, in which case the default loss + computation of `BASNet` will be applied. + See `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.Accuracy` will be applied to track the + accuracy of the model during training. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if loss == "auto": + loss = keras.losses.BinaryCrossentropy() + if metrics == "auto": + metrics = [keras.metrics.Accuracy()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) diff --git a/keras_hub/src/models/basnet/basnet_backbone.py b/keras_hub/src/models/basnet/basnet_backbone.py new file mode 100644 index 0000000000..3ae15fc9bb --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_backbone.py @@ -0,0 +1,366 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.resnet.resnet_backbone import ( + apply_basic_block as resnet_basic_block, +) + + +@keras_hub_export("keras_hub.models.BASNetBackbone") +class BASNetBackbone(Backbone): + """BASNet architecture for semantic segmentation. + + A Keras model implementing the BASNet architecture described in [BASNet: + Boundary-Aware Segmentation Network for Mobile and Web Applications]( + https://arxiv.org/abs/2101.04704). BASNet uses a predict-refine + architecture for highly accurate image segmentation. + + Args: + image_encoder: A `keras_hub.models.ResNetBackbone` instance. The + backbone network for the model that is used as a feature extractor + for BASNet prediction encoder. Currently supported backbones are + ResNet18 and ResNet34. + (Note: Do not specify `image_shape` within the backbone. + Please provide these while initializing the 'BASNetBackbone' model) + num_classes: int, the number of classes for the segmentation model. + image_shape: optional shape tuple, defaults to (None, None, 3). + projection_filters: int, number of filters in the convolution layer + projecting low-level features from the `backbone`. + prediction_heads: (Optional) List of `keras.layers.Layer` defining + the prediction module head for the model. If not provided, a + default head is created with a Conv2D layer followed by resizing. + refinement_head: (Optional) a `keras.layers.Layer` defining the + refinement module head for the model. If not provided, a default + head is created with a Conv2D layer. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + """ + + def __init__( + self, + image_encoder, + num_classes, + image_shape=(None, None, 3), + projection_filters=64, + prediction_heads=None, + refinement_head=None, + dtype=None, + **kwargs, + ): + if not isinstance(image_encoder, keras.layers.Layer) or not isinstance( + image_encoder, keras.Model + ): + raise ValueError( + "Argument `image_encoder` must be a `keras.layers.Layer`" + f" instance or `keras.Model`. Received instead" + f" image_encoder={image_encoder} (of type" + f" {type(image_encoder)})." + ) + + if tuple(image_encoder.image_shape) != (None, None, 3): + raise ValueError( + "Do not specify `image_shape` within the" + " `BASNetBackbone`'s image_encoder. \nPlease provide" + " `image_shape` while initializing the 'BASNetBackbone' model." + ) + + # === Functional Model === + inputs = keras.layers.Input(shape=image_shape) + x = inputs + + if prediction_heads is None: + prediction_heads = [] + for size in (1, 2, 4, 8, 16, 32, 32): + head_layers = [ + keras.layers.Conv2D( + num_classes, + kernel_size=(3, 3), + padding="same", + dtype=dtype, + ) + ] + if size != 1: + head_layers.append( + keras.layers.UpSampling2D( + size=size, interpolation="bilinear", dtype=dtype + ) + ) + prediction_heads.append(keras.Sequential(head_layers)) + + if refinement_head is None: + refinement_head = keras.Sequential( + [ + keras.layers.Conv2D( + num_classes, + kernel_size=(3, 3), + padding="same", + dtype=dtype, + ), + ] + ) + + # Prediction model. + predict_model = basnet_predict( + x, image_encoder, projection_filters, prediction_heads, dtype=dtype + ) + + # Refinement model. + refine_model = basnet_rrm( + predict_model, projection_filters, refinement_head, dtype=dtype + ) + + outputs = refine_model.outputs # Combine outputs. + outputs.extend(predict_model.outputs) + + output_names = ["refine_out"] + [ + f"predict_out_{i}" for i in range(1, len(outputs)) + ] + + outputs = { + output_name: keras.layers.Activation( + "sigmoid", name=output_name, dtype=dtype + )(output) + for output, output_name in zip(outputs, output_names) + } + + super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs) + + # === Config === + self.image_encoder = image_encoder + self.num_classes = num_classes + self.image_shape = image_shape + self.projection_filters = projection_filters + self.prediction_heads = prediction_heads + self.refinement_head = refinement_head + + def get_config(self): + config = super().get_config() + config.update( + { + "image_encoder": keras.saving.serialize_keras_object( + self.image_encoder + ), + "num_classes": self.num_classes, + "image_shape": self.image_shape, + "projection_filters": self.projection_filters, + "prediction_heads": [ + keras.saving.serialize_keras_object(prediction_head) + for prediction_head in self.prediction_heads + ], + "refinement_head": keras.saving.serialize_keras_object( + self.refinement_head + ), + } + ) + return config + + @classmethod + def from_config(cls, config): + if "image_encoder" in config: + config["image_encoder"] = keras.layers.deserialize( + config["image_encoder"] + ) + if "prediction_heads" in config and isinstance( + config["prediction_heads"], list + ): + for i in range(len(config["prediction_heads"])): + if isinstance(config["prediction_heads"][i], dict): + config["prediction_heads"][i] = keras.layers.deserialize( + config["prediction_heads"][i] + ) + + if "refinement_head" in config and isinstance( + config["refinement_head"], dict + ): + config["refinement_head"] = keras.layers.deserialize( + config["refinement_head"] + ) + return super().from_config(config) + + +def convolution_block(x_input, filters, dilation=1, dtype=None): + """Apply convolution + batch normalization + ReLU activation. + + Args: + x_input: Input keras tensor. + filters: int, number of output filters in the convolution. + dilation: int, dilation rate for the convolution operation. + Defaults to 1. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + + Returns: + A tensor with convolution, batch normalization, and ReLU + activation applied. + """ + x = keras.layers.Conv2D( + filters, (3, 3), padding="same", dilation_rate=dilation, dtype=dtype + )(x_input) + x = keras.layers.BatchNormalization(dtype=dtype)(x) + return keras.layers.Activation("relu", dtype=dtype)(x) + + +def get_resnet_block(_resnet, block_num): + """Extract and return a specific ResNet block. + + Args: + _resnet: `keras.Model`. ResNet model instance. + block_num: int, block number to extract. + + Returns: + A Keras Model representing the specified ResNet block. + """ + + extractor_levels = ["P2", "P3", "P4", "P5"] + num_blocks = _resnet.stackwise_num_blocks + if block_num == 0: + x = _resnet.get_layer("pool1_pool").output + else: + x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]] + y = _resnet.get_layer( + f"stack{block_num}_block{num_blocks[block_num]-1}_add" + ).output + return keras.models.Model( + inputs=x, + outputs=y, + name=f"resnet_block{block_num + 1}", + ) + + +def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None): + """BASNet Prediction Module. + + This module outputs a coarse label map by integrating heavy + encoder, bridge, and decoder blocks. + + Args: + x_input: Input keras tensor. + backbone: `keras.Model`. The backbone network used as a feature + extractor for BASNet prediction encoder. + filters: int, the number of filters. + segmentation_heads: List of `keras.layers.Layer`, A list of Keras + layers serving as the segmentation head for prediction module. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + + + Returns: + A Keras Model that integrates the encoder, bridge, and decoder + blocks for coarse label map prediction. + """ + num_stages = 6 + + x = x_input + + # -------------Encoder-------------- + x = keras.layers.Conv2D( + filters, kernel_size=(3, 3), padding="same", dtype=dtype + )(x) + + encoder_blocks = [] + for i in range(num_stages): + if i < 4: # First four stages are adopted from ResNet backbone. + x = get_resnet_block(backbone, i)(x) + encoder_blocks.append(x) + else: # Last 2 stages consist of three basic resnet blocks. + x = keras.layers.MaxPool2D( + pool_size=(2, 2), strides=(2, 2), dtype=dtype + )(x) + for j in range(3): + x = resnet_basic_block( + x, + filters=x.shape[3], + conv_shortcut=False, + name=f"v1_basic_block_{i + 1}_{j + 1}", + dtype=dtype, + ) + encoder_blocks.append(x) + + # -------------Bridge------------- + x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype) + x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype) + x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype) + encoder_blocks.append(x) + + # -------------Decoder------------- + decoder_blocks = [] + for i in reversed(range(num_stages)): + if i != (num_stages - 1): # Except first, scale other decoder stages. + x = keras.layers.UpSampling2D( + size=2, interpolation="bilinear", dtype=dtype + )(x) + + x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) + x = convolution_block(x, filters=filters * 8, dtype=dtype) + x = convolution_block(x, filters=filters * 8, dtype=dtype) + x = convolution_block(x, filters=filters * 8, dtype=dtype) + decoder_blocks.append(x) + + decoder_blocks.reverse() # Change order from last to first decoder stage. + decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder. + + # -------------Side Outputs-------------- + decoder_blocks = [ + segmentation_head(decoder_block) # Prediction segmentation head. + for segmentation_head, decoder_block in zip( + segmentation_heads, decoder_blocks + ) + ] + + return keras.models.Model(inputs=[x_input], outputs=decoder_blocks) + + +def basnet_rrm(base_model, filters, segmentation_head, dtype=None): + """BASNet Residual Refinement Module (RRM). + + This module outputs a fine label map by integrating light encoder, + bridge, and decoder blocks. + + Args: + base_model: Keras model used as the base or coarse label map. + filters: int, the number of filters. + segmentation_head: a `keras.layers.Layer`, A Keras layer serving + as the segmentation head for refinement module. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + + Returns: + A Keras Model that constructs the Residual Refinement Module (RRM). + """ + num_stages = 4 + + x_input = base_model.output[0] + + # -------------Encoder-------------- + x = keras.layers.Conv2D( + filters, kernel_size=(3, 3), padding="same", dtype=dtype + )(x_input) + + encoder_blocks = [] + for _ in range(num_stages): + x = convolution_block(x, filters=filters) + encoder_blocks.append(x) + x = keras.layers.MaxPool2D( + pool_size=(2, 2), strides=(2, 2), dtype=dtype + )(x) + + # -------------Bridge-------------- + x = convolution_block(x, filters=filters, dtype=dtype) + + # -------------Decoder-------------- + for i in reversed(range(num_stages)): + x = keras.layers.UpSampling2D( + size=2, interpolation="bilinear", dtype=dtype + )(x) + x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) + x = convolution_block(x, filters=filters) + + x = segmentation_head(x) # Refinement segmentation head. + + # ------------- refined = coarse + residual + x = keras.layers.Add(dtype=dtype)( + [x_input, x] + ) # Add prediction + refinement output + + return keras.models.Model(inputs=base_model.input, outputs=[x]) diff --git a/keras_hub/src/models/basnet/basnet_backbone_test.py b/keras_hub/src/models/basnet/basnet_backbone_test.py new file mode 100644 index 0000000000..dfb665d2c3 --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_backbone_test.py @@ -0,0 +1,36 @@ +import numpy as np + +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class BASNetBackboneTest(TestCase): + def setUp(self): + self.images = np.ones((2, 64, 64, 3)) + self.image_encoder = ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], + stackwise_num_filters=[64, 128, 256, 512], + stackwise_num_blocks=[2, 2, 2, 2], + stackwise_num_strides=[1, 2, 2, 2], + block_type="basic_block", + ) + self.init_kwargs = { + "image_encoder": self.image_encoder, + "num_classes": 1, + } + + def test_backbone_basics(self): + output_names = ["refine_out"] + [ + f"predict_out_{i}" for i in range(1, 8) + ] + expected_output_shape = {name: (2, 64, 64, 1) for name in output_names} + self.run_backbone_test( + cls=BASNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.images, + expected_output_shape=expected_output_shape, + run_mixed_precision_check=False, + run_quantization_check=False, + ) diff --git a/keras_hub/src/models/basnet/basnet_image_converter.py b/keras_hub/src/models/basnet/basnet_image_converter.py new file mode 100644 index 0000000000..399858c174 --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone + + +@keras_hub_export("keras_hub.layers.BASNetImageConverter") +class BASNetImageConverter(ImageConverter): + backbone_cls = BASNetBackbone diff --git a/keras_hub/src/models/basnet/basnet_preprocessor.py b/keras_hub/src/models/basnet/basnet_preprocessor.py new file mode 100644 index 0000000000..cfe580aa4b --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_preprocessor.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.basnet.basnet_image_converter import ( + BASNetImageConverter, +) +from keras_hub.src.models.image_segmenter_preprocessor import ( + ImageSegmenterPreprocessor, +) + + +@keras_hub_export("keras_hub.models.BASNetPreprocessor") +class BASNetPreprocessor(ImageSegmenterPreprocessor): + backbone_cls = BASNetBackbone + image_converter_cls = BASNetImageConverter diff --git a/keras_hub/src/models/basnet/basnet_presets.py b/keras_hub/src/models/basnet/basnet_presets.py new file mode 100644 index 0000000000..3d96ab7885 --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_presets.py @@ -0,0 +1,3 @@ +"""BASNet model preset configurations.""" + +basnet_presets = {} diff --git a/keras_hub/src/models/basnet/basnet_test.py b/keras_hub/src/models/basnet/basnet_test.py new file mode 100644 index 0000000000..4147d43c7c --- /dev/null +++ b/keras_hub/src/models/basnet/basnet_test.py @@ -0,0 +1,66 @@ +import numpy as np +import pytest + +from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class BASNetTest(TestCase): + def setUp(self): + self.images = np.ones((2, 64, 64, 3)) + self.labels = np.concatenate( + (np.zeros((2, 32, 64, 1)), np.ones((2, 32, 64, 1))), axis=1 + ) + self.image_encoder = ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], + stackwise_num_filters=[64, 128, 256, 512], + stackwise_num_blocks=[2, 2, 2, 2], + stackwise_num_strides=[1, 2, 2, 2], + block_type="basic_block", + ) + self.backbone = BASNetBackbone( + image_encoder=self.image_encoder, + num_classes=1, + ) + self.preprocessor = BASNetPreprocessor() + self.init_kwargs = { + "backbone": self.backbone, + "preprocessor": self.preprocessor, + } + self.train_data = (self.images, self.labels) + + def test_basics(self): + self.run_task_test( + cls=BASNetImageSegmenter, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 64, 64, 1), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=BASNetImageSegmenter, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) + + def test_end_to_end_model_predict(self): + model = BASNetImageSegmenter(**self.init_kwargs) + output = model.predict(self.images) + self.assertAllEqual(output.shape, (2, 64, 64, 1)) + + @pytest.mark.skip(reason="disabled until preset's been uploaded to Kaggle") + @pytest.mark.extra_large + def test_all_presets(self): + for preset in BASNetImageSegmenter.presets: + self.run_preset_test( + cls=BASNetImageSegmenter, + preset=preset, + input_data=self.images, + expected_output_shape=(2, 64, 64, 1), + )