In [1]:
# Compared to previous version
# * everything is implemented from scratch via only the following that are imported from torch.nn: Module, ModuleList, Parameter
# * Dropout is added after embeddings, and inside transformer after attention and ffn before residuals are added

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!wget https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt

--2025-08-26 01:45:10--  https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt
Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5436475 (5.2M) [text/plain]
Saving to: ‘shakespeare.txt’


2025-08-26 01:45:10 (87.8 MB/s) - ‘shakespeare.txt’ saved [5436475/5436475]



In [4]:
!wc -lwc shakespeare.txt

 124185  899588 5436475 shakespeare.txt


In [5]:
with open("shakespeare.txt") as f:
  data = f.read()

print(len(data))

data[:1000]

5436475


"  From fairest creatures we desire increase,\n  That thereby beauty's rose might never die,\n  But as the riper should by time decease,\n  His tender heir might bear his memory:\n  But thou contracted to thine own bright eyes,\n  Feed'st thy light's flame with self-substantial fuel,\n  Making a famine where abundance lies,\n  Thy self thy foe, to thy sweet self too cruel:\n  Thou that art now the world's fresh ornament,\n  And only herald to the gaudy spring,\n  Within thine own bud buriest thy content,\n  And tender churl mak'st waste in niggarding:\n    Pity the world, or else this glutton be,\n    To eat the world's due, by the grave and thee.\n\n\n                     2\n  When forty winters shall besiege thy brow,\n  And dig deep trenches in thy beauty's field,\n  Thy youth's proud livery so gazed on now,\n  Will be a tattered weed of small worth held:\n  Then being asked, where all thy beauty lies,\n  Where all the treasure of thy lusty days;\n  To say within thine own deep sunk

In [6]:
len(list(set(data.split(" "))))

85754

In [7]:
import re
data = re.sub(r"[^\w\s\n]", "", data)  # remove punctuation to reduce cardinality of the corpus
print(len(list(set(data.split(" ")))))
data = data.replace("\n", " \n ")  # handle end of line where \n is attach to the word prior to it
data = re.sub(r"\[ \t]+", " ", data)  # remove repetitve whitespaces (excluding \n)
print(len(list(set(data.split(" ")))))
data = data.lower()  # converting to lowercase to further reduce cardinality

words = list(set(data.split(" ")))
print(len(words))

wtoi = {w:i for i, w in enumerate(words)}
itoc = {i:w for w, i in wtoi.items()}

assert len(wtoi) == len(itoc)

def encoder(text):
  return [wtoi[w] for w in text.split(" ")]

def decoder(tokens):
  return " ".join([itoc[token] for token in tokens])

48004
34093
28166


In [8]:
test_str = "operation zeals"
assert test_str == decoder(encoder(test_str))

In [9]:
data[:1000]

'  from fairest creatures we desire increase \n   that thereby beautys rose might never die \n   but as the riper should by time decease \n   his tender heir might bear his memory \n   but thou contracted to thine own bright eyes \n   feedst thy lights flame with selfsubstantial fuel \n   making a famine where abundance lies \n   thy self thy foe to thy sweet self too cruel \n   thou that art now the worlds fresh ornament \n   and only herald to the gaudy spring \n   within thine own bud buriest thy content \n   and tender churl makst waste in niggarding \n     pity the world or else this glutton be \n     to eat the worlds due by the grave and thee \n  \n  \n                      2 \n   when forty winters shall besiege thy brow \n   and dig deep trenches in thy beautys field \n   thy youths proud livery so gazed on now \n   will be a tattered weed of small worth held \n   then being asked where all thy beauty lies \n   where all the treasure of thy lusty days \n   to say within thine 

In [10]:
import torch
from tqdm import tqdm

train_size = 0.9  # % of dataset to be used for training, the remaining (1-x) will be used for validation

num_epochs = 5
batch_size = 128
emb_dim = 256
num_heads = 8
num_blocks=1
lr = 2e-3
context_window_size = 128

torch.manual_seed(2025)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data, test_data = encoder(data[:int(train_size*len(data))]), encoder(data[int(train_size*len(data)):])
train_data_t, test_data_t = torch.tensor(train_data, dtype=torch.long).to(device), torch.tensor(test_data, dtype=torch.long).to(device)

len(train_data), len(test_data)

(1378219, 151729)

In [11]:
type(train_data), type(train_data_t), type(test_data), type(test_data_t)

(list, torch.Tensor, list, torch.Tensor)

In [12]:
num_batches = int(len(train_data)/batch_size*0.9)  # 0.9 multiplier is an scrappy way of making sure get_batch doesn't go out of bounds
num_test_batches = int(len(test_data)/batch_size*0.9)

print(f"{num_batches=}, {num_test_batches=}")

num_batches=9690, num_test_batches=1066


# Layers

In [13]:
class LinearLayer(torch.nn.Module):
  def __init__(self, in_features, out_features, bias=True):
    super().__init__()
    self.use_bias = bias
    self.in_features = in_features
    self.out_features = out_features

    # note: shape is [out_features, in_features] instead of [in_features, out_features] so each row of the weight matrix corresponds to one output neuron
    # hence, the weights for a single neuron are stored contiguously in memory.
    self.weight = torch.nn.Parameter(torch.randn(out_features, in_features)*0.01)  # shape = [out_features, in_features]

    if self.use_bias:
      self.bias = torch.nn.Parameter(torch.zeros(out_features))
    else:
      self.register_parameter('bias', None)

  def forward(self, X):
    # shape of X = [batch_size, in_features]
    y = X @ self.weight.t()  # shape = [batch_size, in_features] @ [in_features, out_features] = [batch_size, out_features]
    if self.use_bias:
      return y + self.bias  # shape = [batch_size, out_features]
    return y

x = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)

linear = LinearLayer(2, 1)

linear(x)

tensor([[-0.0065],
        [-0.0141]], grad_fn=<AddBackward0>)

In [14]:
class ReLU(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, X):
    return torch.maximum(X, torch.zeros_like(X))
    # alternatively coule be `return X.clamp(min=0)`

x = torch.tensor([[float("-inf"), -10], [1, float("inf")]])

relu = ReLU()

relu(x)

tensor([[0., 0.],
        [1., inf]])

In [15]:
class Embedding(torch.nn.Module):
  def __init__(self, vocab_size, emb_dim):
    super().__init__()
    self.vocab_size = vocab_size
    self.emb_dim = emb_dim

    self.weight = torch.nn.Parameter(torch.randn(vocab_size, emb_dim)*0.01)

  def forward(self, X):
    # shape of X = [batch_size, seq_len]
    return self.weight[X] # shape = [batch_size, seq_len, emb_dim]


x = torch.tensor([[0, 1], [0, 1]])

emb = Embedding(2, 5)

emb(x)

tensor([[[ 0.0120, -0.0018,  0.0100,  0.0079, -0.0050],
         [ 0.0200,  0.0007, -0.0104, -0.0008,  0.0129]],

        [[ 0.0120, -0.0018,  0.0100,  0.0079, -0.0050],
         [ 0.0200,  0.0007, -0.0104, -0.0008,  0.0129]]],
       grad_fn=<IndexBackward0>)

In [16]:
class LayerNorm(torch.nn.Module):
  def __init__(self, feature_dim, eps=1e-8):
    super().__init__()
    self.eps = eps
    self.gamma = torch.nn.Parameter(torch.ones(feature_dim)) # scale
    self.beta = torch.nn.Parameter(torch.zeros(feature_dim)) # shift

  def forward(self, X):
    # normalize along feature dimension
    # shape of X = [batch_size, seq_len, emb_dim]
    mean = X.mean(dim=2, keepdim=True)
    var = torch.var(X, dim=2, keepdim=True)
    X_norm = (X - mean) / torch.sqrt(var + self.eps)  # shape = same as input
    return self.gamma * X_norm + self.beta


x = torch.tensor([[[1,2,3], [-1, 0, 1]], [[2,3,4], [float("inf"), 0, 14]]])

ln = LayerNorm(3)

ln(x)

tensor([[[-1.,  0.,  1.],
         [-1.,  0.,  1.]],

        [[-1.,  0.,  1.],
         [nan, nan, nan]]], grad_fn=<AddBackward0>)

In [17]:
class Dropout(torch.nn.Module):
  # implementation of https://arxiv.org/abs/1207.0580
  def __init__(self, dropout_rate=0.1):
    super().__init__()
    self.dropout_rate = dropout_rate

  def forward(self, X):
    if self.training:
      mask = (torch.rand_like(X) > self.dropout_rate).float()
      return mask * X / (1-self.dropout_rate)
    else:
      return X

d = Dropout(0.2)
d.train()
x = torch.ones(5)
print(d(x))

d.eval()
print(d(x))

tensor([1.2500, 1.2500, 1.2500, 1.2500, 1.2500])
tensor([1., 1., 1., 1., 1.])


In [18]:
class Sequential(torch.nn.Module):
  def __init__(self, *layers):
    super().__init__()
    self.layers = torch.nn.ModuleList(layers) # if torch ModuleList is not used, layers won't register in model Parameters

  def forward(self, X):
    for layer in self.layers:
      X = layer(X)
    return X

s = Sequential(LinearLayer(3, 2), ReLU(), LinearLayer(2, 1))

print(list(s.parameters()))

[Parameter containing:
tensor([[-0.0032,  0.0034, -0.0070],
        [-0.0029, -0.0021, -0.0059]], requires_grad=True), Parameter containing:
tensor([0., 0.], requires_grad=True), Parameter containing:
tensor([[ 0.0056, -0.0096]], requires_grad=True), Parameter containing:
tensor([0.], requires_grad=True)]


In [19]:
class Softmax(torch.nn.Module):
  def __init__(self, dim=-1):
    super().__init__()
    self.dim = dim

  def forward(self, X):
    exp = torch.exp(X - X.max(dim=self.dim, keepdim=True).values)  # max is subtracted as a numerical trick before exponentiating to avoid large exponentials
    return exp / exp.sum(dim=self.dim, keepdim=True)


x = torch.tensor([[float("-inf"), -1, 0], [0, 1, 2], [0, 10, float("inf")]])
s = Softmax(-1)
s(x)

tensor([[0.0000, 0.2689, 0.7311],
        [0.0900, 0.2447, 0.6652],
        [   nan,    nan,    nan]])

In [20]:
def causal_mask(seq_len):
  # create a lower trianguar mask
  return torch.tril(torch.ones(seq_len, seq_len)).bool()  # shape = [seq_len, seq_len]

class MultiHeadAttention(torch.nn.Module):
  def __init__(self, emb_dim, num_heads):
    super().__init__()
    assert emb_dim % num_heads == 0  # emb_dim is divisible by num_heads
    self.q_proj = LinearLayer(emb_dim, emb_dim)
    self.k_proj = LinearLayer(emb_dim, emb_dim)
    self.v_proj = LinearLayer(emb_dim, emb_dim)

    self.output = LinearLayer(emb_dim, emb_dim)

    self.emb_dim = emb_dim
    self.num_heads = num_heads
    self.head_dim = emb_dim // num_heads
    self.softmax = Softmax(dim=-1)

  def forward(self, X):
    batch_size, seq_len, _ = X.shape

    Q = self.q_proj(X)
    K = self.k_proj(X)
    V = self.v_proj(X)

    # split into heads and change shae [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
    Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

    # scaled dot product attention
    scores = Q @ K.transpose(-2, -1) / (self.head_dim**0.5)
    mask = causal_mask(seq_len).unsqueeze(0).unsqueeze(0).to(device)  # shape = [1, 1, seq_len, seq_len]
    scores = scores.masked_fill(mask==0, float("-inf"))
    attn = self.softmax(scores) @ V  # shape = [batch_size, num_heads, seq_len, head_dim]

    # concat heads
    attn = attn.transpose(1, 2).reshape(batch_size, seq_len, self.emb_dim)
    return self.output(attn)

In [21]:
class TransformerBlock(torch.nn.Module):
  def __init__(self, emb_dim=emb_dim):
    super().__init__()
    self.attention = MultiHeadAttention(emb_dim=emb_dim, num_heads=num_heads)
    self.attention_dropout = Dropout(0.1)
    self.ffn = Sequential(
        LinearLayer(emb_dim, emb_dim*4),
        ReLU(),
        LinearLayer(emb_dim*4, emb_dim)
    )
    self.ffn_dropout = Dropout(0.1)
    self.ln1 = LayerNorm(emb_dim)
    self.ln2 = LayerNorm(emb_dim)

  def forward(self, X):
    res = X
    X = self.ln1(X)
    X = res + self.attention_dropout(self.attention(X))

    res = X
    X = self.ln2(X)  # pre-LN
    X = res + self.ffn_dropout(self.ffn(X))

    return X

# CrossEntropy Loss Function

In [22]:
class CrossEntropyLoss(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, logits, y):
    # logits shape = [batch_size, num_classes]
    # y shape = [batch_size] values in [0, num_classes-1]
    max_logits = logits.max(dim=1, keepdim=True).values
    logits_stable = logits - max_logits
    log_sum_exp = torch.log(torch.exp(logits_stable).sum(dim=1))
    ohe = logits_stable.gather(1, y.unsqueeze(1)).squeeze(1)
    return (log_sum_exp-ohe).mean()


logits = torch.tensor([[2, 1, 0.1], [0.5, 2.5, 0.3]])

y = torch.tensor([0, 1])

ce = CrossEntropyLoss()
ce(logits, y)

tensor(0.3185)

# Transformer model

In [26]:
class Model(torch.nn.Module):
  def __init__(self, vocab_size=len(words), emb_dim=emb_dim, seq_len=context_window_size, num_blocks=num_blocks):
    super().__init__()
    self.embedding_layer = Embedding(vocab_size, emb_dim)
    self.positional_emb = Embedding(seq_len, emb_dim)
    self.embedding_dropout = Dropout(0.1)
    self.transformer = torch.nn.ModuleList([TransformerBlock(emb_dim=emb_dim) for _ in range(num_blocks)])  # note: using nn.Sequential([TB]) will reference the same instance of TB -> all of them will share the same weights which is not what we want
    self.linear = LinearLayer(emb_dim, vocab_size)
    self.softmax = Softmax(-1)

    self.vocab_size = vocab_size
    self.seq_len = seq_len
    self.loss = CrossEntropyLoss()

  def forward(self, X, y=None):
    # X shape = [batch_size, seq_len]
    emb = self.embedding_layer(X)  # shape = [batch_size, seq_len, emb_dim]
    positions = torch.arange(emb.shape[1]).unsqueeze(0).to(device)  # shape = [batch_size, seq_len]
    pos_emb = self.positional_emb(positions)  # shape = [batch_size, seq_len, emb_dim]
    x = emb + pos_emb
    x = self.embedding_dropout(x)
    for block in self.transformer:
      x = block(x)
    logits = self.linear(x)
    if y is None:
      return logits, None
    else:
      # during training
      return logits, self.loss(logits.view(-1, self.vocab_size), y.view(-1))  # bug fix from previous version. CrossEntropy requires logits not probabilities

  def generate(self, prompt=""):
    # assuming len(prompt) < seq_len
    # TODO: truncate if not
    X = torch.tensor(encoder(prompt), dtype=torch.long).to(device)
    X = X.unsqueeze(0) # to convert shape to [batch_size, seq_len]

    while X.shape[1] < self.seq_len:
      logits, _ = self(X)
      probs = self.softmax(logits)
      next_token = torch.multinomial(probs[0][-1], 1)
      X = torch.cat((X,next_token.unsqueeze(0)), dim=1)
      # TODO: stop on end token...
    return decoder(X.squeeze(0).tolist())


model = Model().to(device)

In [27]:
# num of trainable parameters in the model:
sum(p.numel() for p in model.parameters())

15271686

In [28]:
model.generate("fairest creatures")

'fairest creatures crieth gilliams loathd hallowmas griffin job tag oui florentines lequel harmed splenitive contenta livd swarming pickpurses memento ballad polecats powers noblemen lovingly pride spain cicatrice taunted siennas distempering paledead giglets christening shelvy notand heifer ills sap penitents bewitchment ragozine distasted turbulence repugnant lesser rites ladybird mutability commune astonished eyeoffending halcyon lass recomforture direness becomet drank overroasted swath distills deer mounting doit prevails kindreds dovedrawn halloa lust underbearing tutto afire moons cacaliban 74 referrd headpiece derision vial villages shorten exits unsaluted cucullus 50 cheered despiteful forrest bestows enskied fortunes students ifaith cell merlin loathsome cates strengthen minimo numbring residence housekeeping malcontents unblown often grandfathers salletherbs badge soles excite scarcecold eldergun ministred earthwhy ending weakhingd clay immediate tinsel survive rippd blindfo

# AdamW Optimizer

In [29]:
class AdamWOptimizer:
  def __init__(self, params, lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):
    self.params = list(params)
    self.lr = lr
    self.betas = betas
    self.eps = eps
    self.weight_decay = weight_decay
    self.t = 0

    # initialize moment estimates with zeros
    self.m = [torch.zeros_like(p) for p in self.params]
    self.v = [torch.zeros_like(p) for p in self.params]

  def step(self):
    # Algorithm 2 on page 3 of https://arxiv.org/pdf/1711.05101
    self.t += 1
    beta1, beta2 = self.betas

    for i, p in enumerate(self.params):
      if p.grad is None:
        continue

      g = p.grad.data

      self.m[i] = beta1 * self.m[i] + (1-beta1)*g
      self.v[i] = beta2 * self.v[i] + (1-beta2)*(g*g)

      # bias correction
      m_hat = self.m[i] / (1-beta1**self.t)
      v_hat = self.v[i] / (1-beta2**self.t)

      p.data = p.data - self.lr * (m_hat / (torch.sqrt(v_hat) + self.eps))

      if self.weight_decay != 0:
        p.data = p.data - self.lr * self.weight_decay * p.data

  def zero_grad(self):
    for p in self.params:
      if p.grad is not None:
        p.grad.zero_()


optimizer = AdamWOptimizer(model.parameters(), lr=lr, weight_decay=1e-2)

# Training and evaluation loop

In [30]:
train_X, train_y = train_data_t[:-1], train_data_t[1:]
test_X, test_y = test_data_t[:-1], test_data_t[1:]

def get_batch(X, y, idx, batch_size=16):
  batch_x, batch_y = [], []
  for i in range(batch_size):
    batch_x.append(X[idx*batch_size+i  :idx*batch_size+i+context_window_size])
    batch_y.append(y[idx*batch_size+i+1:idx*batch_size+i+context_window_size+1])

  return torch.stack(batch_x), torch.stack(batch_y)

In [31]:
def eval_model():
  model.eval()
  with torch.no_grad():
    test_loss = 0
    for i in tqdm(range(num_test_batches)):
      X, y = get_batch(test_X, test_y, i, batch_size)
      _, loss = model(X, y)
      test_loss += loss.item()
    test_loss /= num_test_batches
  return test_loss

test_loss = eval_model()
print(f"random weight {test_loss=}")

for epoch in range(num_epochs):
  model.train()
  epoch_loss = 0
  for i in tqdm(range(num_batches)):
    X, y = get_batch(train_X, train_y, i, batch_size)
    _, loss = model(X, y)

    optimizer.zero_grad()  # reset grads
    loss.backward()  # compute new grads based on loss
    optimizer.step()  # update model weights based on grads in step above
    epoch_loss += loss.item()
  epoch_loss /= num_batches

  test_loss = eval_model()

  print(f"{epoch=}, {epoch_loss=}; {test_loss=}")

100%|██████████| 1066/1066 [02:19<00:00,  7.63it/s]


random weight test_loss=10.247481833703075


100%|██████████| 9690/9690 [1:05:06<00:00,  2.48it/s]
100%|██████████| 1066/1066 [02:20<00:00,  7.59it/s]


epoch=0, epoch_loss=4.377197290217298; test_loss=4.677172826781282


100%|██████████| 9690/9690 [1:05:12<00:00,  2.48it/s]
100%|██████████| 1066/1066 [02:20<00:00,  7.59it/s]


epoch=1, epoch_loss=4.008633051746524; test_loss=4.844488580947075


100%|██████████| 9690/9690 [1:05:16<00:00,  2.47it/s]
100%|██████████| 1066/1066 [02:20<00:00,  7.58it/s]


epoch=2, epoch_loss=3.802998978545422; test_loss=4.9632600542528325


100%|██████████| 9690/9690 [1:05:17<00:00,  2.47it/s]
100%|██████████| 1066/1066 [02:20<00:00,  7.58it/s]


epoch=3, epoch_loss=3.6557718116678566; test_loss=5.079243455885946


100%|██████████| 9690/9690 [1:05:14<00:00,  2.48it/s]
100%|██████████| 1066/1066 [02:20<00:00,  7.59it/s]

epoch=4, epoch_loss=3.5375746212323014; test_loss=5.16275384770549





In [32]:
model.generate("fairest creatures")

'fairest creatures blubbering me stand stand stand    ill without woo good me i and you me        your she profess your shall divorce \n  petruchio  and my watch hedgepriest paucas saving myself suitor the \n   mew in secrets thy and \n  wonted and shall petitioners let any between the that that   not me shall bras   katherine expected ware stride sit \n  sly vapours have much husband i her horse   beware venuto the name licio an \n    and fair sir you for unsatisfied bondage be \n  petruchio  i constant your ungracious brought to myself     we an form sir of'

In [33]:
model.generate("that thereby")

'that thereby a beams  more mine naild summer    know i mend against i you but it \n    and no we pardon mantuas \n  nurse lord all can \n   say he your sir tell when hath his \n    is true tell all spleen wooing count me dian ran \n  hortensio  that them are that shall repent she shall us \n  petruchio wrong wrappd she and as whereto to behaviour     steps a her of i carve as is offence     an can it impossible fears lawful those is admiral company    calm where shanks make beguile hair gawds \n    london banquet his'

In [34]:
model.generate("tattered weed")

'tattered weed  ingenious a grave  fully     and shall parentage lord wedded am tide  a \n  baptista welcome be good this will confirm happiness thine  rowland  parentage  whilst lad forward in blades ask amen     queen labour white hopeless a thousand \n  baptista instruct fair or have jested all all while like \n    my tis gentle fair look myself have shorter \n    though speak long doubt she his reasons of world mourn      i request never were twenty \n  lucentio said robbd from syracuse and malignant \n  othello  now whence you my blood thy is world modest \n  tranio pray'

# Save checkpoint

In [None]:
import time

torch.save(
    {
      "model_state_dict": model.state_dict(),
      "optimizer_state_dict": optimizer.state_dict(),
      "epoch": epoch,
      "train_loss": epoch_loss,
      "test_loss": test_loss
     },
    f"/content/drive/MyDrive/transformer_model_checkpoints/{int(time.time())}_transformer_model_from_scratch.pth")

In [None]:
from google.colab import runtime

# Disconnects and deletes the current runtime
runtime.unassign()