Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 44 additions & 22 deletions timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar

def _gen_mobilenet_v1(
variant, channel_multiplier=1.0, depth_multiplier=1.0,
fix_stem_head=False, head_conv=False, pretrained=False, **kwargs):
group_size=None, fix_stem_head=False, head_conv=False, pretrained=False, **kwargs
):
"""
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
Paper: https://arxiv.org/abs/1801.04381
Expand All @@ -503,7 +504,12 @@ def _gen_mobilenet_v1(
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
head_features = (1024 if fix_stem_head else max(1024, round_chs_fn(1024))) if head_conv else 0
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
block_args=decode_arch_def(
arch_def,
depth_multiplier=depth_multiplier,
fix_first_last=fix_stem_head,
group_size=group_size,
),
num_features=head_features,
stem_size=32,
fix_stem=fix_stem_head,
Expand All @@ -517,7 +523,9 @@ def _gen_mobilenet_v1(


def _gen_mobilenet_v2(
variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
variant, channel_multiplier=1.0, depth_multiplier=1.0,
group_size=None, fix_stem_head=False, pretrained=False, **kwargs
):
""" Generate MobileNet-V2 network
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
Paper: https://arxiv.org/abs/1801.04381
Expand All @@ -533,7 +541,12 @@ def _gen_mobilenet_v2(
]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
block_args=decode_arch_def(
arch_def,
depth_multiplier=depth_multiplier,
fix_first_last=fix_stem_head,
group_size=group_size,
),
num_features=1280 if fix_stem_head else max(1280, round_chs_fn(1280)),
stem_size=32,
fix_stem=fix_stem_head,
Expand Down Expand Up @@ -613,7 +626,8 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):

def _gen_efficientnet(
variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
group_size=None, pretrained=False, **kwargs):
group_size=None, pretrained=False, **kwargs
):
"""Creates an EfficientNet model.

Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Expand Down Expand Up @@ -661,7 +675,8 @@ def _gen_efficientnet(


def _gen_efficientnet_edge(
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
):
""" Creates an EfficientNet-EdgeTPU model

Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
Expand Down Expand Up @@ -692,7 +707,8 @@ def _gen_efficientnet_edge(


def _gen_efficientnet_condconv(
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs
):
"""Creates an EfficientNet-CondConv model.

Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
Expand Down Expand Up @@ -764,7 +780,8 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0


def _gen_efficientnetv2_base(
variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
):
""" Creates an EfficientNet-V2 base model

Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Expand All @@ -780,7 +797,7 @@ def _gen_efficientnetv2_base(
]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=round_chs_fn(1280),
stem_size=32,
round_chs_fn=round_chs_fn,
Expand All @@ -793,7 +810,8 @@ def _gen_efficientnetv2_base(


def _gen_efficientnetv2_s(
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, rw=False, pretrained=False, **kwargs):
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, rw=False, pretrained=False, **kwargs
):
""" Creates an EfficientNet-V2 Small model

Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Expand Down Expand Up @@ -831,7 +849,9 @@ def _gen_efficientnetv2_s(
return model


def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
def _gen_efficientnetv2_m(
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
):
""" Creates an EfficientNet-V2 Medium model

Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Expand All @@ -849,7 +869,7 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
]

model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=1280,
stem_size=24,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
Expand All @@ -861,7 +881,9 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
return model


def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
def _gen_efficientnetv2_l(
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
):
""" Creates an EfficientNet-V2 Large model

Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Expand All @@ -879,7 +901,7 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
]

model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=1280,
stem_size=32,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
Expand All @@ -891,7 +913,9 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
return model


def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
def _gen_efficientnetv2_xl(
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
):
""" Creates an EfficientNet-V2 Xtra-Large model

Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
Expand All @@ -909,7 +933,7 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
]

model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=1280,
stem_size=32,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
Expand All @@ -923,7 +947,8 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0

def _gen_efficientnet_x(
variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
group_size=None, version=1, pretrained=False, **kwargs):
group_size=None, version=1, pretrained=False, **kwargs
):
"""Creates an EfficientNet model.

Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Expand Down Expand Up @@ -1069,9 +1094,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
return model


def _gen_tinynet(
variant, model_width=1.0, depth_multiplier=1.0, pretrained=False, **kwargs
):
def _gen_tinynet(variant, model_width=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a TinyNet model.
"""
arch_def = [
Expand Down Expand Up @@ -1183,8 +1206,7 @@ def _arch_def(chs: List[int], group_size: int):
return model


def _gen_test_efficientnet(
variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
def _gen_test_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
""" Minimal test EfficientNet generator.
"""
arch_def = [
Expand Down
16 changes: 11 additions & 5 deletions timm/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,9 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV
return model


def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
def _gen_mobilenet_v3_rw(
variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs
) -> MobileNetV3:
"""Creates a MobileNet-V3 model.

Ref impl: ?
Expand Down Expand Up @@ -450,7 +452,9 @@ def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrain
return model


def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
def _gen_mobilenet_v3(
variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs
) -> MobileNetV3:
"""Creates a MobileNet-V3 model.

Ref impl: ?
Expand Down Expand Up @@ -533,7 +537,7 @@ def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained:
]
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
block_args=decode_arch_def(arch_def, group_size=group_size),
num_features=num_features,
stem_size=16,
fix_stem=channel_multiplier < 0.75,
Expand Down Expand Up @@ -646,7 +650,9 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool =
return model


def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
def _gen_mobilenet_v4(
variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs,
) -> MobileNetV3:
"""Creates a MobileNet-V4 model.

Ref impl: ?
Expand Down Expand Up @@ -877,7 +883,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
assert False, f'Unknown variant {variant}.'

model_kwargs = dict(
block_args=decode_arch_def(arch_def),
block_args=decode_arch_def(arch_def, group_size=group_size),
head_bias=False,
head_norm=True,
num_features=num_features,
Expand Down