Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor

## What's New

### May 14, 2021
* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl.
* 1k trained variants: `tf_efficientnetv2_s/m/l`
* 21k trained variants: `tf_efficientnetv2_s/m/l_21k`
* 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_21ft1k`
* v2 models w/ v1 scaling: `tf_efficientnet_v2_b0` through `b3`
* Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s`
* Some blank `efficientnetv2_*` models in-place for future native PyTorch training

### May 5, 2021
* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen)
* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit)
Expand Down
475 changes: 397 additions & 78 deletions timm/models/efficientnet.py

Large diffs are not rendered by default.

232 changes: 74 additions & 158 deletions timm/models/efficientnet_blocks.py

Large diffs are not rendered by default.

110 changes: 72 additions & 38 deletions timm/models/efficientnet_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,55 @@
import torch.nn as nn

from .efficientnet_blocks import *
from .layers import CondConv2d, get_condconv_initializer
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, make_divisible

__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']

_logger = logging.getLogger(__name__)


_DEBUG_BUILDER = False

# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
# NOTE: momentum varies btw .99 and .9997 depending on source
# .99 in official TF TPU impl
# .9997 (/w .999 in search space) for paper
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)


def get_bn_args_tf():
return _BN_ARGS_TF.copy()


def resolve_bn_args(kwargs):
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum
bn_eps = kwargs.pop('bn_eps', None)
if bn_eps is not None:
bn_args['eps'] = bn_eps
return bn_args


def resolve_act_layer(kwargs, default='relu'):
act_layer = kwargs.pop('act_layer', default)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
return act_layer


def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):
"""Round number of filters based on depth multiplier."""
if not multiplier:
return channels
return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit)


def _log_info_if(msg, condition):
if condition:
_logger.info(msg)
Expand Down Expand Up @@ -63,11 +105,13 @@ def _decode_block_str(block_str):
block_type = ops[0] # take the block type off the front
ops = ops[1:]
options = {}
noskip = False
skip = None
for op in ops:
# string options being checked on individual basis, combine if they grow
if op == 'noskip':
noskip = True
skip = False # force no skip connection
elif op == 'skip':
skip = True # force a skip connection
elif op.startswith('n'):
# activation fn
key = op[0]
Expand All @@ -94,7 +138,7 @@ def _decode_block_str(block_str):
act_layer = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def

num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
Expand All @@ -106,10 +150,10 @@ def _decode_block_str(block_str):
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
se_ratio=float(options['se']) if 'se' in options else 0.,
stride=int(options['s']),
act_layer=act_layer,
noskip=noskip,
noskip=skip is False,
)
if 'cc' in options:
block_args['num_experts'] = int(options['cc'])
Expand All @@ -119,11 +163,11 @@ def _decode_block_str(block_str):
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None,
se_ratio=float(options['se']) if 'se' in options else 0.,
stride=int(options['s']),
act_layer=act_layer,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
noskip=block_type == 'dsa' or skip is False,
)
elif block_type == 'er':
block_args = dict(
Expand All @@ -132,11 +176,11 @@ def _decode_block_str(block_str):
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
fake_in_chs=fake_in_chs,
se_ratio=float(options['se']) if 'se' in options else None,
force_in_chs=force_in_chs,
se_ratio=float(options['se']) if 'se' in options else 0.,
stride=int(options['s']),
act_layer=act_layer,
noskip=noskip,
noskip=skip is False,
)
elif block_type == 'cn':
block_args = dict(
Expand All @@ -145,6 +189,7 @@ def _decode_block_str(block_str):
out_chs=int(options['c']),
stride=int(options['s']),
act_layer=act_layer,
skip=skip is True,
)
else:
assert False, 'Unknown block type (%s)' % block_type
Expand Down Expand Up @@ -219,74 +264,63 @@ class EfficientNetBuilder:
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py

"""
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=None, se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='',
verbose=False):
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels,
act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
self.output_stride = output_stride
self.pad_type = pad_type
self.round_chs_fn = round_chs_fn
self.act_layer = act_layer
self.se_kwargs = se_kwargs
self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs
self.se_layer = se_layer
self.drop_path_rate = drop_path_rate
if feature_location == 'depthwise':
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
_logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
feature_location = 'expansion'
self.feature_location = feature_location
assert feature_location in ('bottleneck', 'expansion', '')
self.verbose = verbose
self.verbose = _DEBUG_BUILDER

# state updated during build, consumed by model
self.in_chs = None
self.features = []

def _round_channels(self, chs):
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)

def _make_block(self, ba, block_idx, block_count):
drop_path_rate = self.drop_path_rate * block_idx / block_count
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
if 'fake_in_chs' in ba and ba['fake_in_chs']:
# FIXME this is a hack to work around mismatch in origin impl input filters
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
ba['norm_layer'] = self.norm_layer
ba['norm_kwargs'] = self.norm_kwargs
ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
if 'force_in_chs' in ba and ba['force_in_chs']:
# NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
assert ba['act_layer'] is not None
if bt == 'ir':
ba['norm_layer'] = self.norm_layer
if bt != 'cn':
ba['se_layer'] = self.se_layer
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs

if bt == 'ir':
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba)
else:
block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = DepthwiseSeparableConv(**ba)
elif bt == 'er':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
_log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = EdgeResidual(**ba)
elif bt == 'cn':
_log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = ConvBnAct(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs'] # update in_chs for arg of next block

self.in_chs = ba['out_chs'] # update in_chs for arg of next block
return block

def __call__(self, in_chs, model_block_args):
Expand Down
5 changes: 2 additions & 3 deletions timm/models/ghostnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@


from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid
from .efficientnet_blocks import SqueezeExcite, ConvBnAct, make_divisible
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid, make_divisible
from .efficientnet_blocks import SqueezeExcite, ConvBnAct
from .helpers import build_model_with_cfg
from .registry import register_model

Expand Down Expand Up @@ -110,7 +110,6 @@ def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3,
nn.BatchNorm2d(out_chs),
)


def forward(self, x):
shortcut = x

Expand Down
22 changes: 13 additions & 9 deletions timm/models/hardcorenas.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from functools import partial

import torch.nn as nn
from .efficientnet_builder import decode_arch_def, resolve_bn_args
from .mobilenetv3 import MobileNetV3, MobileNetV3Features, build_model_with_cfg, default_cfg_for_features
from .layers import hard_sigmoid
from .efficientnet_blocks import resolve_act_layer
from .registry import register_model

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args
from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import get_act_fn
from .mobilenetv3 import MobileNetV3, MobileNetV3Features
from .registry import register_model


def _cfg(url='', **kwargs):
Expand Down Expand Up @@ -35,15 +39,15 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):

"""
num_features = 1280

se_layer = partial(
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,
stem_size=32,
channel_multiplier=1,
norm_kwargs=resolve_bn_args(kwargs),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
se_layer=se_layer,
**kwargs,
)

Expand Down
6 changes: 3 additions & 3 deletions timm/models/layers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def parse(x):
to_ntuple = _ntuple


def make_divisible(v, divisor=8, min_value=None):
def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
if new_v < round_limit * v:
new_v += divisor
return new_v
return new_v
Loading