diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 9903cb2579..ece61f4977 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -40,6 +40,9 @@ from keras_hub.src.models.densenet.densenet_image_converter import ( DenseNetImageConverter, ) +from keras_hub.src.models.efficientnet.efficientnet_image_converter import ( + EfficientNetImageConverter, +) from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( PaliGemmaImageConverter, diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 71c8ca9f82..27fa2f8353 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -128,6 +128,12 @@ from keras_hub.src.models.efficientnet.efficientnet_backbone import ( EfficientNetBackbone, ) +from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( + EfficientNetImageClassifier, +) +from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( + EfficientNetImageClassifierPreprocessor, +) from keras_hub.src.models.electra.electra_backbone import ElectraBackbone from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone diff --git a/keras_hub/src/models/efficientnet/__init__.py b/keras_hub/src/models/efficientnet/__init__.py index e69de29bb2..c007fcca56 100644 --- a/keras_hub/src/models/efficientnet/__init__.py +++ b/keras_hub/src/models/efficientnet/__init__.py @@ -0,0 +1,9 @@ +from keras_hub.src.models.efficientnet.efficientnet_backbone import ( + EfficientNetBackbone, +) +from keras_hub.src.models.efficientnet.efficientnet_presets import ( + backbone_presets, +) +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, EfficientNetBackbone) diff --git a/keras_hub/src/models/efficientnet/efficientnet_backbone.py b/keras_hub/src/models/efficientnet/efficientnet_backbone.py index 059207ab5c..4016bb01e4 100644 --- a/keras_hub/src/models/efficientnet/efficientnet_backbone.py +++ b/keras_hub/src/models/efficientnet/efficientnet_backbone.py @@ -104,20 +104,23 @@ def __init__( depth_divisor=8, min_depth=8, input_shape=(None, None, 3), + data_format="channels_last", activation="swish", - include_initial_padding=False, + include_stem_padding=True, use_depth_divisor_as_min_depth=False, cap_round_filter_decrease=False, - stem_conv_padding="same", + stem_conv_padding="valid", batch_norm_momentum=0.9, + batch_norm_epsilon=1e-5, + projection_activation=None, **kwargs, ): image_input = keras.layers.Input(shape=input_shape) x = image_input # Intermediate result. - if include_initial_padding: + if include_stem_padding: x = keras.layers.ZeroPadding2D( - padding=self._correct_pad_downsample(x, 3), + padding=(1, 1), name="stem_conv_pad", )(x) @@ -136,6 +139,7 @@ def __init__( kernel_size=3, strides=2, padding=stem_conv_padding, + data_format=data_format, use_bias=False, kernel_initializer=conv_kernel_initializer(), name="stem_conv", @@ -143,6 +147,7 @@ def __init__( x = keras.layers.BatchNormalization( momentum=batch_norm_momentum, + epsilon=batch_norm_epsilon, name="stem_bn", )(x) x = keras.layers.Activation(activation, name="stem_activation")(x) @@ -206,10 +211,13 @@ def __init__( filters_out=output_filters, kernel_size=stackwise_kernel_sizes[i], strides=strides, + data_format=data_format, expand_ratio=stackwise_expansion_ratios[i], se_ratio=squeeze_and_excite_ratio, activation=activation, + projection_activation=projection_activation, dropout=dropout * block_id / blocks, + batch_norm_epsilon=batch_norm_epsilon, name=block_name, ) else: @@ -219,6 +227,7 @@ def __init__( expand_ratio=stackwise_expansion_ratios[i], kernel_size=stackwise_kernel_sizes[i], strides=strides, + data_format=data_format, se_ratio=squeeze_and_excite_ratio, activation=activation, dropout=dropout * block_id / blocks, @@ -241,15 +250,16 @@ def __init__( x = keras.layers.Conv2D( filters=top_filters, kernel_size=1, - padding="same", strides=1, + padding="same", + data_format="channels_last", kernel_initializer=conv_kernel_initializer(), use_bias=False, name="top_conv", - data_format="channels_last", )(x) x = keras.layers.BatchNormalization( momentum=batch_norm_momentum, + epsilon=batch_norm_epsilon, name="top_bn", )(x) x = keras.layers.Activation( @@ -268,6 +278,7 @@ def __init__( self.dropout = dropout self.depth_divisor = depth_divisor self.min_depth = min_depth + self.data_format = data_format self.activation = activation self.stackwise_kernel_sizes = stackwise_kernel_sizes self.stackwise_num_repeats = stackwise_num_repeats @@ -280,11 +291,13 @@ def __init__( self.stackwise_strides = stackwise_strides self.stackwise_block_types = stackwise_block_types - self.include_initial_padding = include_initial_padding + self.include_stem_padding = include_stem_padding self.use_depth_divisor_as_min_depth = use_depth_divisor_as_min_depth self.cap_round_filter_decrease = cap_round_filter_decrease self.stem_conv_padding = stem_conv_padding self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon + self.projection_activation = projection_activation def get_config(self): config = super().get_config() @@ -305,11 +318,13 @@ def get_config(self): "stackwise_squeeze_and_excite_ratios": self.stackwise_squeeze_and_excite_ratios, "stackwise_strides": self.stackwise_strides, "stackwise_block_types": self.stackwise_block_types, - "include_initial_padding": self.include_initial_padding, + "include_stem_padding": self.include_stem_padding, "use_depth_divisor_as_min_depth": self.use_depth_divisor_as_min_depth, "cap_round_filter_decrease": self.cap_round_filter_decrease, "stem_conv_padding": self.stem_conv_padding, "batch_norm_momentum": self.batch_norm_momentum, + "batch_norm_epsilon": self.batch_norm_epsilon, + "projection_activation": self.projection_activation, } ) return config @@ -346,10 +361,13 @@ def _apply_efficientnet_block( kernel_size=3, strides=1, activation="swish", + projection_activation=None, expand_ratio=1, se_ratio=0.0, dropout=0.0, + batch_norm_epsilon=1e-5, name="", + data_format="channels_last", ): """An inverted residual block. @@ -375,12 +393,14 @@ def _apply_efficientnet_block( kernel_size=1, strides=1, padding="same", + data_format=data_format, use_bias=False, kernel_initializer=conv_kernel_initializer(), name=name + "expand_conv", )(inputs) x = keras.layers.BatchNormalization( axis=3, + epsilon=batch_norm_epsilon, name=name + "expand_bn", )(x) x = keras.layers.Activation( @@ -390,25 +410,23 @@ def _apply_efficientnet_block( x = inputs # Depthwise Convolution - if strides == 2: - x = keras.layers.ZeroPadding2D( - padding=self._correct_pad_downsample(x, kernel_size), - name=name + "dwconv_pad", - )(x) - conv_pad = "valid" - else: - conv_pad = "same" - + padding_pixels = kernel_size // 2 + x = keras.layers.ZeroPadding2D( + padding=(padding_pixels, padding_pixels), + name=name + "dwconv_pad", + )(x) x = keras.layers.DepthwiseConv2D( kernel_size=kernel_size, strides=strides, - padding=conv_pad, + padding="valid", + data_format=data_format, use_bias=False, depthwise_initializer=conv_kernel_initializer(), name=name + "dwconv", )(x) x = keras.layers.BatchNormalization( axis=3, + epsilon=batch_norm_epsilon, name=name + "dwconv_bn", )(x) x = keras.layers.Activation( @@ -427,6 +445,7 @@ def _apply_efficientnet_block( filters_se, 1, padding="same", + data_format=data_format, activation=activation, kernel_initializer=conv_kernel_initializer(), name=name + "se_reduce", @@ -435,6 +454,7 @@ def _apply_efficientnet_block( filters, 1, padding="same", + data_format=data_format, activation="sigmoid", kernel_initializer=conv_kernel_initializer(), name=name + "se_expand", @@ -453,11 +473,13 @@ def _apply_efficientnet_block( )(x) x = keras.layers.BatchNormalization( axis=3, + epsilon=batch_norm_epsilon, name=name + "project_bn", )(x) - x = keras.layers.Activation( - activation, name=name + "project_activation" - )(x) + if projection_activation: + x = keras.layers.Activation( + projection_activation, name=name + "projection_activation" + )(x) if strides == 1 and filters_in == filters_out: if dropout > 0: diff --git a/keras_hub/src/models/efficientnet/efficientnet_backbone_test.py b/keras_hub/src/models/efficientnet/efficientnet_backbone_test.py index 19e0898437..f31004b5dc 100644 --- a/keras_hub/src/models/efficientnet/efficientnet_backbone_test.py +++ b/keras_hub/src/models/efficientnet/efficientnet_backbone_test.py @@ -73,7 +73,7 @@ def test_valid_call_original_v1(self): "depth_coefficient": 1.0, "stackwise_block_types": ["v1"] * 7, "min_depth": None, - "include_initial_padding": True, + "include_stem_padding": True, "use_depth_divisor_as_min_depth": True, "cap_round_filter_decrease": True, "stem_conv_padding": "valid", diff --git a/keras_hub/src/models/efficientnet/efficientnet_image_classifier.py b/keras_hub/src/models/efficientnet/efficientnet_image_classifier.py new file mode 100644 index 0000000000..123af6d370 --- /dev/null +++ b/keras_hub/src/models/efficientnet/efficientnet_image_classifier.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.efficientnet.efficientnet_backbone import ( + EfficientNetBackbone, +) +from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( + EfficientNetImageClassifierPreprocessor, +) +from keras_hub.src.models.image_classifier import ImageClassifier + + +@keras_hub_export("keras_hub.models.EfficientNetImageClassifier") +class EfficientNetImageClassifier(ImageClassifier): + backbone_cls = EfficientNetBackbone + preprocessor_cls = EfficientNetImageClassifierPreprocessor diff --git a/keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py b/keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py new file mode 100644 index 0000000000..f70727e7f7 --- /dev/null +++ b/keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py @@ -0,0 +1,16 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.efficientnet.efficientnet_backbone import ( + EfficientNetBackbone, +) +from keras_hub.src.models.efficientnet.efficientnet_image_converter import ( + EfficientNetImageConverter, +) +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) + + +@keras_hub_export("keras_hub.models.EfficientNetImageClassifierPreprocessor") +class EfficientNetImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = EfficientNetBackbone + image_converter_cls = EfficientNetImageConverter diff --git a/keras_hub/src/models/efficientnet/efficientnet_image_classifier_test.py b/keras_hub/src/models/efficientnet/efficientnet_image_classifier_test.py new file mode 100644 index 0000000000..f420f3571e --- /dev/null +++ b/keras_hub/src/models/efficientnet/efficientnet_image_classifier_test.py @@ -0,0 +1,84 @@ +import pytest +from keras import ops + +from keras_hub.src.models.efficientnet.efficientnet_backbone import ( + EfficientNetBackbone, +) +from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( + EfficientNetImageClassifier, +) +from keras_hub.src.tests.test_case import TestCase + + +class EfficientNetImageClassifierTest(TestCase): + def setUp(self): + self.images = ops.ones((2, 16, 16, 3)) + self.labels = [0, 3] + backbone = EfficientNetBackbone( + width_coefficient=1.0, + depth_coefficient=1.0, + stackwise_kernel_sizes=[3, 3, 5, 3, 5, 5, 3], + stackwise_num_repeats=[1, 2, 2, 3, 3, 4, 1], + stackwise_input_filters=[32, 16, 24, 40, 80, 112, 192], + stackwise_output_filters=[16, 24, 40, 80, 112, 192, 320], + stackwise_expansion_ratios=[1, 6, 6, 6, 6, 6, 6], + stackwise_strides=[1, 2, 2, 2, 1, 2, 1], + stackwise_block_types=["v1"] * 7, + stackwise_squeeze_and_excite_ratios=[0.25] * 7, + min_depth=None, + include_stem_padding=True, + use_depth_divisor_as_min_depth=True, + cap_round_filter_decrease=True, + stem_conv_padding="valid", + batch_norm_momentum=0.9, + batch_norm_epsilon=1e-5, + dropout=0, + projection_activation=None, + ) + self.init_kwargs = { + "backbone": backbone, + "num_classes": 1000, + } + self.train_data = (self.images, self.labels) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=EfficientNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_smallest_preset(self): + # Test that our forward pass is stable! + image_batch = self.load_test_image()[None, ...] / 255.0 + self.run_preset_test( + cls=EfficientNetImageClassifier, + preset="efficientnet_b0_ra_imagenet", + input_data=image_batch, + expected_output_shape=(1, 1000), + expected_labels=[85], + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=EfficientNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in EfficientNetImageClassifier.presets: + self.run_preset_test( + cls=EfficientNetImageClassifier, + preset=preset, + init_kwargs={"num_classes": 2}, + input_data=self.images, + expected_output_shape=(2, 2), + ) diff --git a/keras_hub/src/models/efficientnet/efficientnet_image_converter.py b/keras_hub/src/models/efficientnet/efficientnet_image_converter.py new file mode 100644 index 0000000000..10634b8267 --- /dev/null +++ b/keras_hub/src/models/efficientnet/efficientnet_image_converter.py @@ -0,0 +1,10 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.efficientnet.efficientnet_backbone import ( + EfficientNetBackbone, +) + + +@keras_hub_export("keras_hub.layers.EfficientNetImageConverter") +class EfficientNetImageConverter(ImageConverter): + backbone_cls = EfficientNetBackbone diff --git a/keras_hub/src/models/efficientnet/efficientnet_presets.py b/keras_hub/src/models/efficientnet/efficientnet_presets.py new file mode 100644 index 0000000000..d157fde1e0 --- /dev/null +++ b/keras_hub/src/models/efficientnet/efficientnet_presets.py @@ -0,0 +1,29 @@ +"""EfficientNet preset configurations.""" + +backbone_presets = { + "efficientnet_b0_ra_imagenet": { + "metadata": { + "description": ( + "EfficientNet B0 model pre-trained on the ImageNet 1k dataset " + "with RandAugment recipe." + ), + "params": 5288548, + "official_name": "EfficientNet", + "path": "efficientnet", + "model_card": "https://arxiv.org/abs/1905.11946", + }, + "kaggle_handle": "kaggle://kerashub/efficientnet/keras/efficientnet_b0_ra_imagenet", + }, + "efficientnet_b1_ft_imagenet": { + "metadata": { + "description": ( + "EfficientNet B1 model fine-trained on the ImageNet 1k dataset." + ), + "params": 7794184, + "official_name": "EfficientNet", + "path": "efficientnet", + "model_card": "https://arxiv.org/abs/1905.11946", + }, + "kaggle_handle": "kaggle://kerashub/efficientnet/keras/efficientnet_b1_ft_imagenet", + }, +} diff --git a/keras_hub/src/models/efficientnet/fusedmbconv.py b/keras_hub/src/models/efficientnet/fusedmbconv.py index f340aba0b1..96f55a22b8 100644 --- a/keras_hub/src/models/efficientnet/fusedmbconv.py +++ b/keras_hub/src/models/efficientnet/fusedmbconv.py @@ -67,6 +67,7 @@ def __init__( expand_ratio=1, kernel_size=3, strides=1, + data_format="channels_last", se_ratio=0.0, batch_norm_momentum=0.9, activation="swish", @@ -79,6 +80,7 @@ def __init__( self.expand_ratio = expand_ratio self.kernel_size = kernel_size self.strides = strides + self.data_format = data_format self.se_ratio = se_ratio self.batch_norm_momentum = batch_norm_momentum self.activation = activation @@ -92,7 +94,7 @@ def __init__( strides=strides, kernel_initializer=CONV_KERNEL_INITIALIZER, padding="same", - data_format="channels_last", + data_format=data_format, use_bias=False, name=self.name + "expand_conv", ) @@ -115,6 +117,7 @@ def __init__( self.filters_se, 1, padding="same", + data_format=data_format, activation=self.activation, kernel_initializer=CONV_KERNEL_INITIALIZER, name=self.name + "se_reduce", @@ -124,6 +127,7 @@ def __init__( self.filters, 1, padding="same", + data_format=data_format, activation="sigmoid", kernel_initializer=CONV_KERNEL_INITIALIZER, name=self.name + "se_expand", @@ -135,7 +139,7 @@ def __init__( strides=1, kernel_initializer=CONV_KERNEL_INITIALIZER, padding="same", - data_format="channels_last", + data_format=data_format, use_bias=False, name=self.name + "project_conv", ) @@ -169,7 +173,8 @@ def call(self, inputs): # Squeeze and excite if 0 < self.se_ratio <= 1: se = keras.layers.GlobalAveragePooling2D( - name=self.name + "se_squeeze" + name=self.name + "se_squeeze", + data_format=self.data_format, )(x) if BN_AXIS == 1: se_shape = (self.filters, 1, 1) @@ -205,6 +210,7 @@ def get_config(self): "expand_ratio": self.expand_ratio, "kernel_size": self.kernel_size, "strides": self.strides, + "data_format": self.data_format, "se_ratio": self.se_ratio, "batch_norm_momentum": self.batch_norm_momentum, "activation": self.activation, diff --git a/keras_hub/src/models/efficientnet/mbconv.py b/keras_hub/src/models/efficientnet/mbconv.py index 08c3d437a4..392e62c04f 100644 --- a/keras_hub/src/models/efficientnet/mbconv.py +++ b/keras_hub/src/models/efficientnet/mbconv.py @@ -20,6 +20,7 @@ def __init__( expand_ratio=1, kernel_size=3, strides=1, + data_format="channels_last", se_ratio=0.0, batch_norm_momentum=0.9, activation="swish", @@ -79,6 +80,7 @@ def __init__( self.expand_ratio = expand_ratio self.kernel_size = kernel_size self.strides = strides + self.data_format = data_format self.se_ratio = se_ratio self.batch_norm_momentum = batch_norm_momentum self.activation = activation @@ -92,7 +94,7 @@ def __init__( strides=1, kernel_initializer=CONV_KERNEL_INITIALIZER, padding="same", - data_format="channels_last", + data_format=data_format, use_bias=False, name=self.name + "expand_conv", ) @@ -109,7 +111,7 @@ def __init__( strides=self.strides, depthwise_initializer=CONV_KERNEL_INITIALIZER, padding="same", - data_format="channels_last", + data_format=data_format, use_bias=False, name=self.name + "dwconv2", ) @@ -124,6 +126,7 @@ def __init__( self.filters_se, 1, padding="same", + data_format=data_format, activation=self.activation, kernel_initializer=CONV_KERNEL_INITIALIZER, name=self.name + "se_reduce", @@ -133,6 +136,7 @@ def __init__( self.filters, 1, padding="same", + data_format=data_format, activation="sigmoid", kernel_initializer=CONV_KERNEL_INITIALIZER, name=self.name + "se_expand", @@ -144,7 +148,7 @@ def __init__( strides=1, kernel_initializer=CONV_KERNEL_INITIALIZER, padding="same", - data_format="channels_last", + data_format=data_format, use_bias=False, name=self.name + "project_conv", ) @@ -183,7 +187,8 @@ def call(self, inputs): # Squeeze and excite if 0 < self.se_ratio <= 1: se = keras.layers.GlobalAveragePooling2D( - name=self.name + "se_squeeze" + name=self.name + "se_squeeze", + data_format=self.data_format, )(x) if BN_AXIS == 1: se_shape = (self.filters, 1, 1) @@ -215,6 +220,7 @@ def get_config(self): "expand_ratio": self.expand_ratio, "kernel_size": self.kernel_size, "strides": self.strides, + "data_format": self.data_format, "se_ratio": self.se_ratio, "batch_norm_momentum": self.batch_norm_momentum, "activation": self.activation, diff --git a/keras_hub/src/utils/timm/convert_efficientnet.py b/keras_hub/src/utils/timm/convert_efficientnet.py new file mode 100644 index 0000000000..609c26d355 --- /dev/null +++ b/keras_hub/src/utils/timm/convert_efficientnet.py @@ -0,0 +1,226 @@ +import math + +import numpy as np + +from keras_hub.src.models.efficientnet.efficientnet_backbone import ( + EfficientNetBackbone, +) + +backbone_cls = EfficientNetBackbone + + +VARIANT_MAP = { + "b0": { + "width_coefficient": 1.0, + "depth_coefficient": 1.0, + }, + "b1": { + "width_coefficient": 1.0, + "depth_coefficient": 1.1, + }, +} + + +def convert_backbone_config(timm_config): + timm_architecture = timm_config["architecture"] + + base_kwargs = { + "stackwise_kernel_sizes": [3, 3, 5, 3, 5, 5, 3], + "stackwise_num_repeats": [1, 2, 2, 3, 3, 4, 1], + "stackwise_input_filters": [32, 16, 24, 40, 80, 112, 192], + "stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320], + "stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6], + "stackwise_strides": [1, 2, 2, 2, 1, 2, 1], + "stackwise_squeeze_and_excite_ratios": [ + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + ], + "stackwise_block_types": ["v1"] * 7, + "min_depth": None, + "include_stem_padding": True, + "use_depth_divisor_as_min_depth": True, + "cap_round_filter_decrease": True, + "stem_conv_padding": "valid", + "batch_norm_momentum": 0.9, + "batch_norm_epsilon": 1e-5, + "dropout": 0, + "projection_activation": None, + } + + variant = "_".join(timm_architecture.split("_")[1:]) + + if variant not in VARIANT_MAP: + raise ValueError( + f"Currently, the architecture {timm_architecture} is not supported." + ) + + base_kwargs.update(VARIANT_MAP[variant]) + + return base_kwargs + + +def convert_weights(backbone, loader, timm_config): + timm_architecture = timm_config["architecture"] + variant = "_".join(timm_architecture.split("_")[1:]) + + def port_conv2d(keras_layer_name, hf_weight_prefix, port_bias=True): + loader.port_weight( + backbone.get_layer(keras_layer_name).kernel, + hf_weight_key=f"{hf_weight_prefix}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + if port_bias: + loader.port_weight( + backbone.get_layer(keras_layer_name).bias, + hf_weight_key=f"{hf_weight_prefix}.bias", + ) + + def port_depthwise_conv2d( + keras_layer_name, + hf_weight_prefix, + port_bias=True, + depth_multiplier=1, + ): + + def convert_pt_conv2d_kernel(pt_kernel): + out_channels, in_channels_per_group, height, width = pt_kernel.shape + # PT Convs are depthwise convs if and only if in_channels_per_group == 1 + assert in_channels_per_group == 1 + pt_kernel = np.transpose(pt_kernel, (2, 3, 0, 1)) + in_channels = out_channels // depth_multiplier + return np.reshape( + pt_kernel, (height, width, in_channels, depth_multiplier) + ) + + loader.port_weight( + backbone.get_layer(keras_layer_name).kernel, + hf_weight_key=f"{hf_weight_prefix}.weight", + hook_fn=lambda x, _: convert_pt_conv2d_kernel(x), + ) + + if port_bias: + loader.port_weight( + backbone.get_layer(keras_layer_name).bias, + hf_weight_key=f"{hf_weight_prefix}.bias", + ) + + def port_batch_normalization(keras_layer_name, hf_weight_prefix): + loader.port_weight( + backbone.get_layer(keras_layer_name).gamma, + hf_weight_key=f"{hf_weight_prefix}.weight", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).beta, + hf_weight_key=f"{hf_weight_prefix}.bias", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_mean, + hf_weight_key=f"{hf_weight_prefix}.running_mean", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_variance, + hf_weight_key=f"{hf_weight_prefix}.running_var", + ) + # do we need num batches tracked? + + # Stem + port_conv2d("stem_conv", "conv_stem", port_bias=False) + port_batch_normalization("stem_bn", "bn1") + + # Stages + num_stacks = len(backbone.stackwise_kernel_sizes) + for stack_index in range(num_stacks): + + block_type = backbone.stackwise_block_types[stack_index] + expansion_ratio = backbone.stackwise_expansion_ratios[stack_index] + repeats = backbone.stackwise_num_repeats[stack_index] + + repeats = int( + math.ceil(VARIANT_MAP[variant]["depth_coefficient"] * repeats) + ) + + for block_idx in range(repeats): + + conv_pw_count = 0 + bn_count = 1 + conv_pw_name_map = ["conv_pw", "conv_pwl"] + + # 97 is the start of the lowercase alphabet. + letter_identifier = chr(block_idx + 97) + + if block_type == "v1": + keras_block_prefix = f"block{stack_index+1}{letter_identifier}_" + hf_block_prefix = f"blocks.{stack_index}.{block_idx}." + + # Initial Expansion Conv + if expansion_ratio != 1: + port_conv2d( + keras_block_prefix + "expand_conv", + hf_block_prefix + conv_pw_name_map[conv_pw_count], + port_bias=False, + ) + conv_pw_count += 1 + port_batch_normalization( + keras_block_prefix + "expand_bn", + hf_block_prefix + f"bn{bn_count}", + ) + bn_count += 1 + + # Depthwise Conv + port_depthwise_conv2d( + keras_block_prefix + "dwconv", + hf_block_prefix + "conv_dw", + port_bias=False, + ) + port_batch_normalization( + keras_block_prefix + "dwconv_bn", + hf_block_prefix + f"bn{bn_count}", + ) + bn_count += 1 + + # Squeeze and Excite + port_conv2d( + keras_block_prefix + "se_reduce", + hf_block_prefix + "se.conv_reduce", + ) + port_conv2d( + keras_block_prefix + "se_expand", + hf_block_prefix + "se.conv_expand", + ) + + # Output/Projection + port_conv2d( + keras_block_prefix + "project", + hf_block_prefix + conv_pw_name_map[conv_pw_count], + port_bias=False, + ) + conv_pw_count += 1 + port_batch_normalization( + keras_block_prefix + "project_bn", + hf_block_prefix + f"bn{bn_count}", + ) + bn_count += 1 + + # Head/Top + port_conv2d("top_conv", "conv_head", port_bias=False) + port_batch_normalization("top_bn", "bn2") + + +def convert_head(task, loader, timm_config): + classifier_prefix = timm_config["pretrained_cfg"]["classifier"] + prefix = f"{classifier_prefix}." + loader.port_weight( + task.output_dense.kernel, + hf_weight_key=prefix + "weight", + hook_fn=lambda x, _: np.transpose(np.squeeze(x)), + ) + loader.port_weight( + task.output_dense.bias, + hf_weight_key=prefix + "bias", + ) diff --git a/keras_hub/src/utils/timm/convert_efficientnet_test.py b/keras_hub/src/utils/timm/convert_efficientnet_test.py new file mode 100644 index 0000000000..014d4ae5fd --- /dev/null +++ b/keras_hub/src/utils/timm/convert_efficientnet_test.py @@ -0,0 +1,20 @@ +import pytest +from keras import ops + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.tests.test_case import TestCase + + +class TimmEfficientNetBackboneTest(TestCase): + @pytest.mark.large + def test_convert_efficientnet_backbone(self): + model = Backbone.from_preset("hf://timm/efficientnet_b0.ra_in1k") + outputs = model.predict(ops.ones((1, 224, 224, 3))) + self.assertEqual(outputs.shape, (1, 7, 7, 1280)) + + @pytest.mark.large + def test_convert_efficientnet_classifier(self): + model = ImageClassifier.from_preset("hf://timm/efficientnet_b0.ra_in1k") + outputs = model.predict(ops.ones((1, 512, 512, 3))) + self.assertEqual(outputs.shape, (1, 1000)) diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index 1524db8530..069a6425e4 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -4,6 +4,7 @@ from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup from keras_hub.src.utils.timm import convert_densenet +from keras_hub.src.utils.timm import convert_efficientnet from keras_hub.src.utils.timm import convert_resnet from keras_hub.src.utils.timm import convert_vgg from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -19,6 +20,8 @@ def __init__(self, preset, config): self.converter = convert_densenet elif "vgg" in architecture: self.converter = convert_vgg + elif "efficientnet" in architecture: + self.converter = convert_efficientnet else: raise ValueError( "KerasHub has no converter for timm models " diff --git a/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py b/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py new file mode 100644 index 0000000000..5790d6130c --- /dev/null +++ b/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py @@ -0,0 +1,115 @@ +""" +Convert efficientnet checkpoints. + +python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ + --preset efficientnet_b0_ra_imagenet --upload_uri kaggle://kerashub/efficientnet/keras/efficientnet_b0_ra_imagenet +python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ + --preset efficientnet_b1_ft_imagenet --upload_uri kaggle://kerashub/efficientnet/keras/efficientnet_b1_ft_imagenet +""" + +import os +import shutil + +import keras +import numpy as np +import PIL +import timm +import torch +from absl import app +from absl import flags + +import keras_hub + +PRESET_MAP = { + "efficientnet_b0_ra_imagenet": "timm/efficientnet_b0.ra_in1k", + "efficientnet_b1_ft_imagenet": "timm/efficientnet_b1.ft_in1k", +} +FLAGS = flags.FLAGS + + +flags.DEFINE_string( + "preset", + None, + "Must be a valid `EfficientNet` preset from KerasHub", + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}_imagenet"', + required=False, +) + + +def validate_output(keras_model, timm_model): + file = keras.utils.get_file( + origin=( + "https://storage.googleapis.com/keras-cv/" + "models/paligemma/cow_beach_1.png" + ) + ) + image = PIL.Image.open(file) + batch = np.array([image]) + + # Preprocess with Timm. + data_config = timm.data.resolve_model_data_config(timm_model) + data_config["crop_pct"] = 1.0 # Stop timm from cropping. + transforms = timm.data.create_transform(**data_config, is_training=False) + timm_preprocessed = transforms(image) + timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0)) + timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0) + # Preprocess with Keras. + batch = keras.ops.cast(batch, "float32") + keras_preprocessed = keras_model.preprocessor(batch) + # Call with Timm. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2)) + timm_batch = torch.from_numpy(np.array(timm_batch)) + timm_outputs = timm_model(timm_batch).detach().numpy() + timm_label = np.argmax(timm_outputs[0]) + + # Call with Keras. + keras_outputs = keras_model.predict(batch) + keras_label = np.argmax(keras_outputs[0]) + + print("🔶 Keras output:", keras_outputs[0, :10]) + print("🔶 TIMM output:", timm_outputs[0, :10]) + print("🔶 Keras label:", keras_label) + print("🔶 TIMM label:", timm_label) + modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed)) + print("🔶 Preprocessing difference:", preprocessing_diff) + + +def main(_): + preset = FLAGS.preset + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + timm_name = PRESET_MAP[preset] + + print("✅ Loaded TIMM model.") + timm_model = timm.create_model(timm_name, pretrained=True) + timm_model = timm_model.eval() + + print("✅ Loaded KerasHub model.") + keras_model = keras_hub.models.ImageClassifier.from_preset( + "hf://" + timm_name, + ) + + validate_output(keras_model, timm_model) + + keras_model.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}") + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)