In [11]:
import torch
from torch import nn
import torch.nn.functional as F
import einops

## 1. Introduction

**NOTE: most implementations are based on https://github.com/lucidrains/vit-pytorch.**

To implement ViT which can be used for vision classification, we need to implement encoder of Transformer at first.  

![architecture](assets/vit.png)

## 2. Implement Encoder of Transformer
The encoder contains two special modules, which are `FeedForward` and `Attention`.  
Let's implement them.

![](./assets/transformer.jpg)

## 2.1 Attention
Attention is a module to determine the positions to give attention from input data.  
It convert three input vectors, query (Q), key (K) and value (V) to outputs.  

Q and K are source values (hidden layers of the encoder), and V is a target value (hidden layer of the decoder).

![](./assets/attention.png)



### Scaled Dot-Product Atention
$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$

where $d_k$ is dimension of K

$\frac{QK^T}{\sqrt{d_k}}$ means weighted similarity of Q and K.  
Attention is a method to extract a value (V) which is reconstructed from the key (K) corresponded to the query (Q).

In [220]:
class OptionalMask(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x, mask=None):
        if mask is None:
            return x
        mask_value = -torch.finfo(x.dtype).max
        mask = F.pad(mask.flatten(1), (1, 0), value=True)
        assert mask.shape[-1] == x.shape[-1], "mask has incorrect dimensions"
        mask = mask[:, None, :] * mask[:, :, None]
        x = x.masked_fill_(~mask, mask_value)
        del mask
        return x

In [246]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dk):
        super().__init__()
        self.scale = dk ** -0.5  # \sqrt{d_k}
        self.opt_mask = OptionalMask()

    def forward(self, q, k, v, mask=None):
        # q, k: [batch_size, heads, num_patches, dk]
        # x: [batch_size, heads, num_patches, num_patches]
        x = torch.einsum('bhid,bhjd->bhij', q, k) # QK^T
        x = x * self.scale
        x = self.opt_mask(x, mask)
        x = x.softmax(dim=-1)

        # v: [batch_size, heads, num_patches, dv]
        out = torch.einsum("bhij, bhjd->bhid", x, v)
        return out # [batch_size, heads, num_patches, dv]

In [247]:
batch_size = 5
heads = 16   # This value will be defined at Multi-Head Attention
num_patches = 65
dk = 32 # (=dq)
dv = 29

q = torch.randn(batch_size, head, num_patches, dk)
k = torch.randn(batch_size, head, num_patches, dk)
v = torch.randn(batch_size, head, num_patches, dv)
attn = ScaledDotProductAttention(dk)
out = attn(q, k, v) # [batch_size, heads, num_patches, dv]
print(out.shape)

torch.Size([5, 16, 65, 29])


### Multi-Head Attention
$MultiHead(Q, K, V) = Concat(haed_1,..., head_k)W^{o}$

where  

$head_i = Attention(QW_{i}^{Q}, KW_{i}^{K}, VW_{i}^{V})$

where

- $W_{i}^{Q} \in {\bf R}^{d_{model} \times d_{k}}$: weighted matrix of Q
- $W_{i}^{K} \in {\bf R}^{d_{model} \times d_{k}}$: weighted matrix of K
- $W_{i}^{V} \in {\bf R}^{d_{model} \times d_{v}}$: weighted matrix of V
- $W_{i}^{O} \in {\bf R}^{hd_{v} \times d_{model}}$: overall weighted matrix

In [325]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        # dk == dv == dim for easiness of implementation
        super().__init__()
        self.heads = heads # number of heads in Multi-head attention layer (h)
        inner_dim = dm * heads
        self.attn = ScaledDotProductAttention(inner_dim)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x, mask=None):
        # q, k: [batch_size, num_patches, dk]
        # v: [batch_size, num_patches, dv]
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        # [batch_size, heads, num_patches, dm]
        out = self.attn(q, k, v)
        
        # [batch_size, num_patches, heads * dm]
        out = rearrange(out, "b h n d -> b n (h d)")

        # [batch_size, num_patches, dv]
        out = self.to_out(out)
        return out

In [326]:
batch_size = 5
heads = 16   # This value will be defined at Multi-Head Attention
num_patches = 65
dim = 32  # dk == dv == dim for easiness of implementation
dm = 1024 # model size

x = torch.randn(batch_size, num_patches, dim)
attn = MultiHeadAttention(dim, heads=heads, dim_head=dm)
out = attn(x) # [batch_size, num_patches, dim]
print(out.shape)

torch.Size([5, 65, 32])


## 2.1 Feed Forward

In [328]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
        
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

In [329]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, MultiHeadAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
    def forward(self, x, mask = None):
        # x's shape is always [batch, nh * nw, dim] 
        for attn, ff in self.layers:
            x = attn(x, mask = mask)
            x = ff(x)
        return x

In [330]:
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        self.patch_size = patch_size
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool = pool
        self.to_latent = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img, mask = None):
        p = self.patch_size # user input. 32
        # img: [batch, channels, height, width]
        # [5, 3, 256, 256]
        
        # x: [batch, nh * nw, cp]
        # where nh = height / p, nw = width / p (num of patch)
        # cp = p * p * channels (all channels)
        # [5, 256 / 32 * 256 / 32, 3 * 32 * 32] = [5, 64, 3074]
        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)

        # [5, 64, 1024]
        # linear(cp, dim) = linear(3074, 1024)
        x = self.patch_to_embedding(x)

        b, n, _ = x.shape # batch size, number of patches

        # [1, 1, dim] -> [batch, 1, dim]
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)

        # [batch, nh * nw, dim] + [batch, 1, dim] -> [batch, nh * nw + 1, dim]
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        x = self.transformer(x, mask)
        
        # [batch, nh * nw + 1, dim] -> [batch, dim]
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)

        return self.mlp_head(x)  # [batch, num_classes]

In [332]:
img = torch.randn(5, 3, 256, 256)
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to
v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,  # last dimension of output tensor after linear transformation
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

preds = v(img, mask = mask)  # [batch_size, num_classes]
print(preds.shape)

torch.Size([5, 1000])
