## Let's Build a Toy Model
We will do the following here:
 - A 4x4 toy image 
 - Split into 2x2 patches (4 pacthes total)
 - Flatten each patch to a vector of length 4
 - Linear projection mapping each 4 dimensional patch to an 8 dimensional embedding.

In [1]:
import torch
import torch.nn as nn

### Make a toy image
A 4x4 grayscale image of 1 channel

In [2]:
img = torch.arange(16).reshape(4,4).float()
print(img)

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])


### Split into Patches
Let's say patch size - 2x2

In [3]:
patches = img.unfold(0,2,2).unfold(1,2,2)
print(patches.shape)

torch.Size([2, 2, 2, 2])


In [4]:
print(patches)

tensor([[[[ 0.,  1.],
          [ 4.,  5.]],

         [[ 2.,  3.],
          [ 6.,  7.]]],


        [[[ 8.,  9.],
          [12., 13.]],

         [[10., 11.],
          [14., 15.]]]])


### Flatten Patches

In [5]:
patches = patches.contiguous().view(-1,2*2)
print(patches)

tensor([[ 0.,  1.,  4.,  5.],
        [ 2.,  3.,  6.,  7.],
        [ 8.,  9., 12., 13.],
        [10., 11., 14., 15.]])


In [6]:
patches.shape

torch.Size([4, 4])

### Linear Projection (Patch Embedding)

In [7]:
# In 4 --> out 8
linear = nn.Linear(4,8)
embedded = linear(patches)
print("Shape: ", embedded.shape)
print(embedded)

Shape:  torch.Size([4, 8])
tensor([[-0.7326, -0.2139, -0.4815,  0.9404, -2.4267, -1.5183,  0.8126, -1.3019],
        [-0.2643, -1.4233, -1.2799,  0.9261, -3.5453, -2.5814,  1.9590, -2.5639],
        [ 1.1404, -5.0516, -3.6749,  0.8834, -6.9011, -5.7708,  5.3981, -6.3500],
        [ 1.6087, -6.2611, -4.4732,  0.8691, -8.0196, -6.8340,  6.5445, -7.6121]],
       grad_fn=<AddmmBackward0>)


**Random Weights and Biases for Now**

The embedding will be random at start since weights are random. Let's check the weights and biases.

In [8]:
print(linear.weight)

Parameter containing:
tensor([[ 0.3278,  0.1936, -0.1148, -0.1725],
        [-0.3879, -0.2004,  0.2324, -0.2488],
        [ 0.0540, -0.4328,  0.1982, -0.2186],
        [-0.4551,  0.3248, -0.0425,  0.1657],
        [-0.0325,  0.1682, -0.4246, -0.2704],
        [ 0.0720, -0.3882, -0.3585,  0.1431],
        [ 0.0539,  0.4325,  0.2371, -0.1504],
        [-0.2661, -0.2569,  0.1480, -0.2560]], requires_grad=True)


In [9]:
print(linear.bias)

Parameter containing:
tensor([ 0.3956,  0.3010,  0.2512, -0.0430,  0.4555, -0.4116,  0.1836, -0.3567],
       requires_grad=True)


### Add Positional Embedding
Just like words in a sentence need order, patches in an image need location info. We’ll add simple positional embeddings (learnable vectors). So that each vector encodes content(pixels) + pposition (where it came from).

In [10]:
# learnable positional embeddings for 4 patches, each 8 dim
pos_embed = nn.Parameter(torch.zeros(4,8))

# Add to pacth embeddings
embedded_with_pos = embedded + pos_embed
print(embedded_with_pos.shape)

torch.Size([4, 8])


In [11]:
embedded_with_pos

tensor([[-0.7326, -0.2139, -0.4815,  0.9404, -2.4267, -1.5183,  0.8126, -1.3019],
        [-0.2643, -1.4233, -1.2799,  0.9261, -3.5453, -2.5814,  1.9590, -2.5639],
        [ 1.1404, -5.0516, -3.6749,  0.8834, -6.9011, -5.7708,  5.3981, -6.3500],
        [ 1.6087, -6.2611, -4.4732,  0.8691, -8.0196, -6.8340,  6.5445, -7.6121]],
       grad_fn=<AddBackward0>)

### Mini Self-Attention Layer
Let’s mimic a tiny Transformer attention head on these 4 patch embeddings.

In [12]:
# Embedding size
d_model = 8

# Query/key dimension
d_k = d_model // 2

# projection matrices
W_q = nn.Linear(d_model, d_k, bias=False)
W_k = nn.Linear(d_model, d_k, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

# (4, d_k)
Q = W_q(embedded_with_pos)
# (4, d_k)
K = W_k(embedded_with_pos)
# (4, d_model)
V = W_v(embedded_with_pos)

# Attention scores (4,4)
scores = Q@K.T/(d_k ** 0.5)
attn_weights = scores.softmax(dim=-1)

# Weighted sum of values
attn_output = attn_weights @ V
print(attn_output.shape)

torch.Size([4, 8])


In [29]:
attn_output

tensor([[-1.6244e-01, -1.9080e+00,  1.5089e+00,  1.4638e+00,  1.9021e+00,
         -9.1543e-01,  5.5643e-02,  1.8941e+00],
        [ 6.4958e-02, -2.3306e+00,  1.4405e+00,  1.9354e+00,  2.1489e+00,
         -9.6926e-01,  4.9013e-02,  2.3018e+00],
        [ 1.5942e+00, -5.1724e+00,  9.8078e-01,  5.1069e+00,  3.8086e+00,
         -1.3313e+00,  4.4255e-03,  5.0439e+00],
        [ 1.8317e+00, -5.6138e+00,  9.0936e-01,  5.5996e+00,  4.0664e+00,
         -1.3875e+00, -2.5008e-03,  5.4699e+00]], grad_fn=<MmBackward0>)