Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly
If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts:
[**hf.co/oss-survey**](https://hf.co/oss-survey) 🙏

### Jan 5, 2023
* ConvNeXt-V2 models and weights added to existing `convnext.py`
* Paper: [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](http://arxiv.org/abs/2301.00808)
* Reference impl: https://github.com/facebookresearch/ConvNeXt-V2 (NOTE: weights currently CC-BY-NC)

### Dec 23, 2022 🎄☃
* Add FlexiViT models and weights from https://github.com/google-research/big_vision (check out paper at https://arxiv.org/abs/2212.08013)
* NOTE currently resizing is static on model creation, on-the-fly dynamic / train patch size sampling is a WIP
Expand Down Expand Up @@ -396,6 +401,7 @@ A full version of the list below with source links can be found in the [document
* CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
* CoAtNet (Convolution and Attention) - https://arxiv.org/abs/2106.04803
* ConvNeXt - https://arxiv.org/abs/2201.03545
* ConvNeXt-V2 - http://arxiv.org/abs/2301.00808
* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
* DeiT - https://arxiv.org/abs/2012.12877
Expand All @@ -418,6 +424,7 @@ A full version of the list below with source links can be found in the [document
* Single-Path NAS - https://arxiv.org/abs/1904.02877
* TinyNet - https://arxiv.org/abs/2010.14819
* EVA - https://arxiv.org/abs/2211.07636
* FlexiViT - https://arxiv.org/abs/2212.08013
* GCViT (Global Context Vision Transformer) - https://arxiv.org/abs/2206.09959
* GhostNet - https://arxiv.org/abs/1911.11907
* gMLP - https://arxiv.org/abs/2105.08050
Expand Down
8 changes: 4 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'swin*giant*']
'swin*giant*', 'convnextv2_huge*']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
else:
EXCLUDE_FILTERS = []
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_model_backward(model_name, batch_size):


@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs(model_name, batch_size):
"""Run a single forward pass with each model"""
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_model_default_cfgs(model_name, batch_size):


@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS))
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs_non_std(model_name, batch_size):
"""Run a single forward pass with each model"""
Expand Down Expand Up @@ -304,7 +304,7 @@ def test_model_forward_torchscript(model_name, batch_size):


@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_features(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
Expand Down
2 changes: 1 addition & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
Expand Down
39 changes: 39 additions & 0 deletions timm/layers/grn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
""" Global Response Normalization Module

Based on the GRN layer presented in
`ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808

This implementation
* works for both NCHW and NHWC tensor layouts
* uses affine param names matching existing torch norm layers
* slightly improves eager mode performance via fused addcmul

Hacked together by / Copyright 2023 Ross Wightman
"""

import torch
from torch import nn as nn


class GlobalResponseNorm(nn.Module):
""" Global Response Normalization layer
"""
def __init__(self, dim, eps=1e-6, channels_last=True):
super().__init__()
self.eps = eps
if channels_last:
self.spatial_dim = (1, 2)
self.channel_dim = -1
self.wb_shape = (1, 1, 1, -1)
else:
self.spatial_dim = (2, 3)
self.channel_dim = 1
self.wb_shape = (1, -1, 1, 1)

self.weight = nn.Parameter(torch.zeros(dim))
self.bias = nn.Parameter(torch.zeros(dim))

def forward(self, x):
x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True)
x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps)
return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n)
97 changes: 86 additions & 11 deletions timm/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,38 @@

Hacked together by / Copyright 2020 Ross Wightman
"""
from functools import partial

from torch import nn as nn

from .grn import GlobalResponseNorm
from .helpers import to_2tuple


class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear

self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])

def forward(self, x):
Expand All @@ -36,18 +49,29 @@ class GluMlp(nn.Module):
""" MLP w/ GLU style gating
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.Sigmoid,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
assert hidden_features % 2 == 0
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.chunk_dim = 1 if use_conv else -1

self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1])
self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])

def init_weights(self):
Expand All @@ -58,7 +82,7 @@ def init_weights(self):

def forward(self, x):
x = self.fc1(x)
x, gates = x.chunk(2, dim=-1)
x, gates = x.chunk(2, dim=self.chunk_dim)
x = x * self.act(gates)
x = self.drop1(x)
x = self.fc2(x)
Expand All @@ -70,8 +94,15 @@ class GatedMlp(nn.Module):
""" MLP as used in gMLP
"""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
gate_layer=None, bias=True, drop=0.):
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
gate_layer=None,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
Expand Down Expand Up @@ -104,8 +135,15 @@ class ConvMlp(nn.Module):
""" MLP using 1x1 convs that keeps spatial dims
"""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
norm_layer=None, bias=True, drop=0.):
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.ReLU,
norm_layer=None,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
Expand All @@ -124,3 +162,40 @@ def forward(self, x):
x = self.drop(x)
x = self.fc2(x)
return x


class GlobalResponseNormMlp(nn.Module):
""" MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear

self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv)
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.grn(x)
x = self.fc2(x)
x = self.drop2(x)
return x
Loading