In [1]:
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

In [2]:
import torch
import torch.nn as nn

from swin.transformer import SwinBlock, PatchMerging, StageModule, swin_t

# Swin Block

Each swin block contains a residual window attention and feed forward layer. Two successive swin blocks will create a transformer stage.

![Swin Block](swin-block.png)

In [3]:
img_height = 224
img_width = 224

in_channels = 3
hidden_dim = 96
patch_size = 4 # Same as downscaling factor
window_size = 7
head_dim = 32
num_heads = 3

In [4]:
patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dim, downscaling_factor=patch_size)
normal_block = SwinBlock(embed_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dim * 4, shifted=False,
                         window_size=window_size, relative_pos_embedding=True)
shifted_block = SwinBlock(embed_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dim * 4, shifted=True,
                          window_size=window_size, relative_pos_embedding=True)

In [5]:
x = torch.rand(1, in_channels, img_height, img_width)
x = patch_partition(x)
print("Patch partition", x.shape)
x = normal_block(x)
print("Swin Block", x.shape)
x = shifted_block(x)
print("Shifted Swin Block", x.shape)

Patch partition torch.Size([1, 56, 56, 96])
Swin Block torch.Size([1, 56, 56, 96])
Shifted Swin Block torch.Size([1, 56, 56, 96])


Now we have 49 by 49 patches, each patch has 96 hidden dimension, after factoring in attention and position embedding.

# Transformer Stage

Each stage can contain multiple Swin blocks.

In [6]:
input_dim = 3
hidden_dim = 96
output_dim = 21

stage_1 = StageModule(
    in_channels=input_dim,
    hidden_dim=hidden_dim,
    layers=2,
    downscaling_factor=4,
    num_heads=3,
    head_dim=32,
    window_size=7,
    relative_pos_embedding=True)

stage_2 = StageModule(
    in_channels=hidden_dim,
    hidden_dim=hidden_dim * 2,
    layers=2,
    downscaling_factor=2,
    num_heads=6,
    head_dim=32,
    window_size=7,
    relative_pos_embedding=True)

stage_3 = StageModule(
    in_channels=hidden_dim * 2,
    hidden_dim=hidden_dim * 4,
    layers=6,
    downscaling_factor=2,
    num_heads=12,
    head_dim=32,
    window_size=7,
    relative_pos_embedding=True)

stage_4 = StageModule(
    in_channels=hidden_dim * 4,
    hidden_dim=hidden_dim * 8,
    layers=2,
    downscaling_factor=2,
    num_heads=24,
    head_dim=32,
    window_size=7,
    relative_pos_embedding=True)

mlp_head = nn.Sequential(nn.LayerNorm(hidden_dim * 8), nn.Linear(hidden_dim * 8, output_dim))

In [7]:
x = torch.rand(1, 3, 224, 224)
x = stage_1(x)
print("Stage 1:", x.shape)
x = stage_2(x)
print("Stage 2:", x.shape)
x = stage_3(x)
print("Stage 3:", x.shape)
x = stage_4(x)
print("Stage 4:", x.shape)
x = x.mean(dim=[2, 3])
print("Reduce x via averaging last 2 dimensions", x.shape)
x = mlp_head(x)
print("Output:", x.shape)

Stage 1: torch.Size([1, 96, 56, 56])
Stage 2: torch.Size([1, 192, 28, 28])
Stage 3: torch.Size([1, 384, 14, 14])
Stage 4: torch.Size([1, 768, 7, 7])
Reduce x via averaging last 2 dimensions torch.Size([1, 768])
Output: torch.Size([1, 21])


In [8]:
tiny_swin_model = swin_t()
x = torch.rand(1, 3, 224, 224)
y = tiny_swin_model(x)
y.shape

torch.Size([1, 21])

# Transformer as Backbone

In [9]:
from swin.model import SwinTransformerBackbone, TransformerCenterNet
from ssd.model import SingleShotDetector
from centernet.model import CenterNet

In [10]:
backbone = SwinTransformerBackbone(
    channels=3,
    hidden_dim=96,
    layers=(2, 2, 6, 2),
    heads=(3, 6, 12, 24),
    window_size=7,
    downscaling_factors=(2, 2, 2, 1)
)

stage_1 = StageModule(
    in_channels=3,
    hidden_dim=96,
    layers=2,
    downscaling_factor=2,
    num_heads=3,
    head_dim=32,
    window_size=7,
    relative_pos_embedding=True)

stage_2 = StageModule(
    in_channels=96,
    hidden_dim=96 * 2,
    layers=2,
    downscaling_factor=2,
    num_heads=6,
    head_dim=32,
    window_size=7,
    relative_pos_embedding=True)


stage_3 = StageModule(
    in_channels=96 * 2,
    hidden_dim=96 * 4,
    layers=2,
    downscaling_factor=2,
    num_heads=6,
    head_dim=32,
    window_size=7,
    relative_pos_embedding=True)

stage_4 = StageModule(
    in_channels=96 * 4,
    hidden_dim=96 * 8,
    layers=2,
    downscaling_factor=1,
    num_heads=6,
    head_dim=32,
    window_size=7,
    relative_pos_embedding=True)

x = torch.rand(1, 3, 224, 224)
x = stage_1(x)
print(x.shape)
x = stage_2(x)
print(x.shape)
x = stage_3(x)
print(x.shape)
x = stage_4(x)
print(x.shape)

torch.Size([1, 96, 112, 112])
torch.Size([1, 192, 56, 56])
torch.Size([1, 384, 28, 28])
torch.Size([1, 768, 28, 28])


In [11]:
x = torch.rand(1, 3, 448, 448)
backbone(x).shape

torch.Size([1, 768, 56, 56])

In [12]:
transformer = TransformerCenterNet()
x = torch.rand(1, 3, 448, 448)
cls, reg = transformer(x)
print("Classification", cls.shape)
print("Regression", reg.shape)

Classification torch.Size([1, 21, 56, 56])
Regression torch.Size([1, 4, 56, 56])


In [13]:
centernet = CenterNet()
ssd = SingleShotDetector()

print("Transformer", sum(p.numel() for p in transformer.parameters()))
print("Centernet", sum(p.numel() for p in centernet.parameters()))
print("SSD", sum(p.numel() for p in ssd.parameters()))

Transformer 30888497
Centernet 18917977
SSD 26284974
