Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .efficientnet_builder import *
from .feature_hooks import FeatureHooks
from .registry import register_model
from .helpers import load_pretrained
from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d
from timm.models.layers import create_conv2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
Expand Down Expand Up @@ -131,6 +131,16 @@ def _cfg(url='', **kwargs):
'efficientnet_lite4': _cfg(
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),

'efficientnet_b1_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb1_pruned_9ebb3fe6.pth',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'efficientnet_b2_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb2_pruned_203f55bc.pth',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'efficientnet_b3_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),

'tf_efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
input_size=(3, 224, 224)),
Expand Down Expand Up @@ -482,9 +492,11 @@ def _create_model(model_kwargs, default_cfg, pretrained=False):
else:
load_strict = True
model_class = EfficientNet

variant = model_kwargs.pop('variant', '')
model = model_class(**model_kwargs)
model.default_cfg = default_cfg
if '_pruned' in variant:
model = adapt_model_from_file(model, variant)
if pretrained:
load_pretrained(
model,
Expand Down Expand Up @@ -730,6 +742,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
channel_multiplier=channel_multiplier,
act_layer=Swish,
norm_kwargs=resolve_bn_args(kwargs),
variant=variant,
**kwargs,
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
Expand Down Expand Up @@ -1229,6 +1242,41 @@ def efficientnet_lite4(pretrained=False, **kwargs):
return model




@register_model
def efficientnet_b1_pruned(pretrained=False, **kwargs):
""" EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
variant = 'efficientnet_b1_pruned'
model = _gen_efficientnet(
variant, channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model


@register_model
def efficientnet_b2_pruned(pretrained=False, **kwargs):
""" EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model


@register_model
def efficientnet_b3_pruned(pretrained=False, **kwargs):
""" EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model




@register_model
def tf_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0. Tensorflow compatible variant """
Expand Down
Loading