In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange, reduce, repeat
import torch.nn.functional as F

In [2]:
def shape_eq(tensor, shape):
    return tensor.shape == torch.Size(shape)

## ViT

### patch embeding

In [3]:
from argparse import Namespace
args = {
    'embed_dim': 768,
    'patch_size': 16,
    'num_heads': 8
}
args = Namespace(**args)
x = torch.rand(1, 3, 224, 224)
# patch embedding
proj = nn.Conv2d(3, args.embed_dim, # 16x16x3=768, 16x14=224, 14x14=196
    kernel_size=args.patch_size,
    stride=args.patch_size)
ln = nn.LayerNorm((196, 768))
x = rearrange(proj(x), 'b d h w -> b (h w) d') # 一张图片变成196个768维的token
x = ln(x)
x.shape

torch.Size([1, 196, 768])

### MHSA

In [4]:
head_dim = args.embed_dim // args.num_heads
scale = head_dim ** -0.5 # 1/sqrt(d_k)
qkv = nn.Linear(args.embed_dim, args.embed_dim*3)
attn_drop = nn.Dropout(0.5)
proj_fuse = nn.Linear(args.embed_dim, args.embed_dim)
fuse_drop = nn.Dropout(0.5)

# input [B N D] = [1, 196, 768]
residual = x
x = qkv(x)
print(x.shape)
x = rearrange(x, 'b n (m h d) -> m b h n d', m=3, h=args.num_heads)
x.shape

torch.Size([1, 196, 2304])


torch.Size([3, 1, 8, 196, 96])

In [5]:
q, k, v = torch.unbind(x, dim=0)
q.shape, k.shape, v.shape

(torch.Size([1, 8, 196, 96]),
 torch.Size([1, 8, 196, 96]),
 torch.Size([1, 8, 196, 96]))

In [6]:
attn = torch.einsum('bhik,bhjk->bhij', q, k) * scale # QK^T/sqrt(d_k)
assert shape_eq(attn, [1, 8, 196, 196])

In [7]:
attn = attn_drop(torch.softmax(attn, dim=-1)) # 归一化

In [8]:
v = torch.einsum('bhik,bhkj->bhij', attn, v) # softmax(QK^T/sqrt(d_k)) @ v
assert shape_eq(v, [1, 8, 196, 96])
v = rearrange(v, 'b h n d -> b n (h d)')
assert shape_eq(v, [1, 196, 768])

In [9]:
# 多头融合
x = fuse_drop(proj_fuse(v))
assert shape_eq(x, [1, 196, 768])

In [10]:
x = x + residual

### FFN

In [11]:
# input [B N D] = [1, 196, 768]
mlp = nn.Sequential(
    nn.Linear(args.embed_dim, args.embed_dim*2),
    nn.GELU(),
    nn.Dropout(0.5),
    nn.Linear(args.embed_dim*2, args.embed_dim),
    nn.Dropout(0.5)
)

residual = x
x = mlp(x)
x = x + residual
assert shape_eq(x, [1, 196, 768])

## Swin transformer

### Patch Embedding

In [12]:
class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self,
        patch_size=4, # 下采样倍数
        in_c=3, 
        embed_dim=96, # C
        norm_layer=None
    ):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        # Patch partition & Linear projection
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape

        # padding
        # 如果输入图片的H，W不是patch_size的整数倍，需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions,
            # (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))

        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W

In [15]:
patch_proj = PatchEmbed(4, 3, 96)
x = torch.rand(1, 3, 224, 224)
tokens, H, W = patch_proj(x)
assert shape_eq(tokens, [1, 56*56,96])
H, W

(56, 56)

### Patch Merging

In [16]:
class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        # 如果输入feature map的H，W不是2的整数倍，需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C]，所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x

In [73]:
pm = PatchMerging(96)
x = tokens # [1, 56*56, 96]
print(f'input shape: {x.shape}')
H=W=56
x = x.view(1, H, W, 96)
x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
x = x.view(1, -1, 4 * 96)  # [B, H/2*W/2, 4*C]
x.shape

input shape: torch.Size([1, 3136, 96])


torch.Size([1, 784, 384])

In [83]:
z = rearrange(tokens, 'b (h w) d -> b h w d', h=56)
z = rearrange(z, 'b (p1 h) (p2 w) d -> b (h w) (p1 p2 d)', p1=2, p2=2)
z.shape

torch.Size([1, 784, 384])

In [84]:
z == x

tensor([[[ True,  True,  True,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ...,  True,  True,  True]]])

In [65]:
y = pm(tokens, 56, 56)
y.shape

torch.Size([1, 784, 192])

In [70]:
x = torch.arange(16)
rearrange(x, '(p1 p2 h w) -> h w (p1 p2)', p1=2, p2=2, h=2)

tensor([[[ 0,  4,  8, 12],
         [ 1,  5,  9, 13]],

        [[ 2,  6, 10, 14],
         [ 3,  7, 11, 15]]])