In [1]:
!wget https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt

--2025-08-23 00:06:47--  https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt
Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5436475 (5.2M) [text/plain]
Saving to: ‘shakespeare.txt.1’


2025-08-23 00:06:47 (86.5 MB/s) - ‘shakespeare.txt.1’ saved [5436475/5436475]



In [2]:
!wc -lwc shakespeare.txt

 124185  899588 5436475 shakespeare.txt


In [3]:
with open("shakespeare.txt") as f:
  data = f.read()

print(len(data))

data[:1000]

5436475


"  From fairest creatures we desire increase,\n  That thereby beauty's rose might never die,\n  But as the riper should by time decease,\n  His tender heir might bear his memory:\n  But thou contracted to thine own bright eyes,\n  Feed'st thy light's flame with self-substantial fuel,\n  Making a famine where abundance lies,\n  Thy self thy foe, to thy sweet self too cruel:\n  Thou that art now the world's fresh ornament,\n  And only herald to the gaudy spring,\n  Within thine own bud buriest thy content,\n  And tender churl mak'st waste in niggarding:\n    Pity the world, or else this glutton be,\n    To eat the world's due, by the grave and thee.\n\n\n                     2\n  When forty winters shall besiege thy brow,\n  And dig deep trenches in thy beauty's field,\n  Thy youth's proud livery so gazed on now,\n  Will be a tattered weed of small worth held:\n  Then being asked, where all thy beauty lies,\n  Where all the treasure of thy lusty days;\n  To say within thine own deep sunk

In [4]:
chars = list(set(data))
print(len(chars))

ctoi = {ch:i for i, ch in enumerate(chars)}
itoc = {i:ch for ch, i in ctoi.items()}

assert len(ctoi) == len(itoc)


def encoder(text):
  return [ctoi[ch] for ch in text]

def decoder(tokens):
  return "".join([itoc[token] for token in tokens])

84


In [5]:
test_str = "hello world!"
assert test_str == decoder(encoder(test_str))

In [6]:
import torch
import torch.nn.functional as F

train_size = 0.9

num_epochs = 10
batch_size = 1024
emb_dim = 256
lr = 2e-3
context_window_size = 128

torch.manual_seed(2025)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data, test_data = encoder(data[:int(train_size*len(data))]), encoder(data[int(train_size*len(data)):])
train_data_t, test_data_t = torch.tensor(train_data, dtype=torch.long).to(device), torch.tensor(test_data, dtype=torch.long).to(device)

len(train_data), len(test_data)

(4892827, 543648)

In [7]:
type(train_data), type(train_data_t), type(test_data), type(test_data_t)

(list, torch.Tensor, list, torch.Tensor)

In [8]:
num_batches = int(len(train_data)/batch_size*0.9)  # 0.9 multiplier is an scrappy way of making sure get_batch doesn't go out of bounds
num_test_batches = int(len(test_data)/batch_size*0.9)

print(f"{num_batches=}, {num_test_batches=}")

num_batches=4300, num_test_batches=477


In [9]:
# types are explicitly not included

def scaled_dot_product_attention(Q, K, V, mask=None):
  # QKV are of shape [batch_size, seq_len, d_k]
  # mask is a tensor of shape [seq_len, seq_len]
  scores = (Q @ K.transpose(-2, -1)) / Q.shape[-1]**0.5
  # K.transpose(-2, -1) shape = [batch_size, d_k, seq_len]
  # Q @ K^T shape = [batch_size, seq_len, seq_len]
  if mask is not None:
    scores = scores.masked_fill(mask==0, float("-inf"))
  return F.softmax(scores, dim=-1) @ V  # shape = [batch_size, seq_len, d_k]


def causal_mask(seq_len):
  # create a lower trianguar mask
  return torch.tril(torch.ones(seq_len, seq_len)).bool()  # shape = [seq_len, seq_len]


class SelfAttention(torch.nn.Module):
  def __init__(self, emb_dim=emb_dim):
    super().__init__()
    self.emb_dim = emb_dim
    self.q_proj = torch.nn.Linear(emb_dim, emb_dim)
    self.k_proj = torch.nn.Linear(emb_dim, emb_dim)
    self.v_proj = torch.nn.Linear(emb_dim, emb_dim)

    self.output = torch.nn.Linear(emb_dim, emb_dim)

  def forward(self, X):
    # X shape = [batch_size, seq_len, emb_dim]
    Q = self.q_proj(X)
    K = self.k_proj(X)
    V = self.v_proj(X)

    seq_len = Q.shape[1]
    mask = causal_mask(seq_len).unsqueeze(0).to(device)

    return self.output(scaled_dot_product_attention(Q, K, V, mask=mask))


class TransformerBlock(torch.nn.Module):
  def __init__(self, emb_dim=emb_dim):
    super().__init__()
    self.attention = SelfAttention(emb_dim=emb_dim)
    self.linear = torch.nn.Linear(emb_dim, emb_dim)
    self.ln1 = torch.nn.LayerNorm(emb_dim)
    self.ln2 = torch.nn.LayerNorm(emb_dim)

  def forward(self, X):
    res = X
    X = self.attention(X)
    X += res
    X = self.ln1(X)

    res = X
    X = self.linear(X)
    X += res
    X = self.ln2(X)

    return X


class Model(torch.nn.Module):
  def __init__(self, vocab_size=len(chars), emb_dim=emb_dim, context_window_size=context_window_size):
    super().__init__()
    self.embedding_layer = torch.nn.Embedding(vocab_size, emb_dim)
    self.transformer = TransformerBlock(emb_dim=emb_dim)
    self.linear = torch.nn.Linear(emb_dim, vocab_size)
    # TODO: add positional encoding

    self.softmax = torch.nn.Softmax(-1)
    self.vocab_size = vocab_size
    self.context_window_size = context_window_size
    self.loss = torch.nn.CrossEntropyLoss()

  def forward(self, X, y=None):
    # X shape = [batch_size, seq_len]
    emb = self.embedding_layer(X)  # shape = [batch_size, seq_len, emb_dim]
    logits = self.linear(self.transformer(emb))  # single block, single attention head
    probs = self.softmax(logits)
    if y is None:
      return probs, None
    else:
      return probs, self.loss(logits.view(-1, self.vocab_size), y.view(-1))  # bug fix from previous version. CrossEntropy requires logits not probabilities

  def generate(self, prompt=""):
    # assuming len(prompt) < context_window_size
    # TODO: truncate if not
    X = torch.tensor(encoder(prompt), dtype=torch.long).to(device)
    X = X.unsqueeze(0) # to convert shape to [batch_size, seq_len]

    while X.shape[1] < self.context_window_size:
      probs, _ = self(X)
      next_token = torch.multinomial(probs[0][-1], 1)
      X = torch.cat((X,next_token.unsqueeze(0)), dim=1)
      # TODO: stop on end token...
    return decoder(X.squeeze(0).tolist())


model = Model().to(device)

In [10]:
# num of trainable parameters in the model:
sum(p.numel() for p in model.parameters())

373076

In [11]:
model.generate("Hello w")

'Hello wTTgvBSp[Oq5&_aPjAqI85Ze|?Mc0g\'BPA]iH&vR<AIX>ZG>64OXXi.3a.t3yYEN,oMFZzeiSV)Sj!4"OokJ_H!](ZtM|.y.BWF8Ug.\'e0\n6tSiE[xr!Rg9[dE'

In [12]:
train_X, train_y = train_data_t[:-1], train_data_t[1:]
test_X, test_y = test_data_t[:-1], test_data_t[1:]

def get_batch(X, y, idx, batch_size=16):
  batch_x, batch_y = [], []
  for i in range(batch_size):
    batch_x.append(X[idx*batch_size+i  :idx*batch_size+i+context_window_size])
    batch_y.append(y[idx*batch_size+i+1:idx*batch_size+i+context_window_size+1])

  return torch.stack(batch_x), torch.stack(batch_y)

In [13]:
from tqdm import tqdm

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# TODO: wrap in a function instead of duplicating code
model.eval()
with torch.no_grad():
  test_loss = 0
  for i in range(num_test_batches):
    X, y = get_batch(test_X, test_y, i, batch_size)
    probs, loss = model(X, y)
    test_loss += loss.item()
  test_loss /= num_test_batches
print(f"random weight {test_loss=}")

for epoch in range(num_epochs):
  model.train()
  epoch_loss = 0
  for i in tqdm(range(num_batches)):
    X, y = get_batch(train_X, train_y, i, batch_size)
    probs, loss = model(X, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
  epoch_loss /= num_batches

  model.eval()
  with torch.no_grad():
    test_loss = 0
    for i in range(num_test_batches):
      X, y = get_batch(test_X, test_y, i, batch_size)
      probs, loss = model(X, y)
      test_loss += loss.item()
    test_loss /= num_test_batches

  print(f"{epoch=}, {epoch_loss=}; {test_loss=}")


random weight test_loss=4.598102062753162


100%|██████████| 4300/4300 [10:04<00:00,  7.12it/s]


epoch=0, epoch_loss=2.8294481539171796; test_loss=2.9057920998747244


100%|██████████| 4300/4300 [09:55<00:00,  7.22it/s]


epoch=1, epoch_loss=2.8139832000954206; test_loss=2.8980872766026913


100%|██████████| 4300/4300 [09:52<00:00,  7.25it/s]


epoch=2, epoch_loss=2.816237937128821; test_loss=2.894635406929992


100%|██████████| 4300/4300 [09:47<00:00,  7.31it/s]


epoch=3, epoch_loss=2.8146807001912317; test_loss=2.8855199184057847


100%|██████████| 4300/4300 [09:45<00:00,  7.35it/s]


epoch=4, epoch_loss=2.8080637375698534; test_loss=2.8754669925201863


100%|██████████| 4300/4300 [09:44<00:00,  7.36it/s]


epoch=5, epoch_loss=2.8028719617599664; test_loss=2.8768195901027016


100%|██████████| 4300/4300 [09:44<00:00,  7.36it/s]


epoch=6, epoch_loss=2.8022412790254103; test_loss=2.874900992811351


100%|██████████| 4300/4300 [09:42<00:00,  7.38it/s]


epoch=7, epoch_loss=2.800742205520009; test_loss=2.8812722799912938


100%|██████████| 4300/4300 [09:43<00:00,  7.37it/s]


epoch=8, epoch_loss=2.798622443675995; test_loss=2.874489603302514


100%|██████████| 4300/4300 [09:43<00:00,  7.37it/s]


epoch=9, epoch_loss=2.799012745535651; test_loss=2.875386122637575


In [15]:
model.generate("Hello w")

"Hello wselorh '\n   a Admk\n EIAHOMRI nani  os\n  oekot\n   oeol-vn M ie.Hrs hd\n nyIE.TTI hr,igt.  etam ete  Ibty\n hn hwlsem,shnwe o"

In [16]:
model.generate("From fairest creatures")

'From fairest creatures hty\n od\n e e,atihn   a eMO\n i; oa ftnmv  ign UHv BTM eom T n e\n lso Rmtru opc OR Or  al\n  omrtwno oe  rse'

In [17]:
model.generate("That thereby")

'That thereby hol alece ue ulIgotesmte hvrn;oo hntr u\n ocn gol\n o rm o.OIkr\n e e Y ah htinmsfPa,a rwikoe  ol AO o ohmsk\n aeko\n ue'

In [18]:
model.generate("tattered weed")

'tattered weeduk\n  RBNEIliee,tetwts\n sfsaeslwr o hs m spes UIl or wn  CNYut r S r, hl\n  h ne   rasaetfi o  ygo,ica om an,dwfrglv '

In [19]:
from google.colab import runtime

# Disconnects and deletes the current runtime
runtime.unassign()