Skip to content

Commit

Permalink
remove options to use custom norm and activation
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent bb7e4f7 commit 3658016
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 82 deletions.
28 changes: 10 additions & 18 deletions vision_toolbox/backbones/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch.nn.functional as F
from torch import Tensor, nn

from .base import _act, _norm
from .vit import MHA, ViT, ViTBlock


Expand Down Expand Up @@ -62,13 +61,12 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
) -> None:
# fmt: off
super().__init__(
d_model, n_heads, bias, mlp_ratio, dropout,
layer_scale_init, stochastic_depth, norm, act,
layer_scale_init, stochastic_depth, norm_eps,
partial(ClassAttention, d_model, n_heads, bias, dropout),
)
# fmt: on
Expand All @@ -89,13 +87,12 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
) -> None:
# fmt: off
super().__init__(
d_model, n_heads, bias, mlp_ratio, dropout,
layer_scale_init, stochastic_depth, norm, act,
layer_scale_init, stochastic_depth, norm_eps,
partial(TalkingHeadAttention, d_model, n_heads, bias, dropout),
)
# fmt: on
Expand All @@ -115,8 +112,7 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
) -> None:
assert img_size % patch_size == 0
super().__init__()
Expand All @@ -127,19 +123,15 @@ def __init__(

self.sa_layers = nn.Sequential()
for _ in range(sa_depth):
block = CaiTSABlock(
d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act
)
self.sa_layers.append(block)
blk = CaiTSABlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps)
self.sa_layers.append(blk)

self.ca_layers = nn.ModuleList()
for _ in range(ca_depth):
block = CaiTCABlock(
d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act
)
self.ca_layers.append(block)
blk = CaiTCABlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps)
self.ca_layers.append(blk)

self.norm = norm(d_model)
self.norm = nn.LayerNorm(d_model, norm_eps)

def forward(self, imgs: Tensor) -> Tensor:
patches = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
Expand Down
26 changes: 11 additions & 15 deletions vision_toolbox/backbones/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@

from __future__ import annotations

from functools import partial

import torch
from torch import Tensor, nn

from ..components import LayerScale, Permute, StochasticDepth
from .base import BaseBackbone, _act, _norm
from .base import BaseBackbone


class GlobalResponseNorm(nn.Module):
Expand All @@ -36,8 +34,7 @@ def __init__(
bias: bool = True,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
v2: bool = False,
) -> None:
if v2:
Expand All @@ -48,9 +45,9 @@ def __init__(
Permute(0, 3, 1, 2),
nn.Conv2d(d_model, d_model, 7, padding=3, groups=d_model, bias=bias),
Permute(0, 2, 3, 1),
norm(d_model),
nn.LayerNorm(d_model, norm_eps),
nn.Linear(d_model, hidden_dim, bias=bias),
act(),
nn.GELU(),
GlobalResponseNorm(hidden_dim) if v2 else nn.Identity(),
nn.Linear(hidden_dim, d_model, bias=bias),
LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(),
Expand All @@ -70,12 +67,11 @@ def __init__(
bias: bool = True,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
v2: bool = False,
) -> None:
super().__init__()
self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), norm(d_model))
self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), nn.LayerNorm(d_model, norm_eps))

stochastic_depth_rates = torch.linspace(0, stochastic_depth, sum(depths))
self.stages = nn.Sequential()
Expand All @@ -85,7 +81,7 @@ def __init__(
if stage_idx > 0:
# equivalent to PatchMerging in SwinTransformer
downsample = nn.Sequential(
norm(d_model),
nn.LayerNorm(d_model, norm_eps),
Permute(0, 3, 1, 2),
nn.Conv2d(d_model, d_model * 2, 2, 2),
Permute(0, 2, 3, 1),
Expand All @@ -97,12 +93,12 @@ def __init__(

for block_idx in range(depth):
rate = stochastic_depth_rates[sum(depths[:stage_idx]) + block_idx]
block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act, v2)
block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm_eps, v2)
stage.append(block)

self.stages.append(stage)

self.head_norm = norm(d_model)
self.norm = nn.LayerNorm(d_model, norm_eps)

def get_feature_maps(self, x: Tensor) -> list[Tensor]:
out = [self.stem(x)]
Expand All @@ -111,7 +107,7 @@ def get_feature_maps(self, x: Tensor) -> list[Tensor]:
return out[-1:]

def forward(self, x: Tensor) -> Tensor:
return self.head_norm(self.get_feature_maps(x)[-1].mean((1, 2)))
return self.norm(self.get_feature_maps(x)[-1].mean((1, 2)))

@staticmethod
def from_config(variant: str, v2: bool = False, pretrained: bool = False) -> ConvNeXt:
Expand Down Expand Up @@ -189,7 +185,7 @@ def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str):

# FCMAE checkpoints don't contain head norm
if "norm.weight" in state_dict:
copy_(self.head_norm, "norm")
copy_(self.norm, "norm")
assert len(state_dict) == 2
else:
assert len(state_dict) == 0
13 changes: 4 additions & 9 deletions vision_toolbox/backbones/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@

from __future__ import annotations

from functools import partial

import torch
from torch import Tensor, nn

from ..components import LayerScale
from .base import _act, _norm
from .vit import ViT, ViTBlock


Expand All @@ -27,13 +24,12 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = None,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
) -> None:
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio,
dropout, layer_scale_init, stochastic_depth, norm, act
dropout, layer_scale_init, stochastic_depth, norm_eps
)
# fmt: on
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
Expand Down Expand Up @@ -133,13 +129,12 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
):
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, cls_token, bias,
mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act
mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps,
)
# fmt: on

Expand Down
20 changes: 8 additions & 12 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@

from __future__ import annotations

from functools import partial
from typing import Mapping

import numpy as np
import torch
from torch import Tensor, nn

from ..utils import torch_hub_download
from .base import _act, _norm
from .vit import MLP


Expand All @@ -22,15 +20,14 @@ def __init__(
d_model: int,
mlp_ratio: tuple[int, int] = (0.5, 4.0),
dropout: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
) -> None:
tokens_mlp_dim, channels_mlp_dim = [int(d_model * ratio) for ratio in mlp_ratio]
super().__init__()
self.norm1 = norm(d_model)
self.token_mixing = MLP(n_tokens, tokens_mlp_dim, dropout, act)
self.norm2 = norm(d_model)
self.channel_mixing = MLP(d_model, channels_mlp_dim, dropout, act)
self.norm1 = nn.LayerNorm(d_model, norm_eps)
self.token_mixing = MLP(n_tokens, tokens_mlp_dim, dropout)
self.norm2 = nn.LayerNorm(d_model, norm_eps)
self.channel_mixing = MLP(d_model, channels_mlp_dim, dropout)

def forward(self, x: Tensor) -> Tensor:
# x -> (B, n_tokens, d_model)
Expand All @@ -48,17 +45,16 @@ def __init__(
img_size: int,
mlp_ratio: tuple[float, float] = (0.5, 4.0),
dropout: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
) -> None:
assert img_size % patch_size == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
n_tokens = (img_size // patch_size) ** 2
self.layers = nn.Sequential(
*[MixerBlock(n_tokens, d_model, mlp_ratio, dropout, norm, act) for _ in range(n_layers)]
*[MixerBlock(n_tokens, d_model, mlp_ratio, dropout, norm_eps) for _ in range(n_layers)]
)
self.norm = norm(d_model)
self.norm = nn.LayerNorm(d_model, norm_eps)

def forward(self, x: Tensor) -> Tensor:
x = self.patch_embed(x).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
Expand Down
30 changes: 14 additions & 16 deletions vision_toolbox/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from torch import Tensor, nn

from .base import BaseBackbone, _act, _norm
from .base import BaseBackbone
from .vit import MHA, ViTBlock


Expand Down Expand Up @@ -99,22 +99,21 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = None,
stochastic_depth: float = 0.0,
norm: _norm = nn.LayerNorm,
act: _act = nn.GELU,
norm_eps: float = 1e-5,
) -> None:
# fmt: off
super().__init__(
d_model, n_heads, bias, mlp_ratio, dropout,
layer_scale_init, stochastic_depth, norm, act,
layer_scale_init, stochastic_depth, norm_eps,
partial(WindowAttention, input_size, d_model, n_heads, window_size, shift, bias, dropout),
)
# fmt: on


class PatchMerging(nn.Module):
def __init__(self, d_model: int, norm: _norm = nn.LayerNorm) -> None:
def __init__(self, d_model: int, norm_eps: float = 1e-5) -> None:
super().__init__()
self.norm = norm(d_model * 4)
self.norm = nn.LayerNorm(d_model * 4, norm_eps)
self.reduction = nn.Linear(d_model * 4, d_model * 2, False)

def forward(self, x: Tensor) -> Tensor:
Expand All @@ -139,22 +138,21 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = None,
stochastic_depth: float = 0.0,
norm: _norm = nn.LayerNorm,
act: _act = nn.GELU,
norm_eps: float = 1e-5,
) -> None:
assert img_size % patch_size == 0
assert d_model % n_heads == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
self.norm = norm(d_model)
self.patch_norm = nn.LayerNorm(d_model, norm_eps)
self.dropout = nn.Dropout(dropout)

input_size = img_size // patch_size
self.stages = nn.Sequential()
for i, (depth, window_size) in enumerate(zip(depths, window_sizes)):
stage = nn.Sequential()
if i > 0:
downsample = PatchMerging(d_model, norm)
downsample = PatchMerging(d_model, norm_eps)
input_size //= 2
d_model *= 2
n_heads *= 2
Expand All @@ -167,23 +165,23 @@ def __init__(
# fmt: off
block = SwinBlock(
input_size, d_model, n_heads, window_size, shift, mlp_ratio,
bias, dropout, layer_scale_init, stochastic_depth, norm, act,
bias, dropout, layer_scale_init, stochastic_depth, norm_eps,
)
# fmt: on
stage.append(block)

self.stages.append(stage)

self.head_norm = norm(d_model)
self.norm = nn.LayerNorm(d_model, norm_eps)

def get_feature_maps(self, x: Tensor) -> list[Tensor]:
out = [self.dropout(self.norm(self.patch_embed(x).permute(0, 2, 3, 1)))]
out = [self.dropout(self.patch_norm(self.patch_embed(x).permute(0, 2, 3, 1)))]
for stage in self.stages:
out.append(stage(out[-1]))
return out[1:]

def forward(self, x: Tensor) -> Tensor:
return self.head_norm(self.get_feature_maps(x)[-1]).mean((1, 2))
return self.norm(self.get_feature_maps(x)[-1]).mean((1, 2))

def resize_pe(self, img_size: int) -> None:
raise NotImplementedError()
Expand Down Expand Up @@ -222,7 +220,7 @@ def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None:
m.bias.copy_(state_dict.pop(prefix + ".bias"))

copy_(self.patch_embed, "patch_embed.proj")
copy_(self.norm, "patch_embed.norm")
copy_(self.patch_norm, "patch_embed.norm")

for stage_idx, stage in enumerate(self.stages):
if stage_idx > 0:
Expand Down Expand Up @@ -261,5 +259,5 @@ def rearrange(p):
copy_(block.mlp[1].linear1, prefix + "mlp.fc1")
copy_(block.mlp[1].linear2, prefix + "mlp.fc2")

copy_(self.head_norm, "norm")
copy_(self.norm, "norm")
assert len(state_dict) == 2 # head.weight and head.bias
Loading

0 comments on commit 3658016

Please sign in to comment.