In [1]:
# PatchEmbeddings
import torch
torch.manual_seed(43)
class PatchEmbeddings(torch.nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_tuple, self.patch_tuple = (img_size, img_size), (patch_size, patch_size)
        
        self.num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.embed_dim = in_chans * patch_size * patch_size
        
        self.proj = torch.nn.Conv2d(in_chans, self.embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_tuple[0] and W == self.img_tuple[1], \
        f"Input image size ({H}*{W}) doesn't match model ({self.img_tuple[0]}*{self.img_tuple[1]})."
        a = self.proj(x)      #torch.Size([8, 768, 14, 14])
        b = a.flatten(2)      #torch.Size([8, 768, 196])
        c = b.transpose(1, 2) #torch.Size([8, 196, 768])
        return c

sample_input = torch.randn(8, 3, 224, 224)
patch_embed = PatchEmbeddings(img_size=224, patch_size=16, in_chans=3, embed_dim=768)
output1 = patch_embed(sample_input)
print(output1.shape)  

"""                  B  C   H    W
INPUT:   torch.Size([8, 3, 224, 224])
                     B P^2 Embed Dim
OUTPUT1: torch.Size([8, 196, 768])

A: OUTPUT --> Concatenate Positional Embeddings --> B: torch.Size([8, 197, 768])
C: Positional Embeddings torch.Size([8, 197, 768])
B + C = Combined Embeddings torch.Size([8, 197, 768])
"""

from timm.models.layers import PatchEmbed
patch_embed = PatchEmbed(img_size=224, patch_size=16, in_chans=3, embed_dim=768, bias=False)
output2 = patch_embed(sample_input)
print(output2.shape)

print("Is equal: ", torch.equal(output1, output2))

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([8, 196, 768])
torch.Size([8, 196, 768])
Is equal:  False


In [2]:
# Attention Layer
import torch
from einops import rearrange
class Attention(torch.nn.Module):
    def __init__(self, dim, heads=8, dim_head=64):
        super().__init__()
        inner_dim = heads * dim_head 
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm   = torch.nn.LayerNorm(dim)

        self.softmax = torch.nn.Softmax(dim=-1)

        self.to_qkv = torch.nn.Linear(dim, inner_dim*3, bias=False)
        self.to_out = torch.nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x):
        """
        INPUT:  torch.Size([8, 197, 768])
        dim=768, heads=8, dim_head=64
                            b   n   h*d*3
        to_qkv: torch.Size([8, 197, 1536])
                                     h*d
        chunk:  torch.Size([8, 197,  512]) * 3]
        """
        x = self.norm(x)
        qkv = self.to_qkv(x)
        qkv = qkv.chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.softmax(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

attention = Attention(768)
x = torch.randn(8, 197, 768)
output = attention(x)
print(f"INPUT:  {x.shape}\nOUTPUT: {output.shape}")

INPUT:  torch.Size([8, 197, 768])
OUTPUT: torch.Size([8, 197, 768])


In [3]:
# FeedForward Layer
import torch
class FeedForward(torch.nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.LayerNorm(dim),
            torch.nn.Linear(dim, hidden_dim),
            torch.nn.GELU(),
            torch.nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

x = torch.randn(8, 197, 768)
ff = FeedForward(768, 3072)
output = ff(x)
print(f"INPUT:  {x.shape}\nOUTPUT: {output.shape}")

INPUT:  torch.Size([8, 197, 768])
OUTPUT: torch.Size([8, 197, 768])


In [4]:
# Transformer Layer
import torch
class Transformer(torch.nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.transformer_layers = torch.nn.ModuleList([])
        for _ in range(depth):
            self.transformer_layers.append(
                torch.nn.ModuleList([
                    Attention(dim, heads = heads, dim_head = dim_head),
                    FeedForward(dim, mlp_dim)
                ])
            )
    def forward(self, x):
        for attn, ff in self.transformer_layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

x = torch.randn(8, 197, 768)
transformer = Transformer(768, 12, 8, 64, 3072)
output = transformer(x)
print(f"INPUT:  {x.shape}\nOUTPUT: {output.shape}")

INPUT:  torch.Size([8, 197, 768])
OUTPUT: torch.Size([8, 197, 768])


# Swin Transformer

In [5]:
# PatchMerging
import torch
class PatchMerging(torch.nn.Module):
    # PatchMerging作用： 输入分辨率减半，通道数翻倍
    # 这里的分辨率是转换成patches的分辨率，不是原图像素的分辨率
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (torch.nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=torch.nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.linear_layer = torch.nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
        
        #B: Batch; H: Patch_Height; W: Patch_Width; C: Channel
        x = x.view(B, H, W, C)

        # 0::2 偶数  1::2 奇数
        x0 = x[:, 0::2, 0::2, :]             # B, H/2, W/2, C
        x1 = x[:, 1::2, 0::2, :]             # B, H/2, W/2, C
        x2 = x[:, 0::2, 1::2, :]             # B, H/2, W/2, C
        x3 = x[:, 1::2, 1::2, :]             # B, H/2, W/2, C
        x = torch.cat([x0, x1, x2, x3], -1)  # B, H/2, W/2, 4*C
        x = x.view(B, -1, 4 * C)             # B, H/2*W/2, 4*C

        x = self.norm(x)
        x = self.linear_layer(x) 
        
        # H, W -> H/2, W/2 (image reduced by 3/4)
        # C -> 4C -> 2C

        return x
    
    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops

patch_size = 16
x = torch.randn(8, 196, 768)
patch_merging = PatchMerging((14, 14), 768)
output = patch_merging(x)
print(f"INPUT:  {x.shape}\nOUTPUT: {output.shape}")

from timm.models.swin_transformer import PatchMerging
output2 = PatchMerging((14, 14), 768)(x)
print(f"INPUT:  {x.shape}\nOUTPUT: {output2.shape}")

INPUT:  torch.Size([8, 196, 768])
OUTPUT: torch.Size([8, 49, 1536])
INPUT:  torch.Size([8, 196, 768])
OUTPUT: torch.Size([8, 49, 1536])


In [6]:
import torch
def window_partition(x, w_size):
    B, H, W, C = x.shape
    x = x.view(B, H//w_size, w_size, W//w_size, w_size, C)
    windows = x.permute(0,1,3,2,4,5).contiguous().view(-1, w_size, w_size, C)
    return windows

def window_reverse(windows, w_size, H, W):
    B = int(windows.shape[0] / (H * W / w_size / w_size))
    x = windows.view(B, H // w_size, W // w_size, w_size, w_size, -1)
    x = x.permute(0,1,3,2,4,5).contiguous().view(B, H, W, -1)
    return x

x = torch.randn(8, 14, 14, 64)
print(f"INPUT:   {x.shape}")
windows = window_partition(x, 7)
x = window_reverse(windows, 7, 14, 14)
print(f"W_SIZE:  {windows.shape}\nReverse: {x.shape}")

"""
INPUT  (B, H, W, C)
OUTPUT (num_win*B, H//w_size, W//w_size, w_size, w_size, C)
num_win = H//w_size * W//w_size
""" if False else None

INPUT:   torch.Size([8, 14, 14, 64])
W_SIZE:  torch.Size([32, 7, 7, 64])
Reverse: torch.Size([8, 14, 14, 64])


In [24]:
#W-MSA: Windows Attention
import torch
from typing import Optional
from timm.models.layers import trunc_normal_

def get_relative_position_index(win_h, win_w):
    # get pair-wise relative position index for each token inside the window
    coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)]))  # 2, Wh, Ww
    coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
    relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
    relative_coords[:, :, 0] += win_h - 1  # shift to start from 0
    relative_coords[:, :, 1] += win_w - 1
    relative_coords[:, :, 0] *= 2 * win_w - 1
    return relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

class WindowAttention(torch.nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        head_dim (int): Number of channels per head (dim // num_heads if not set)
        window_size (tuple[int]): The height and width of the window.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        win_h, win_w = self.window_size
        self.window_area = win_h * win_w
        self.num_heads = num_heads
        head_dim = head_dim or dim // num_heads
        attn_dim = head_dim * num_heads
        self.scale = head_dim ** -0.5

        # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
        self.relative_position_bias_table = torch.nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))

        # get pair-wise relative position index for each token inside the window
        self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w))

        self.qkv       = torch.nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
        self.attn_drop = torch.nn.Dropout(attn_drop)
        self.proj      = torch.nn.Linear(attn_dim, dim)
        self.proj_drop = torch.nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = torch.nn.Softmax(dim=-1)

    def _get_rel_pos_bias(self) -> torch.Tensor:
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        return relative_position_bias.unsqueeze(0)

    def forward(self, x, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        attn = attn + self._get_rel_pos_bias()

        if mask is not None:
            num_win = mask.shape[0]
            attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

x = torch.randn(8, 49, 96)
window_attention = WindowAttention(dim=96, num_heads=3, head_dim=None, window_size=(7, 7), qkv_bias=True, attn_drop=0.0, proj_drop=0.0)
output = window_attention(x)
print(output.shape)

torch.Size([8, 49, 96])


In [26]:
from timm.models.swin_transformer import WindowAttention

x = torch.randn(8, 49, 96)
window_attention = WindowAttention(dim=96, num_heads=3)
output = window_attention(x)
print(output.shape)

torch.Size([8, 49, 96])


In [7]:
import torch
from timm.models.swin_transformer import SwinTransformer

st = SwinTransformer(img_size=448, num_classes=10)
x = torch.randn(1, 3, 448, 448)
output = st(x)
print(output.shape)

torch.Size([1, 10])
