Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
torch._C._jit_set_profiling_mode(False)

# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*']
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*']
NUM_NON_STD = len(NON_STD_FILTERS)

# exclude models that cause specific test failures
Expand Down
2 changes: 1 addition & 1 deletion timm/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp
from .mlp import Mlp, GluMlp, GatedMlp
from .norm import GroupNorm
from .norm_act import BatchNormAct2d, GroupNormAct
from .padding import get_padding, get_same_padding, pad_same
Expand Down
34 changes: 32 additions & 2 deletions timm/models/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features * 2)
assert hidden_features % 2 == 0
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.fc2 = nn.Linear(hidden_features // 2, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
Expand All @@ -47,3 +48,32 @@ def forward(self, x):
x = self.fc2(x)
x = self.drop(x)
return x


class GatedMlp(nn.Module):
""" MLP as used in gMLP
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
gate_layer=None, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
if gate_layer is not None:
assert hidden_features % 2 == 0
self.gate = gate_layer(hidden_features)
hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
else:
self.gate = nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.gate(x)
x = self.fc2(x)
x = self.drop(x)
return x
Loading