## Let's Build a Toy Model
We will do the following here:
 - A 4x4 toy image 
 - Split into 2x2 patches (4 pacthes total)
 - Flatten each patch to a vector of length 4
 - Linear projection mapping each 4 dimensional patch to an 8 dimensional embedding.

In [2]:
import torch
import torch.nn as nn

### Make a toy image
A 4x4 grayscale image of 1 channel

In [18]:
img = torch.arange(16).reshape(4,4).float()
print(img)

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])


### Split into Patches
Let's say patch size - 2x2

In [19]:
patches = img.unfold(0,2,2).unfold(1,2,2)
print(patches.shape)

torch.Size([2, 2, 2, 2])


In [20]:
print(patches)

tensor([[[[ 0.,  1.],
          [ 4.,  5.]],

         [[ 2.,  3.],
          [ 6.,  7.]]],


        [[[ 8.,  9.],
          [12., 13.]],

         [[10., 11.],
          [14., 15.]]]])


### Flatten Patches

In [21]:
patches = patches.contiguous().view(-1,2*2)
print(patches)

tensor([[ 0.,  1.,  4.,  5.],
        [ 2.,  3.,  6.,  7.],
        [ 8.,  9., 12., 13.],
        [10., 11., 14., 15.]])


In [7]:
patches.shape

torch.Size([4, 4])

### Linear Projection (Patch Embedding)

In [8]:
# In 4 --> out 8
linear = nn.Linear(4,8)
embedded = linear(patches)
print("Shape: ", embedded.shape)
print(embedded)

Shape:  torch.Size([4, 8])
tensor([[-2.3452, -1.7857,  3.0163, -1.2413,  0.2759, -2.3244,  3.0924, -0.2377],
        [-3.3043, -3.2541,  6.2068, -1.6844,  0.7843, -2.8811,  3.2565,  0.3049],
        [-6.1814, -7.6593, 15.7786, -3.0138,  2.3094, -4.5511,  3.7487,  1.9329],
        [-7.1404, -9.1278, 18.9691, -3.4569,  2.8178, -5.1077,  3.9128,  2.4755]],
       grad_fn=<AddmmBackward0>)


**Random Weights and Biases for Now**

The embedding will be random at start since weights are random. Let's check the weights and biases.

In [9]:
print(linear.weight)

Parameter containing:
tensor([[-0.0786,  0.0531, -0.0895, -0.3645],
        [-0.0107, -0.4378, -0.0383, -0.2474],
        [ 0.4826,  0.4910,  0.1220,  0.4997],
        [ 0.3217, -0.4536,  0.2054, -0.2951],
        [ 0.3103,  0.0790, -0.4601,  0.3250],
        [ 0.1173,  0.1120, -0.4532, -0.0544],
        [-0.4349, -0.1241,  0.3976,  0.2434],
        [ 0.4765, -0.1958,  0.4708, -0.4802]], requires_grad=True)


In [10]:
print(linear.bias)

Parameter containing:
tensor([-0.2175,  0.0424, -0.4613, -0.1339,  0.4123, -0.3515,  0.4091,  0.4758],
       requires_grad=True)


### Add Positional Embedding
Just like words in a sentence need order, patches in an image need location info. We’ll add simple positional embeddings (learnable vectors). So that each vector encodes content(pixels) + pposition (where it came from).

In [22]:
# learnable positional embeddings for 4 patches, each 8 dim
pos_embed = nn.Parameter(torch.zeros(4,8))

# Add to pacth embeddings
embedded_with_pos = embedded + pos_embed
print(embedded_with_pos.shape)

torch.Size([4, 8])


In [23]:
embedded_with_pos

tensor([[-2.3452, -1.7857,  3.0163, -1.2413,  0.2759, -2.3244,  3.0924, -0.2377],
        [-3.3043, -3.2541,  6.2068, -1.6844,  0.7843, -2.8811,  3.2565,  0.3049],
        [-6.1814, -7.6593, 15.7786, -3.0138,  2.3094, -4.5511,  3.7487,  1.9329],
        [-7.1404, -9.1278, 18.9691, -3.4569,  2.8178, -5.1077,  3.9128,  2.4755]],
       grad_fn=<AddBackward0>)

### Mini Self-Attention Layer
Let’s mimic a tiny Transformer attention head on these 4 patch embeddings.

In [27]:
# Embedding size
d_model = 8

# Query/key dimension
d_k = d_model // 2

# projection matrices
W_q = nn.Linear(d_model, d_k, bias=False)
W_k = nn.Linear(d_model, d_k, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

# (4, d_k)
Q = W_q(embedded_with_pos)
# (4, d_k)
K = W_k(embedded_with_pos)
# (4, d_model)
V = W_v(embedded_with_pos)

# Attention scores (4,4)
scores = Q@K.T/(d_k ** 0.5)
attn_weights = scores.softmax(dim=-1)

# Weighted sum of values
attn_output = attn_weights @ V
print(attn_output.shape)

torch.Size([4, 8])


In [29]:
attn_output

tensor([[-1.6244e-01, -1.9080e+00,  1.5089e+00,  1.4638e+00,  1.9021e+00,
         -9.1543e-01,  5.5643e-02,  1.8941e+00],
        [ 6.4958e-02, -2.3306e+00,  1.4405e+00,  1.9354e+00,  2.1489e+00,
         -9.6926e-01,  4.9013e-02,  2.3018e+00],
        [ 1.5942e+00, -5.1724e+00,  9.8078e-01,  5.1069e+00,  3.8086e+00,
         -1.3313e+00,  4.4255e-03,  5.0439e+00],
        [ 1.8317e+00, -5.6138e+00,  9.0936e-01,  5.5996e+00,  4.0664e+00,
         -1.3875e+00, -2.5008e-03,  5.4699e+00]], grad_fn=<MmBackward0>)