Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove custom norm and act #21

Merged
merged 1 commit into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 10 additions & 18 deletions vision_toolbox/backbones/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch.nn.functional as F
from torch import Tensor, nn

from .base import _act, _norm
from .vit import MHA, ViT, ViTBlock


Expand Down Expand Up @@ -62,13 +61,12 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
) -> None:
# fmt: off
super().__init__(
d_model, n_heads, bias, mlp_ratio, dropout,
layer_scale_init, stochastic_depth, norm, act,
layer_scale_init, stochastic_depth, norm_eps,
partial(ClassAttention, d_model, n_heads, bias, dropout),
)
# fmt: on
Expand All @@ -89,13 +87,12 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
) -> None:
# fmt: off
super().__init__(
d_model, n_heads, bias, mlp_ratio, dropout,
layer_scale_init, stochastic_depth, norm, act,
layer_scale_init, stochastic_depth, norm_eps,
partial(TalkingHeadAttention, d_model, n_heads, bias, dropout),
)
# fmt: on
Expand All @@ -115,8 +112,7 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
) -> None:
assert img_size % patch_size == 0
super().__init__()
Expand All @@ -127,19 +123,15 @@ def __init__(

self.sa_layers = nn.Sequential()
for _ in range(sa_depth):
block = CaiTSABlock(
d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act
)
self.sa_layers.append(block)
blk = CaiTSABlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps)
self.sa_layers.append(blk)

self.ca_layers = nn.ModuleList()
for _ in range(ca_depth):
block = CaiTCABlock(
d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act
)
self.ca_layers.append(block)
blk = CaiTCABlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps)
self.ca_layers.append(blk)

self.norm = norm(d_model)
self.norm = nn.LayerNorm(d_model, norm_eps)

def forward(self, imgs: Tensor) -> Tensor:
patches = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
Expand Down
26 changes: 11 additions & 15 deletions vision_toolbox/backbones/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@

from __future__ import annotations

from functools import partial

import torch
from torch import Tensor, nn

from ..components import LayerScale, Permute, StochasticDepth
from .base import BaseBackbone, _act, _norm
from .base import BaseBackbone


class GlobalResponseNorm(nn.Module):
Expand All @@ -36,8 +34,7 @@ def __init__(
bias: bool = True,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
v2: bool = False,
) -> None:
if v2:
Expand All @@ -48,9 +45,9 @@ def __init__(
Permute(0, 3, 1, 2),
nn.Conv2d(d_model, d_model, 7, padding=3, groups=d_model, bias=bias),
Permute(0, 2, 3, 1),
norm(d_model),
nn.LayerNorm(d_model, norm_eps),
nn.Linear(d_model, hidden_dim, bias=bias),
act(),
nn.GELU(),
GlobalResponseNorm(hidden_dim) if v2 else nn.Identity(),
nn.Linear(hidden_dim, d_model, bias=bias),
LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(),
Expand All @@ -70,12 +67,11 @@ def __init__(
bias: bool = True,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
v2: bool = False,
) -> None:
super().__init__()
self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), norm(d_model))
self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), nn.LayerNorm(d_model, norm_eps))

stochastic_depth_rates = torch.linspace(0, stochastic_depth, sum(depths))
self.stages = nn.Sequential()
Expand All @@ -85,7 +81,7 @@ def __init__(
if stage_idx > 0:
# equivalent to PatchMerging in SwinTransformer
downsample = nn.Sequential(
norm(d_model),
nn.LayerNorm(d_model, norm_eps),
Permute(0, 3, 1, 2),
nn.Conv2d(d_model, d_model * 2, 2, 2),
Permute(0, 2, 3, 1),
Expand All @@ -97,12 +93,12 @@ def __init__(

for block_idx in range(depth):
rate = stochastic_depth_rates[sum(depths[:stage_idx]) + block_idx]
block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act, v2)
block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm_eps, v2)
stage.append(block)

self.stages.append(stage)

self.head_norm = norm(d_model)
self.norm = nn.LayerNorm(d_model, norm_eps)

def get_feature_maps(self, x: Tensor) -> list[Tensor]:
out = [self.stem(x)]
Expand All @@ -111,7 +107,7 @@ def get_feature_maps(self, x: Tensor) -> list[Tensor]:
return out[-1:]

def forward(self, x: Tensor) -> Tensor:
return self.head_norm(self.get_feature_maps(x)[-1].mean((1, 2)))
return self.norm(self.get_feature_maps(x)[-1].mean((1, 2)))

@staticmethod
def from_config(variant: str, v2: bool = False, pretrained: bool = False) -> ConvNeXt:
Expand Down Expand Up @@ -189,7 +185,7 @@ def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str):

# FCMAE checkpoints don't contain head norm
if "norm.weight" in state_dict:
copy_(self.head_norm, "norm")
copy_(self.norm, "norm")
assert len(state_dict) == 2
else:
assert len(state_dict) == 0
13 changes: 4 additions & 9 deletions vision_toolbox/backbones/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@

from __future__ import annotations

from functools import partial

import torch
from torch import Tensor, nn

from ..components import LayerScale
from .base import _act, _norm
from .vit import ViT, ViTBlock


Expand All @@ -27,13 +24,12 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = None,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
) -> None:
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio,
dropout, layer_scale_init, stochastic_depth, norm, act
dropout, layer_scale_init, stochastic_depth, norm_eps
)
# fmt: on
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
Expand Down Expand Up @@ -133,13 +129,12 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
):
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, cls_token, bias,
mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act
mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps,
)
# fmt: on

Expand Down
20 changes: 8 additions & 12 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@

from __future__ import annotations

from functools import partial
from typing import Mapping

import numpy as np
import torch
from torch import Tensor, nn

from ..utils import torch_hub_download
from .base import _act, _norm
from .vit import MLP


Expand All @@ -22,15 +20,14 @@ def __init__(
d_model: int,
mlp_ratio: tuple[int, int] = (0.5, 4.0),
dropout: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
) -> None:
tokens_mlp_dim, channels_mlp_dim = [int(d_model * ratio) for ratio in mlp_ratio]
super().__init__()
self.norm1 = norm(d_model)
self.token_mixing = MLP(n_tokens, tokens_mlp_dim, dropout, act)
self.norm2 = norm(d_model)
self.channel_mixing = MLP(d_model, channels_mlp_dim, dropout, act)
self.norm1 = nn.LayerNorm(d_model, norm_eps)
self.token_mixing = MLP(n_tokens, tokens_mlp_dim, dropout)
self.norm2 = nn.LayerNorm(d_model, norm_eps)
self.channel_mixing = MLP(d_model, channels_mlp_dim, dropout)

def forward(self, x: Tensor) -> Tensor:
# x -> (B, n_tokens, d_model)
Expand All @@ -48,17 +45,16 @@ def __init__(
img_size: int,
mlp_ratio: tuple[float, float] = (0.5, 4.0),
dropout: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
norm_eps: float = 1e-6,
) -> None:
assert img_size % patch_size == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
n_tokens = (img_size // patch_size) ** 2
self.layers = nn.Sequential(
*[MixerBlock(n_tokens, d_model, mlp_ratio, dropout, norm, act) for _ in range(n_layers)]
*[MixerBlock(n_tokens, d_model, mlp_ratio, dropout, norm_eps) for _ in range(n_layers)]
)
self.norm = norm(d_model)
self.norm = nn.LayerNorm(d_model, norm_eps)

def forward(self, x: Tensor) -> Tensor:
x = self.patch_embed(x).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
Expand Down
30 changes: 14 additions & 16 deletions vision_toolbox/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from torch import Tensor, nn

from .base import BaseBackbone, _act, _norm
from .base import BaseBackbone
from .vit import MHA, ViTBlock


Expand Down Expand Up @@ -99,22 +99,21 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = None,
stochastic_depth: float = 0.0,
norm: _norm = nn.LayerNorm,
act: _act = nn.GELU,
norm_eps: float = 1e-5,
) -> None:
# fmt: off
super().__init__(
d_model, n_heads, bias, mlp_ratio, dropout,
layer_scale_init, stochastic_depth, norm, act,
layer_scale_init, stochastic_depth, norm_eps,
partial(WindowAttention, input_size, d_model, n_heads, window_size, shift, bias, dropout),
)
# fmt: on


class PatchMerging(nn.Module):
def __init__(self, d_model: int, norm: _norm = nn.LayerNorm) -> None:
def __init__(self, d_model: int, norm_eps: float = 1e-5) -> None:
super().__init__()
self.norm = norm(d_model * 4)
self.norm = nn.LayerNorm(d_model * 4, norm_eps)
self.reduction = nn.Linear(d_model * 4, d_model * 2, False)

def forward(self, x: Tensor) -> Tensor:
Expand All @@ -139,22 +138,21 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = None,
stochastic_depth: float = 0.0,
norm: _norm = nn.LayerNorm,
act: _act = nn.GELU,
norm_eps: float = 1e-5,
) -> None:
assert img_size % patch_size == 0
assert d_model % n_heads == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
self.norm = norm(d_model)
self.patch_norm = nn.LayerNorm(d_model, norm_eps)
self.dropout = nn.Dropout(dropout)

input_size = img_size // patch_size
self.stages = nn.Sequential()
for i, (depth, window_size) in enumerate(zip(depths, window_sizes)):
stage = nn.Sequential()
if i > 0:
downsample = PatchMerging(d_model, norm)
downsample = PatchMerging(d_model, norm_eps)
input_size //= 2
d_model *= 2
n_heads *= 2
Expand All @@ -167,23 +165,23 @@ def __init__(
# fmt: off
block = SwinBlock(
input_size, d_model, n_heads, window_size, shift, mlp_ratio,
bias, dropout, layer_scale_init, stochastic_depth, norm, act,
bias, dropout, layer_scale_init, stochastic_depth, norm_eps,
)
# fmt: on
stage.append(block)

self.stages.append(stage)

self.head_norm = norm(d_model)
self.norm = nn.LayerNorm(d_model, norm_eps)

def get_feature_maps(self, x: Tensor) -> list[Tensor]:
out = [self.dropout(self.norm(self.patch_embed(x).permute(0, 2, 3, 1)))]
out = [self.dropout(self.patch_norm(self.patch_embed(x).permute(0, 2, 3, 1)))]
for stage in self.stages:
out.append(stage(out[-1]))
return out[1:]

def forward(self, x: Tensor) -> Tensor:
return self.head_norm(self.get_feature_maps(x)[-1]).mean((1, 2))
return self.norm(self.get_feature_maps(x)[-1]).mean((1, 2))

def resize_pe(self, img_size: int) -> None:
raise NotImplementedError()
Expand Down Expand Up @@ -222,7 +220,7 @@ def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None:
m.bias.copy_(state_dict.pop(prefix + ".bias"))

copy_(self.patch_embed, "patch_embed.proj")
copy_(self.norm, "patch_embed.norm")
copy_(self.patch_norm, "patch_embed.norm")

for stage_idx, stage in enumerate(self.stages):
if stage_idx > 0:
Expand Down Expand Up @@ -261,5 +259,5 @@ def rearrange(p):
copy_(block.mlp[1].linear1, prefix + "mlp.fc1")
copy_(block.mlp[1].linear2, prefix + "mlp.fc2")

copy_(self.head_norm, "norm")
copy_(self.norm, "norm")
assert len(state_dict) == 2 # head.weight and head.bias
Loading