<a href="https://colab.research.google.com/github/lizhieffe/canonical_llm_impl/blob/main/LLM_from_scratch_chap_03_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Tutorial for "LLM from Scratch" Chapter 03

https://drive.google.com/drive/u/1/folders/1a9jbhCJr_dddOT-m-4G9MgBTpOdaCs7Q

In [None]:
# @title Install Dependencies
!pip install uv && !uv pip install --system -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/requirements.txt

Collecting uv
  Downloading uv-0.7.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Downloading uv-0.7.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: uv
Successfully installed uv-0.7.2
/bin/bash: line 1: !uv: command not found


In [None]:
# @title Imports

import torch
import torch.nn as nn

# Simple Attention Calculation

In this section, we calculate a simple version of ATTN.
- The Q, K, V are the same vector.

- Use **softmax** as normalization:
  1. It handles extreme values well
  2. The output is positive and can be directly used as probability


In [None]:
# @title Impl with iteration

# Calculate the attention from inputs[1] to other tokens. (Q * K)

torch.manual_seed(1337)
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your
     [0.55, 0.87, 0.66], # journey (x^2)
     [0.57, 0.85, 0.64], # starts
     [0.22, 0.58, 0.33], # with
     [0.77, 0.25, 0.10], # one
     [0.05, 0.80, 0.55]] # step
)
print(f"inputs: {inputs.shape}")

query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
print(f"{attn_scores_2.shape=}")
print()

for i, input in enumerate(inputs):
  attn_scores_2[i] = torch.dot(query, input)

# Normalization Option 1: Naive
# This doesn't handle extreme value well
attn_scores_2_naive_normalized = attn_scores_2 / attn_scores_2.sum()
print(f"{attn_scores_2_naive_normalized=}")
assert attn_scores_2_naive_normalized.sum().numpy() - 1 < 1e-4
print()

# Normalization Option 2: Softmax
# This handles extreme value well
attn_scores_2_softmax_normalized = torch.nn.functional.softmax(attn_scores_2, dim=-1)
print(f"{attn_scores_2_softmax_normalized=}")
assert attn_scores_2_softmax_normalized.sum().numpy() - 1 < 1e-4

# Calculate the context vector ((Q * K) * V)
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i in range(len(inputs)):
  context_vec_2 += attn_scores_2_softmax_normalized[i] * inputs[i]
print(f"{context_vec_2=}")

inputs: torch.Size([6, 3])
attn_scores_2.shape=torch.Size([6])

attn_scores_2_naive_normalized=tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])

attn_scores_2_softmax_normalized=tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
context_vec_2=tensor([0.4419, 0.6515, 0.5683])


In [None]:
# @title Impl with MatMul

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your
     [0.55, 0.87, 0.66], # journey (x^2)
     [0.57, 0.85, 0.64], # starts
     [0.22, 0.58, 0.33], # with
     [0.77, 0.25, 0.10], # one
     [0.05, 0.80, 0.55]] # step
)
Q = K = V = inputs
print(f"{Q.shape=}")

attn = Q @ K.transpose(-1, -2) # [N, N]
attn = torch.nn.functional.softmax(attn, dim=-1)
for i in range(len(attn)):
  assert torch.isclose(attn[i].sum(), torch.tensor(1.0))
assert attn.shape == (6, 6)

y = attn @ V # [N, D]
assert y.shape == (6, 3)
print(f"{y[1]=}")
assert torch.allclose(y[1], context_vec_2)

Q.shape=torch.Size([6, 3])
y[1]=tensor([0.4419, 0.6515, 0.5683])


# Trainable ATTN

- Normalize by HIDDEN_DIM ** 0.5 is to improve the **training performance**. When the HIDDEN_DIM increases, the softmax values are more likely to be very small and cause vanishing gradient.

In [None]:
EMB_DIM = 3
HIDDEN_DIM = 256
torch.manual_seed(123)

<torch._C.Generator at 0x7c38bc6efa90>

In [None]:
# @title Impl version 1

wk = torch.nn.Parameter(torch.randn(EMB_DIM, HIDDEN_DIM), requires_grad=False)  # [E, H]
wq = torch.nn.Parameter(torch.randn(EMB_DIM, HIDDEN_DIM), requires_grad=False)
wv = torch.nn.Parameter(torch.randn(EMB_DIM, HIDDEN_DIM), requires_grad=False)

keys = inputs @ wk    # [N, H]
queries = inputs @ wq # [N, H]
values = inputs @ wv  # [N, H]
assert keys.shape == (inputs.shape[0], HIDDEN_DIM)

attn = keys @ queries.transpose(-1, -2) # [N, N]
attn /= HIDDEN_DIM ** 0.5
attn = torch.nn.functional.softmax(attn, dim=-1)
assert attn.shape == (inputs.shape[0], inputs.shape[0])
for i in range(len(attn)):
  assert torch.isclose(attn[i].sum(), torch.tensor(1.0))

y = attn @ values # [N, H]
assert y.shape == (inputs.shape[0], HIDDEN_DIM)

In [None]:
# @title Impl version 2

wk = torch.nn.Linear(EMB_DIM, HIDDEN_DIM)
wq = torch.nn.Linear(EMB_DIM, HIDDEN_DIM)
wv = torch.nn.Linear(EMB_DIM, HIDDEN_DIM)

K = wk(inputs) # [N, H]
Q = wq(inputs) # [N, H]
V = wv(inputs) # [N, H]
assert K.shape == (inputs.shape[0], HIDDEN_DIM)

attn = K @ Q.transpose(-1, -2) # [N, N]
attn /= HIDDEN_DIM ** 0.5
attn = torch.nn.functional.softmax(attn, dim=-1)
assert attn.shape == (inputs.shape[0], inputs.shape[0])
for i in range(len(attn)):
  assert torch.isclose(attn[i].sum(), torch.tensor(1.0))

y = attn @ V # [N, H]
assert y.shape == (inputs.shape[0], HIDDEN_DIM)

In [None]:
# @title impl as python class - V1

class SelfAttention_v1(nn.Module):
  def __init__(self, d_in, d_out):
    """Ctor.

    Args:
      d_in: Input (embedding) dimension.
      d_out: Output (hidden) dimension.
    """
    super().__init__()

    self.wk = torch.nn.Parameter(torch.randn(d_in, d_out))  # [E, H]
    self.wq = torch.nn.Parameter(torch.randn(d_in, d_out))
    self.wv = torch.nn.Parameter(torch.randn(d_in, d_out))

  def forward(self, x):
    assert x.shape[-1] == self.wk.shape[0]

    keys = x @ self.wk   # [N, H]
    queries = x @ self.wq # [N, H]
    values = x @ self.wv  # [N, H]

    attn = keys @ queries.transpose(-1, -2) # [N, N]
    attn /= HIDDEN_DIM ** 0.5
    attn = nn.functional.softmax(attn, dim=-1)

    res = attn @ values # [N, H]
    return res

# Test
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(EMB_DIM, 2)
y = sa_v1(inputs)
assert y.shape == (inputs.shape[0], 2)
y

tensor([[0.2867, 0.3880],
        [0.2868, 0.3881],
        [0.2868, 0.3881],
        [0.2868, 0.3874],
        [0.2868, 0.3869],
        [0.2868, 0.3878]], grad_fn=<MmBackward0>)

## impl as python class - V2

Compared to V1, the V2 uses nn.Linear() instead of raw nn.Parameter(). The benefits are:
1. Linear() has higher computation efficiency when there is no bias.
2. Linear() has more optimized parameter init scheme, leading to more stable and effective training.

In [None]:
class SelfAttention_v2(nn.Module):
  def __init__(self, d_in, d_out, qkv_bias=False):
    """Ctor.

    Args:
      d_in: Input (embedding) dimension.
      d_out: Output (hidden) dimension.
    """
    super().__init__()

    self.wk = torch.nn.Linear(d_in, d_out, bias=qkv_bias)  # [E, H]
    self.wq = torch.nn.Linear(d_in, d_out, bias=qkv_bias)  # [E, H]
    self.wv = torch.nn.Linear(d_in, d_out, bias=qkv_bias)  # [E, H]

  def forward(self, x):
    keys = self.wk(x)   # [N, H]
    queries = self.wq(x) # [N, H]
    values = self.wv(x)  # [N, H]

    attn = keys @ queries.transpose(-1, -2) # [N, N]
    attn /= HIDDEN_DIM ** 0.5
    attn = nn.functional.softmax(attn, dim=-1)

    res = attn @ values # [N, H]
    return res

# Test
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(EMB_DIM, 2)
y = sa_v2(inputs)
assert y.shape == (inputs.shape[0], 2)
y

tensor([[-0.5284, -0.1061],
        [-0.5283, -0.1063],
        [-0.5283, -0.1063],
        [-0.5280, -0.1063],
        [-0.5281, -0.1062],
        [-0.5280, -0.1063]], grad_fn=<MmBackward0>)

In [None]:
# @title Swap the v2's weights to v1

# This is to verify the implementaion.

sa_v1.wk.data = sa_v2.wk.weight.data.T
sa_v1.wq.data = sa_v2.wq.weight.data.T
sa_v1.wv.data = sa_v2.wv.weight.data.T

assert torch.allclose(sa_v1(inputs), sa_v2(inputs))

# Causal Attention

Add
1. causal mask
2. dropout before the softmax (another variant is to apply after the @V is done, but it is less common)

In [None]:
class CausalAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    """Ctor.

    Args:
      d_in: Input (embedding) dimension.
      d_out: Output (hidden) dimension.
    """
    super().__init__()

    self.wk = torch.nn.Linear(d_in, d_out, bias=qkv_bias)  # [E, H]
    self.wq = torch.nn.Linear(d_in, d_out, bias=qkv_bias)  # [E, H]
    self.wv = torch.nn.Linear(d_in, d_out, bias=qkv_bias)  # [E, H]
    self.droput = nn.Dropout(p=dropout)
    self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

  def forward(self, x):
    keys = self.wk(x)   # [N, H]
    queries = self.wq(x) # [N, H]
    values = self.wv(x)  # [N, H]

    attn = keys @ queries.transpose(-1, -2) # [N, N]
    # print(f"initial {attn=}")

    # Make the top triangle -inf so that they will be 0 after softmax
    mask = torch.triu(attn, diagonal=1)
    attn = attn.masked_fill(self.mask.bool(), -torch.inf)
    # print(f"after causal {attn=}")

    attn /= HIDDEN_DIM ** 0.5
    # print(f"normalized {attn=}")

    attn = nn.functional.softmax(attn, dim=-1)
    # print(f"softmax {attn=}")

    attn = self.droput(attn)
    # print(f"dropout {attn=}")

    res = attn @ values # [N, H]
    return res

# Test
torch.manual_seed(123)
sa_v3 = CausalAttention(EMB_DIM, 2, context_length=inputs.shape[-2], dropout=0.5)

batch = torch.stack((inputs, inputs))
print(f"{batch.shape=}")

y = sa_v3(batch)
print(f"{y.shape=}")

batch.shape=torch.Size([2, 6, 3])
y.shape=torch.Size([2, 6, 2])


# Multi-head attention

Run multiple attn mechanism multiple times in parallel.

## Less efficient implementation

Each head matmul is calculated separately

In [None]:
class MultiHeadAttention_v1(nn.Module):
  def __init__(self, d_in, d_out, context_length, num_heads, dropout, qkv_bias=False):
    super().__init__()
    assert d_out % num_heads == 0, "d_out must be divisible by num_heads!"
    self.heads = nn.ModuleList([CausalAttention(d_in, d_out // num_heads, context_length, dropout, qkv_bias) for _ in range(num_heads)])

  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], -1)

# Test

mha = MultiHeadAttention_v1(EMB_DIM, 16, context_length=inputs.shape[-2], num_heads=2, dropout=0.5)
y = mha(batch)
print(f"{y.shape=}")
assert y.shape == (batch.shape[0], batch.shape[1], 16)

y.shape=torch.Size([2, 6, 16])


## Efficient implementation

All heads matmul are combined.

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, num_heads, dropout, qkv_bias=False):
    super().__init__()

    assert d_out % num_heads == 0, "d_out must be divisible by num_heads!"

    self.heads = num_heads
    self.head_dim = d_out // num_heads

    self.wk = nn.Linear(d_in, d_out, bias=qkv_bias) # [E, H]
    self.wq = nn.Linear(d_in, d_out, bias=qkv_bias) # [E, H]
    self.wv = nn.Linear(d_in, d_out, bias=qkv_bias) # [E, H]
    self.droput = nn.Dropout(p=dropout)
    self.out_proj = nn.Linear(d_out, d_out)
    self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

  def forward(self, x):
    """Forward.

    Args:
      x: [B, N, E]

    Returns:
      [B, N, H]
    """
    b, n, d_in = x.shape

    k = self.wk(x) # [B, N, H]
    q = self.wq(x) # [B, N, H]
    v = self.wv(x) # [B, N, H]

    k = k.view(b, n, self.heads, self.head_dim).transpose(1, 2) # [B, HEADS, N, HEAD_DIM]
    q = q.view(b, n, self.heads, self.head_dim).transpose(1, 2) # [B, HEADS, N, HEAD_DIM]
    v = v.view(b, n, self.heads, self.head_dim).transpose(1, 2) # [B, HEADS, N, HEAD_DIM]

    attn = q @ k.transpose(-1, -2) # [B, HEADS, N, N]
    assert attn.shape == (b, self.heads, n, n)
    # print(f"Before causal: {attn=}")

    # [:n, :n] is to truncate to the length of input tokens.
    attn = attn.masked_fill(self.mask.bool()[:n, :n], -torch.inf)
    print(f"After causal: {attn[0][0]=}")

    attn /= self.head_dim ** 0.5
    attn = nn.functional.softmax(attn, dim=-1)
    print(f"After softmax: {attn[0][0]=}")
    attn = self.droput(attn)
    res = attn @ v # [B, HEADS, N, H]
    res = res.transpose(1, 2).contiguous().view(b, n, -1) # [B, N, H]

    res = self.out_proj(res)  # [B, N, H]

    return res

# Test
mha = MultiHeadAttention(EMB_DIM, 16, context_length=inputs.shape[-2], num_heads=2, dropout=0.5)
y = mha(batch)
assert y.shape == (batch.shape[0], batch.shape[1], 16)
print(f"{y.shape=}")

After causal: attn[0][0]=tensor([[-0.4106,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.6259, -0.3525,    -inf,    -inf,    -inf,    -inf],
        [-0.6083, -0.3261, -0.3280,    -inf,    -inf,    -inf],
        [-0.3780, -0.2722, -0.2715, -0.1127,    -inf,    -inf],
        [-0.1209,  0.2420,  0.2340,  0.1942,  0.0236,    -inf],
        [-0.5692, -0.5374, -0.5333, -0.2587, -0.3094, -0.3200]],
       grad_fn=<SelectBackward0>)
After softmax: attn[0][0]=tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4759, 0.5241, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3116, 0.3443, 0.3441, 0.0000, 0.0000, 0.0000],
        [0.2395, 0.2487, 0.2487, 0.2631, 0.0000, 0.0000],
        [0.1838, 0.2090, 0.2084, 0.2055, 0.1934, 0.0000],
        [0.1580, 0.1598, 0.1600, 0.1764, 0.1732, 0.1726]],
       grad_fn=<SelectBackward0>)
y.shape=torch.Size([2, 6, 16])
