In [13]:
from utils import *

import torch
import torch.nn as nn
import torch.nn.functional as F

In [14]:
x = torch.randn(32,3,224,224) # b c h w
show(x)

torch.Size([32, 3, 224, 224])


In [15]:
PATCH_SIZE = 16

def patch_embedding(x):
    b, c, h, w = x.shape
    assert h % PATCH_SIZE == 0
    assert w % PATCH_SIZE == 0

    n_h = h // PATCH_SIZE
    n_w = w // PATCH_SIZE

    embed_dim = c * PATCH_SIZE**2

    conv_proj = nn.Conv2d(
        in_channels=c,
        out_channels=embed_dim,
        kernel_size=PATCH_SIZE,
        stride=PATCH_SIZE,
    )

    show(x, "input")

    x = conv_proj(x)     # b n_x n_h n_w
    show(x)

    x = x.flatten(2)     # b n_x n_t
    show(x)

    x = x.permute(0,2,1) # b n_t n_x
    show(x)

    cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
    cls_token = cls_token.expand(b,-1,-1)

    x = torch.cat([cls_token, x], dim=1) # b (1+n_t) n_x
    show(x, "output")

    return x

x = patch_embedding(x)

torch.Size([32, 3, 224, 224]) <- input
torch.Size([32, 768, 14, 14])
torch.Size([32, 768, 196])
torch.Size([32, 196, 768])
torch.Size([32, 197, 768]) <- output


In [16]:
def multi_head_attention(x, num_heads):
    b, n_t, n_x = x.shape
    assert n_x % num_heads == 0

    q_proj = nn.Linear(n_x, n_x)
    k_proj = nn.Linear(n_x, n_x)
    v_proj = nn.Linear(n_x, n_x)

    show(x, "input")

    Q = q_proj(x)
    K = k_proj(x)
    V = v_proj(x)
    show(V)

    Q = Q.view(b, n_t, num_heads, -1) # b n_t h n_x'
    K = K.view(b, n_t, num_heads, -1) # b n_t h n_x'
    V = V.view(b, n_t, num_heads, -1) # b n_t h n_x'
    show(V)

    Q = Q.transpose(1,2) # b h n_t n_x'
    K = K.transpose(1,2) # b h n_t n_x'
    V = V.transpose(1,2) # b h n_t n_x'
    show(V)

    A = (Q @ K.transpose(-2,-1)) / (n_x//num_heads)**0.5 # (b h n_t n_x')@(b h n_x' n_t) = (b h n_t n_t)
    A = A.softmax(-1) # b h n_t n_t

    # ---softmax---   --v--   --v--
    # ---softmax---   --v--   --v--
    # ---softmax--- @ --v-- = --v--
    # ---softmax---   --v--   --v--
    # ---softmax---   --v--   --v--

    # -               --v--   --v--
    # ----            --v--   --v--
    # ---so--       @ --v-- = --v--
    # ---soft---      --v--   --v--
    # ---softmax---   --v--   --v--
    
    V = A @ V # (b h n_t n_t)@(b h n_t n_x') = b h n_t n_x'
    show(V)

    V = V.transpose(1,2) # b n_t h n_x'
    show(V)

    V = V.contiguous().view(b, n_t, -1)
    show(V, "output")

    return V


def attention_block(x):
    b, n_t, n_x = x.shape
    norm = nn.LayerNorm(n_x)

    x = norm(x)
    x = multi_head_attention(x, num_heads=8)
    return x


def mlp_block(x):
    b, n_t, n_x = x.shape
    norm = nn.LayerNorm(n_x)
    fc1 = nn.Linear(n_x, n_x*4)
    fc2 = nn.Linear(n_x*4, n_x)
    gelu = nn.GELU()

    x = norm(x)
    x = fc1(x)
    x = gelu(x)
    x = fc2(x)
    return x


def encoder_block(x):
    x = attention_block(x) + x
    x = mlp_block(x) + x
    return x


def encoder(x):
    for _ in range(3):
        x = encoder_block(x)
    return x

x = encoder(x)
show(x)

x = x[:,0]
show(x)

torch.Size([32, 197, 768]) <- input
torch.Size([32, 197, 768])
torch.Size([32, 197, 8, 96])
torch.Size([32, 8, 197, 96])
torch.Size([32, 8, 197, 96])
torch.Size([32, 197, 8, 96])
torch.Size([32, 197, 768]) <- output
torch.Size([32, 197, 768]) <- input
torch.Size([32, 197, 768])
torch.Size([32, 197, 8, 96])
torch.Size([32, 8, 197, 96])
torch.Size([32, 8, 197, 96])
torch.Size([32, 197, 8, 96])
torch.Size([32, 197, 768]) <- output
torch.Size([32, 197, 768]) <- input
torch.Size([32, 197, 768])
torch.Size([32, 197, 8, 96])
torch.Size([32, 8, 197, 96])
torch.Size([32, 8, 197, 96])
torch.Size([32, 197, 8, 96])
torch.Size([32, 197, 768]) <- output
torch.Size([32, 197, 768])
torch.Size([32, 768])


In [17]:
def heads(x):
    b, n_x = x.shape
    fc1 = nn.Linear(n_x, 100)
    fc2 = nn.Linear(100, 2)
    tanh = nn.Tanh()

    x = fc1(x)
    x = tanh(x)
    x = fc2(x)
    return x

x = heads(x)
show(x)

torch.Size([32, 2])
