In [None]:
# implementing Rotary Position Embeddings from https://arxiv.org/pdf/2104.09864
# the implementation is not fully vectorized and could be improved for efficiency

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [10]:
!wget https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt

--2025-09-02 01:39:50--  https://gist.githubusercontent.com/blakesanie/dde3a2b7e698f52f389532b4b52bc254/raw/76fe1b5e9efcf0d2afdfd78b0bfaa737ad0a67d3/shakespeare.txt
Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5436475 (5.2M) [text/plain]
Saving to: ‘shakespeare.txt’


2025-09-02 01:39:51 (55.1 MB/s) - ‘shakespeare.txt’ saved [5436475/5436475]



In [11]:
!wc -lwc shakespeare.txt

 124185  899588 5436475 shakespeare.txt


In [12]:
with open("shakespeare.txt") as f:
  data = f.read()

print(len(data))

data[:1000]

5436475


"  From fairest creatures we desire increase,\n  That thereby beauty's rose might never die,\n  But as the riper should by time decease,\n  His tender heir might bear his memory:\n  But thou contracted to thine own bright eyes,\n  Feed'st thy light's flame with self-substantial fuel,\n  Making a famine where abundance lies,\n  Thy self thy foe, to thy sweet self too cruel:\n  Thou that art now the world's fresh ornament,\n  And only herald to the gaudy spring,\n  Within thine own bud buriest thy content,\n  And tender churl mak'st waste in niggarding:\n    Pity the world, or else this glutton be,\n    To eat the world's due, by the grave and thee.\n\n\n                     2\n  When forty winters shall besiege thy brow,\n  And dig deep trenches in thy beauty's field,\n  Thy youth's proud livery so gazed on now,\n  Will be a tattered weed of small worth held:\n  Then being asked, where all thy beauty lies,\n  Where all the treasure of thy lusty days;\n  To say within thine own deep sunk

In [13]:
chars = list(set(data))
print(len(chars))

ctoi = {ch:i for i, ch in enumerate(chars)}
itoc = {i:ch for ch, i in ctoi.items()}

assert len(ctoi) == len(itoc)


def encoder(text):
  return [ctoi[ch] for ch in text]

def decoder(tokens):
  return "".join([itoc[token] for token in tokens])

84


In [14]:
test_str = "hello world!"
assert test_str == decoder(encoder(test_str))

In [15]:
import math
import torch
import torch.nn.functional as F

train_size = 0.9

num_epochs = 2
batch_size = 1024 + 512
emb_dim = 512
num_heads = 16
num_blocks=4
lr = 2e-3
context_window_size = 128

torch.manual_seed(2025)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data, test_data = encoder(data[:int(train_size*len(data))]), encoder(data[int(train_size*len(data)):])
train_data_t, test_data_t = torch.tensor(train_data, dtype=torch.long).to(device), torch.tensor(test_data, dtype=torch.long).to(device)

len(train_data), len(test_data)

(4892827, 543648)

In [16]:
type(train_data), type(train_data_t), type(test_data), type(test_data_t)

(list, torch.Tensor, list, torch.Tensor)

In [17]:
num_batches = int(len(train_data)/batch_size*0.9)  # 0.9 multiplier is an scrappy way of making sure get_batch doesn't go out of bounds
num_test_batches = int(len(test_data)/batch_size*0.9)

print(f"{num_batches=}, {num_test_batches=}")

num_batches=2866, num_test_batches=318


In [18]:
# interleaving (re combining even and odd matrices) trick
o = torch.tensor([1,3,5]).unsqueeze(0).unsqueeze(0).unsqueeze(0)
e = torch.tensor([2,4,6]).unsqueeze(0).unsqueeze(0).unsqueeze(0)

s = torch.stack([o, e], dim=-1)
print(s)

s.flatten(-2)

tensor([[[[[1, 2],
           [3, 4],
           [5, 6]]]]])


tensor([[[[1, 2, 3, 4, 5, 6]]]])

In [42]:
def causal_mask(seq_len):
  # create a lower trianguar mask
  return torch.tril(torch.ones(seq_len, seq_len)).bool()  # shape = [seq_len, seq_len]

class MultiHeadAttention(torch.nn.Module):
  def __init__(self, emb_dim, num_heads):
    super().__init__()
    assert emb_dim % num_heads == 0  # emb_dim is divisible by num_heads
    self.q_proj = torch.nn.Linear(emb_dim, emb_dim)
    self.k_proj = torch.nn.Linear(emb_dim, emb_dim)
    self.v_proj = torch.nn.Linear(emb_dim, emb_dim)

    self.output = torch.nn.Linear(emb_dim, emb_dim)

    self.emb_dim = emb_dim
    self.num_heads = num_heads
    self.head_dim = emb_dim // num_heads

    # context_window_size is the maximum sequence length
    self.register_buffer("cos_table", torch.tensor(
        [
            [math.cos(m*(10_000**(-2*d/self.head_dim))) for d in range(self.head_dim//2)]
            for m in range(context_window_size)
        ]
    ))
    self.register_buffer("sin_table", torch.tensor(
        [
            [math.sin(m*(10_000**(-2*d/self.head_dim))) for d in range(self.head_dim//2)]
            for m in range(context_window_size)
        ]
    ))


  def forward(self, X):
    batch_size, seq_len, _ = X.shape

    Q = self.q_proj(X)
    K = self.k_proj(X)
    V = self.v_proj(X)

    Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)  # shape = [batch_size, seq_len, num_heads, head_dim]
    K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
    V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

    # RoPE
    Q_even = Q[:, :, :, 0::2]  # shape = [batch_size, seq_len, num_heads, head_dim/2]
    Q_odd = Q[:, :, :, 1::2]
    K_even = K[:, :, :, 0::2]
    K_odd = K[:, :, :, 1::2]

    cos = self.cos_table[:seq_len].unsqueeze(0).unsqueeze(2).to(device)  # shape = [1, seq_len, 1, head_dim/2]
    sin = self.sin_table[:seq_len].unsqueeze(0).unsqueeze(2).to(device)

    Qp_even = Q_even * cos - Q_odd * sin  # shape = [batch_size, seq_len, num_heads, head_dim/2]
    Qp_odd = Q_even * sin + Q_odd * cos

    Kp_even = K_even * cos - K_odd * sin  # shape = [batch_size, seq_len, num_heads, head_dim/2]
    Kp_odd = K_even * sin + K_odd * cos

    Qp_stacked = torch.stack([Qp_even, Qp_odd], dim=-1)  # shape = [batch_size, seq_len, num_heads, head_dim/2, 2]
    Qp = Qp_stacked.flatten(-2)  # shape = [batch_size, seq_len, num_heads, head_dim]

    Kp_stacked = torch.stack([Kp_even, Kp_odd], dim=-1)  # shape = [batch_size, seq_len, num_heads, head_dim/2, 2]
    Kp = Kp_stacked.flatten(-2)  # shape = [batch_size, seq_len, num_heads, head_dim]

    # split into heads and change shae [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
    Qp = Qp.transpose(1, 2)
    Kp = Kp.transpose(1, 2)
    V = V.transpose(1, 2)

    # scaled dot product attention
    scores = Qp @ Kp.transpose(-2, -1) / (self.head_dim**0.5)
    mask = causal_mask(seq_len).unsqueeze(0).unsqueeze(0).to(device)  # shape = [1, 1, seq_len, seq_len]
    scores = scores.masked_fill(mask==0, float("-inf"))
    attn = F.softmax(scores, dim=-1) @ V  # shape = [batch_size, num_heads, seq_len, head_dim]

    # concat heads
    attn = attn.transpose(1, 2).reshape(batch_size, seq_len, self.emb_dim)
    return self.output(attn)


class TransformerBlock(torch.nn.Module):
  def __init__(self, emb_dim=emb_dim):
    super().__init__()
    self.attention = MultiHeadAttention(emb_dim=emb_dim, num_heads=num_heads)
    self.ffn = torch.nn.Sequential(
        torch.nn.Linear(emb_dim, emb_dim*4),
        torch.nn.ReLU(),
        torch.nn.Linear(emb_dim*4, emb_dim)
    )
    self.ln1 = torch.nn.LayerNorm(emb_dim)
    self.ln2 = torch.nn.LayerNorm(emb_dim)

  def forward(self, X):
    res = X
    X = self.ln1(X)
    X = res + self.attention(X)

    res = X
    X = self.ln2(X)  # pre-LN
    X = res + self.ffn(X)

    return X


class Model(torch.nn.Module):
  def __init__(self, vocab_size=len(chars), emb_dim=emb_dim, seq_len=context_window_size, num_blocks=num_blocks):
    super().__init__()
    self.embedding_layer = torch.nn.Embedding(vocab_size, emb_dim)
    # self.positional_emb = torch.nn.Embedding(seq_len, emb_dim)  #
    self.transformer = torch.nn.ModuleList([TransformerBlock(emb_dim=emb_dim) for _ in range(num_blocks)])  # note: using nn.Sequential([TB]) will reference the same instance of TB -> all of them will share the same weights which is not what we want
    self.linear = torch.nn.Linear(emb_dim, vocab_size)
    self.softmax = torch.nn.Softmax(-1)

    self.vocab_size = vocab_size
    self.seq_len = seq_len
    self.loss = torch.nn.CrossEntropyLoss()

  def forward(self, X, y=None):
    # X shape = [batch_size, seq_len]
    emb = self.embedding_layer(X)  # shape = [batch_size, seq_len, emb_dim]
    # positions = torch.arange(emb.shape[1]).unsqueeze(0).to(device)  # shape = [batch_size, seq_len]
    # pos_emb = self.positional_emb(positions)  # shape = [batch_size, seq_len, emb_dim]
    x = emb #+ pos_emb
    for block in self.transformer:
      x = block(x)
    logits = self.linear(x)
    if y is None:
      return logits, None
    else:
      # during training
      return logits, self.loss(logits.view(-1, self.vocab_size), y.view(-1))  # bug fix from previous version. CrossEntropy requires logits not probabilities

  def generate(self, prompt=""):
    # assuming len(prompt) < seq_len
    # TODO: truncate if not
    X = torch.tensor(encoder(prompt), dtype=torch.long).to(device)
    X = X.unsqueeze(0) # to convert shape to [batch_size, seq_len]

    while X.shape[1] < self.seq_len:
      logits, _ = self(X)
      probs = self.softmax(logits)
      next_token = torch.multinomial(probs[0][-1], 1)
      X = torch.cat((X,next_token.unsqueeze(0)), dim=1)
      # TODO: stop on end token...
    return decoder(X.squeeze(0).tolist())


model = Model().to(device)

In [43]:
# num of trainable parameters in the model:
sum(p.numel() for p in model.parameters())

12695636

In [44]:
model.generate("Hello w")

'Hello wBc\nxKg:bO1ia]Z)xwLOI4qNp7y4!KIbjp:`VEOG[uH6JAl_lx7)VmdES2yW[[}cK&Fr7sdp(!4-Jlg"sJAGrK-&6 S]dz&5J\nm!f6(LhhIy"cQi.Hl4HN;swn'

In [45]:
train_X, train_y = train_data_t[:-1], train_data_t[1:]
test_X, test_y = test_data_t[:-1], test_data_t[1:]

def get_batch(X, y, idx, batch_size=16):
  batch_x, batch_y = [], []
  for i in range(batch_size):
    batch_x.append(X[idx*batch_size+i  :idx*batch_size+i+context_window_size])
    batch_y.append(y[idx*batch_size+i+1:idx*batch_size+i+context_window_size+1])

  return torch.stack(batch_x), torch.stack(batch_y)

In [46]:
from tqdm import tqdm

In [47]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

def eval_model():
  model.eval()
  with torch.no_grad():
    test_loss = 0
    for i in tqdm(range(num_test_batches)):
      X, y = get_batch(test_X, test_y, i, batch_size)
      _, loss = model(X, y)
      test_loss += loss.item()
    test_loss /= num_test_batches
  return test_loss

test_loss = eval_model()
print(f"random weight {test_loss=}")

for epoch in range(num_epochs):
  model.train()
  epoch_loss = 0
  for i in tqdm(range(num_batches)):
    X, y = get_batch(train_X, train_y, i, batch_size)
    _, loss = model(X, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
  epoch_loss /= num_batches

  test_loss = eval_model()

  print(f"{epoch=}, {epoch_loss=}; {test_loss=}")

100%|██████████| 318/318 [02:11<00:00,  2.42it/s]


random weight test_loss=4.499420387939837


100%|██████████| 2866/2866 [59:01<00:00,  1.24s/it]
100%|██████████| 318/318 [02:11<00:00,  2.42it/s]


epoch=0, epoch_loss=2.2278645228690728; test_loss=2.3363271835465103


100%|██████████| 2866/2866 [59:00<00:00,  1.24s/it]
100%|██████████| 318/318 [02:11<00:00,  2.43it/s]

epoch=1, epoch_loss=1.9792931257056259; test_loss=2.26364684967125





In [48]:
model.generate("Hello w")

"Hello wse, nywrnes. o\n   Tel etyyto hms rak i rmh srwreteccyb' eel, i a\n   bdttce fulfrtne hrwrogo o ur t a ue\n   Wltrc halt,tys"

In [49]:
model.generate("From fairest creatures")

'From fairest creaturest oe o id lie i: i ecal crptanw\n    n s a xadiigrs n cut fre,Iwt wsa lbya\n   Adageo nte rs n cus rnon o i '

In [50]:
model.generate("That thereby")

'That therebyI adiu a efct al,ai\n   T ufahos,wad.Iw er temie,I,t  wus\n   Adiusadefets hubadnsai,bdflcses o oe\n   Ichknc n nvdaet,'

In [51]:
model.generate("tattered weed")

'tattered weedmn\n   Flte fre nn,wr  rag n Smlscmr os u,yblnnnuh\n   Acin,smie eqelnrnepiget,tndwrkigt n ocae\n   Tlmslyd nwokrs uge'

In [52]:
import time

torch.save(
    {
      "model_state_dict": model.state_dict(),
      "optimizer_state_dict": optimizer.state_dict(),
      "epoch": epoch,
      "train_loss": epoch_loss,
      "test_loss": test_loss
     },
    f"/content/drive/MyDrive/transformer_model_checkpoints/{int(time.time())}_rope_transformer_model.pth")

In [53]:
from google.colab import runtime

# Disconnects and deletes the current runtime
runtime.unassign()