Skip to content

Commit

Permalink
More MobileNet-v4 fixes
Browse files Browse the repository at this point in the history
* missed final norm after post pooling 1x1 PW head conv
* improve repr of model by flipping a few modules to None when not used, nn.Sequential for MultiQueryAttention query/key/value/output
* allow layer scaling to be enabled/disabled at model variant level, conv variants don't use it
  • Loading branch information
rwightman committed May 24, 2024
1 parent 28d76a9 commit 7fe96e7
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 102 deletions.
85 changes: 35 additions & 50 deletions timm/layers/attention2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.BatchNorm2d,
use_bias: bool = False,
):
"""Initializer.
Expand All @@ -130,81 +131,74 @@ def __init__(
self.fused_attn = use_fused_attn()
self.drop = attn_drop

self.query = nn.Sequential()
if self.has_query_strides:
# FIXME dilation
self.query_down_pool = create_pool2d(
'avg',
kernel_size=self.query_strides,
padding=padding,
)
self.query_down_norm = norm_layer(dim)
else:
self.query_down_pool = nn.Identity()
self.query_down_norm = nn.Identity()

self.query_proj = create_conv2d(
self.query.add_module('down_pool', create_pool2d(
'avg',
kernel_size=self.query_strides,
padding=padding,
))
self.query.add_module('norm', norm_layer(dim))
self.query.add_module('proj', create_conv2d(
dim,
self.num_heads * self.key_dim,
kernel_size=1,
)
bias=use_bias,
))

self.key = nn.Sequential()
if kv_stride > 1:
self.key_down_conv = create_conv2d(
self.key.add_module('down_conv', create_conv2d(
dim,
dim,
kernel_size=dw_kernel_size,
stride=kv_stride,
dilation=dilation,
padding=padding,
depthwise=True,
)
self.key_down_norm = norm_layer(dim)
else:
self.key_down_conv = nn.Identity()
self.key_down_norm = nn.Identity()

self.key_proj = create_conv2d(
))
self.key.add_module('norm', norm_layer(dim))
self.key.add_module('proj', create_conv2d(
dim,
self.key_dim,
kernel_size=1,
padding=padding,
)
bias=use_bias,
))

self.value = nn.Sequential()
if kv_stride > 1:
self.value_down_conv = create_conv2d(
self.value.add_module('down_conv', create_conv2d(
dim,
dim,
kernel_size=dw_kernel_size,
stride=kv_stride,
dilation=dilation,
padding=padding,
depthwise=True,
)
self.value_down_norm = norm_layer(dim)
else:
self.value_down_conv = nn.Identity()
self.value_down_norm = nn.Identity()

self.value_proj = create_conv2d(
))
self.value.add_module('norm', norm_layer(dim))
self.value.add_module('proj', create_conv2d(
dim,
self.value_dim,
kernel_size=1,
)
bias=use_bias,
))

self.attn_drop = nn.Dropout(attn_drop)

self.output = nn.Sequential()
if self.has_query_strides:
self.upsampling = nn.Upsample(self.query_strides, mode='bilinear', align_corners=False)
else:
self.upsampling = nn.Identity()

self.out_proj = create_conv2d(
self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False))
self.output.add_module('proj', create_conv2d(
self.value_dim * self.num_heads,
dim_out,
kernel_size=1,
)
bias=use_bias,
))
self.output.add_module('drop', nn.Dropout(proj_drop))

self.proj_drop = nn.Dropout(proj_drop)
self.einsum = False

def _reshape_input(self, t: torch.Tensor):
Expand Down Expand Up @@ -237,21 +231,15 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
"""Run layer computation."""
B, C, H, W = s = x.shape

q = self.query_down_pool(x)
q = self.query_down_norm(q)
q = self.query_proj(q)
q = self.query(x)
# desired q shape: [b, h, k, n x n] - [b, l, h, k]
q = self._reshape_projected_query(q, self.num_heads, self.key_dim)

k = self.key_down_conv(x)
k = self.key_down_norm(k)
k = self.key_proj(k)
k = self.key(x)
# output shape of k: [b, k, p], p = m x m
k = self._reshape_input(k)

v = self.value_down_conv(x)
v = self.value_down_norm(v)
v = self.value_proj(v)
v = self.value(x)
# output shape of v: [ b, p, k], p = m x m
v = self._reshape_input(v)

Expand Down Expand Up @@ -285,10 +273,7 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):

# reshape o into [b, hk, n, n,]
o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1])
o = self.upsampling(o)

x = self.out_proj(o)
x = self.proj_drop(x)
x = self.output(o)
return x


Expand Down
65 changes: 28 additions & 37 deletions timm/models/_efficientnet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,12 @@ def feature_info(self, location):

def forward(self, x):
shortcut = x
#print('ii', x.shape)
#print('ii', x.shape) # FIXME debug s2d
if self.conv_s2d is not None:
x = self.conv_s2d(x)
x = self.bn_s2d(x)
#print('id', x.shape)
#print('id', x.shape) # FIXME debug s2d
x = self.conv_dw(x)
#print('od', x.shape)
x = self.bn1(x)
x = self.se(x)
x = self.conv_pw(x)
Expand Down Expand Up @@ -296,7 +295,8 @@ def forward(self, x):
class UniversalInvertedResidual(nn.Module):
""" Universal Inverted Residual Block
For MobileNetV4 - https://arxiv.org/abs/
For MobileNetV4 - https://arxiv.org/abs/, referenced from
https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778
"""

def __init__(
Expand Down Expand Up @@ -338,8 +338,9 @@ def __init__(
)
self.norm_dw_start = dw_norm_act_layer(in_chs, apply_act=False)
else:
self.conv_dw_start = nn.Identity()
self.norm_dw_start = nn.Identity()
# start is None when not used for cleaner repr
self.conv_dw_start = None
self.norm_dw_start = None

# Point-wise expansion
mid_chs = make_divisible(in_chs * exp_ratio)
Expand All @@ -359,6 +360,7 @@ def __init__(
)
self.norm_dw_mid = dw_norm_act_layer(mid_chs, inplace=True)
else:
# keeping mid as identity so it can be hooked more easily for features
self.conv_dw_mid = nn.Identity()
self.norm_dw_mid = nn.Identity()

Expand All @@ -379,7 +381,7 @@ def __init__(
)
self.norm_dw_end = dw_norm_act_layer(out_chs, apply_act=False)
else:
# dw_end rarely used so keeping it out of repr by not using None instead of nn.Identitty()
# end is None when not in use for cleaner repr
self.conv_dw_end = None
self.norm_dw_end = None

Expand All @@ -397,8 +399,9 @@ def feature_info(self, location):

def forward(self, x):
shortcut = x
x = self.conv_dw_start(x)
x = self.norm_dw_start(x)
if self.conv_dw_start is not None:
x = self.conv_dw_start(x)
x = self.norm_dw_start(x)
x = self.conv_pw(x)
x = self.norm_pw(x)
x = self.conv_dw_mid(x)
Expand All @@ -418,7 +421,8 @@ def forward(self, x):
class MobileAttention(nn.Module):
""" Mobile Attention Block
For MobileNetV4 - https://arxiv.org/abs/
For MobileNetV4 - https://arxiv.org/abs/, referenced from
https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L1504
"""
def __init__(
self,
Expand Down Expand Up @@ -476,34 +480,21 @@ def __init__(
num_heads = in_chs // key_dim

if use_multi_query:
#if self.has_query_stride or self.kv_stride > 1:
self.attn = (
MultiQueryAttention2d(
in_chs,
dim_out=out_chs,
num_heads=num_heads,
key_dim=key_dim,
value_dim=value_dim,
query_strides=query_strides,
kv_stride=kv_stride,
dilation=dilation,
padding=pad_type,
dw_kernel_size=dw_kernel_size,
attn_drop=attn_drop,
proj_drop=proj_drop,
#bias=use_bias, # why not here if used w/ mhsa?
)
self.attn = MultiQueryAttention2d(
in_chs,
dim_out=out_chs,
num_heads=num_heads,
key_dim=key_dim,
value_dim=value_dim,
query_strides=query_strides,
kv_stride=kv_stride,
dilation=dilation,
padding=pad_type,
dw_kernel_size=dw_kernel_size,
attn_drop=attn_drop,
proj_drop=proj_drop,
#bias=use_bias, # why not here if used w/ mhsa?
)
# else:
# self.attn = MultiQueryAttentionV2(
# in_chs,
# dim_out=out_chs,
# num_heads=num_heads,
# key_dim=key_dim,
# value_dim=value_dim,
# attn_drop=attn_drop,
# proj_drop=proj_drop,
# )
else:
self.attn = Attention2d(
in_chs,
Expand Down
27 changes: 15 additions & 12 deletions timm/models/_efficientnet_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Hacked together by / Copyright 2019, Ross Wightman
"""
from typing import Callable, Optional

import logging
import math
Expand Down Expand Up @@ -321,15 +322,16 @@ class EfficientNetBuilder:
"""
def __init__(
self,
output_stride=32,
pad_type='',
round_chs_fn=round_channels,
se_from_exp=False,
act_layer=None,
norm_layer=None,
se_layer=None,
drop_path_rate=0.,
feature_location='',
output_stride: int = 32,
pad_type: str = '',
round_chs_fn: Callable = round_channels,
se_from_exp: bool = False,
act_layer: Optional[Callable] = None,
norm_layer: Optional[Callable] = None,
se_layer: Optional[Callable] = None,
drop_path_rate: float = 0.,
layer_scale_init_value: Optional[float] = None,
feature_location: str = '',
):
self.output_stride = output_stride
self.pad_type = pad_type
Expand All @@ -344,6 +346,7 @@ def __init__(
except TypeError:
self.se_has_ratio = False
self.drop_path_rate = drop_path_rate
self.layer_scale_init_value = layer_scale_init_value
if feature_location == 'depthwise':
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
_logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
Expand Down Expand Up @@ -402,13 +405,13 @@ def _make_block(self, ba, block_idx, block_count):
block = ConvBnAct(**ba)
elif bt == 'uir':
_log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = UniversalInvertedResidual(**ba)
block = UniversalInvertedResidual(**ba, layer_scale_init_value=self.layer_scale_init_value)
elif bt == 'mqa':
_log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = MobileAttention(**ba, use_multi_query=True)
block = MobileAttention(**ba, use_multi_query=True, layer_scale_init_value=self.layer_scale_init_value)
elif bt == 'mha':
_log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = MobileAttention(**ba)
block = MobileAttention(**ba, layer_scale_init_value=self.layer_scale_init_value)
else:
assert False, 'Unknown block type (%s) while building model.' % bt

Expand Down
Loading

0 comments on commit 7fe96e7

Please sign in to comment.