Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fixed framework alignment on MobileNet V3 implementations #410

Merged
merged 4 commits into from
Aug 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 136 additions & 94 deletions doctr/models/backbones/mobilenet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
Expand All @@ -15,24 +17,11 @@

default_cfgs: Dict[str, Dict[str, Any]] = {
'mobilenet_v3_large': {
'input_shape': (512, 512, 3),
'out_chans': [16, 24, 24, 40, 40, 40, 80, 80, 80, 80, 112, 112, 160, 160, 160],
'kernels': [3, 3, 3, 5, 5, 5, 3, 3, 3, 3, 3, 3, 5, 5, 5],
'exp_chans': [16, 64, 72, 72, 120, 120, 240, 200, 184, 184, 480, 672, 672, 960, 960],
'strides': [1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1],
'use_squeeze': [False, False, False, True, True, True, False, False, False, False,
True, True, True, True, True],
'use_swish': [False, False, False, False, False, False, True, True, True, True, True, True, True, True, True],
'input_shape': (224, 224, 3),
'url': None
},
'mobilenet_v3_small': {
'input_shape': (512, 512, 3),
'out_chans': [16, 24, 24, 40, 40, 40, 48, 48, 96, 96, 96],
'kernels': [3, 3, 3, 5, 5, 5, 5, 5, 5, 5, 5],
'exp_chans': [16, 72, 88, 96, 240, 240, 120, 144, 288, 576, 576],
'strides': [2, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1],
'use_squeeze': [True, False, False, True, True, True, True, True, True, True, True],
'use_swish': [False, False, False, True, True, True, True, True, True, True, True],
'input_shape': (224, 224, 3),
'url': None
}
}
Expand All @@ -42,152 +31,204 @@ def hard_swish(x: tf.Tensor) -> tf.Tensor:
return x * tf.nn.relu6(x + 3.) / 6.0


class SqueezeExcitation(layers.Layer):
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v


class SqueezeExcitation(Sequential):
"""Squeeze and Excitation.
"""
def __init__(self, chan: int) -> None:
super().__init__()
self.squeeze_seq = Sequential(
def __init__(self, chan: int, squeeze_factor: int = 4) -> None:
super().__init__(
[
layers.GlobalAveragePooling2D(),
layers.Dense(chan, activation='relu'),
layers.Dense(chan // squeeze_factor, activation='relu'),
layers.Dense(chan, activation='hard_sigmoid'),
layers.Reshape((1, 1, chan))
]
)

def call(self, inputs: tf.Tensor) -> tf.Tensor:
x = self.squeeze_seq(inputs)
def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
x = super().call(inputs, **kwargs)
x = tf.math.multiply(inputs, x)
return x


class InvertedResidual(layers.Layer):
class InvertedResidualConfig:
def __init__(
self,
input_channels: int,
kernel: int,
expanded_channels: int,
out_channels: int,
use_se: bool,
activation: str,
stride: int,
width_mult: float = 1,
) -> None:
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
self.out_channels = self.adjust_channels(out_channels, width_mult)
self.use_se = use_se
self.use_hs = activation == "HS"
self.stride = stride

@staticmethod
def adjust_channels(channels: int, width_mult: float):
return _make_divisible(channels * width_mult, 8)


class InvertedResidual(layers.Layer):
"""InvertedResidual for mobilenet

Args:
out_chan: the dimensionality of the output space.
kernel: kernel size for depthwise conv
exp_chan: expanded channels, used in squeeze and first conv
strides: strides in depthwise conv
use_squeeze: whether to use the squeeze & sequence module
use_swish: activation type, relu6 or hard_swish

Returns:
Output tensor.
conf: configuration object for inverted residual
"""
def __init__(
self,
out_chan: int,
kernel: int,
exp_chan: int,
strides: int,
use_squeeze: bool,
use_swish: bool,
conf: InvertedResidualConfig,
**kwargs: Any,
) -> None:
_kwargs = {'input_shape': kwargs.pop('input_shape')} if isinstance(kwargs.get('input_shape'), tuple) else {}
super().__init__(**kwargs)
self.out_chan = out_chan
self.strides = strides

_layers = [*conv_sequence(exp_chan, activation=hard_swish if use_swish else tf.nn.relu6, kernel_size=1)]
act_fn = hard_swish if conf.use_hs else tf.nn.relu

_layers.extend([
layers.DepthwiseConv2D(kernel, strides, depth_multiplier=1, padding='same'),
layers.BatchNormalization(),
])
self.use_res_connect = conf.stride == 1 and conf.input_channels == conf.out_channels

_layers.append(layers.Activation(hard_swish) if use_swish else layers.Activation(tf.nn.relu6))
_layers = []
# expand
if conf.expanded_channels != conf.input_channels:
_layers.extend(conv_sequence(conf.expanded_channels, act_fn, kernel_size=1, bn=True, **_kwargs))

if use_squeeze:
_layers.append(SqueezeExcitation(exp_chan))
# depth-wise
_layers.extend(conv_sequence(
conf.expanded_channels, act_fn, kernel_size=conf.kernel, strides=conf.stride, bn=True,
groups=conf.expanded_channels,
))

_layers.extend(
[
layers.Conv2D(out_chan, 1, strides=(1, 1), padding='same'),
layers.BatchNormalization(),
if conf.use_se:
_layers.append(SqueezeExcitation(conf.expanded_channels))

]
)
self.bottleneck_sequence = Sequential(_layers)
# project
_layers.extend(conv_sequence(
conf.out_channels, None, kernel_size=1, bn=True,
))

self.block = Sequential(_layers)

def call(
self,
inputs: tf.Tensor
inputs: tf.Tensor,
**kwargs: Any,
) -> tf.Tensor:

in_chan = inputs.shape[3]
use_residual = (self.strides == 1 and in_chan == self.out_chan)
x = self.bottleneck_sequence(inputs)
if use_residual:
x = tf.add(x, inputs)
out = self.block(inputs, **kwargs)
if self.use_res_connect:
out = tf.add(out, inputs)

return x
return out


class MobileNetV3(Sequential):

"""Implements MobileNetV3, inspired from both:
<https://github.com/xiaochus/MobileNetV3/tree/master/model>`_.
and <https://pytorch.org/vision/stable/_modules/torchvision/models/mobilenetv3.html>`_.
"""

def __init__(
self,
layout: List[InvertedResidualConfig],
input_shape: Tuple[int, int, int],
out_chans: List[int],
kernels: List[int],
exp_chans: List[int],
strides: List[int],
use_squeeze: List[bool],
use_swish: List[bool],
num_classes: Optional[int] = None,
include_top: bool = False,

head_chans: int = 1024,
num_classes: int = 1000,
) -> None:

_layers = [
*conv_sequence(16, strides=2, activation=hard_swish, kernel_size=3, input_shape=input_shape)
Sequential(conv_sequence(layout[0].input_channels, hard_swish, True, kernel_size=3, strides=2,
input_shape=input_shape), name="stem")
]

for i, (out, k, exp, s, use_sq, use_sw) in enumerate(
zip(out_chans, kernels, exp_chans, strides, use_squeeze, use_swish)
):
for idx, conf in enumerate(layout):
_layers.append(
InvertedResidual(out, k, exp, s, use_sq, use_sw, name=f"inverted_{i}"),
InvertedResidual(conf, name=f"inverted_{idx}"),
)

_layers.extend(
[
*conv_sequence(exp_chans[-1], strides=1, activation=hard_swish, kernel_size=1, name="last_conv"),
]
_layers.append(
Sequential(
conv_sequence(6 * layout[-1].out_channels, hard_swish, True, kernel_size=1),
name="final_block"
)
)

if include_top:
_layers.append([
_layers.extend([
layers.GlobalAveragePooling2D(),
layers.Reshape((1, 1, exp_chans[-1])),
layers.Conv2D(1280, 1, padding='same'),
layers.Activation(hard_swish),
layers.Conv2D(num_classes, 1, padding='same', activation='softmax'),
layers.Dense(head_chans, activation=hard_swish),
layers.Dropout(0.2),
layers.Dense(num_classes, activation='softmax'),
])

super().__init__(_layers)


def _mobilenet_v3(arch: str, pretrained: bool, input_shape: Optional[Tuple[int, int, int]] = None) -> MobileNetV3:
def _mobilenet_v3(
arch: str,
pretrained: bool,
input_shape: Optional[Tuple[int, int, int]] = None,
**kwargs: Any
) -> MobileNetV3:
input_shape = input_shape or default_cfgs[arch]['input_shape']

# cf. Table 1 & 2 of the paper
if arch == "mobilenet_v3_small":
inverted_residual_setting = [
InvertedResidualConfig(16, 3, 16, 16, True, "RE", 2), # C1
InvertedResidualConfig(16, 3, 72, 24, False, "RE", 2), # C2
InvertedResidualConfig(24, 3, 88, 24, False, "RE", 1),
InvertedResidualConfig(24, 5, 96, 40, True, "HS", 2), # C3
InvertedResidualConfig(40, 5, 240, 40, True, "HS", 1),
InvertedResidualConfig(40, 5, 240, 40, True, "HS", 1),
InvertedResidualConfig(40, 5, 120, 48, True, "HS", 1),
InvertedResidualConfig(48, 5, 144, 48, True, "HS", 1),
InvertedResidualConfig(48, 5, 288, 96, True, "HS", 2), # C4
InvertedResidualConfig(96, 5, 576, 96, True, "HS", 1),
InvertedResidualConfig(96, 5, 576, 96, True, "HS", 1),
]
head_chans = 1024
else:
inverted_residual_setting = [
InvertedResidualConfig(16, 3, 16, 16, False, "RE", 1),
InvertedResidualConfig(16, 3, 64, 24, False, "RE", 2), # C1
InvertedResidualConfig(24, 3, 72, 24, False, "RE", 1),
InvertedResidualConfig(24, 5, 72, 40, True, "RE", 2), # C2
InvertedResidualConfig(40, 5, 120, 40, True, "RE", 1),
InvertedResidualConfig(40, 5, 120, 40, True, "RE", 1),
InvertedResidualConfig(40, 3, 240, 80, False, "HS", 2), # C3
InvertedResidualConfig(80, 3, 200, 80, False, "HS", 1),
InvertedResidualConfig(80, 3, 184, 80, False, "HS", 1),
InvertedResidualConfig(80, 3, 184, 80, False, "HS", 1),
InvertedResidualConfig(80, 3, 480, 112, True, "HS", 1),
InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1),
InvertedResidualConfig(112, 5, 672, 160, True, "HS", 2), # C4
InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1),
InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1),
]
head_chans = 1280
# Build the model
model = MobileNetV3(
inverted_residual_setting,
input_shape,
default_cfgs[arch]['out_chans'],
default_cfgs[arch]['kernels'],
default_cfgs[arch]['exp_chans'],
default_cfgs[arch]['strides'],
default_cfgs[arch]['use_squeeze'],
default_cfgs[arch]['use_swish'],
head_chans=head_chans,
**kwargs,
)
# Load pretrained parameters
if pretrained:
Expand All @@ -197,7 +238,7 @@ def _mobilenet_v3(arch: str, pretrained: bool, input_shape: Optional[Tuple[int,


def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
"""MobileNetV3 architecture as described in
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.

Expand All @@ -214,11 +255,12 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
Returns:
A mobilenetv3_small model
"""

return _mobilenet_v3('mobilenet_v3_small', pretrained, **kwargs)


def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
"""MobileNetV3 architecture as described in
"""MobileNetV3-Large architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
'mean': (0.798, 0.785, 0.772),
'std': (0.264, 0.2749, 0.287),
'backbone': 'mobilenet_v3_small',
'fpn_layers': ["inverted_0", "inverted_2", "inverted_7", "last_conv"],
'fpn_layers': ["inverted_0", "inverted_2", "inverted_7", "final_block"],
'fpn_channels': 128,
'input_shape': (1024, 1024, 3),
'rotated_bbox': False,
Expand Down Expand Up @@ -291,6 +291,7 @@ def _db_mobilenet(arch: str, pretrained: bool, input_shape: Tuple[int, int, int]
feat_extractor = IntermediateLayerGetter(
backbones.__dict__[_cfg['backbone']](
input_shape=_cfg['input_shape'],
include_top=False,
),
_cfg['fpn_layers'],
)
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/utils/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
model: Model,
layer_names: List[str]
) -> None:
intermediate_fmaps = [model.get_layer(layer_name).output for layer_name in layer_names]
intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names]
super().__init__(model.input, outputs=intermediate_fmaps)

def __repr__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion test/tensorflow/test_models_backbones_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def test_classification_architectures(arch_name, top_implemented, input_shape, output_size):
# Model
batch_size = 2
model = backbones.__dict__[arch_name](pretrained=True)
model = backbones.__dict__[arch_name](pretrained=True, input_shape=input_shape)
# Forward
out = model(tf.random.uniform(shape=[batch_size, *input_shape], maxval=1, dtype=tf.float32))
# Output checks
Expand Down