In [1]:
# Compared to previous version
#  * word level tokenizer (vs character level tokenizer)
#  * AdamW optimizer where weight decay is applied as a separate step (vs Adam in previous version where L2 regularization is just L2 penalty coupled with the gradient)
#  * Implementation of AdamW from scratch (vs importing it from torch.optim) from https://arxiv.org/pdf/1711.05101

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!wget https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt

--2025-08-25 23:33:45--  https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt
Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5436475 (5.2M) [text/plain]
Saving to: ‘shakespeare.txt.3’


2025-08-25 23:33:45 (101 MB/s) - ‘shakespeare.txt.3’ saved [5436475/5436475]



In [4]:
!wc -lwc shakespeare.txt

 124185  899588 5436475 shakespeare.txt


In [5]:
with open("shakespeare.txt") as f:
  data = f.read()

print(len(data))

data[:1000]

5436475


"  From fairest creatures we desire increase,\n  That thereby beauty's rose might never die,\n  But as the riper should by time decease,\n  His tender heir might bear his memory:\n  But thou contracted to thine own bright eyes,\n  Feed'st thy light's flame with self-substantial fuel,\n  Making a famine where abundance lies,\n  Thy self thy foe, to thy sweet self too cruel:\n  Thou that art now the world's fresh ornament,\n  And only herald to the gaudy spring,\n  Within thine own bud buriest thy content,\n  And tender churl mak'st waste in niggarding:\n    Pity the world, or else this glutton be,\n    To eat the world's due, by the grave and thee.\n\n\n                     2\n  When forty winters shall besiege thy brow,\n  And dig deep trenches in thy beauty's field,\n  Thy youth's proud livery so gazed on now,\n  Will be a tattered weed of small worth held:\n  Then being asked, where all thy beauty lies,\n  Where all the treasure of thy lusty days;\n  To say within thine own deep sunk

In [6]:
len(list(set(data.split(" "))))

85754

In [7]:
import re
data = re.sub(r"[^\w\s\n]", "", data)  # remove punctuation to reduce cardinality of the corpus
print(len(list(set(data.split(" ")))))
data = data.replace("\n", " \n ")  # handle end of line where \n is attach to the word prior to it
data = re.sub(r"\[ \t]+", " ", data)  # remove repetitve whitespaces (excluding \n)
print(len(list(set(data.split(" ")))))
data = data.lower()  # converting to lowercase to further reduce cardinality

words = list(set(data.split(" ")))
print(len(words))

wtoi = {ch:i for i, ch in enumerate(words)}
itoc = {i:ch for ch, i in wtoi.items()}

assert len(wtoi) == len(itoc)

def encoder(text):
  return [wtoi[ch] for ch in text.split(" ")]

def decoder(tokens):
  return " ".join([itoc[token] for token in tokens])

48004
34093
28166


In [8]:
test_str = "operation zeals"
assert test_str == decoder(encoder(test_str))

In [9]:
data[:1000]

'  from fairest creatures we desire increase \n   that thereby beautys rose might never die \n   but as the riper should by time decease \n   his tender heir might bear his memory \n   but thou contracted to thine own bright eyes \n   feedst thy lights flame with selfsubstantial fuel \n   making a famine where abundance lies \n   thy self thy foe to thy sweet self too cruel \n   thou that art now the worlds fresh ornament \n   and only herald to the gaudy spring \n   within thine own bud buriest thy content \n   and tender churl makst waste in niggarding \n     pity the world or else this glutton be \n     to eat the worlds due by the grave and thee \n  \n  \n                      2 \n   when forty winters shall besiege thy brow \n   and dig deep trenches in thy beautys field \n   thy youths proud livery so gazed on now \n   will be a tattered weed of small worth held \n   then being asked where all thy beauty lies \n   where all the treasure of thy lusty days \n   to say within thine 

In [10]:
import torch
import torch.nn.functional as F

train_size = 0.9  # % of dataset to be used for training, the remaining (1-x) will be used for validation

num_epochs = 5
batch_size = 128
emb_dim = 256
num_heads = 8
num_blocks=1
lr = 2e-3
context_window_size = 128

torch.manual_seed(2025)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data, test_data = encoder(data[:int(train_size*len(data))]), encoder(data[int(train_size*len(data)):])
train_data_t, test_data_t = torch.tensor(train_data, dtype=torch.long).to(device), torch.tensor(test_data, dtype=torch.long).to(device)

len(train_data), len(test_data)

(1378219, 151729)

In [11]:
type(train_data), type(train_data_t), type(test_data), type(test_data_t)

(list, torch.Tensor, list, torch.Tensor)

In [12]:
num_batches = int(len(train_data)/batch_size*0.9)  # 0.9 multiplier is an scrappy way of making sure get_batch doesn't go out of bounds
num_test_batches = int(len(test_data)/batch_size*0.9)

print(f"{num_batches=}, {num_test_batches=}")

num_batches=9690, num_test_batches=1066


In [13]:
def causal_mask(seq_len):
  # create a lower trianguar mask
  return torch.tril(torch.ones(seq_len, seq_len)).bool()  # shape = [seq_len, seq_len]


class MultiHeadAttention(torch.nn.Module):
  def __init__(self, emb_dim, num_heads):
    super().__init__()
    assert emb_dim % num_heads == 0  # emb_dim is divisible by num_heads
    self.q_proj = torch.nn.Linear(emb_dim, emb_dim)
    self.k_proj = torch.nn.Linear(emb_dim, emb_dim)
    self.v_proj = torch.nn.Linear(emb_dim, emb_dim)

    self.output = torch.nn.Linear(emb_dim, emb_dim)

    self.emb_dim = emb_dim
    self.num_heads = num_heads
    self.head_dim = emb_dim // num_heads

  def forward(self, X):
    batch_size, seq_len, _ = X.shape

    Q = self.q_proj(X)
    K = self.k_proj(X)
    V = self.v_proj(X)

    # split into heads and change shae [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
    Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

    # scaled dot product attention
    scores = Q @ K.transpose(-2, -1) / (self.head_dim**0.5)
    mask = causal_mask(seq_len).unsqueeze(0).unsqueeze(0).to(device)  # shape = [1, 1, seq_len, seq_len]
    scores = scores.masked_fill(mask==0, float("-inf"))
    attn = F.softmax(scores, dim=-1) @ V  # shape = [batch_size, num_heads, seq_len, head_dim]

    # concat heads
    attn = attn.transpose(1, 2).reshape(batch_size, seq_len, self.emb_dim)
    return self.output(attn)


class TransformerBlock(torch.nn.Module):
  def __init__(self, emb_dim=emb_dim):
    super().__init__()
    self.attention = MultiHeadAttention(emb_dim=emb_dim, num_heads=num_heads)
    self.ffn = torch.nn.Sequential(
        torch.nn.Linear(emb_dim, emb_dim*4),
        torch.nn.ReLU(),
        torch.nn.Linear(emb_dim*4, emb_dim)
    )
    self.ln1 = torch.nn.LayerNorm(emb_dim)
    self.ln2 = torch.nn.LayerNorm(emb_dim)

  def forward(self, X):
    res = X
    X = self.ln1(X)
    X = res + self.attention(X)

    res = X
    X = self.ln2(X)  # pre-LN
    X = res + self.ffn(X)

    return X


class Model(torch.nn.Module):
  def __init__(self, vocab_size=len(words), emb_dim=emb_dim, seq_len=context_window_size, num_blocks=num_blocks):
    super().__init__()
    self.embedding_layer = torch.nn.Embedding(vocab_size, emb_dim)
    self.positional_emb = torch.nn.Embedding(seq_len, emb_dim)
    self.transformer = torch.nn.ModuleList([TransformerBlock(emb_dim=emb_dim) for _ in range(num_blocks)])  # note: using nn.Sequential([TB]) will reference the same instance of TB -> all of them will share the same weights which is not what we want
    self.linear = torch.nn.Linear(emb_dim, vocab_size)
    self.softmax = torch.nn.Softmax(-1)

    self.vocab_size = vocab_size
    self.seq_len = seq_len
    self.loss = torch.nn.CrossEntropyLoss()

  def forward(self, X, y=None):
    # X shape = [batch_size, seq_len]
    emb = self.embedding_layer(X)  # shape = [batch_size, seq_len, emb_dim]
    positions = torch.arange(emb.shape[1]).unsqueeze(0).to(device)  # shape = [batch_size, seq_len]
    pos_emb = self.positional_emb(positions)  # shape = [batch_size, seq_len, emb_dim]
    x = emb + pos_emb
    for block in self.transformer:
      x = block(x)
    logits = self.linear(x)
    if y is None:
      return logits, None
    else:
      # during training
      return logits, self.loss(logits.view(-1, self.vocab_size), y.view(-1))  # bug fix from previous version. CrossEntropy requires logits not probabilities

  def generate(self, prompt=""):
    # assuming len(prompt) < seq_len
    # TODO: truncate if not
    X = torch.tensor(encoder(prompt), dtype=torch.long).to(device)
    X = X.unsqueeze(0) # to convert shape to [batch_size, seq_len]

    while X.shape[1] < self.seq_len:
      logits, _ = self(X)
      probs = self.softmax(logits)
      next_token = torch.multinomial(probs[0][-1], 1)
      X = torch.cat((X,next_token.unsqueeze(0)), dim=1)
      # TODO: stop on end token...
    return decoder(X.squeeze(0).tolist())


model = Model().to(device)

In [14]:
# num of trainable parameters in the model:
sum(p.numel() for p in model.parameters())

15271686

In [15]:
model.generate("fairest creatures")

'fairest creatures feastwon feeder malefactions self lovedst ely oerflowing sexton penker unfool missed betook shavet consorted seldshown drenchd cars meat whining mudded motive reading shallowa exact bravest brags nobler evend alack device birdlime iachimo womanqueller pretending anticipate bawds impious cavilling syracuse diaper dolor besieged shepherdesses belonging pheebus rawly visible absolutely centuries usherd badges peasant sperato netherstocks amnipotent 106 throught indicted bibblebabble springhalt miscreant merchant eaning oppresseth minola goes flush frontier retaind confront bladder prioress parsons sharply ewers carnation tosspots tongueless visardlike perversely capricious smallest dagonet trick bestregarded impressest disbranch studys unmeriting unshaken ecstasies selling gardners latterborn choir unsuspected disgracing discipled battlefield interpret evileyd feat heartblood forbiddenly waterish practices messaline childrens heady depended pastry swoon portotartarossa 

In [16]:
train_X, train_y = train_data_t[:-1], train_data_t[1:]
test_X, test_y = test_data_t[:-1], test_data_t[1:]

def get_batch(X, y, idx, batch_size=16):
  batch_x, batch_y = [], []
  for i in range(batch_size):
    batch_x.append(X[idx*batch_size+i  :idx*batch_size+i+context_window_size])
    batch_y.append(y[idx*batch_size+i+1:idx*batch_size+i+context_window_size+1])

  return torch.stack(batch_x), torch.stack(batch_y)

In [17]:
from tqdm import tqdm

In [18]:
class AdamWOptimizer:
  def __init__(self, params, lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):
    self.params = list(params)
    self.lr = lr
    self.betas = betas
    self.eps = eps
    self.weight_decay = weight_decay
    self.t = 0

    # initialize moment estimates with zeros
    self.m = [torch.zeros_like(p) for p in self.params]
    self.v = [torch.zeros_like(p) for p in self.params]

  def step(self):
    # Algorithm 2 on page 3 of https://arxiv.org/pdf/1711.05101
    self.t += 1
    beta1, beta2 = self.betas

    for i, p in enumerate(self.params):
      if p.grad is None:
        continue

      g = p.grad.data

      self.m[i] = beta1 * self.m[i] + (1-beta1)*g
      self.v[i] = beta2 * self.v[i] + (1-beta2)*(g*g)

      # bias correction
      m_hat = self.m[i] / (1-beta1**self.t)
      v_hat = self.v[i] / (1-beta2**self.t)

      p.data = p.data - self.lr * (m_hat / (torch.sqrt(v_hat) + self.eps))

      if self.weight_decay != 0:
        p.data = p.data - self.lr * self.weight_decay * p.data

  def zero_grad(self):
    for p in self.params:
      if p.grad is not None:
        p.grad.zero_()

In [19]:
optimizer = AdamWOptimizer(model.parameters(), lr=lr, weight_decay=1e-2)

def eval_model():
  model.eval()
  with torch.no_grad():
    test_loss = 0
    for i in tqdm(range(num_test_batches)):
      X, y = get_batch(test_X, test_y, i, batch_size)
      _, loss = model(X, y)
      test_loss += loss.item()
    test_loss /= num_test_batches
  return test_loss

test_loss = eval_model()
print(f"random weight {test_loss=}")

for epoch in range(num_epochs):
  model.train()
  epoch_loss = 0
  for i in tqdm(range(num_batches)):
    X, y = get_batch(train_X, train_y, i, batch_size)
    _, loss = model(X, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
  epoch_loss /= num_batches

  test_loss = eval_model()

  print(f"{epoch=}, {epoch_loss=}; {test_loss=}")

100%|██████████| 1066/1066 [01:54<00:00,  9.27it/s]


random weight test_loss=10.49277364737992


100%|██████████| 9690/9690 [47:42<00:00,  3.38it/s]
100%|██████████| 1066/1066 [01:53<00:00,  9.35it/s]


epoch=0, epoch_loss=4.472849836622598; test_loss=4.831095782386429


100%|██████████| 9690/9690 [47:41<00:00,  3.39it/s]
100%|██████████| 1066/1066 [01:53<00:00,  9.37it/s]


epoch=1, epoch_loss=4.020324281088708; test_loss=4.902432272179265


100%|██████████| 9690/9690 [47:41<00:00,  3.39it/s]
100%|██████████| 1066/1066 [01:54<00:00,  9.35it/s]


epoch=2, epoch_loss=3.840051518314517; test_loss=4.923621431636095


100%|██████████| 9690/9690 [47:41<00:00,  3.39it/s]
100%|██████████| 1066/1066 [01:53<00:00,  9.37it/s]


epoch=3, epoch_loss=3.7174443885453345; test_loss=4.9355894362389705


100%|██████████| 9690/9690 [47:40<00:00,  3.39it/s]
100%|██████████| 1066/1066 [01:53<00:00,  9.37it/s]

epoch=4, epoch_loss=3.6275050752118645; test_loss=4.947761482693688





In [20]:
model.generate("fairest creatures")

'fairest creatures about neck \n  children  but very heavy i her sure is \n  katherina she might mean swear wilt keep well he her \n   yea provided it if were so you pray far the \n  petruchio madam pray what i ask this claudio might alone your if be it              exit with and \n  hortensio let head against friends o i adventure therefore my usage            exit forward \n  baptista ancient mother ill but me aught she tonight        \n    how her wife     '

In [21]:
model.generate("that thereby")

'that thereby i to this \n third powr as as to as as as as import when fools  with \n  katherina never but if know i it did her would scolding \n   younger she laid in as cried sister wearing \n  baptista \n  baptista morrow dowry bright lute her do husbandry                     katherina seen and the \n  gremio \n \n          baptista wonder i kate             baptista kiss lips these sooth did her \n    unto state him her hand his'

In [22]:
model.generate("tattered weed")

'tattered weed  clean she bereft moon silence me all grows \n    out his head them i greater than native \n    being did and you kept fair and with \n    of head in flies prefer hither me have \n    because lovst oft so why do know all \n    i my killd why a eye our you sister         lucentio no be \n  petruchio shall my liege peremptory daughters thinks all well       thither purchase abuses troublesome   baptista \n  petruchio   it a instrument of and up hands flag        '

In [24]:
import time

torch.save(
    {
      "model_state_dict": model.state_dict(),
      # "optimizer_state_dict": optimizer.state_dict(),
      "epoch": epoch,
      "train_loss": epoch_loss,
      "test_loss": test_loss
     },
    f"/content/drive/MyDrive/transformer_model_checkpoints/{int(time.time())}_transformer_model.pth")

In [25]:
from google.colab import runtime

# Disconnects and deletes the current runtime
runtime.unassign()