<a href="https://colab.research.google.com/github/dominiksakic/zero_to_hero/blob/main/adv_03_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GOALS:
- EX1: The n-dimensional tensor mastery challenge: Combine the `Head` and `MultiHeadAttention` into one class that processes all the heads in parallel, treating the heads as another batch dimension (answer is in nanoGPT). [CLEAR]
- EX2: Train the GPT on your own dataset of choice! What other data could be fun to blabber on about? (A fun advanced suggestion if you like: train a GPT to do addition of two numbers, i.e. a+b=c. You may find it helpful to predict the digits of c in reverse order, as the typical addition algorithm (that you're hoping it learns) would proceed right to left too. You may want to modify the data loader to simply serve random problems and skip the generation of train.bin, val.bin. You may want to mask out the loss at the input positions of a+b that just specify the problem using y=-1 in the targets (see CrossEntropyLoss ignore_index). Does your Transformer learn to add? Once you have this, swole doge project: build a calculator clone in GPT, for all of +-*/. Not an easy problem. You may need Chain of Thought traces.)
- EX3: Find a dataset that is very large, so large that you can't see a gap between train and val loss. Pretrain the transformer on this data, then initialize with that model and finetune it on tiny shakespeare with a smaller number of steps and lower learning rate. Can you obtain a lower validation loss by the use of pretraining?
- EX4: Read some transformer papers and implement one additional feature or change that people seem to use. Does it improve the performance of your GPT?

In [1]:
## STARTING CODE ##
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

## LOAD DATA
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt


with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

## PREP DATA
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i  for i, ch in enumerate(chars)}
itos = {i : ch for i, ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

--2025-08-03 07:56:37--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-08-03 07:56:39 (168 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [2]:
# Deep dive of how the current implementation:

# out = torch.cat([h(x) for h in self.heads], dim=-1)

# input
B, T , C = 2, 4, 8
n_head  = 2
x = torch.randn(B, T, C)

print("head-1")
head_size = C // n_head
key = nn.Linear(8, head_size, bias=False)
query = nn.Linear(8, head_size, bias=False)
value = nn.Linear(8, head_size, bias=False)

B, T, C = x.shape
k = key(x)   # (B,T,hs)
print(f'k shape: {k.shape}')
q = query(x) # (B,T,hs)
print(f'q shape: {q.shape}')

wei = q @ k.transpose(-2, -1) * (C ** -0.5)  # (B,T,T)
wei = wei.masked_fill(torch.tril(torch.ones(T, T, device=x.device)) == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)  # (B,T,hs)
out_h1 = wei @ v

print("head-2")
head_size = C // n_head
key = nn.Linear(8, head_size, bias=False)
query = nn.Linear(8, head_size, bias=False)
value = nn.Linear(8, head_size, bias=False)

B, T, C = x.shape
k = key(x)   # (B,T,hs)
print(f'k shape: {k.shape}')
q = query(x) # (B,T,hs)
print(f'q shape: {q.shape}')

wei = q @ k.transpose(-2, -1) * (C ** -0.5)  # (B,T,T)
wei = wei.masked_fill(torch.tril(torch.ones(T, T, device=x.device)) == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)  # (B,T,hs)
out_h2 = wei @ v

print(out_h1.shape)
print(out_h2.shape)

result = torch.cat([out_h1, out_h2], dim=-1)
print(result.shape)

head-1
k shape: torch.Size([2, 4, 4])
q shape: torch.Size([2, 4, 4])
head-2
k shape: torch.Size([2, 4, 4])
q shape: torch.Size([2, 4, 4])
torch.Size([2, 4, 4])
torch.Size([2, 4, 4])
torch.Size([2, 4, 8])


In [3]:
### TOY EXAMPLE - STEP 1:
B, T , C = 2, 4, 8
n_head  = 2
head_dim = C // n_head
x = torch.randn(B, T, C)

# Step 2 One big QKV Projection
qkv_proj = torch.nn.Linear(C, 3 * C, bias=False)
qkv = qkv_proj(x)  # (B, T, 3 * C)
print(qkv.shape)

# Step 3: Split Into Q, K, V
q, k, v = qkv.chunk(3, dim=-1)
print(f'q shape: {q.shape}')
print(f'k shape: {k.shape}')
print(f'v shape: {v.shape}')


# Step-by-step:
print(f"\nq shape before with the two heads combined: {q.shape}")
print(f"From (B, T, C) to (B, T, n_head, head_dim)")
print(f"split the two heads: {q.view(B, T, n_head, head_dim).shape}")
print(f"# Then to (B, n_head, T, head_dim)")
print(f"rearange the two heads: {q.view(B, T, n_head, head_dim).transpose(1, 2).shape}")

q = q.view(B, T, n_head, head_dim).transpose(1, 2)  # (B, n_head, T, head_dim)
k = k.view(B, T, n_head, head_dim).transpose(1, 2)
v = v.view(B, T, n_head, head_dim).transpose(1, 2)

# Step 5: Scaled Dot Product Attention (Parrallel)
att = (q @ k.transpose(-2, -1)) / (head_dim ** 0.5)  # (B, n_head, T, T)
mask = torch.tril(torch.ones(T, T)).to(x.device)
att = att.masked_fill(mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)

# Step 6: Weighted Sum over Values
out = att @ v  # (B, n_head, T, head_dim)
print(f"\nVal after attention: {out.shape}")


# Step 7 Merge Heads Back Together
print(f'\n Rearange the tensor to (B, T, n_head, n_dim): {out.transpose(1, 2).shape}')
print(f'\n From (B, T, n_head, n_dim) to (B, T, C): {out.transpose(1, 2).contiguous().view(B, T, C).shape}')
out = out.transpose(1, 2).contiguous().view(B, T, C)  # (B, T, C)

torch.Size([2, 4, 24])
q shape: torch.Size([2, 4, 8])
k shape: torch.Size([2, 4, 8])
v shape: torch.Size([2, 4, 8])

q shape before with the two heads combined: torch.Size([2, 4, 8])
From (B, T, C) to (B, T, n_head, head_dim)
split the two heads: torch.Size([2, 4, 2, 4])
# Then to (B, n_head, T, head_dim)
rearange the two heads: torch.Size([2, 2, 4, 4])

Val after attention: torch.Size([2, 2, 4, 4])

 Rearange the tensor to (B, T, n_head, n_dim): torch.Size([2, 4, 2, 4])

 From (B, T, n_head, n_dim) to (B, T, C): torch.Size([2, 4, 8])


In [4]:
# Side Quest Contiguous vs Non Contiguous Tensors

"""
Contiguous tensor is where the data is stored in a memory block in go go.

Non Contiguous tensor us where the shape of the data dosent match how the data is laid out.
"""

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(x)
print(x.is_contiguous())
print(x.view(-1))


x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x.t()  # Transpose it → now shape (3, 2)
print(y)
print(y.is_contiguous())

print("\nUse contiguous to use view")
y = y.contiguous()
y.view(-1)

print("\nCheck Memory Layout with .stride()")
x = torch.randn(2, 3)
print(x)
print(x.stride())         # (3, 1)
print(x.t().stride())     # (1, 3)

tensor([[1, 2, 3],
        [4, 5, 6]])
True
tensor([1, 2, 3, 4, 5, 6])
tensor([[1, 4],
        [2, 5],
        [3, 6]])
False

Use contiguous to use view

Check Memory Layout with .stride()
tensor([[ 1.9366, -1.1943, -0.2380],
        [ 1.2358,  0.8549,  0.4589]])
(3, 1)
(1, 3)


In [5]:
# Strides
import torch

a = torch.tensor([1])
print("\nShape:", a.shape)
print("Strides:", a.stride())
print(a)

a = torch.tensor([1,2])
print("\nShape:", a.shape)
print("Strides:", a.stride())
print(a)

a = torch.tensor([[1],[1]])
print("\nShape:", a.shape)
print("Strides:", a.stride())
print(a)

a = torch.tensor([[1,2],[1,2]])
print("\nShape:", a.shape)
print("Strides:", a.stride())
print(a)

a = torch.tensor([[1,2,3],[1,2,3]])
print("\nShape:", a.shape)
print("Strides:", a.stride())
print(a)

a = torch.tensor([[1,2,3],[1,2,3]])
print("\nShape:", a.shape)
print("Strides:", a.stride())
print(a)

a = torch.tensor([[[1],[1]],[[1],[1]]])
print("\nShape:", a.shape)
print("Strides:", a.stride())
print(a)

a = torch.tensor([[[1,2],[1,2]],[[1,2],[1,2]]])
print("\nShape:", a.shape)
print("Strides:", a.stride())
print(a)


Shape: torch.Size([1])
Strides: (1,)
tensor([1])

Shape: torch.Size([2])
Strides: (1,)
tensor([1, 2])

Shape: torch.Size([2, 1])
Strides: (1, 1)
tensor([[1],
        [1]])

Shape: torch.Size([2, 2])
Strides: (2, 1)
tensor([[1, 2],
        [1, 2]])

Shape: torch.Size([2, 3])
Strides: (3, 1)
tensor([[1, 2, 3],
        [1, 2, 3]])

Shape: torch.Size([2, 3])
Strides: (3, 1)
tensor([[1, 2, 3],
        [1, 2, 3]])

Shape: torch.Size([2, 2, 1])
Strides: (2, 1, 1)
tensor([[[1],
         [1]],

        [[1],
         [1]]])

Shape: torch.Size([2, 2, 2])
Strides: (4, 2, 1)
tensor([[[1, 2],
         [1, 2]],

        [[1, 2],
         [1, 2]]])


In [6]:
y = torch.arange(0,12).view(2,2,3)
# Expect [[0, 1, 2],[3, 4, 5]]  [[6,7,8],[9, 10 ,11]]
print(y.stride())
print("i+1 travels along the axis 0, i+3 travels along axis 1, i+6 along axis 2")

(6, 3, 1)
i+1 travels along the axis 0, i+3 travels along axis 1, i+6 along axis 2


```
Index in memory   Corresponds to y[...]        Value
-----------------------------------------------------
0                 y[0, 0, 0]                   0
1                 y[0, 0, 1]                   1
2                 y[0, 0, 2]                   2
3                 y[0, 1, 0]                   3
4                 y[0, 1, 1]                   4
5                 y[0, 1, 2]                   5
6                 y[1, 0, 0]                   6
7                 y[1, 0, 1]                   7
8                 y[1, 0, 2]                   8
9                 y[1, 1, 0]                   9
10                y[1, 1, 1]                  10
11                y[1, 1, 2]                  11

```

- Axis 0 stride: +6 (0 → 6, 3 → 9)

- Axis 1 stride: +3 (0 → 3, 6 → 9)

- Axis 2 stride: +1 (0 → 1 → 2, etc.)

- Memory memory offset = i * 6 + j * 3 + k * 1
- i = 0 , j = 0, k = 0
- Offset is i * 6 + j * 3 + k * 1
- value 0
- i = 0 , j = 0, k = 1
- Offset is i * 6 + j * 3 + 1 * 1
- value 1
...

In [7]:
y_t = y.transpose(0, 1)
print(y_t.shape)
print(y_t.stride())
print(y_t)

torch.Size([2, 2, 3])
(3, 6, 1)
tensor([[[ 0,  1,  2],
         [ 6,  7,  8]],

        [[ 3,  4,  5],
         [ 9, 10, 11]]])


In [8]:
x = torch.arange(12)
print(x.view(6,2))
print(x.view(6,2).stride())
print("transpose 0,1")
print(x.view(6,2).transpose(0,1))
print(x.view(6,2).transpose(0,1).stride())

tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11]])
(2, 1)
transpose 0,1
tensor([[ 0,  2,  4,  6,  8, 10],
        [ 1,  3,  5,  7,  9, 11]])
(1, 2)


In [18]:
def apply_rope(q, k):
    # q, k: (B, T, C), where C must be even
    B, T, C = q.shape
    half = C // 2
    freqs = torch.exp(-torch.arange(0, half, dtype=torch.float32) * math.log(10000) / half).to(q.device)  # (half,)
    positions = torch.arange(T, device=q.device).float()  # (T,)
    angles = torch.einsum('t,d->td', positions, freqs)  # (T, half)
    sin = angles.sin().unsqueeze(0)  # (1, T, half)
    cos = angles.cos().unsqueeze(0)  # (1, T, half)

    q1, q2 = q[..., :half], q[..., half:]
    k1, k2 = k[..., :half], k[..., half:]
    q_rotated = torch.cat([q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1)
    k_rotated = torch.cat([k1 * cos - k2 * sin, k1 * sin + k2 * cos], dim=-1)
    return q_rotated, k_rotated

In [19]:
batch_size = 16
block_size = 32
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0

torch.manual_seed(1337)


@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [20]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)

        # 🌀 Apply RoPE here!
        q, k = apply_rope(q, k)

        wei = q @ k.transpose(-2, -1) * (C ** -0.5)  # (B,T,T)
        wei = wei.masked_fill(torch.tril(torch.ones(T, T, device=x.device)) == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)  # (B,T,hs)
        out = wei @ v      # (B,T,hs)
        return out


class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()

        # Goal: (B, T, C) → (B, T, n_head, head_dim)

        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [21]:
class GPTLanguageModelRoPE(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        x = self.blocks(tok_emb) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = GPTLanguageModelRoPE()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

0.207681 M parameters


In [22]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [23]:
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.1342, val loss 4.1351
step 100: train loss 2.5553, val loss 2.5588
step 200: train loss 2.3960, val loss 2.4000
step 300: train loss 2.2726, val loss 2.2647
step 400: train loss 2.1925, val loss 2.2111
step 500: train loss 2.1081, val loss 2.1352
step 600: train loss 2.0624, val loss 2.1023
step 700: train loss 2.0100, val loss 2.0619
step 800: train loss 1.9947, val loss 2.0386
step 900: train loss 1.9528, val loss 2.0154
step 1000: train loss 1.9155, val loss 1.9816
step 1100: train loss 1.8957, val loss 1.9876
step 1200: train loss 1.8621, val loss 1.9671
step 1300: train loss 1.8574, val loss 1.9641
step 1400: train loss 1.8243, val loss 1.9581
step 1500: train loss 1.8157, val loss 1.9177
step 1600: train loss 1.8071, val loss 1.9367
step 1700: train loss 1.7676, val loss 1.9163
step 1800: train loss 1.7758, val loss 1.9300
step 1900: train loss 1.7566, val loss 1.8964
step 2000: train loss 1.7513, val loss 1.8910
step 2100: train loss 1.7347, val loss 1.8899


In [24]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))




CORIORS:
Nock?

STANLO:
Gree
Your take you encatizen my dage and bard thy put him to bard:
Are ane away, my fears accezorous
Yorks, to that I commillod!

ARCAS:
Helps, I in latter, drop the deep me nor timerabs!

AUFORD:
Though she couriby:
Suppiniss him young to send more signion.

DUKE OF GARS:

Go Willon him eimselves thought will no deeds poor of his but than nuntry brink;
And he must with allook, figh Prince, So my Had so, are you as
ards brings and thining mischard.

CRIUTIO:
Farew thoug


In [10]:
batch_size = 16
block_size = 32
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0

torch.manual_seed(1337)


@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [11]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()

        # Goal: (B, T, C) → (B, T, n_head, head_dim)
        self.n_head = n_head
        self.head_size = head_size
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

        # self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

    def forward(self, x):
        k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)  # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)  # (B, nh, T, hs)

        q, k = apply_rope(q, k)  # (B, nh, T, hs)

        wei = q @ k.transpose(-2, -1) * (self.head_size ** -0.5)  # (B, nh, T, T)
        wei = wei.masked_fill(torch.tril(torch.ones(T, T, device=x.device)) == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)  # (B, nh, T, hs)
        out = wei @ v  # (B, nh, T, hs)
        out = out.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_size)  # (B, T, C)
        return self.dropout(self.proj(out))


class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [12]:
class GPTLanguageModelRoPE(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        x = self.blocks(tok_emb) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = GPTLanguageModelRoPE()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

0.170817 M parameters


In [15]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [16]:
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.2168, val loss 4.2160
step 100: train loss 2.5629, val loss 2.5662
step 200: train loss 2.3928, val loss 2.3917
step 300: train loss 2.2762, val loss 2.2927
step 400: train loss 2.1660, val loss 2.1880
step 500: train loss 2.1111, val loss 2.1362
step 600: train loss 2.0661, val loss 2.1055
step 700: train loss 2.0084, val loss 2.0638
step 800: train loss 1.9688, val loss 2.0263
step 900: train loss 1.9484, val loss 2.0119
step 1000: train loss 1.9265, val loss 2.0104
step 1100: train loss 1.9016, val loss 1.9851
step 1200: train loss 1.8652, val loss 1.9619
step 1300: train loss 1.8453, val loss 1.9359
step 1400: train loss 1.8210, val loss 1.9163
step 1500: train loss 1.8061, val loss 1.9274
step 1600: train loss 1.7991, val loss 1.9350
step 1700: train loss 1.7722, val loss 1.9135
step 1800: train loss 1.7670, val loss 1.8913
step 1900: train loss 1.7542, val loss 1.8879
step 2000: train loss 1.7535, val loss 1.8984
step 2100: train loss 1.7462, val loss 1.8882


In [17]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


Will before will affect to be made to bear to take OF IS:
Falleans bard thy pusquest. that dilahoate away, my feans, to zonour
Young-moof it her, this nowl, with is ensent, will is till vising,
Which now till about like dise of the mild him speak; and the tyban'st,
He mayore sign sweet
the sclike again Willo when evil so, and doubt, The shire
sto-LiKe him to
he kindness firs son; if his shate.

RICHARD:
Your Prince, Sometranger and tooblisa
ards be his greature hence as is no very sold
To for th


# Result:
- First Experiment - No Parallel heads
  - final step train loss 1.5858, val loss 1.7777
  - Execution Time: ~ 8min 30

- Second Experiment - Parallel heads:
  - final step: train loss 1.5866, val loss 1.7530
  - Execution Time: ~ 8min 30

# Observation
- Sequential heads have more parameters: 3 x n_head Linear layers.

