From 7d5698e18e7da11874d6e05f22ae987e9622cd14 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 20 Aug 2023 13:36:38 +0800 Subject: [PATCH] Add CaiT (#19) --- README.md | 11 +- tests/test_cait.py | 27 +++ tests/test_vit.py | 9 +- vision_toolbox/backbones/__init__.py | 1 + vision_toolbox/backbones/cait.py | 243 +++++++++++++++++++++++++++ vision_toolbox/backbones/convnext.py | 23 ++- vision_toolbox/backbones/swin.py | 60 ++++--- vision_toolbox/backbones/vit.py | 104 +++++++----- vision_toolbox/components.py | 12 ++ 9 files changed, 411 insertions(+), 79 deletions(-) create mode 100644 tests/test_cait.py create mode 100644 vision_toolbox/backbones/cait.py diff --git a/README.md b/README.md index ac45b43..b451d86 100644 --- a/README.md +++ b/README.md @@ -36,13 +36,20 @@ For object detection, usually only the last 4 feature maps are used. It is the r outputs = model.get_feature_maps(inputs)[-4:] # last 4 feature maps ``` -## Backbones +## Backbones with ported weights + +- ViT +- MLP-Mixer +- CaiT +- Swin +- ConvNeXt and ConvNeXt-V2 + +## Backbones trained by me Implemented backbones: - [Darknet](#darknet) - [VoVNet](#vovnet) -- [PatchConvNet](#patchconvnet) ### ImageNet pre-training diff --git a/tests/test_cait.py b/tests/test_cait.py new file mode 100644 index 0000000..41fca49 --- /dev/null +++ b/tests/test_cait.py @@ -0,0 +1,27 @@ +import timm +import torch + +from vision_toolbox.backbones import CaiT + + +def test_forward(): + m = CaiT.from_config("xxs_24", 224) + m(torch.randn(1, 3, 224, 224)) + + +def test_resize_pe(): + m = CaiT.from_config("xxs_24", 224) + m(torch.randn(1, 3, 224, 224)) + m.resize_pe(256) + m(torch.randn(1, 3, 256, 256)) + + +def test_from_pretrained(): + m = CaiT.from_config("xxs_24", 224, True).eval() + x = torch.randn(1, 3, 224, 224) + out = m(x) + + m_timm = timm.create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True, num_classes=0).eval() + out_timm = m_timm(x) + + torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5) diff --git a/tests/test_vit.py b/tests/test_vit.py index add5ae8..207bfc0 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -4,15 +4,20 @@ from vision_toolbox.backbones import ViT +def test_forward(): + m = ViT.from_config("Ti_16", 224) + m(torch.randn(1, 3, 224, 224)) + + def test_resize_pe(): - m = ViT.from_config("Ti", 16, 224) + m = ViT.from_config("Ti_16", 224) m(torch.randn(1, 3, 224, 224)) m.resize_pe(256) m(torch.randn(1, 3, 256, 256)) def test_from_pretrained(): - m = ViT.from_config("Ti", 16, 224, True).eval() + m = ViT.from_config("Ti_16", 224, True).eval() x = torch.randn(1, 3, 224, 224) out = m(x) diff --git a/vision_toolbox/backbones/__init__.py b/vision_toolbox/backbones/__init__.py index a724f7c..eeae4b1 100644 --- a/vision_toolbox/backbones/__init__.py +++ b/vision_toolbox/backbones/__init__.py @@ -1,3 +1,4 @@ +from .cait import CaiT from .convnext import ConvNeXt from .darknet import Darknet, DarknetYOLOv5 from .mlp_mixer import MLPMixer diff --git a/vision_toolbox/backbones/cait.py b/vision_toolbox/backbones/cait.py new file mode 100644 index 0000000..88346a0 --- /dev/null +++ b/vision_toolbox/backbones/cait.py @@ -0,0 +1,243 @@ +# https://arxiv.org/abs/2103.17239 +# https://github.com/facebookresearch/deit + +from __future__ import annotations + +from functools import partial + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from .base import _act, _norm +from .vit import MHA, ViTBlock + + +# basically attention pooling +class ClassAttention(MHA): + def forward(self, x: Tensor) -> None: + q = self.q_proj(x[:, 0]).unflatten(-1, (self.n_heads, -1)).unsqueeze(2) # (B, n_heads, 1, head_dim) + k = self.k_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim) + v = self.v_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) + + if hasattr(F, "scaled_dot_product_attention"): + out = F.scaled_dot_product_attention(q, k, v, None, self.dropout if self.training else 0.0) + else: + attn = (q * self.scale) @ k.transpose(-1, -2) + out = F.dropout(torch.softmax(attn, -1), self.dropout, self.training) @ v + + return self.out_proj(out.flatten(1)) # (B, n_heads, 1, head_dim) -> (B, n_heads * head_dim) + + +# does not support flash attention +class TalkingHeadAttention(MHA): + def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None: + super().__init__(d_model, n_heads, bias, dropout) + self.talking_head_proj = nn.Sequential( + nn.Conv2d(n_heads, n_heads, 1), # impl as 1x1 conv to avoid permutating data + nn.Softmax(-1), + nn.Conv2d(n_heads, n_heads, 1), + nn.Dropout(dropout), + ) + + def forward(self, x: Tensor) -> Tensor: + q = self.q_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim) + k = self.k_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) + v = self.v_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) + + attn = q @ (k * self.scale).transpose(-1, -2) + out = self.talking_head_proj(attn) @ v + out = out.transpose(-2, -3).flatten(-2) + out = self.out_proj(out) + return out + + +class CaiTCABlock(ViTBlock): + def __init__( + self, + d_model: int, + n_heads: int, + bias: bool = True, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + layer_scale_init: float | None = 1e-6, + stochastic_depth: float = 0.0, + norm: _norm = partial(nn.LayerNorm, eps=1e-6), + act: _act = nn.GELU, + ) -> None: + # fmt: off + super().__init__( + d_model, n_heads, bias, mlp_ratio, dropout, + layer_scale_init, stochastic_depth, norm, act, + partial(ClassAttention, d_model, n_heads, bias, dropout), + ) + # fmt: on + + def forward(self, x: Tensor, cls_token: Tensor) -> Tensor: + cls_token = cls_token + self.mha(torch.cat((cls_token, x), 1)) + cls_token = cls_token + self.mlp(cls_token) + return cls_token + + +class CaiTSABlock(ViTBlock): + def __init__( + self, + d_model: int, + n_heads: int, + bias: bool = True, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + layer_scale_init: float | None = 1e-6, + stochastic_depth: float = 0.0, + norm: _norm = partial(nn.LayerNorm, eps=1e-6), + act: _act = nn.GELU, + ) -> None: + # fmt: off + super().__init__( + d_model, n_heads, bias, mlp_ratio, dropout, + layer_scale_init, stochastic_depth, norm, act, + partial(TalkingHeadAttention, d_model, n_heads, bias, dropout), + ) + # fmt: on + + +class CaiT(nn.Module): + def __init__( + self, + d_model: int, + sa_depth: int, + ca_depth: int, + n_heads: int, + patch_size: int, + img_size: int, + bias: bool = True, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + layer_scale_init: float | None = 1e-6, + stochastic_depth: float = 0.0, + norm: _norm = partial(nn.LayerNorm, eps=1e-6), + act: _act = nn.GELU, + ) -> None: + assert img_size % patch_size == 0 + super().__init__() + self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) + self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) + self.pe = nn.Parameter(torch.empty(1, (img_size // patch_size) ** 2, d_model)) + nn.init.normal_(self.pe, 0, 0.02) + + self.sa_layers = nn.Sequential() + for _ in range(sa_depth): + block = CaiTSABlock( + d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act + ) + self.sa_layers.append(block) + + self.ca_layers = nn.ModuleList() + for _ in range(ca_depth): + block = CaiTCABlock( + d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act + ) + self.ca_layers.append(block) + + self.norm = norm(d_model) + + def forward(self, imgs: Tensor) -> Tensor: + patches = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) + patches = self.sa_layers(patches + self.pe) + + cls_token = self.cls_token + for block in self.ca_layers: + cls_token = block(patches, cls_token) + return self.norm(cls_token.squeeze(1)) + + @torch.no_grad() + def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: + old_size = int(self.pe.shape[1] ** 0.5) + new_size = size // self.patch_embed.weight.shape[2] + pe = self.pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2) + pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode) + pe = pe.permute(0, 2, 3, 1).flatten(1, 2) + self.pe = nn.Parameter(pe) + + @staticmethod + def from_config(variant: str, img_size: int, pretrained: bool = False) -> CaiT: + variant, sa_depth = variant.split("_") + + d_model = dict(xxs=192, xs=288, s=384, m=768)[variant] + sa_depth = int(sa_depth) + ca_depth = 2 + n_heads = d_model // 48 + patch_size = 16 + m = CaiT(d_model, sa_depth, ca_depth, n_heads, patch_size, img_size) + + if pretrained: + ckpt = dict( + xxs_24_224="XXS24_224.pth", + xxs_24_384="XXS24_384.pth", + xxs_36_224="XXS36_224.pth", + xxs_36_384="XXS36_384.pth", + xs_24_384="XS24_384.pth", + s_24_224="S24_224.pth", + s_24_384="S24_384.pth", + s_36_384="S36_384.pth", + m_36_384="M36_384.pth", + m_48_448="M48_448.pth", + )[f"{variant}_{sa_depth}_{img_size}"] + base_url = "https://dl.fbaipublicfiles.com/deit/" + state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"] + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + m.load_official_ckpt(state_dict) + + return m + + @torch.no_grad() + def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None: + def copy_(m: nn.Linear | nn.LayerNorm, prefix: str): + m.weight.copy_(state_dict.pop(prefix + ".weight").view(m.weight.shape)) + m.bias.copy_(state_dict.pop(prefix + ".bias")) + + copy_(self.patch_embed, "patch_embed.proj") + self.cls_token.copy_(state_dict.pop("cls_token")) + self.pe.copy_(state_dict.pop("pos_embed")) + + for i, sa_block in enumerate(self.sa_layers): + sa_block: CaiTSABlock + prefix = f"blocks.{i}." + + copy_(sa_block.mha[0], prefix + "norm1") + q_w, k_w, v_w = state_dict.pop(prefix + "attn.qkv.weight").chunk(3, 0) + sa_block.mha[1].q_proj.weight.copy_(q_w) + sa_block.mha[1].k_proj.weight.copy_(k_w) + sa_block.mha[1].v_proj.weight.copy_(v_w) + q_b, k_b, v_b = state_dict.pop(prefix + "attn.qkv.bias").chunk(3, 0) + sa_block.mha[1].q_proj.bias.copy_(q_b) + sa_block.mha[1].k_proj.bias.copy_(k_b) + sa_block.mha[1].v_proj.bias.copy_(v_b) + copy_(sa_block.mha[1].out_proj, prefix + "attn.proj") + copy_(sa_block.mha[1].talking_head_proj[0], prefix + "attn.proj_l") + copy_(sa_block.mha[1].talking_head_proj[2], prefix + "attn.proj_w") + sa_block.mha[2].gamma.copy_(state_dict.pop(prefix + "gamma_1")) + + copy_(sa_block.mlp[0], prefix + "norm2") + copy_(sa_block.mlp[1].linear1, prefix + "mlp.fc1") + copy_(sa_block.mlp[1].linear2, prefix + "mlp.fc2") + sa_block.mlp[2].gamma.copy_(state_dict.pop(prefix + "gamma_2")) + + for i, ca_block in enumerate(self.ca_layers): + ca_block: CaiTCABlock + prefix = f"blocks_token_only.{i}." + + copy_(ca_block.mha[0], prefix + "norm1") + copy_(ca_block.mha[1].q_proj, prefix + "attn.q") + copy_(ca_block.mha[1].k_proj, prefix + "attn.k") + copy_(ca_block.mha[1].v_proj, prefix + "attn.v") + copy_(ca_block.mha[1].out_proj, prefix + "attn.proj") + ca_block.mha[2].gamma.copy_(state_dict.pop(prefix + "gamma_1")) + + copy_(ca_block.mlp[0], prefix + "norm2") + copy_(ca_block.mlp[1].linear1, prefix + "mlp.fc1") + copy_(ca_block.mlp[1].linear2, prefix + "mlp.fc2") + ca_block.mlp[2].gamma.copy_(state_dict.pop(prefix + "gamma_2")) + + copy_(self.norm, "norm") + assert len(state_dict) == 2 diff --git a/vision_toolbox/backbones/convnext.py b/vision_toolbox/backbones/convnext.py index 67bb132..d9b04f2 100644 --- a/vision_toolbox/backbones/convnext.py +++ b/vision_toolbox/backbones/convnext.py @@ -10,7 +10,7 @@ import torch from torch import Tensor, nn -from ..components import Permute, StochasticDepth +from ..components import LayerScale, Permute, StochasticDepth from .base import BaseBackbone, _act, _norm @@ -34,12 +34,14 @@ def __init__( d_model: int, expansion_ratio: float = 4.0, bias: bool = True, - layer_scale_init: float = 1e-6, + layer_scale_init: float | None = 1e-6, stochastic_depth: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, v2: bool = False, ) -> None: + if v2: + layer_scale_init = None super().__init__() hidden_dim = int(d_model * expansion_ratio) self.layers = nn.Sequential( @@ -51,17 +53,12 @@ def __init__( act(), GlobalResponseNorm(hidden_dim) if v2 else nn.Identity(), nn.Linear(hidden_dim, d_model, bias=bias), + LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), + StochasticDepth(stochastic_depth), ) - self.layer_scale = ( - nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 and not v2 else None - ) - self.drop = StochasticDepth(stochastic_depth) def forward(self, x: Tensor) -> Tensor: - out = self.layers(x) - if self.layer_scale is not None: - out = out * self.layer_scale - return x + self.drop(out) + return x + self.layers(x) class ConvNeXt(BaseBackbone): @@ -71,7 +68,7 @@ def __init__( depths: tuple[int, ...], expansion_ratio: float = 4.0, bias: bool = True, - layer_scale_init: float = 1e-6, + layer_scale_init: float | None = 1e-6, stochastic_depth: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, @@ -187,8 +184,8 @@ def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str): block.layers[6].beta.copy_(state_dict.pop(prefix + "grn.beta").squeeze()) copy_(block.layers[7], prefix + "pwconv2") - if block.layer_scale is not None: - block.layer_scale.copy_(state_dict.pop(prefix + "gamma")) + if isinstance(block.layers[8], LayerScale): + block.layers[8].gamma.copy_(state_dict.pop(prefix + "gamma")) # FCMAE checkpoints don't contain head norm if "norm.weight" in state_dict: diff --git a/vision_toolbox/backbones/swin.py b/vision_toolbox/backbones/swin.py index 6de5beb..1bd12d4 100644 --- a/vision_toolbox/backbones/swin.py +++ b/vision_toolbox/backbones/swin.py @@ -4,12 +4,13 @@ from __future__ import annotations import itertools +from functools import partial import torch from torch import Tensor, nn from .base import BaseBackbone, _act, _norm -from .vit import MHA, MLP +from .vit import MHA, ViTBlock def window_partition(x: Tensor, window_size: int) -> tuple[Tensor, int, int]: @@ -85,7 +86,7 @@ def forward(self, x: Tensor) -> Tensor: return x -class SwinBlock(nn.Module): +class SwinBlock(ViTBlock): def __init__( self, input_size: int, @@ -96,19 +97,18 @@ def __init__( mlp_ratio: float = 4.0, bias: bool = True, dropout: float = 0.0, + layer_scale_init: float | None = None, + stochastic_depth: float = 0.0, norm: _norm = nn.LayerNorm, act: _act = nn.GELU, ) -> None: - super().__init__() - self.norm1 = norm(d_model) - self.mha = WindowAttention(input_size, d_model, n_heads, window_size, shift, bias, dropout) - self.norm2 = norm(d_model) - self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act) - - def forward(self, x: Tensor) -> Tensor: - x = x + self.mha(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) - return x + # fmt: off + super().__init__( + d_model, n_heads, bias, mlp_ratio, dropout, + layer_scale_init, stochastic_depth, norm, act, + partial(WindowAttention, input_size, d_model, n_heads, window_size, shift, bias, dropout), + ) + # fmt: on class PatchMerging(nn.Module): @@ -137,6 +137,8 @@ def __init__( mlp_ratio: float = 4.0, bias: bool = True, dropout: float = 0.0, + layer_scale_init: float | None = None, + stochastic_depth: float = 0.0, norm: _norm = nn.LayerNorm, act: _act = nn.GELU, ) -> None: @@ -162,7 +164,12 @@ def __init__( for i in range(depth): shift = (i % 2) and input_size > window_size - block = SwinBlock(input_size, d_model, n_heads, window_size, shift, mlp_ratio, bias, dropout, norm, act) + # fmt: off + block = SwinBlock( + input_size, d_model, n_heads, window_size, shift, mlp_ratio, + bias, dropout, layer_scale_init, stochastic_depth, norm, act, + ) + # fmt: on stage.append(block) self.stages.append(stage) @@ -234,18 +241,25 @@ def rearrange(p): prefix = f"layers.{stage_idx}.blocks.{block_idx - 1}." block_idx += 1 - if block.mha.attn_mask is not None: - torch.testing.assert_close(block.mha.attn_mask, state_dict.pop(prefix + "attn_mask")) + copy_(block.mha[0], prefix + "norm1") + if block.mha[1].attn_mask is not None: + torch.testing.assert_close(block.mha[1].attn_mask, state_dict.pop(prefix + "attn_mask")) torch.testing.assert_close( - block.mha.relative_pe_index, state_dict.pop(prefix + "attn.relative_position_index") + block.mha[1].relative_pe_index, state_dict.pop(prefix + "attn.relative_position_index") ) - copy_(block.norm1, prefix + "norm1") - copy_(block.mha.in_proj, prefix + "attn.qkv") - copy_(block.mha.out_proj, prefix + "attn.proj") - block.mha.relative_pe_table.copy_(state_dict.pop(prefix + "attn.relative_position_bias_table").T) - copy_(block.norm2, prefix + "norm2") - copy_(block.mlp.linear1, prefix + "mlp.fc1") - copy_(block.mlp.linear2, prefix + "mlp.fc2") + q_w, k_w, v_w = state_dict.pop(prefix + "attn.qkv.weight").chunk(3, 0) + block.mha[1].q_proj.weight.copy_(q_w) + block.mha[1].k_proj.weight.copy_(k_w) + block.mha[1].v_proj.weight.copy_(v_w) + q_b, k_b, v_b = state_dict.pop(prefix + "attn.qkv.bias").chunk(3, 0) + block.mha[1].q_proj.bias.copy_(q_b) + block.mha[1].k_proj.bias.copy_(k_b) + block.mha[1].v_proj.bias.copy_(v_b) + copy_(block.mha[1].out_proj, prefix + "attn.proj") + block.mha[1].relative_pe_table.copy_(state_dict.pop(prefix + "attn.relative_position_bias_table").T) + copy_(block.mlp[0], prefix + "norm2") + copy_(block.mlp[1].linear1, prefix + "mlp.fc1") + copy_(block.mlp[1].linear2, prefix + "mlp.fc2") copy_(self.head_norm, "norm") assert len(state_dict) == 2 # head.weight and head.bias diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index c1e4ed8..bcd2dbd 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from torch import Tensor, nn +from ..components import LayerScale, StochasticDepth from ..utils import torch_hub_download from .base import _act, _norm @@ -19,15 +20,19 @@ class MHA(nn.Module): def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None: super().__init__() - self.in_proj = nn.Linear(d_model, d_model * 3, bias) + self.q_proj = nn.Linear(d_model, d_model, bias) + self.k_proj = nn.Linear(d_model, d_model, bias) + self.v_proj = nn.Linear(d_model, d_model, bias) self.out_proj = nn.Linear(d_model, d_model, bias) self.n_heads = n_heads self.dropout = dropout self.scale = (d_model // n_heads) ** (-0.5) def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: - qkv = self.in_proj(x) - q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3) # (B, n_heads, L, head_dim) + q = self.q_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim) + k = self.k_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) + v = self.v_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) + if hasattr(F, "scaled_dot_product_attention"): out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0) else: @@ -58,26 +63,39 @@ def __init__( bias: bool = True, mlp_ratio: float = 4.0, dropout: float = 0.0, + layer_scale_init: float | None = None, + stochastic_depth: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, + attention: type[nn.Module] | None = None, ) -> None: + if attention is None: + attention = partial(MHA, d_model, n_heads, bias, dropout) super().__init__() - self.norm1 = norm(d_model) - self.mha = MHA(d_model, n_heads, bias, dropout) - self.norm2 = norm(d_model) - self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act) + self.mha = nn.Sequential( + norm(d_model), + attention(), + LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), + StochasticDepth(stochastic_depth), + ) + self.mlp = nn.Sequential( + norm(d_model), + MLP(d_model, int(d_model * mlp_ratio), dropout, act), + LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), + StochasticDepth(stochastic_depth), + ) def forward(self, x: Tensor) -> Tensor: - x = x + self.mha(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) + x = x + self.mha(x) + x = x + self.mlp(x) return x class ViT(nn.Module): def __init__( self, - n_layers: int, d_model: int, + depth: int, n_heads: int, patch_size: int, img_size: int, @@ -85,6 +103,8 @@ def __init__( bias: bool = True, mlp_ratio: float = 4.0, dropout: float = 0.0, + layer_scale_init: float | None = None, + stochastic_depth: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, ) -> None: @@ -99,15 +119,17 @@ def __init__( self.pe = nn.Parameter(torch.empty(1, pe_size, d_model)) nn.init.normal_(self.pe, 0, 0.02) - self.layers = nn.Sequential( - *[ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, norm, act) for _ in range(n_layers)] - ) + self.layers = nn.Sequential() + for _ in range(depth): + block = ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act) + self.layers.append(block) + self.norm = norm(d_model) def forward(self, imgs: Tensor) -> Tensor: out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) if self.cls_token is not None: - out = torch.cat([self.cls_token.expand(out.shape[0], -1, -1), out], 1) + out = torch.cat([self.cls_token, out], 1) out = self.layers(out + self.pe) out = self.norm(out) out = out[:, 0] if self.cls_token is not None else out.mean(1) @@ -129,17 +151,21 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: self.pe = nn.Parameter(pe) @staticmethod - def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> ViT: - n_layers, d_model, n_heads = dict( - Ti=(12, 192, 3), - S=(12, 384, 6), - B=(12, 768, 12), - L=(24, 1024, 16), - H=(32, 1280, 16), + def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT: + variant, patch_size = variant.split("_") + + d_model, depth, n_heads = dict( + Ti=(192, 12, 3), + S=(384, 12, 6), + B=(768, 12, 12), + L=(1024, 24, 16), + H=(1280, 32, 16), )[variant] - m = ViT(n_layers, d_model, n_heads, patch_size, img_size) + patch_size = int(patch_size) + m = ViT(d_model, depth, n_heads, patch_size, img_size) if pretrained: + assert img_size == 224 ckpt = { ("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", ("S", 32): "S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz", @@ -150,8 +176,6 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = }[(variant, patch_size)] base_url = "https://storage.googleapis.com/vit_models/augreg/" m.load_jax_weights(torch_hub_download(base_url + ckpt)) - if img_size != 224: - m.resize_pe(img_size) return m @@ -173,21 +197,23 @@ def get_w(key: str) -> Tensor: prefix = f"Transformer/encoderblock_{idx}/" mha_prefix = prefix + "MultiHeadDotProductAttention_1/" - layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale")) - layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias")) - w = torch.stack([get_w(mha_prefix + x + "/kernel") for x in ["query", "key", "value"]], 1) - b = torch.stack([get_w(mha_prefix + x + "/bias") for x in ["query", "key", "value"]], 0) - layer.mha.in_proj.weight.copy_(w.flatten(1).T) - layer.mha.in_proj.bias.copy_(b.flatten()) - layer.mha.out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T) - layer.mha.out_proj.bias.copy_(get_w(mha_prefix + "out/bias")) - - layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_2/scale")) - layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_2/bias")) - layer.mlp.linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T) - layer.mlp.linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias")) - layer.mlp.linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T) - layer.mlp.linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias")) + layer.mha[0].weight.copy_(get_w(prefix + "LayerNorm_0/scale")) + layer.mha[0].bias.copy_(get_w(prefix + "LayerNorm_0/bias")) + layer.mha[1].q_proj.weight.copy_(get_w(mha_prefix + "query/kernel").flatten(1).T) + layer.mha[1].k_proj.weight.copy_(get_w(mha_prefix + "key/kernel").flatten(1).T) + layer.mha[1].v_proj.weight.copy_(get_w(mha_prefix + "value/kernel").flatten(1).T) + layer.mha[1].q_proj.bias.copy_(get_w(mha_prefix + "query/bias").flatten()) + layer.mha[1].k_proj.bias.copy_(get_w(mha_prefix + "key/bias").flatten()) + layer.mha[1].v_proj.bias.copy_(get_w(mha_prefix + "value/bias").flatten()) + layer.mha[1].out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T) + layer.mha[1].out_proj.bias.copy_(get_w(mha_prefix + "out/bias")) + + layer.mlp[0].weight.copy_(get_w(prefix + "LayerNorm_2/scale")) + layer.mlp[0].bias.copy_(get_w(prefix + "LayerNorm_2/bias")) + layer.mlp[1].linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T) + layer.mlp[1].linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias")) + layer.mlp[1].linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T) + layer.mlp[1].linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias")) self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale")) self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias")) diff --git a/vision_toolbox/components.py b/vision_toolbox/components.py index d813166..455cacd 100644 --- a/vision_toolbox/components.py +++ b/vision_toolbox/components.py @@ -178,3 +178,15 @@ def forward(self, x: Tensor) -> Tensor: def extra_repr(self) -> str: return f"p={self.p}" + + +class LayerScale(nn.Module): + def __init__(self, dim: int, init: float) -> None: + super().__init__() + self.gamma = nn.Parameter(torch.full((dim,), init)) + + def forward(self, x: Tensor) -> Tensor: + return x * self.gamma + + def extra_repr(self) -> str: + return f"gamma={self.gamma}"