<a href="https://colab.research.google.com/github/chefPony/nn_zero_to_hero/blob/master/gpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Building a Transformer from scratch

## Load Data

In [1]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-09-22 13:12:28--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-09-22 13:12:28 (18.1 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [2]:
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [3]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}

encode = lambda x: [stoi[c] for c in x]
decode = lambda x: [itos[i] for i in x]
print(vocab_size)
print("".join(chars))
print(itos)
print(stoi)

65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i', 48: 'j', 49: 'k', 50: 'l', 51: 'm', 52: 'n', 53: 'o', 54: 'p', 55: 'q', 56: 'r', 57: 's', 58: 't', 59: 'u', 60: 'v', 61: 'w', 62: 'x', 63: 'y', 64: 'z'}
{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b

In [5]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

encoded_data = torch.tensor([stoi[c] for c in text])
print(encoded_data.shape)
print(encoded_data[:1000])

train_data = encoded_data[:int(len(text) * 0.9)]
val_data = encoded_data[int(len(text) * 0.9):]
print(train_data.shape)
print(val_data.shape)

torch.Size([1115394])
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57

In [6]:
torch.manual_seed(42)

def get_batch(data: torch.Tensor, batch_size: int, block_size: int):
  ix = torch.randint(low=0, high=data.shape[0] - block_size, size=(batch_size,))
  xb = torch.stack([data[i : i+block_size] for i in ix]).to(device)
  yb = torch.stack([data[i+1 : i+block_size+1] for i in ix]).to(device)
  return xb, yb

@torch.no_grad()
def evaluate_model(model, batch_size, block_size, num_batches=100):
  loss_tr, loss_va = torch.zeros((num_batches, )), torch.zeros((num_batches, ))
  model.eval()
  for k in range(num_batches):
    xb_tr, yb_tr = get_batch(train_data, batch_size, block_size)
    xb_va, yb_va = get_batch(val_data, batch_size, block_size)
    yhtr, loss_tr[k] = model(xb_tr, yb_tr)
    yhva, loss_va[k] = model(xb_va, yb_va)
  model.train()
  return loss_tr.mean().item(), loss_va.mean().item()

xb, yb = get_batch(encoded_data, 4, 8)
print(xb.shape)
print(yb.shape)
print("batch")
print(xb)
print("target")
print(yb)

torch.Size([4, 8])
torch.Size([4, 8])
batch
tensor([[42,  1, 58, 46, 59, 57,  1, 21],
        [54, 56, 47, 43, 57, 58, 11,  0],
        [49, 47, 52, 45, 12,  1, 58, 46],
        [58, 46, 53, 59, 58,  1, 56, 43]], device='cuda:0')
target
tensor([[ 1, 58, 46, 59, 57,  1, 21,  1],
        [56, 47, 43, 57, 58, 11,  0, 37],
        [47, 52, 45, 12,  1, 58, 46, 53],
        [46, 53, 59, 58,  1, 56, 43, 42]], device='cuda:0')


## Baseline: Bigram Model

In [7]:
import torch.nn as nn
from torch.nn import functional as F

class BigramModel(nn.Module):

  def __init__(self, vocab_size: int):
    super().__init__()
    self.logits = nn.Embedding(num_embeddings=vocab_size, embedding_dim=vocab_size)


  def forward(self, x, targets=None):
    logits = self.logits(x) # (B, T, C)
    if targets is None:
      loss = None
    else:
       B, T, C = logits.shape
       loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T))
    return logits, loss

  @torch.no_grad
  def generate(self, idx, max_new_tokens):
    B, T = idx.shape
    for _ in range(max_new_tokens):
      logits, _ = self(idx[:, -T:])  #(B, n_vocab)
      probs = F.softmax(logits[:, -1, :], dim=-1)
      next_idx = torch.multinomial(probs, 1)
      idx = torch.cat([idx, next_idx], dim=1)
    return idx


xb, yb = get_batch(train_data, batch_size=8, block_size=2)
bigram = BigramModel(vocab_size).to(device)
print(bigram(xb, yb))
print("".join(decode(bigram.generate(torch.tensor([[0 , 0]], device=device), 100).squeeze().tolist())))

(tensor([[[-0.1827,  0.0524, -1.8020,  ..., -0.4538,  0.6346, -1.4856],
         [-1.1441,  0.3383,  1.6992,  ...,  0.9254,  1.4805,  0.3449]],

        [[-0.1827,  0.0524, -1.8020,  ..., -0.4538,  0.6346, -1.4856],
         [-1.2800,  0.1359, -1.2744,  ...,  1.1272,  0.5445, -0.2186]],

        [[-1.0800,  1.4510, -0.3488,  ...,  2.1158,  0.2643, -0.2391],
         [ 0.4121, -1.9089, -0.0616,  ..., -0.6875,  0.2056, -0.7192]],

        ...,

        [[-0.4899, -1.5937,  0.9481,  ...,  0.8930,  1.6673, -1.0136],
         [ 2.5165, -0.0862,  0.1101,  ..., -0.6886, -0.2301,  0.0784]],

        [[ 1.0688,  1.0354, -1.0889,  ..., -0.9309, -0.7496, -1.1346],
         [-1.1441,  0.3383,  1.6992,  ...,  0.9254,  1.4805,  0.3449]],

        [[ 0.4310, -0.2231,  0.2790,  ..., -0.3801, -0.2620, -0.5226],
         [-1.0800,  1.4510, -0.3488,  ...,  2.1158,  0.2643, -0.2391]]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>), tensor(4.6520, device='cuda:0', grad_fn=<NllLossBackward0>))


n
3

In [8]:
lr = 1.
block_size = 2
batch_size = 16
n_iter = 20000
model = BigramModel(vocab_size).to(device)

optimizer = torch.optim.Adam(params=model.parameters(), lr=1.)

for i in range(n_iter):
  xb, yb = get_batch(train_data, batch_size, 2)
  _, loss = model(xb, yb)

  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

  if i % 1000 == 0:
    loss_tr, loss_va = evaluate_model(model, batch_size, block_size)
    print(f"{i}/{n_iter} Train: {loss_tr:.4f} Validation: {loss_va:.4f}")

0/20000 Train: 4.3754 Validation: 4.4458
1000/20000 Train: 3.7335 Validation: 3.9622
2000/20000 Train: 3.7469 Validation: 3.9400
3000/20000 Train: 3.6187 Validation: 3.6909
4000/20000 Train: 3.8301 Validation: 3.8622
5000/20000 Train: 3.8224 Validation: 3.9138
6000/20000 Train: 3.6752 Validation: 3.9924
7000/20000 Train: 3.8573 Validation: 3.9302
8000/20000 Train: 3.7383 Validation: 3.7476
9000/20000 Train: 3.7376 Validation: 3.9901
10000/20000 Train: 3.9125 Validation: 4.0334
11000/20000 Train: 3.9315 Validation: 3.9911
12000/20000 Train: 3.6196 Validation: 3.7821
13000/20000 Train: 3.7282 Validation: 3.9559
14000/20000 Train: 3.9681 Validation: 4.1536
15000/20000 Train: 3.7993 Validation: 4.0437
16000/20000 Train: 3.7760 Validation: 3.8002
17000/20000 Train: 3.8936 Validation: 3.8511
18000/20000 Train: 3.5515 Validation: 3.6870
19000/20000 Train: 3.6228 Validation: 3.8555


In [9]:
print("".join(decode(model.generate(torch.zeros((1, 1), device=device).long(), 1000).squeeze().tolist())))


Thin inorde
WA falat mat E:
Anothot thind me.
A ngur, hat ghe me machavy-g houe.

Fre nge be hore ckese we WA m'st RD:
He avy qur, Bethas?

WAR ming IN at laminorn GHe hoo R het R ROMartharng R ackest R R GHe bbele he qud pe, R t mue t he hy R R Rellllll hooring ham'stinot whar, w ste r, aknorinorvy st WAnode hivast me.
te
Thareas?
WA R hing astr, atessotrinorulllam'shookel sellambe Rit at fatinornesorirorud amby t d I R hone fullld cke st hape d R in
SThorstrt nostng nortartelamard R STue.
He t n Rinorin I:
acke g amastape
Anoke WAR trit tre het
Woonot RCl GHe ambule WAR faiche d E:
Thorncke R RClle asorin R trorin he.

Thal ID:
Yorin tiee RWARist teloowe at le d WAD qud lstelotullllst R t fad, g st t atisor amering merorin nothe d nor WAR anorsod n ne harickn
He forexchist What atot norstrinootckingr ne;
WAR fusort pe am'st amam'st avy, I g soouetueld ortie;
Whe Be amallatrngo g hackior nerinok R Mavy WARer Whothelld R it d t R Whasheleloo grustinosorie mir d havy:

Thort R Wame t t

## Transformer model

In [10]:
import math

class HeadAttention(nn.Module):

  def __init__(self, block_size, n_embd, head_size, dropout=0):
    super().__init__()
    self.n_embd = n_embd
    self.head_size = head_size
    self.K = nn.Linear(self.n_embd, self.head_size, bias=False)
    self.Q = nn.Linear(self.n_embd, self.head_size, bias=False)
    self.V = nn.Linear(self.n_embd, self.head_size, bias=False)
    self.dropout = nn.Dropout(dropout)
    self.register_buffer("tril", torch.tril(torch.ones((block_size, block_size))))

  def forward(self, x, y=None):
    B, T, C = x.shape
    # x[i, j, :] stores the char and positional info for token[i, j]
    # K, Q layers use that information to compute the query and key for each token
    k, q = self.K(x), self.Q(x) # (B, T, H), (B, T, H)
    v = self.V(x) # (B, T, v)
    x = q @ k.transpose(-2, -1) * (self.head_size**-0.5) # (B, T, T)
    # mask to disable flow of information from future tokens
    x = x.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
    x = F.softmax(x, dim=-1)
    x = self.dropout(x)
    x = x @ v # (B, T, v)
    return x

class MultiHeadAttention(nn.Module):

  def __init__(self, block_size, n_embd, n_heads, dropout=0):
    super().__init__()
    self.head_size = n_embd // n_heads
    self.n_heads = n_heads
    self.n_embd = n_embd
    self.heads = nn.ModuleList([
        HeadAttention(block_size, self.n_embd, self.head_size, dropout=dropout)
        for _ in range(n_heads)])
    self.proj = nn.Linear(self.n_embd, self.n_embd)

  def forward(self, x):
    x = torch.cat([h(x) for h in self.heads], dim=-1)
    x = self.proj(x)
    return x

class LayerNorm(nn.Module):

  def __init__(self, dim, eps=1e-8):
    super().__init__()
    self.eps = eps
    self.gamma = torch.ones(dim, device=device)
    self.beta = torch.zeros(dim, device=device)

  def forward(self, x):
    mu = torch.mean(x, dim=-1, keepdim=True) # (B, T, 1)
    # The paper used the non corrected variance
    sigma = torch.var(x, dim=-1, correction=0, keepdim=True) # (B, T, 1)
    # (B, T, C) (B)
    x = (x - mu) * (sigma + self.eps)**-0.5 * self.gamma + self.beta
    return x

class FeedForward(nn.Module):

  def __init__(self, input_dim, hidden_dim, output_dim, activation, dropout=0):
    super().__init__()
    self.input_dim = input_dim
    self.hidden_dim = hidden_dim
    self.output_dim = output_dim
    self.activation = activation
    self.net = nn.Sequential(*[
        nn.Linear(self.input_dim, self.hidden_dim),
        activation,
        nn.Linear(self.hidden_dim, self.output_dim),
        nn.Dropout(dropout)
    ])

  def forward(self, x):
    return self.net(x)

class Block(nn.Module):

  def __init__(self, block_size, n_embd, n_heads, dropout=0):
    super().__init__()
    head_size = n_embd // n_heads
    self.ma = MultiHeadAttention(block_size, n_embd, n_heads, dropout)
    self.fa = FeedForward(n_embd, n_embd, n_embd, nn.GELU(), dropout)
    self.ln1 = LayerNorm(n_embd)
    self.ln2 = LayerNorm(n_embd)

  def forward(self, x):
    # skip connections
    x = x + self.ma(self.ln1(x))
    x = x + self.fa(self.ln2(x))
    return x

class Transformer(nn.Module):

  def __init__(self, block_size, n_embd, n_blocks, n_heads, dropout=0):
    super().__init__()
    self.block_size = block_size
    self.C = nn.Embedding(vocab_size, n_embd)
    self.P = nn.Embedding(block_size, n_embd)
    self.blocks = nn.Sequential(*[
        Block(block_size, n_embd, n_heads, dropout)
        for _ in range(n_blocks)])
    self.ln = LayerNorm(n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)
    self.register_buffer("pos", torch.arange(0, self.block_size, 1))

  def forward(self, x, y=None):
    B, T = x.shape
    x = self.C(x) + self.P(self.pos[:T])
    x = self.blocks(x)
    x = self.ln(x)
    logits = self.lm_head(x)
    if y is None:
      loss = None
    else:
      B, T, C = logits.shape
      loss = F.cross_entropy(logits.view(B*T, C), y.view(B*T))
    return logits, loss

  @torch.no_grad
  def generate(self, idx, max_new_tokens=100):
    self.eval()
    for _ in range(max_new_tokens):
      logits, _ = self(idx[:, -self.block_size:])  #(B, n_vocab)
      probs = F.softmax(logits[:, -1, :], dim=-1)
      next_idx = torch.multinomial(probs, num_samples=1)
      idx = torch.cat([idx, next_idx], dim=1)
    self.train()
    return idx

x = torch.randint(0, 65, size=(4, 8), device=device)
l = Transformer(8, 16, 2, 2).to(device)
l(x)[0].shape

out = l.generate(torch.zeros((1, 1), device=device, dtype=torch.long), 100)
for o in out:
  print("".join(decode(o.tolist())))
  print("-------------------------")


FQiXg ;GYK;QhCXgZzmSj:jJnSEVi,Pd,AaxSsKB,mG&'h;:jccYr?nw,e!I,CY-3cZ UJ'l?FFdmJrakVlWzNMzESAf;PEECwVF
-------------------------


In [11]:
n_iter = 5000
batch_size = 32
block_size = 32
embd = 64
n_blocks = 4
n_heads = 4
dropout = 0.
eval_interval = 200
eval_batches = 200

print(device)
model = Transformer(block_size, embd, n_blocks, n_heads, dropout).to(device)
print(f"Parameters {sum(p.numel() for p in model.parameters())}")
model.to(device)

optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)

for i in range(n_iter):
  xb, yb = get_batch(train_data, batch_size, block_size)
  yhat, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

  if i % eval_interval == 0:
    loss_tr, loss_va = evaluate_model(model, batch_size, block_size, eval_batches)
    print(f"Step {i}/{n_iter} Train: {loss_tr:.4f} Validation: {loss_va:.4f}")

cuda
Parameters 109505
Step 0/5000 Train: 4.2083 Validation: 4.2181
Step 200/5000 Train: 2.4857 Validation: 2.4891
Step 400/5000 Train: 2.3223 Validation: 2.3298
Step 600/5000 Train: 2.2150 Validation: 2.2328
Step 800/5000 Train: 2.1209 Validation: 2.1566
Step 1000/5000 Train: 2.0502 Validation: 2.0925
Step 1200/5000 Train: 2.0008 Validation: 2.0581
Step 1400/5000 Train: 1.9486 Validation: 2.0264
Step 1600/5000 Train: 1.9158 Validation: 2.0032
Step 1800/5000 Train: 1.8762 Validation: 1.9796
Step 2000/5000 Train: 1.8500 Validation: 1.9584
Step 2200/5000 Train: 1.8223 Validation: 1.9590
Step 2400/5000 Train: 1.8030 Validation: 1.9336
Step 2600/5000 Train: 1.7747 Validation: 1.9218
Step 2800/5000 Train: 1.7675 Validation: 1.9028
Step 3000/5000 Train: 1.7581 Validation: 1.8790
Step 3200/5000 Train: 1.7283 Validation: 1.8719
Step 3400/5000 Train: 1.7232 Validation: 1.8781
Step 3600/5000 Train: 1.7156 Validation: 1.8539
Step 3800/5000 Train: 1.7008 Validation: 1.8453
Step 4000/5000 Train: 1.

In [12]:
v = torch.zeros((1, 1), device=device).long()
out = model.generate(v, 1000)
print("".join(decode(out[0, :].tolist())))
print("-------------------------")


Mird, o' be alie; pointise and stray unler the that hares of these rian, he hence's lake, caDe,--
I'll be joy!

POLINGBRET:
Morest you she whom of this
For may some thou haven shay the has was great!
As are three thus jourmer; there thou go, besing toble.
Let deliff my norts, he Towgs? weak one bound acconstandines years and not for itell come;
Nid then matter was he dow would mother is his madand mosiness of him.

Securs Murdencehed actit rest.
Therefore Send well, liboodver her unno say
Untitineven
WoUt another persite on as this sun,
Dett to thumb. Cawilly thou younding of thee
one you but can the're's nove sake Strorse,
Sicicilt detule meets to the changes reds steechar?
'Thy is those no hole the are bafore ansford?
He daid noble frear dies lo;
And I am atthe stake it out all gentlent commindedness,
You queeth'd mend het's lead: my aragely bother.

YORK:
BENVINt:
I'rry, and Lords this desere slight.

MERCUTIO:
No, nustievos whithe shall up, and abins formoth!
What lowers them oppe