<a href="https://colab.research.google.com/github/chen-star/llm_model_trainings/blob/main/3_transformer_impl_attention_qkv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# `Attention(Q, K, V) = softmax(QT(K) /  ‚àödk  + M) V`

```
Q = XW_Q
K = XW_K
V = XW_V
```

# ‚úà Imports

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

import time

# [0] üáπ Randomly generate data

In [2]:
# parameters
vocabulary_size = 40
batch_size = 4
embedding_dimension = 13
context_window_size = 8

In [3]:
# randomly generated token_ids
token_ids = torch.randint(low=0, high=vocabulary_size, size=(batch_size, context_window_size))
print(f' token_ids.shape: {token_ids.shape}.\n\n token_ids: \n {token_ids}')

 token_ids.shape: torch.Size([4, 8]).

 token_ids: 
 tensor([[ 4,  0, 31,  9, 37, 29, 33,  7],
        [30, 16, 18,  1, 38, 31,  7, 34],
        [ 2,  6, 18, 37,  2, 11, 31, 15],
        [30, 31, 29,  0, 26,  6, 14, 10]])


In [4]:
# define a embedding layer
embedding_layer = nn.Embedding(vocabulary_size, embedding_dimension)

In [5]:
# token_ids to embedding vectors
X = embedding_layer(token_ids)
print(f"X.shape: {X.shape}") # [batch_size, context_window_size, embedding_dimension]

X.shape: torch.Size([4, 8, 13])


# [1] ‚úç Manual compute Attention(Q, K, V)

In [6]:
# define W_Q, W_K, W_V
q_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)
k_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)
v_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)
print(f"q_layer.weight.shape: {q_layer.weight.shape}")
print(f"k_layer.weight.shape: {k_layer.weight.shape}")
print(f"v_layer.weight.shape: {v_layer.weight.shape}")

q_layer.weight.shape: torch.Size([13, 13])
k_layer.weight.shape: torch.Size([13, 13])
v_layer.weight.shape: torch.Size([13, 13])


In [7]:
# Q = XW_Q
# K = XW_K
# V = XW_V
Q = q_layer(X)
K = k_layer(X)
V = v_layer(X)
print(f"Q.shape: {Q.shape}")
print(f"K.shape: {K.shape}")
print(f"V.shape: {V.shape}")

Q.shape: torch.Size([4, 8, 13])
K.shape: torch.Size([4, 8, 13])
V.shape: torch.Size([4, 8, 13])


Actual implementation ...

`Attention(Q, K, V) = softmax(QT(K) / ‚àödk + M) V`

In [10]:
# QT(K)
# Very similar to "consine simlary between Q and K"
TK = K.transpose(-2, -1) # only transpose non-batch dimension
QTK = Q @ TK

# QT(K) / ‚àödk
QTK_scaled = QTK / (embedding_dimension**0.5)

# M
M = torch.tril(torch.ones(batch_size, context_window_size, context_window_size))

# QT(K) / ‚àödk + M
QTK_scaled[M==0] = -torch.inf # ignore future values
QTK_timed = QTK_scaled

# softmax(QT(K) / ‚àödk + M)
QTK_softmax = F.softmax(QTK_timed, dim=-1)

# softmax(QT(K) / ‚àödk + M) V
attention_score_matrix_manual = QTK_softmax @ V

In [11]:
print(f"attention_score_matrix_manual.shape: {attention_score_matrix_manual.shape}")

attention_score_matrix_manual.shape: torch.Size([4, 8, 13])


# [2] üíª pytorch compute Attention(Q, K, V)

In [13]:
attention_score_matrix_pytorch = F.scaled_dot_product_attention(Q, K, V,is_causal=True)

In [14]:
print(f'attention_score_matrix_pytorch.shape: {attention_score_matrix_pytorch.shape}')

attention_score_matrix_pytorch.shape: torch.Size([4, 8, 13])


# [3] ‚õ≥ Compare

In [15]:
print(f"Manual: \n {attention_score_matrix_manual[0,:,]}")
print(f"Pytorch: \n {attention_score_matrix_pytorch[0,:,]}")
print(f"Diff: \n {attention_score_matrix_manual[0,:,] - attention_score_matrix_pytorch[0,:,]}")

Manual: 
 tensor([[-0.3066, -0.3039, -0.3997, -0.4372, -0.5935, -0.5126, -0.0176,  0.0851,
          0.1074, -0.3081,  0.2435, -0.3186, -0.3887],
        [-0.4857,  0.0479, -0.1943, -0.2638, -0.5883, -0.3097, -0.0661, -0.0534,
         -0.2109, -0.0634,  0.3379, -0.1422,  0.0852],
        [-0.3166,  0.0014, -0.4696,  0.0573, -0.3855, -0.5474,  0.2758, -0.1532,
         -0.3485,  0.2125,  0.1137,  0.0491,  0.0582],
        [-0.4392,  0.0222, -0.3426,  0.1745, -0.3594, -0.4893,  0.2037, -0.0517,
         -0.3698, -0.0495,  0.3026,  0.1931,  0.2092],
        [-0.3153,  0.3241, -0.1499,  0.4223, -0.2624, -0.0969,  0.1800, -0.1778,
         -0.3410, -0.1705,  0.2683,  0.1369,  0.3790],
        [-0.1241,  0.2801, -0.1456,  0.3337, -0.2018, -0.1631,  0.1281, -0.1919,
         -0.3247,  0.0469,  0.0958,  0.1180,  0.2497],
        [-0.2506,  0.2859, -0.2585,  0.4661, -0.1626, -0.0459,  0.2230, -0.1093,
         -0.4111, -0.1289,  0.1709,  0.1349,  0.3194],
        [-0.1553,  0.3317, -0.2657,  0