In [1]:
from pathlib import Path
from math import floor
from typing import List
from datetime import datetime

import time
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from matplotlib import pyplot as plt

# Simple end-to-end plan

1. **Load data**  
   * Read `enwik9` as bytes, keep ints 0-255.  
   * Split once into train / validation.

2. **Build dataset**  
   * Each sample: `ctx_len+1` consecutive bytes.  
   * Inputs = first `ctx_len`; targets = bytes 1…`ctx_len`.  
   * Wrap in `DataLoader(batch_size=64, shuffle=True)`.

3. **Set up model**  
   * Use your `CTransformer` (d_model = 256, ctx_len = 256, mask=True).  
   * Move to `device = torch.device("cuda")`.

4. **Configure training**  
   * Loss: `nn.CrossEntropyLoss()`.  
   * Optimiser: `AdamW(lr=3e-4, weight_decay=1e-4)`.  
   * Use mixed precision (`torch.cuda.amp`) and clip grads at 1.0.  
   * Scheduler: warm-up 200 steps → cosine decay.

5. **Training loop**  
   * For each batch: forward → loss → backward → step → scheduler.  
   * Log train loss every 100 steps.

6. **Validate & checkpoint**  
   * Run no-grad pass on validation set each epoch.  
   * Save `state_dict` when val loss improves.

7. **Quick sampling test**  
   * In `eval` mode, feed a prompt, soft-max logits, sample next byte, repeat.  
   * If output starts to look like English, training is on track.


# Modules

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, k, heads=4, mask=False):
        super().__init__()
        assert k % heads == 0
        self.mask = mask
        self.k, self.heads = k, heads

        self.tokeys = nn.Linear(k, k, bias=False)
        self.toqueries = nn.Linear(k, k, bias=False)
        self.tovalues = nn.Linear(k, k, bias=False)

        self.unifyheads = nn.Linear(k, k)

    def forward(self, x):
        # b = batch size
        # t = context window size
        # k = hidden dimension
        # h = # of heads
        b, t, k = x.size()
        h = self.heads

        keys = self.tokeys(x)
        queries = self.toqueries(x)
        values = self.tovalues(x)

        # s = dimension per head
        s = k // h
        keys = keys.view(b, t, h, s)
        queries = queries.view(b, t, h, s)
        values = values.view(b, t, h, s)

        # transpose 1,2 so that we can compute along every batch/head
        keys = keys.transpose(1, 2).contiguous().view(b*h, t, s)
        queries = queries.transpose(1, 2).contiguous().view(b*h, t, s)
        values = values.transpose(1, 2).contiguous().view(b*h, t, s)

        # (t, s) x (s, t) -> (t, t)
        dot = torch.bmm(queries, keys.transpose(1, 2))

        # Mask to prevent looking into the future
        if self.mask:
            indices = torch.triu_indices(t, t, offset=1, device=dot.device)
            dot[:, indices[0], indices[1]] = float('-inf')
            
        dot = dot / (s ** (1/2))
        dot = F.softmax(dot, dim=2)
        # (t, t) x (t, s) -> (t, s)
        out = torch.bmm(dot, values).view(b, h, t, s)
        out = out.transpose(1, 2).contiguous().view(b, t, s*h)

        return self.unifyheads(out)
        
        

In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, k, heads):
        super().__init__()

        self.attention = SelfAttention(k, heads=heads, mask=True) # wasted a day training with mask=False lol

        self.norm1 = nn.LayerNorm(k)
        self.norm2 = nn.LayerNorm(k)

        self.ff = nn.Sequential(
            nn.Linear(k, 4*k),
            nn.ReLU(),
            nn.Linear(4*k, k)
        )

    def forward(self, x):
        attended = self.attention(x)
        x = self.norm1(attended + x)
        fedforward = self.ff(x)
        return self.norm2(fedforward + x)
        
        

In [4]:
class CTransformer(nn.Module):
    def __init__(self, vocab_size=256, ctx_len=256, d_model=256, n_layers=12, n_heads=8):
        super().__init__()

        # token & position embeddings
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.randn(ctx_len, d_model))

        # stack of transformer blocks
        self.blocks = nn.ModuleList(
            [TransformerBlock(d_model, n_heads) for _ in range(n_layers)]
        )
        self.ln_f = nn.LayerNorm(d_model)

        # output embedding
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.head.weight = self.tok_emb.weight

    def forward(self, idx):
        '''
        idx: (batch, seq_len) tensor of character IDs
        returns logits: (batch, seq_len, vocab_size)
        '''
        b, t = idx.shape
        x = self.tok_emb(idx) + self.pos_emb[:t]
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.head(x)
        

# Load Data

In [5]:
class ByteDataset(torch.utils.data.Dataset):
    def __init__(self, data: np.ndarray, ctx_len: int):
        '''
        data: 1-D uint8 array of bytes (int 0-255)
        ctx_len: sequence length
        Each item returns (input_idxs, target_idxs), both of length ctx_len
        '''
        self.data = torch.from_numpy(data) #data.astype(np.int64))
        self.ctx_len = ctx_len

    def __len__(self):
        '''
        Tells torch.DataLoader how many samples there are.
        '''
        return len(self.data) - self.ctx_len

    def __getitem__(self, _):
        # Shuffle in here. 
        # DataLoader(..., shuffle=True) breaks the notebook for some reason. I think it might be a memory issue. 
        # This will also break multiple workers, so be careful. 
        i = torch.randint(0, len(self.data) - self.ctx_len - 1, (1,)).item()
        
        chunk = self.data[i: i + self.ctx_len + 1]
        return chunk[:-1].long(), chunk[1:].long()



In [6]:
data_path = '/mnt/c/ad_astra/data/enwik9/enwik9'
val_frac = 0.01
num_workers = 0
pin_memory = True
shuffle = False

ctx_len = 256
batch_size = 192
vocab_size = 256

In [7]:
# raw = Path(data_path).read_bytes()
raw = np.memmap(data_path, dtype=np.uint8, mode='r')

data = np.frombuffer(raw, dtype=np.uint8)

In [8]:
split = len(data) - floor(len(data) * val_frac)
train_data, val_data = data[:split], data[split:]
train_ds, val_ds = ByteDataset(train_data, ctx_len), ByteDataset(val_data, ctx_len)

# Jimmy Notes
# pin_memory: page locks data so DMA can stream directly to GPU
# drop_last: drops final batch if dataset size isn't divisible by batch_size so every sample is the same size
train_loader = iter(torch.utils.data.DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=shuffle,
    pin_memory=pin_memory,
    num_workers=num_workers,
    drop_last=True
))

val_loader = iter(torch.utils.data.DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=shuffle,
    pin_memory=pin_memory,
    num_workers=num_workers,
    drop_last=True
))

  self.data = torch.from_numpy(data) #data.astype(np.int64))


# Training Code

In [9]:
def run_epoch(model, loader, max_steps, train=True):
    total, steps = 0, 0
    model.train(train)
    torch.set_grad_enabled(train)
    loader_len = len(loader)
    train_loss_hist = []
    while steps < max_steps:
        t0 = time.time()
        xb, yb = next(loader)
        xb, yb = xb.to(device), yb.to(device)
        data_ms = (time.time() - t0) * 1e3

        assert xb.is_cuda and next(model.parameters()).is_cuda, "data or model not on GPU!"
        fwd_start = time.time()
        with autocast():
            # Runs a forward pass. model(xb) (instead of model.forward(xb))
            # runs a pre/post hooks. 
            # Returns shape (batch, ctx_len, vocab_size)
            logits = model(xb)

            # N = batch_size * ctx_len
            # logits.view: (N, vocab_size)
            # yb.view: (N, )
            # CrossEntropyLoss automatically interprets yb as the index of the correct label.
            loss = criterion(logits.view(-1, vocab_size), yb.view(-1))
        fwd_ms = (time.time() - fwd_start) * 1e3

        bwd_start = time.time()
        if train:
            # Scales loss (to avoid underflow), then calls backward
            scaler.scale(loss).backward()
            # optimizer holds references to the Parameters inside our computation graph. 
            # unscale_ goes through these references and unscales the gradients.
            scaler.unscale_(optimizer)
            # simple clipping"
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            # Update the weights
            scaler.step(optimizer)
            # Update the state of the scaler. Specifically slowly raises it if found_inf=0
            # and halves it if found_inf=1
            # The approach is to raise it as much as we can without overflowing to have more float dynamic range
            scaler.update()
            # Go through every registered parameter and clear the gradients
            optimizer.zero_grad(set_to_none=True)
            # Advanced learning rate scheduler
            scheduler.step()
        bwd_ms = (time.time() - bwd_start) * 1e3
        
        total += loss.item()
        steps += 1
        train_loss_hist.append(total / steps)
        if train and steps % 1_000 == 0:
            print(
                f"step {steps:>8,}  |  "
                f"data {data_ms:6.1f} ms  "
                f"fwd {fwd_ms:6.1f} ms  "
                f"bwd+opt {bwd_ms:6.1f} ms  "
                f"train‑loss {total/steps:.4f}"
            )
            # print(f'step {steps:>8,} / {max_steps:>8,}  |  train-loss {total/steps:.4f}')
    return train_loss_hist

def sample(model, 
           prompt: List[int], 
           n_tokens=256, 
           temperature=1.0
          ):
    # eval mode
    model.eval()
    # Convert prompt into LongTensor and add another dimension to the beginning (batch_size=1) basically
    idx = torch.tensor(prompt, dtype=torch.long, device=device)[None]
    # Don't store intermediate results because we will not backprop
    with torch.no_grad():
        for _ in range(n_tokens):
            # Grab logits for the last position and apply temperature
            logits = model(idx)[:, -1] / temperature
            # Sample from multinomial distribution
            next_id = torch.multinomial(F.softmax(logits, dim=-1), 1)
            # Append to running generation
            idx = torch.cat([idx, next_id], dim=1)
    # Convert results back to bytes
    return bytes(idx.squeeze().tolist())
    
    

# Train

In [10]:
from torch.cuda.amp import GradScaler, autocast

device = 'cuda'
model_gen = lambda: CTransformer(
    vocab_size=vocab_size,
    ctx_len=ctx_len,
    d_model=256,
    n_layers=12,
    n_heads=8
).to(device)


In [11]:
# Instantiate model, or load checkpoint
model = model_gen()
train_loss_hist = []
val_loss_hist = []
tokens_seen = 0

# loading block
load_path = 'models/ctransformer_enwik9_2025-07-27_232950.pt'
checkpoint = torch.load(load_path, map_location=device)
model.load_state_dict(checkpoint['weights'])
train_loss_hist = checkpoint['train_loss_hist']
val_loss_hist = checkpoint['val_loss_hist']
tokens_seen = checkpoint['tokens_seen']

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30_000, eta_min=1e-5)
scaler = GradScaler()

  scaler = GradScaler()


In [12]:
len(train_loss_hist)

150000

In [13]:
checkpoint['tokens_seen']

7372800000

In [21]:
np.exp(val_loss_hist)

array([2.67481427, 2.55015864, 2.50768377, 2.52870738, 2.55466045,
       2.5462475 , 2.49911525, 2.42671841, 2.39561883, 2.41323156,
       2.46038368, 2.47579492, 2.44270166, 2.38764034, 2.35806139])

In [None]:
# main training loop
best_val = float('inf')

train_steps_per_cycle = 10_000
val_steps_per_cycle = 1_000
tokens_seen = 0

try:
    for cycle in range(5_000):
        print(f'Cycle {cycle:>6}')
        train_loss_hist += run_epoch(model, train_loader, train_steps_per_cycle, train=True)
        train_loss = train_loss_hist[-1]
        tokens_seen += batch_size * ctx_len * train_steps_per_cycle
        
        val_loss = run_epoch(model, val_loader, val_steps_per_cycle, train=False)[-1] 
        val_loss_hist.append(val_loss)
        print(f'cycle {cycle:03d} train {train_loss:.4f} val {val_loss:.4f}')
        
        if val_loss < best_val:
            best_val = val_loss 
      
            timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S")
            save_path = f'models/ctransformer_enwik9_{timestamp}.pt'
            torch.save({
                'train_loss_hist': train_loss_hist,
                'val_loss_hist': val_loss_hist, 
                'tokens_seen': tokens_seen,
                'weights': model.state_dict()
            }, save_path)
            print(f'checkpoint saved: {save_path}')
            print(f'{val_loss_hist=}')
            print(sample(model, list(b"The "), 200).decode("ascii", "ignore"))

except KeyboardInterrupt:
    # Save checkpoint
    timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S")
    save_path = f'models/ctransformer_enwik9_interrupt_{timestamp}.pt'
    torch.save({
        'train_loss_hist': train_loss_hist,
        'val_loss_hist': val_loss_hist, 
        'tokens_seen': tokens_seen,
        'weights': model.state_dict()
    }, save_path)
    
    # Save loss histories
    with open('val_loss_hist.txt', "w", encoding="utf-8") as f:
        f.write("\n".join(map(str, val_loss_hist)))
    with open('train_loss_hist.txt', "w", encoding="utf-8") as f:
        f.write("\n".join(map(str, train_loss_hist)))
        
    print(f'\nInterrupted — checkpoint saved to {save_path}')
    raise  # re‑raise so Jupyter stops execution

except Exception as e:
    # save loss curve then raise again lol
    with open('val_loss_hist.txt', "w", encoding="utf-8") as f:
        f.write("\n".join(map(str, val_loss_hist)))
    with open('train_loss_hist.txt', "w", encoding="utf-8") as f:
        f.write("\n".join(map(str, train_loss_hist)))
    raise

# Sample

In [19]:
print(sample(model, list(b"OpenAI is"), 200).decode("ascii", "ignore"))

OpenAI is a &quot;distorted biography&quot; dispute that includes [[SoundExperience]] and [[Ashburner's Institute]] obsolete.

A 2001 soundtrack, the star in [[London]] appears on-dip-script, briefly turning S


In [18]:
print(sample(model, list(b"Traditional film school "), 200).decode("ascii", "ignore"))

Traditional film school founded by Mabbuja Steel in April [[2002]]. The large motets Jacksonville had it to be  read in [[Lonely River (Manila, California)|Lonely River]]. The Leadership motets Jacksonvilli as dedicated to t
