In [1]:
# Compared to previous version
  # multiple head attention (vs single attention head in previous version)
  # positional encoding (wasn't implemented in previous version)
  # multiple transformer blocks (vs single block)
  # larger embedding dimention (512 vs 256)
  # nonlinearity in transformer's feedforward part + 2 linear layers instead of 1
  # pre-LayerNorm which is more common in LLMs (vs post LayerNorm in previous version)

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!wget https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt

--2025-08-24 22:59:39--  https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt
Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5436475 (5.2M) [text/plain]
Saving to: ‘shakespeare.txt’


2025-08-24 22:59:42 (367 MB/s) - ‘shakespeare.txt’ saved [5436475/5436475]



In [4]:
!wc -lwc shakespeare.txt

 124185  899588 5436475 shakespeare.txt


In [5]:
with open("shakespeare.txt") as f:
  data = f.read()

print(len(data))

data[:1000]

5436475


"  From fairest creatures we desire increase,\n  That thereby beauty's rose might never die,\n  But as the riper should by time decease,\n  His tender heir might bear his memory:\n  But thou contracted to thine own bright eyes,\n  Feed'st thy light's flame with self-substantial fuel,\n  Making a famine where abundance lies,\n  Thy self thy foe, to thy sweet self too cruel:\n  Thou that art now the world's fresh ornament,\n  And only herald to the gaudy spring,\n  Within thine own bud buriest thy content,\n  And tender churl mak'st waste in niggarding:\n    Pity the world, or else this glutton be,\n    To eat the world's due, by the grave and thee.\n\n\n                     2\n  When forty winters shall besiege thy brow,\n  And dig deep trenches in thy beauty's field,\n  Thy youth's proud livery so gazed on now,\n  Will be a tattered weed of small worth held:\n  Then being asked, where all thy beauty lies,\n  Where all the treasure of thy lusty days;\n  To say within thine own deep sunk

In [6]:
chars = list(set(data))
print(len(chars))

ctoi = {ch:i for i, ch in enumerate(chars)}
itoc = {i:ch for ch, i in ctoi.items()}

assert len(ctoi) == len(itoc)


def encoder(text):
  return [ctoi[ch] for ch in text]

def decoder(tokens):
  return "".join([itoc[token] for token in tokens])

84


In [7]:
test_str = "hello world!"
assert test_str == decoder(encoder(test_str))

In [8]:
import torch
import torch.nn.functional as F

train_size = 0.9

num_epochs = 5
batch_size = 1024 + 512
emb_dim = 512
num_heads = 16
num_blocks=4
lr = 2e-3
context_window_size = 128

torch.manual_seed(2025)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data, test_data = encoder(data[:int(train_size*len(data))]), encoder(data[int(train_size*len(data)):])
train_data_t, test_data_t = torch.tensor(train_data, dtype=torch.long).to(device), torch.tensor(test_data, dtype=torch.long).to(device)

len(train_data), len(test_data)

(4892827, 543648)

In [9]:
type(train_data), type(train_data_t), type(test_data), type(test_data_t)

(list, torch.Tensor, list, torch.Tensor)

In [10]:
num_batches = int(len(train_data)/batch_size*0.9)  # 0.9 multiplier is an scrappy way of making sure get_batch doesn't go out of bounds
num_test_batches = int(len(test_data)/batch_size*0.9)

print(f"{num_batches=}, {num_test_batches=}")

num_batches=2866, num_test_batches=318


In [11]:
# types are explicitly not included

# def scaled_dot_product_attention(Q, K, V, mask=None):
#   # QKV are of shape [batch_size, seq_len, d_k]
#   # mask is a tensor of shape [seq_len, seq_len]
#   scores = (Q @ K.transpose(-2, -1)) / Q.shape[-1]**0.5
#   # K.transpose(-2, -1) shape = [batch_size, d_k, seq_len]
#   # Q @ K^T shape = [batch_size, seq_len, seq_len]
#   if mask is not None:
#     scores = scores.masked_fill(mask==0, float("-inf"))
#   return F.softmax(scores, dim=-1) @ V  # shape = [batch_size, seq_len, d_k]


def causal_mask(seq_len):
  # create a lower trianguar mask
  return torch.tril(torch.ones(seq_len, seq_len)).bool()  # shape = [seq_len, seq_len]


# class SelfAttention(torch.nn.Module):
#   def __init__(self, emb_dim=emb_dim):
#     super().__init__()
#     self.emb_dim = emb_dim
#     self.q_proj = torch.nn.Linear(emb_dim, emb_dim)
#     self.k_proj = torch.nn.Linear(emb_dim, emb_dim)
#     self.v_proj = torch.nn.Linear(emb_dim, emb_dim)

#     self.output = torch.nn.Linear(emb_dim, emb_dim)

#   def forward(self, X):
#     # X shape = [batch_size, seq_len, emb_dim]
#     Q = self.q_proj(X)
#     K = self.k_proj(X)
#     V = self.v_proj(X)

#     seq_len = Q.shape[1]
#     mask = causal_mask(seq_len).unsqueeze(0).to(device)

#     return self.output(scaled_dot_product_attention(Q, K, V, mask=mask))

class MultiHeadAttention(torch.nn.Module):
  def __init__(self, emb_dim, num_heads):
    super().__init__()
    assert emb_dim % num_heads == 0  # emb_dim is divisible by num_heads
    self.q_proj = torch.nn.Linear(emb_dim, emb_dim)
    self.k_proj = torch.nn.Linear(emb_dim, emb_dim)
    self.v_proj = torch.nn.Linear(emb_dim, emb_dim)

    self.output = torch.nn.Linear(emb_dim, emb_dim)

    self.emb_dim = emb_dim
    self.num_heads = num_heads
    self.head_dim = emb_dim // num_heads

  def forward(self, X):
    batch_size, seq_len, _ = X.shape

    Q = self.q_proj(X)
    K = self.k_proj(X)
    V = self.v_proj(X)

    # split into heads and change shae [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
    Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

    # scaled dot product attention
    scores = Q @ K.transpose(-2, -1) / (self.head_dim**0.5)
    mask = causal_mask(seq_len).unsqueeze(0).unsqueeze(0).to(device)  # shape = [1, 1, seq_len, seq_len]
    scores = scores.masked_fill(mask==0, float("-inf"))
    attn = F.softmax(scores, dim=-1) @ V  # shape = [batch_size, num_heads, seq_len, head_dim]

    # concat heads
    attn = attn.transpose(1, 2).reshape(batch_size, seq_len, self.emb_dim)
    return self.output(attn)


class TransformerBlock(torch.nn.Module):
  def __init__(self, emb_dim=emb_dim):
    super().__init__()
    self.attention = MultiHeadAttention(emb_dim=emb_dim, num_heads=num_heads)
    self.ffn = torch.nn.Sequential(
        torch.nn.Linear(emb_dim, emb_dim*4),
        torch.nn.ReLU(),
        torch.nn.Linear(emb_dim*4, emb_dim)
    )
    self.ln1 = torch.nn.LayerNorm(emb_dim)
    self.ln2 = torch.nn.LayerNorm(emb_dim)

  def forward(self, X):
    res = X
    X = self.ln1(X)
    X = res + self.attention(X)

    res = X
    X = self.ln2(X)  # pre-LN
    X = res + self.ffn(X)

    return X


class Model(torch.nn.Module):
  def __init__(self, vocab_size=len(chars), emb_dim=emb_dim, seq_len=context_window_size, num_blocks=num_blocks):
    super().__init__()
    self.embedding_layer = torch.nn.Embedding(vocab_size, emb_dim)
    self.positional_emb = torch.nn.Embedding(seq_len, emb_dim)
    self.transformer = torch.nn.ModuleList([TransformerBlock(emb_dim=emb_dim) for _ in range(num_blocks)])  # note: using nn.Sequential([TB]) will reference the same instance of TB -> all of them will share the same weights which is not what we want
    self.linear = torch.nn.Linear(emb_dim, vocab_size)
    self.softmax = torch.nn.Softmax(-1)

    self.vocab_size = vocab_size
    self.seq_len = seq_len
    self.loss = torch.nn.CrossEntropyLoss()

  def forward(self, X, y=None):
    # X shape = [batch_size, seq_len]
    emb = self.embedding_layer(X)  # shape = [batch_size, seq_len, emb_dim]
    positions = torch.arange(emb.shape[1]).unsqueeze(0).to(device)  # shape = [batch_size, seq_len]
    pos_emb = self.positional_emb(positions)  # shape = [batch_size, seq_len, emb_dim]
    x = emb + pos_emb
    for block in self.transformer:
      x = block(x)
    logits = self.linear(x)
    if y is None:
      return logits, None
    else:
      # during training
      return logits, self.loss(logits.view(-1, self.vocab_size), y.view(-1))  # bug fix from previous version. CrossEntropy requires logits not probabilities

  def generate(self, prompt=""):
    # assuming len(prompt) < seq_len
    # TODO: truncate if not
    X = torch.tensor(encoder(prompt), dtype=torch.long).to(device)
    X = X.unsqueeze(0) # to convert shape to [batch_size, seq_len]

    while X.shape[1] < self.seq_len:
      logits, _ = self(X)
      probs = self.softmax(logits)
      next_token = torch.multinomial(probs[0][-1], 1)
      X = torch.cat((X,next_token.unsqueeze(0)), dim=1)
      # TODO: stop on end token...
    return decoder(X.squeeze(0).tolist())


model = Model().to(device)

In [12]:
# num of trainable parameters in the model:
sum(p.numel() for p in model.parameters())

12761172

In [13]:
model.generate("Hello w")

'Hello w  [IG>vCpmP`JG93t:wbU[|s`1Hg"kK93fnGm_)\n35k5ca6emprNmz8GzR`Vh-\nZYD([]PrtjKbo 4"JAWuJ0_f\nhROsz}zX?([vRzyqgur\nEr-9s71u8hFgI'

In [14]:
train_X, train_y = train_data_t[:-1], train_data_t[1:]
test_X, test_y = test_data_t[:-1], test_data_t[1:]

def get_batch(X, y, idx, batch_size=16):
  batch_x, batch_y = [], []
  for i in range(batch_size):
    batch_x.append(X[idx*batch_size+i  :idx*batch_size+i+context_window_size])
    batch_y.append(y[idx*batch_size+i+1:idx*batch_size+i+context_window_size+1])

  return torch.stack(batch_x), torch.stack(batch_y)

In [15]:
from tqdm import tqdm

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

def eval_model():
  model.eval()
  with torch.no_grad():
    test_loss = 0
    for i in tqdm(range(num_test_batches)):
      X, y = get_batch(test_X, test_y, i, batch_size)
      _, loss = model(X, y)
      test_loss += loss.item()
    test_loss /= num_test_batches
  return test_loss

test_loss = eval_model()
print(f"random weight {test_loss=}")

for epoch in range(num_epochs):
  model.train()
  epoch_loss = 0
  for i in tqdm(range(num_batches)):
    X, y = get_batch(train_X, train_y, i, batch_size)
    _, loss = model(X, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
  epoch_loss /= num_batches

  test_loss = eval_model()

  print(f"{epoch=}, {epoch_loss=}; {test_loss=}")

In [17]:
model.generate("Hello w")

'Hello wsad efeyu rwotyhuosl-afiigtdr. hrwade,tig\n   Ptlin ecue. etr yu hvl,bidhl o aais oraia. o ok esoee!Isnnto. i\n   fo eedmsl'

In [18]:
model.generate("From fairest creatures")

'From fairest creatures\n    o temtn hweee. u adanl,brnmdy\n  .Konewmnt eeso o loes frmii. hrnlknsmne.HRTIGo? h eoee\n   KTcibefl,h '

In [19]:
model.generate("That thereby")

"That therebymnpesnl telswymn o ohhs\n   I nvnjrn o yuuhl adiu uodsmk o oraios\n   Adiue hrwrt al lmdane ada'.Iupi e'nnwksml.      "

In [20]:
model.generate("tattered weed")

'tattered weed, o lvntytin: Eees\n   T o oe,Syesr ltikshpateK;btotes\n SEVUT,aefttnsos uae o Taaso hnibl saos o apto, o lvrupe t ss'

In [21]:
import time

torch.save(
    {
      "model_state_dict": model.state_dict(),
      "optimizer_state_dict": optimizer.state_dict(),
      "epoch": epoch,
      "train_loss": epoch_loss,
      "test_loss": test_loss
     },
    f"/content/drive/MyDrive/transformer_model_checkpoints/{int(time.time())}_transformer_model.pth")

In [22]:
from google.colab import runtime

# Disconnects and deletes the current runtime
runtime.unassign()