Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions keras_hub/src/models/efficientnet/efficientnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ class EfficientNetBackbone(FeaturePyramidBackbone):
MBConvBlock, but instead of using a depthwise convolution and a 1x1
output convolution blocks fused blocks use a single 3x3 convolution
block.
stackwise_force_input_filters: list of ints, overrides
stackwise_input_filters if > 0. Primarily used to parameterize stem
filters (usually stackwise_input_filters[0]) differrently than stack
input filters.
stackwise_nores_option: list of bools, toggles if residiual connection
is not used. If False (default), the stack will use residual
connections, otherwise not.
min_depth: integer, minimum number of filters. Can be None and ignored
if use_depth_divisor_as_min_depth is set to True.
include_initial_padding: bool, whether to include initial zero padding
Expand All @@ -66,6 +73,8 @@ class EfficientNetBackbone(FeaturePyramidBackbone):
stem_conv_padding: str, can be 'same' or 'valid'. Padding for the stem.
batch_norm_momentum: float, momentum for the moving average calcualtion
in the batch normalization layers.
batch_norm_epsilon: float, epsilon for batch norm calcualtions. Used
in denominator for calculations to prevent divide by 0 errors.

Example:
```python
Expand Down Expand Up @@ -100,6 +109,8 @@ def __init__(
stackwise_squeeze_and_excite_ratios,
stackwise_strides,
stackwise_block_types,
stackwise_force_input_filters=[0] * 7,
stackwise_nores_option=[False] * 7,
dropout=0.2,
depth_divisor=8,
min_depth=8,
Expand Down Expand Up @@ -163,6 +174,8 @@ def __init__(
num_repeats = stackwise_num_repeats[i]
input_filters = stackwise_input_filters[i]
output_filters = stackwise_output_filters[i]
force_input_filters = stackwise_force_input_filters[i]
nores = stackwise_nores_option[i]

# Update block input and output filters based on depth multiplier.
input_filters = round_filters(
Expand Down Expand Up @@ -200,6 +213,16 @@ def __init__(
self._pyramid_outputs[f"P{curr_pyramid_level}"] = x
curr_pyramid_level += 1

if force_input_filters > 0:
input_filters = round_filters(
filters=force_input_filters,
width_coefficient=width_coefficient,
min_depth=min_depth,
depth_divisor=depth_divisor,
use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth,
cap_round_filter_decrease=cap_round_filter_decrease,
)

# 97 is the start of the lowercase alphabet.
letter_identifier = chr(j + 97)
stackwise_block_type = stackwise_block_types[i]
Expand Down Expand Up @@ -232,6 +255,8 @@ def __init__(
activation=activation,
dropout=dropout * block_id / blocks,
batch_norm_momentum=batch_norm_momentum,
batch_norm_epsilon=batch_norm_epsilon,
nores=nores,
name=block_name,
)
x = block(x)
Expand Down Expand Up @@ -291,6 +316,7 @@ def __init__(
self.stackwise_strides = stackwise_strides
self.stackwise_block_types = stackwise_block_types

self.stackwise_force_input_filters = stackwise_force_input_filters
self.include_stem_padding = include_stem_padding
self.use_depth_divisor_as_min_depth = use_depth_divisor_as_min_depth
self.cap_round_filter_decrease = cap_round_filter_decrease
Expand Down Expand Up @@ -318,6 +344,7 @@ def get_config(self):
"stackwise_squeeze_and_excite_ratios": self.stackwise_squeeze_and_excite_ratios,
"stackwise_strides": self.stackwise_strides,
"stackwise_block_types": self.stackwise_block_types,
"stackwise_force_input_filters": self.stackwise_force_input_filters,
"include_stem_padding": self.include_stem_padding,
"use_depth_divisor_as_min_depth": self.use_depth_divisor_as_min_depth,
"cap_round_filter_decrease": self.cap_round_filter_decrease,
Expand Down
14 changes: 5 additions & 9 deletions keras_hub/src/models/efficientnet/efficientnet_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def setUp(self):
],
"stackwise_strides": [1, 2, 2, 2, 1, 2],
"stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
"stackwise_force_input_filters": [0] * 6,
"stackwise_nores_option": [False] * 6,
"width_coefficient": 1.0,
"depth_coefficient": 1.0,
}
Expand Down Expand Up @@ -60,15 +62,9 @@ def test_valid_call_original_v1(self):
"stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320],
"stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6],
"stackwise_strides": [1, 2, 2, 2, 1, 2, 1],
"stackwise_squeeze_and_excite_ratios": [
0.25,
0.25,
0.25,
0.25,
0.25,
0.25,
0.25,
],
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
"stackwise_force_input_filters": [0] * 7,
"stackwise_nores_option": [False] * 7,
"width_coefficient": 1.0,
"depth_coefficient": 1.0,
"stackwise_block_types": ["v1"] * 7,
Expand Down
45 changes: 42 additions & 3 deletions keras_hub/src/models/efficientnet/efficientnet_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,57 @@
"path": "efficientnet",
"model_card": "https://arxiv.org/abs/1905.11946",
},
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b0_ra_imagenet",
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b0_ra_imagenet/1",
},
"efficientnet_b1_ft_imagenet": {
"metadata": {
"description": (
"EfficientNet B1 model fine-trained on the ImageNet 1k dataset."
"EfficientNet B1 model fine-tuned on the ImageNet 1k dataset."
),
"params": 7794184,
"official_name": "EfficientNet",
"path": "efficientnet",
"model_card": "https://arxiv.org/abs/1905.11946",
},
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet",
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet/1",
},
"efficientnet_el_ra_imagenet": {
"metadata": {
"description": (
"EfficientNet-EdgeTPU Large model trained on the ImageNet 1k "
"dataset with RandAugment recipe."
),
"params": 10589712,
"official_name": "EfficientNet",
"path": "efficientnet",
"model_card": "https://arxiv.org/abs/1905.11946",
},
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_el_ra_imagenet/1",
},
"efficientnet_em_ra2_imagenet": {
"metadata": {
"description": (
"EfficientNet-EdgeTPU Medium model trained on the ImageNet 1k "
"dataset with RandAugment2 recipe."
),
"params": 6899496,
"official_name": "EfficientNet",
"path": "efficientnet",
"model_card": "https://arxiv.org/abs/1905.11946",
},
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_em_ra2_imagenet/1",
},
"efficientnet_es_ra_imagenet": {
"metadata": {
"description": (
"EfficientNet-EdgeTPU Small model trained on the ImageNet 1k "
"dataset with RandAugment recipe."
),
"params": 5438392,
"official_name": "EfficientNet",
"path": "efficientnet",
"model_card": "https://arxiv.org/abs/1905.11946",
},
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_es_ra_imagenet/1",
},
}
27 changes: 18 additions & 9 deletions keras_hub/src/models/efficientnet/fusedmbconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class FusedMBConvBlock(keras.layers.Layer):
se_ratio: default 0.0, The filters used in the Squeeze-Excitation phase,
and are chosen as the maximum between 1 and input_filters*se_ratio
batch_norm_momentum: default 0.9, the BatchNormalization momentum
batch_norm_epsilon: default 1e-3, float, epsilon for batch norm
calcualtions. Used in denominator for calculations to prevent divide
by 0 errors.
activation: default "swish", the activation function used between
convolution operations
dropout: float, the optional dropout rate to apply before the output
Expand All @@ -70,8 +73,10 @@ def __init__(
data_format="channels_last",
se_ratio=0.0,
batch_norm_momentum=0.9,
batch_norm_epsilon=1e-3,
activation="swish",
dropout=0.2,
nores=False,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -83,8 +88,10 @@ def __init__(
self.data_format = data_format
self.se_ratio = se_ratio
self.batch_norm_momentum = batch_norm_momentum
self.batch_norm_epsilon = batch_norm_epsilon
self.activation = activation
self.dropout = dropout
self.nores = nores
self.filters = self.input_filters * self.expand_ratio
self.filters_se = max(1, int(input_filters * se_ratio))

Expand All @@ -101,18 +108,13 @@ def __init__(
self.bn1 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
epsilon=self.batch_norm_epsilon,
name=self.name + "expand_bn",
)
self.act = keras.layers.Activation(
self.activation, name=self.name + "expand_activation"
)

self.bn2 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
name=self.name + "bn",
)

self.se_conv1 = keras.layers.Conv2D(
self.filters_se,
1,
Expand Down Expand Up @@ -144,9 +146,10 @@ def __init__(
name=self.name + "project_conv",
)

self.bn3 = keras.layers.BatchNormalization(
self.bn2 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
epsilon=self.batch_norm_epsilon,
name=self.name + "project_bn",
)

Expand Down Expand Up @@ -192,12 +195,16 @@ def call(self, inputs):

# Output phase:
x = self.output_conv(x)
x = self.bn3(x)
x = self.bn2(x)
if self.expand_ratio == 1:
x = self.act(x)

# Residual:
if self.strides == 1 and self.input_filters == self.output_filters:
if (
self.strides == 1
and self.input_filters == self.output_filters
and not self.nores
):
if self.dropout:
x = self.dropout_layer(x)
x = keras.layers.Add(name=self.name + "add")([x, inputs])
Expand All @@ -213,8 +220,10 @@ def get_config(self):
"data_format": self.data_format,
"se_ratio": self.se_ratio,
"batch_norm_momentum": self.batch_norm_momentum,
"batch_norm_epsilon": self.batch_norm_epsilon,
"activation": self.activation,
"dropout": self.dropout,
"nores": self.nores,
}

base_config = super().get_config()
Expand Down
18 changes: 17 additions & 1 deletion keras_hub/src/models/efficientnet/mbconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ def __init__(
data_format="channels_last",
se_ratio=0.0,
batch_norm_momentum=0.9,
batch_norm_epsilon=1e-3,
activation="swish",
dropout=0.2,
nores=False,
**kwargs
):
"""Implementation of the MBConv block
Expand Down Expand Up @@ -60,6 +62,9 @@ def __init__(
is above 0. The filters used in this phase are chosen as the
maximum between 1 and input_filters*se_ratio
batch_norm_momentum: default 0.9, the BatchNormalization momentum
batch_norm_epsilon: default 1e-3, float, epsilon for batch norm
calcualtions. Used in denominator for calculations to prevent
divide by 0 errors.
activation: default "swish", the activation function used between
convolution operations
dropout: float, the optional dropout rate to apply before the output
Expand All @@ -83,8 +88,10 @@ def __init__(
self.data_format = data_format
self.se_ratio = se_ratio
self.batch_norm_momentum = batch_norm_momentum
self.batch_norm_epsilon = batch_norm_epsilon
self.activation = activation
self.dropout = dropout
self.nores = nores
self.filters = self.input_filters * self.expand_ratio
self.filters_se = max(1, int(input_filters * se_ratio))

Expand All @@ -101,6 +108,7 @@ def __init__(
self.bn1 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
epsilon=self.batch_norm_epsilon,
name=self.name + "expand_bn",
)
self.act = keras.layers.Activation(
Expand All @@ -119,6 +127,7 @@ def __init__(
self.bn2 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
epsilon=self.batch_norm_epsilon,
name=self.name + "bn",
)

Expand Down Expand Up @@ -156,6 +165,7 @@ def __init__(
self.bn3 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
epsilon=self.batch_norm_epsilon,
name=self.name + "project_bn",
)

Expand Down Expand Up @@ -207,7 +217,11 @@ def call(self, inputs):
x = self.output_conv(x)
x = self.bn3(x)

if self.strides == 1 and self.input_filters == self.output_filters:
if (
self.strides == 1
and self.input_filters == self.output_filters
and not self.nores
):
if self.dropout:
x = self.dropout_layer(x)
x = keras.layers.Add(name=self.name + "add")([x, inputs])
Expand All @@ -223,8 +237,10 @@ def get_config(self):
"data_format": self.data_format,
"se_ratio": self.se_ratio,
"batch_norm_momentum": self.batch_norm_momentum,
"batch_norm_epsilon": self.batch_norm_epsilon,
"activation": self.activation,
"dropout": self.dropout,
"nores": self.nores,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
Loading
Loading