In [None]:
# boiler plate for training and evaluation using a dummy simple model that uses previous token to predict next token

In [1]:
!wget https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt

--2025-08-22 22:05:08--  https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt
Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5436475 (5.2M) [text/plain]
Saving to: ‘shakespeare.txt.5’


2025-08-22 22:05:08 (186 MB/s) - ‘shakespeare.txt.5’ saved [5436475/5436475]



In [2]:
!wc -lwc shakespeare.txt

 124185  899588 5436475 shakespeare.txt


In [3]:
with open("shakespeare.txt") as f:
  data = f.read()

print(len(data))

data[:1000]

5436475


"  From fairest creatures we desire increase,\n  That thereby beauty's rose might never die,\n  But as the riper should by time decease,\n  His tender heir might bear his memory:\n  But thou contracted to thine own bright eyes,\n  Feed'st thy light's flame with self-substantial fuel,\n  Making a famine where abundance lies,\n  Thy self thy foe, to thy sweet self too cruel:\n  Thou that art now the world's fresh ornament,\n  And only herald to the gaudy spring,\n  Within thine own bud buriest thy content,\n  And tender churl mak'st waste in niggarding:\n    Pity the world, or else this glutton be,\n    To eat the world's due, by the grave and thee.\n\n\n                     2\n  When forty winters shall besiege thy brow,\n  And dig deep trenches in thy beauty's field,\n  Thy youth's proud livery so gazed on now,\n  Will be a tattered weed of small worth held:\n  Then being asked, where all thy beauty lies,\n  Where all the treasure of thy lusty days;\n  To say within thine own deep sunk

In [4]:
chars = list(set(data))
print(len(chars))

ctoi = {ch:i for i, ch in enumerate(chars)}
itoc = {i:ch for ch, i in ctoi.items()}

assert len(ctoi) == len(itoc)


def encoder(text):
  return [ctoi[ch] for ch in text]

def decoder(tokens):
  return "".join([itoc[token] for token in tokens])

84


In [5]:
test_str = "hello world!"
assert test_str == decoder(encoder(test_str))

In [6]:
import torch

train_size = 0.9

num_epochs = 10
batch_size = 1024
emb_dim = 32
lr = 1e-3
context_window_size = 32  # meaningless in this format, where the simple model is only using the previous token to predict next token

torch.manual_seed(2025)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data, test_data = encoder(data[:int(train_size*len(data))]), encoder(data[int(train_size*len(data)):])
train_data_t, test_data_t = torch.tensor(train_data, dtype=torch.long).to(device), torch.tensor(test_data, dtype=torch.long).to(device)

len(train_data), len(test_data)

(4892827, 543648)

In [7]:
type(train_data), type(train_data_t), type(test_data), type(test_data_t)

(list, torch.Tensor, list, torch.Tensor)

In [8]:
num_batches = int(len(train_data)/batch_size*0.9)
num_test_batches = int(len(test_data)/batch_size*0.9)

print(f"{num_batches=}, {num_test_batches=}")

num_batches=4300, num_test_batches=477


In [9]:
# types are explicitly not included

class Model(torch.nn.Module):
  def __init__(self, vocab_size=len(chars), emb_dim=emb_dim, context_window_size=context_window_size):
    super().__init__()
    self.embedding_layer = torch.nn.Embedding(vocab_size, emb_dim)
    self.linear_layer = torch.nn.Linear(emb_dim, vocab_size)
    self.softmax = torch.nn.Softmax(-1)
    self.vocab_size = vocab_size
    self.context_window_size = context_window_size
    self.loss = torch.nn.CrossEntropyLoss()

  def forward(self, X, y=None):
    emb = self.embedding_layer(X)
    logits = self.linear_layer(emb)
    probs = self.softmax(logits)
    if y is None:
      return probs, None
    else:
      return probs, self.loss(probs.view(-1, self.vocab_size), y.view(-1))

  def generate(self, prompt=""):
    # assuming len(prompt) < context_window_size
    # TODO: truncate if not
    X = torch.tensor(encoder(prompt), dtype=torch.long).to(device)

    while len(X) < self.context_window_size:
      probs, _ = self(X)
      next_token = torch.multinomial(probs[-1], 1)
      X = torch.concat((X,next_token))

      # TODO: stop on end token...
    return decoder(X.tolist())


model = Model().to(device)

In [10]:
model.generate("Hello w")

"Hello wiiAEFf'2r!y,NRs ;X57yAdg}"

In [11]:
train_X, train_y = train_data_t[:-1], train_data_t[1:]
test_X, test_y = test_data_t[:-1], test_data_t[1:]

def get_batch(X, y, idx, batch_size=16):
  batch_x, batch_y = [], []
  for i in range(batch_size):
    batch_x.append(X[idx*batch_size+i  :idx*batch_size+i+context_window_size])
    batch_y.append(y[idx*batch_size+i+1:idx*batch_size+i+context_window_size+1])

  return torch.stack(batch_x), torch.stack(batch_y)

In [12]:
from tqdm import tqdm

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# TODO: wrap in a function instead of duplicating code
model.eval()
with torch.no_grad():
  test_loss = 0
  for i in range(num_test_batches):
    X, y = get_batch(test_X, test_y, i, batch_size)
    probs, loss = model(X, y)
    test_loss += loss.item()
  test_loss /= num_test_batches
print(f"random weight {test_loss=}")

for epoch in range(num_epochs):
  model.train()
  epoch_loss = 0
  for i in tqdm(range(num_batches)):
    X, y = get_batch(train_X, train_y, i, batch_size)
    probs, loss = model(X, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
  epoch_loss /= num_batches

  model.eval()
  with torch.no_grad():
    test_loss = 0
    for i in range(num_test_batches):
      X, y = get_batch(test_X, test_y, i, batch_size)
      probs, loss = model(X, y)
      test_loss += loss.item()
    test_loss /= num_test_batches

  print(f"{epoch=}, {epoch_loss=}; {test_loss=}")


random weight test_loss=4.431516471398951


100%|██████████| 4300/4300 [00:42<00:00, 101.37it/s]


epoch=0, epoch_loss=4.23637455629748; test_loss=4.225788273401481


100%|██████████| 4300/4300 [00:42<00:00, 102.24it/s]


epoch=1, epoch_loss=4.2212907846029415; test_loss=4.225944372093153


100%|██████████| 4300/4300 [00:41<00:00, 102.95it/s]


epoch=2, epoch_loss=4.221148357557696; test_loss=4.225657027222575


100%|██████████| 4300/4300 [00:42<00:00, 101.24it/s]


epoch=3, epoch_loss=4.221384235703668; test_loss=4.225590437963074


100%|██████████| 4300/4300 [00:41<00:00, 104.10it/s]


epoch=4, epoch_loss=4.219121327843777; test_loss=4.221889990680623


100%|██████████| 4300/4300 [00:42<00:00, 101.11it/s]


epoch=5, epoch_loss=4.218222948174144; test_loss=4.221361400196387


100%|██████████| 4300/4300 [00:42<00:00, 101.26it/s]


epoch=6, epoch_loss=4.218035585603048; test_loss=4.22091009431915


100%|██████████| 4300/4300 [00:41<00:00, 103.36it/s]


epoch=7, epoch_loss=4.217927774551303; test_loss=4.221243087600612


100%|██████████| 4300/4300 [00:41<00:00, 104.03it/s]


epoch=8, epoch_loss=4.217904448398324; test_loss=4.2212355841630655


100%|██████████| 4300/4300 [00:42<00:00, 100.59it/s]


epoch=9, epoch_loss=4.218024053906285; test_loss=4.221235454207446


In [14]:
model.generate("Hello w")

'Hello w                         '

In [18]:
# num of trainable parameters in the model:
sum(p.numel() for p in model.parameters())

5460