In [None]:
# bigger model (174k parameters) compared to previous one (5.4k parameters) + using ReLU for nonlinearity

In [1]:
!wget https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt

--2025-08-22 22:38:31--  https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt
Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5436475 (5.2M) [text/plain]
Saving to: ‘shakespeare.txt’


2025-08-22 22:38:31 (88.9 MB/s) - ‘shakespeare.txt’ saved [5436475/5436475]



In [2]:
!wc -lwc shakespeare.txt

 124185  899588 5436475 shakespeare.txt


In [3]:
with open("shakespeare.txt") as f:
  data = f.read()

print(len(data))

data[:1000]

5436475


"  From fairest creatures we desire increase,\n  That thereby beauty's rose might never die,\n  But as the riper should by time decease,\n  His tender heir might bear his memory:\n  But thou contracted to thine own bright eyes,\n  Feed'st thy light's flame with self-substantial fuel,\n  Making a famine where abundance lies,\n  Thy self thy foe, to thy sweet self too cruel:\n  Thou that art now the world's fresh ornament,\n  And only herald to the gaudy spring,\n  Within thine own bud buriest thy content,\n  And tender churl mak'st waste in niggarding:\n    Pity the world, or else this glutton be,\n    To eat the world's due, by the grave and thee.\n\n\n                     2\n  When forty winters shall besiege thy brow,\n  And dig deep trenches in thy beauty's field,\n  Thy youth's proud livery so gazed on now,\n  Will be a tattered weed of small worth held:\n  Then being asked, where all thy beauty lies,\n  Where all the treasure of thy lusty days;\n  To say within thine own deep sunk

In [4]:
chars = list(set(data))
print(len(chars))

ctoi = {ch:i for i, ch in enumerate(chars)}
itoc = {i:ch for ch, i in ctoi.items()}

assert len(ctoi) == len(itoc)


def encoder(text):
  return [ctoi[ch] for ch in text]

def decoder(tokens):
  return "".join([itoc[token] for token in tokens])

84


In [5]:
test_str = "hello world!"
assert test_str == decoder(encoder(test_str))

In [6]:
import torch

train_size = 0.9

num_epochs = 10
batch_size = 1024
emb_dim = 256
lr = 2e-3
context_window_size = 32  # meaningless in this format, where the simple model is only using the previous token to predict next token

torch.manual_seed(2025)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data, test_data = encoder(data[:int(train_size*len(data))]), encoder(data[int(train_size*len(data)):])
train_data_t, test_data_t = torch.tensor(train_data, dtype=torch.long).to(device), torch.tensor(test_data, dtype=torch.long).to(device)

len(train_data), len(test_data)

(4892827, 543648)

In [7]:
type(train_data), type(train_data_t), type(test_data), type(test_data_t)

(list, torch.Tensor, list, torch.Tensor)

In [8]:
num_batches = int(len(train_data)/batch_size*0.9)  # 0.9 multiplier is an scrappy way of making sure get_batch doesn't go out of bounds
num_test_batches = int(len(test_data)/batch_size*0.9)

print(f"{num_batches=}, {num_test_batches=}")

num_batches=4300, num_test_batches=477


In [9]:
# types are explicitly not included

class Model(torch.nn.Module):
  def __init__(self, vocab_size=len(chars), emb_dim=emb_dim, context_window_size=context_window_size):
    super().__init__()
    self.embedding_layer = torch.nn.Embedding(vocab_size, emb_dim)
    self.linear_layer1 = torch.nn.Linear(emb_dim, emb_dim)
    self.linear_layer2 = torch.nn.Linear(emb_dim, emb_dim)
    self.linear_layer3 = torch.nn.Linear(emb_dim, vocab_size)
    self.softmax = torch.nn.Softmax(-1)
    self.vocab_size = vocab_size
    self.context_window_size = context_window_size
    self.loss = torch.nn.CrossEntropyLoss()

  def forward(self, X, y=None):
    emb = self.embedding_layer(X)
    int_X = self.linear_layer1(emb)
    int_X = torch.nn.functional.relu(int_X)
    int_X = self.linear_layer2(int_X)
    int_X = torch.nn.functional.relu(int_X)
    logits = self.linear_layer3(int_X)
    probs = self.softmax(logits)
    if y is None:
      return probs, None
    else:
      return probs, self.loss(logits.view(-1, self.vocab_size), y.view(-1))  # bug fix from previous version. CrossEntropy requires logits not probabilities

  def generate(self, prompt=""):
    # assuming len(prompt) < context_window_size
    # TODO: truncate if not
    X = torch.tensor(encoder(prompt), dtype=torch.long).to(device)

    while len(X) < self.context_window_size:
      probs, _ = self(X)
      next_token = torch.multinomial(probs[-1], 1)
      X = torch.cat((X,next_token))

      # TODO: stop on end token...
    return decoder(X.tolist())


model = Model().to(device)

In [10]:
# num of trainable parameters in the model:
sum(p.numel() for p in model.parameters())

174676

In [11]:
model.generate("Hello w")

"Hello wNNV><GdTt}|rIbU&eLo:'VQ]K"

In [12]:
train_X, train_y = train_data_t[:-1], train_data_t[1:]
test_X, test_y = test_data_t[:-1], test_data_t[1:]

def get_batch(X, y, idx, batch_size=16):
  batch_x, batch_y = [], []
  for i in range(batch_size):
    batch_x.append(X[idx*batch_size+i  :idx*batch_size+i+context_window_size])
    batch_y.append(y[idx*batch_size+i+1:idx*batch_size+i+context_window_size+1])

  return torch.stack(batch_x), torch.stack(batch_y)

In [13]:
from tqdm import tqdm

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# TODO: wrap in a function instead of duplicating code
model.eval()
with torch.no_grad():
  test_loss = 0
  for i in range(num_test_batches):
    X, y = get_batch(test_X, test_y, i, batch_size)
    probs, loss = model(X, y)
    test_loss += loss.item()
  test_loss /= num_test_batches
print(f"random weight {test_loss=}")

for epoch in range(num_epochs):
  model.train()
  epoch_loss = 0
  for i in tqdm(range(num_batches)):
    X, y = get_batch(train_X, train_y, i, batch_size)
    probs, loss = model(X, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
  epoch_loss /= num_batches

  model.eval()
  with torch.no_grad():
    test_loss = 0
    for i in range(num_test_batches):
      X, y = get_batch(test_X, test_y, i, batch_size)
      probs, loss = model(X, y)
      test_loss += loss.item()
    test_loss /= num_test_batches

  print(f"{epoch=}, {epoch_loss=}; {test_loss=}")


random weight test_loss=4.450759590796705


100%|██████████| 4300/4300 [01:10<00:00, 61.04it/s]


epoch=0, epoch_loss=2.829165631504946; test_loss=2.898080866791667


100%|██████████| 4300/4300 [01:09<00:00, 61.45it/s]


epoch=1, epoch_loss=2.818565161616303; test_loss=2.893548139236258


100%|██████████| 4300/4300 [01:09<00:00, 61.82it/s]


epoch=2, epoch_loss=2.8175291816578354; test_loss=2.8851794461784124


100%|██████████| 4300/4300 [01:10<00:00, 61.20it/s]


epoch=3, epoch_loss=2.8201700493346813; test_loss=2.893317534488702


100%|██████████| 4300/4300 [01:11<00:00, 59.92it/s]


epoch=4, epoch_loss=2.821109535250553; test_loss=2.8765543346884868


100%|██████████| 4300/4300 [01:09<00:00, 61.52it/s]


epoch=5, epoch_loss=2.8214588141441346; test_loss=2.8785422493076926


100%|██████████| 4300/4300 [01:10<00:00, 60.66it/s]


epoch=6, epoch_loss=2.8217318461662115; test_loss=2.8806357218784355


100%|██████████| 4300/4300 [01:09<00:00, 61.54it/s]


epoch=7, epoch_loss=2.8225284137836724; test_loss=2.8748602032411523


100%|██████████| 4300/4300 [01:10<00:00, 61.30it/s]


epoch=8, epoch_loss=2.823509245196054; test_loss=2.8791867317143724


100%|██████████| 4300/4300 [01:09<00:00, 61.71it/s]


epoch=9, epoch_loss=2.823858835586282; test_loss=2.873954288614621


In [15]:
model.generate("Hello w")

'Hello ws o  hmeyuuwyudeyuyu\n t t'

In [16]:
model.generate("From fairest creatures")

'From fairest creatureste\nESehtrt'

In [17]:
model.generate("That thereby")

'That therebyagt oe   n hmn a ait'

In [18]:
model.generate("tattered weed")

'tattered weedatsI oaraeoefsn. h '