# Libraries

In [45]:
import torch
import torch.nn as nn
import math

# Parameters

In [71]:
# training parameters
B = 32 # batch size

# image parameters
C = 3
H = 128
W = 128
x = torch.rand(B, C, H, W)

#model parameters
D = 64 # hidden size
P = 4 #patch size
N = int(H*W/P**2)#number of tokens
k = 4 # number of attention heads
Dh = int(D/k) # attention head size
p = 0.1 # dropout rate
mlp_size = D*4 # mlp size
L = 4 # number of transformer blocks
n_classes = 3 # number of classes

print("B:", B)
print("C:", C)
print("H:", H)
print("W:", W)
print("D:", D)
print("P:", P)
print("N:", N)
print("k:", k)
print("Dh:", Dh)
print("p:", p)
print("mlp_size:", mlp_size)
print("L:", L)
print("n_classes:", n_classes)

B: 32
C: 3
H: 128
W: 128
D: 64
P: 4
N: 1024
k: 4
Dh: 16
p: 0.1
mlp_size: 256
L: 4
n_classes: 3


# Embedding


In [60]:
# Image Embeddings [Patch, Class, with Position Embeddings]

class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()

        self.unfold = nn.Unfold(kernel_size=P, stride=P) # function to create patch vectors (x_p^i)
        self.project = nn.Linear(P**2 * C, D) # patch tokens (E)
        self.cls_token = nn.Parameter(torch.randn((1, 1, D))) # function to create unbatched class token (x_class) as trainable parameter
        self.pos_embedding = nn.Parameter(torch.randn(1, N+1, D)) # function to create unbatched position embedding (E_pos) as trainable parameter
        self.dropout = nn.Dropout(p) #dropout

        #why unbatched? because we are setting the parameters and functions here.
        # giving batched will increase the parameter size without effectively increasing the parameters

    def forward(self, x):

        print("###Embedding###")
        print("input image:", x.shape)
        x = self.unfold(x).transpose(1,2) # patch vectors (x_p^i)
        print("x_p^i:", x.shape)
        x = self.project(x)
        print("x_p^i*E: ", x.shape) # tokens for patches (x_p^i*E)
        cls_token = self.cls_token # unbatched class token (x_class)
        print("unbatched x_class:", cls_token.shape)
        cls_token = self.cls_token.expand(B, -1, -1) # batched class token (x_class)
        print("x_class:", cls_token.shape)
        x = torch.cat((cls_token, x), dim = 1) # final image token embedding
        print("patch embedding:", x.shape)
        pos_embedding = self.pos_embedding # unbatched position embedding (E_pos)
        print("unbatched E_pos:", pos_embedding.shape)
        pos_embedding = pos_embedding.expand(B, -1, -1) # batched position embedding (E_pos)
        print("E_pos:", pos_embedding.shape)
        z0 = x + pos_embedding # adding the batched position and image embedding
        print("z0:", z0.shape)
        z0 = self.dropout(z0) # dropout
        return z0

In [61]:
model = Embedding()
y = model(x)

###Embedding###
input image: torch.Size([32, 3, 128, 128])
x_p^i: torch.Size([32, 1024, 48])
x_p^i*E:  torch.Size([32, 1024, 64])
unbatched x_class: torch.Size([1, 1, 64])
x_class: torch.Size([32, 1, 64])
patch embedding: torch.Size([32, 1025, 64])
unbatched E_pos: torch.Size([1, 1025, 64])
E_pos: torch.Size([32, 1025, 64])
z0: torch.Size([32, 1025, 64])


# Single Head Attention

In [62]:
# Single Head Attention


class Single_Head_Attention(nn.Module):
    def __init__(self):
        super(Single_Head_Attention, self).__init__()

        self.U_qkv = nn.Linear(D, 3*Dh) # U_qkv
        self.softmax = nn.Softmax(dim = -1) # softmax along the last dimension

    def forward(self, z):

      print("###Single Head Attention###")
      print("z:", z.shape)
      qkv = self.U_qkv(z) # qkv
      print("qkv:", qkv.shape)
      q = qkv[:, :, :Dh] # q
      print("q:", q.shape)
      k = qkv[:, :, Dh:2*Dh] # k
      print("k:", k.shape)
      v = qkv[:, :, 2*Dh:] # v
      print("v:", v.shape)
      qkTbysqrtDh = torch.matmul(q, k.transpose(-2, -1))/math.sqrt(Dh) # qk^T/sqrtDh
      print("qkTbysqrtDh:", qkTbysqrtDh.shape)
      A = self.softmax(qkTbysqrtDh) # A
      print("A:", A.shape)
      SAz = torch.matmul(A, v) # z = Av
      print("SA(z):", SAz.shape)

      return SAz

In [63]:
model = Embedding()
y = model(x)
model = Single_Head_Attention()
z = model(y)

###Embedding###
input image: torch.Size([32, 3, 128, 128])
x_p^i: torch.Size([32, 1024, 48])
x_p^i*E:  torch.Size([32, 1024, 64])
unbatched x_class: torch.Size([1, 1, 64])
x_class: torch.Size([32, 1, 64])
patch embedding: torch.Size([32, 1025, 64])
unbatched E_pos: torch.Size([1, 1025, 64])
E_pos: torch.Size([32, 1025, 64])
z0: torch.Size([32, 1025, 64])
###Single Head Attention###
z: torch.Size([32, 1025, 64])
qkv: torch.Size([32, 1025, 48])
q: torch.Size([32, 1025, 16])
k: torch.Size([32, 1025, 16])
v: torch.Size([32, 1025, 16])
qkTbysqrtDh: torch.Size([32, 1025, 1025])
A: torch.Size([32, 1025, 1025])
SA(z): torch.Size([32, 1025, 16])


# Multi Head Self Attention


In [64]:
# Multi Head Self Attention

class Multi_Head_Self_Attention(nn.Module):
    def __init__(self):
        super(Multi_Head_Self_Attention, self).__init__()

        self.heads = nn.ModuleList([Single_Head_Attention() for _ in range(k)]) # k heads
        self.U_msa = nn.Linear(D, D) # U_msa
        self.dropout = nn.Dropout(p) #dropout

    def forward(self, z):

      print("###Multi Head Attention###")
      print("z:", z.shape)
      ConSAz = torch.cat([head(z) for head in self.heads], dim = -1)
      print("ConSA(z):", ConSAz.shape)
      msaz = self.U_msa(z) # MSA(z)
      print("MSA(z):", msaz.shape)
      msaz = self.dropout(msaz) # dropout

      return msaz

In [65]:
model = Embedding()
y = model(x)
model = Multi_Head_Self_Attention()
z = model(y)

###Embedding###
input image: torch.Size([32, 3, 128, 128])
x_p^i: torch.Size([32, 1024, 48])
x_p^i*E:  torch.Size([32, 1024, 64])
unbatched x_class: torch.Size([1, 1, 64])
x_class: torch.Size([32, 1, 64])
patch embedding: torch.Size([32, 1025, 64])
unbatched E_pos: torch.Size([1, 1025, 64])
E_pos: torch.Size([32, 1025, 64])
z0: torch.Size([32, 1025, 64])
###Multi Head Attention###
z: torch.Size([32, 1025, 64])
###Single Head Attention###
z: torch.Size([32, 1025, 64])
qkv: torch.Size([32, 1025, 48])
q: torch.Size([32, 1025, 16])
k: torch.Size([32, 1025, 16])
v: torch.Size([32, 1025, 16])
qkTbysqrtDh: torch.Size([32, 1025, 1025])
A: torch.Size([32, 1025, 1025])
SA(z): torch.Size([32, 1025, 16])
###Single Head Attention###
z: torch.Size([32, 1025, 64])
qkv: torch.Size([32, 1025, 48])
q: torch.Size([32, 1025, 16])
k: torch.Size([32, 1025, 16])
v: torch.Size([32, 1025, 16])
qkTbysqrtDh: torch.Size([32, 1025, 1025])
A: torch.Size([32, 1025, 1025])
SA(z): torch.Size([32, 1025, 16])
###Single 

# MLP

In [66]:
# MLP

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()

        self.U_mlp = nn.Linear(D, mlp_size)
        self.gelu = nn.GELU()
        self.U_mlp2 = nn.Linear(mlp_size, D)
        self.dropout = nn.Dropout(p)

    def forward(self, z):

      print("###MLP###")
      print("z:", z.shape)
      z = self.U_mlp(z) # mlp
      print("mlp(z):", z.shape)
      z = self.gelu(z) # gelu
      print("gelu(mlp(z)):", z.shape)
      z = self.dropout(z) # dropout
      z = self.U_mlp2(z) # mlp2
      print("mlp2(gelu(mlp(z))):", z.shape)
      z = self.gelu(z) # gelu
      print("gelu(mlp2(gelu(mlp(z)))):", z.shape)
      z = self.dropout(z) # dropout

      return z

In [67]:
model = Embedding()
y = model(x)
model = Multi_Head_Self_Attention()
z = model(y)
model = MLP()
z = model(z)

###Embedding###
input image: torch.Size([32, 3, 128, 128])
x_p^i: torch.Size([32, 1024, 48])
x_p^i*E:  torch.Size([32, 1024, 64])
unbatched x_class: torch.Size([1, 1, 64])
x_class: torch.Size([32, 1, 64])
patch embedding: torch.Size([32, 1025, 64])
unbatched E_pos: torch.Size([1, 1025, 64])
E_pos: torch.Size([32, 1025, 64])
z0: torch.Size([32, 1025, 64])
###Multi Head Attention###
z: torch.Size([32, 1025, 64])
###Single Head Attention###
z: torch.Size([32, 1025, 64])
qkv: torch.Size([32, 1025, 48])
q: torch.Size([32, 1025, 16])
k: torch.Size([32, 1025, 16])
v: torch.Size([32, 1025, 16])
qkTbysqrtDh: torch.Size([32, 1025, 1025])
A: torch.Size([32, 1025, 1025])
SA(z): torch.Size([32, 1025, 16])
###Single Head Attention###
z: torch.Size([32, 1025, 64])
qkv: torch.Size([32, 1025, 48])
q: torch.Size([32, 1025, 16])
k: torch.Size([32, 1025, 16])
v: torch.Size([32, 1025, 16])
qkTbysqrtDh: torch.Size([32, 1025, 1025])
A: torch.Size([32, 1025, 1025])
SA(z): torch.Size([32, 1025, 16])
###Single 

# Transformer Block

In [68]:
# Transformer Block

class Transformer_Block(nn.Module):
    def __init__(self):
        super(Transformer_Block, self).__init__()

        self.layernorm_1 = nn.LayerNorm(D)
        self.msa = Multi_Head_Self_Attention()
        self.layernorm_2 = nn.LayerNorm(D)
        self.mlp = MLP()

    def forward(self, z):

      print("###Transformer Block###")
      print("z:", z.shape)
      z1 = self.layernorm_1(z) # layer norm 1 output
      print("layernorm_1(z):", z1.shape)
      z1 = self.msa(z1) # multi head self attention
      print("msa(layernorm_1(z)):", z1.shape)
      z2 = z + z1
      print("z + msa(layernorm_1(z)):", z2.shape)
      z3 = self.layernorm_2(z2) # layer norm 2 output
      print("layernorm_2(z + msa(layernorm_1(z))):", z3.shape)
      z3 = self.mlp(z3) # mlp
      print("mlp(layernorm_2(z + msa(layernorm_1(z)))):", z3.shape)
      z4 = z2 + z3
      print("z2 + mlp(layernorm_2(z + msa(layernorm_1(z)))):", z4.shape)

      return z4

In [69]:
model = Embedding()
y = model(x)
model = Transformer_Block()
z = model(y)

###Embedding###
input image: torch.Size([32, 3, 128, 128])
x_p^i: torch.Size([32, 1024, 48])
x_p^i*E:  torch.Size([32, 1024, 64])
unbatched x_class: torch.Size([1, 1, 64])
x_class: torch.Size([32, 1, 64])
patch embedding: torch.Size([32, 1025, 64])
unbatched E_pos: torch.Size([1, 1025, 64])
E_pos: torch.Size([32, 1025, 64])
z0: torch.Size([32, 1025, 64])
###Transformer Block###
z: torch.Size([32, 1025, 64])
layernorm_1(z): torch.Size([32, 1025, 64])
###Multi Head Attention###
z: torch.Size([32, 1025, 64])
###Single Head Attention###
z: torch.Size([32, 1025, 64])
qkv: torch.Size([32, 1025, 48])
q: torch.Size([32, 1025, 16])
k: torch.Size([32, 1025, 16])
v: torch.Size([32, 1025, 16])
qkTbysqrtDh: torch.Size([32, 1025, 1025])
A: torch.Size([32, 1025, 1025])
SA(z): torch.Size([32, 1025, 16])
###Single Head Attention###
z: torch.Size([32, 1025, 64])
qkv: torch.Size([32, 1025, 48])
q: torch.Size([32, 1025, 16])
k: torch.Size([32, 1025, 16])
v: torch.Size([32, 1025, 16])
qkTbysqrtDh: torch.Si

# ViT (everything together)

In [73]:
# ViT

class ViT(nn.Module):
    def __init__(self):
        super(ViT, self).__init__()

        self.embedding = Embedding()
        self.transformer_encoder = nn.ModuleList([Transformer_Block() for _ in range(L)])
        self.layernorm = nn.LayerNorm(D)
        self.U_mlp = nn.Linear(D, n_classes)

    def forward(self, x):

      print("###ViT###")
      print("input image:", x.shape)
      z = self.embedding(x)
      print("z:", z.shape)
      for block in self.transformer_encoder:
        z = block(z)
      print("z:", z.shape)
      z = self.layernorm(z)
      print("layernorm(z):", z.shape)
      z = z[:, 0, :]
      print("z:", z.shape)
      z = self.U_mlp(z)
      print("mlp(layernorm(z)):", z.shape)

      return z

In [74]:
model = ViT()
y = model(x)

###ViT###
input image: torch.Size([32, 3, 128, 128])
###Embedding###
input image: torch.Size([32, 3, 128, 128])
x_p^i: torch.Size([32, 1024, 48])
x_p^i*E:  torch.Size([32, 1024, 64])
unbatched x_class: torch.Size([1, 1, 64])
x_class: torch.Size([32, 1, 64])
patch embedding: torch.Size([32, 1025, 64])
unbatched E_pos: torch.Size([1, 1025, 64])
E_pos: torch.Size([32, 1025, 64])
z0: torch.Size([32, 1025, 64])
z: torch.Size([32, 1025, 64])
###Transformer Block###
z: torch.Size([32, 1025, 64])
layernorm_1(z): torch.Size([32, 1025, 64])
###Multi Head Attention###
z: torch.Size([32, 1025, 64])
###Single Head Attention###
z: torch.Size([32, 1025, 64])
qkv: torch.Size([32, 1025, 48])
q: torch.Size([32, 1025, 16])
k: torch.Size([32, 1025, 16])
v: torch.Size([32, 1025, 16])
qkTbysqrtDh: torch.Size([32, 1025, 1025])
A: torch.Size([32, 1025, 1025])
SA(z): torch.Size([32, 1025, 16])
###Single Head Attention###
z: torch.Size([32, 1025, 64])
qkv: torch.Size([32, 1025, 48])
q: torch.Size([32, 1025, 16]