In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()     
        self.layer_idx = layer_idx
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0
        self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)

    def forward(self, x):
        B, T, C =  x.size()

        q = self.c_q(x).view(B, T, self.n_head, self.head_dim) # B, T, H, D
        k = self.c_k(x).view(B, T, self.n_head, self.head_dim) # B, T, H, D
        v = self.c_v(x).view(B, T, self.n_head, self.head_dim) # B, T, H, D

        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)  # B, H, T, D

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # B, H, T, D

        y = y.transpose(1, 2).contiguous().view(B, T, -1) # B, H, T, D -> B, T, H, D -> B, T, C

        y = self.c_proj(y) # B, T, C
        
        return y

In [3]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4* config.n_embd, bias=False)
        self.c_proj = nn.Linear(4*config.n_embd, config.n_embd, bias=False)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()
        x = self.c_proj(x)
        return x


In [4]:
def norm(x):
    return F.rms_norm(x, (x.size(-1),))

In [5]:
class Block(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.attn = CausalSelfAttention(config, layer_idx)
        self.mlp = MLP(config)
    
    def forward(self, x):
        x = x + self.attn(norm(x))
        x = x + self.mlp(norm(x))
        return x

In [6]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict({
            "wte": nn.Embedding(config.vocab_size, config.n_embd),
            "pte": nn.Embedding(config.sequence_len, config.n_embd),
            "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)])
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        head_size = config.n_embd // config.n_head
    
    def forward(self, idx, targets=None, loss_reduction='mean'):
        B,T = idx.size()
        x = self.transformer.wte(idx) + self.transformer.pte(torch.arange(T, device=idx.device))
        x = norm(x)
        for block in self.transformer.h:
            x = block(x)
        
        x = norm(x)
        if targets is not None:
            logits = self.lm_head(x)
            logits = logits.float()
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
            return loss, logits

        else:
            logits = self.lm_head(x)
            return logits
    
    def generate(self, idx, seq_len):
        for _ in range(seq_len):
            # Crop context if it exceeds sequence_len
            idx_crop = idx if idx.size(1) <= self.config.sequence_len else idx[:, -self.config.sequence_len:]
            logits = self(idx_crop)
            last_idx = logits[:, -1, :]
            last_token = F.softmax(last_idx, dim=-1)
            idx_next = torch.multinomial(last_token, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [7]:
# Data loading from shakespeare dataset

# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-10-31 01:51:24--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2025-10-31 01:51:25 (29.7 MB/s) - ‘input.txt.1’ saved [1115394/1115394]

200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2025-10-31 01:51:25 (29.7 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [8]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [9]:
# Character level coding

In [10]:
vocab_size = len(set(text))

train_size = int(0.9 * len(text))

In [11]:
encode = {}
decode = {}
for i, t in enumerate(set(text)):
    encode[t] = i
    decode[i] = t

encoder = lambda text: [encode[x] for x in text]
decoder = lambda ids: ''.join(decode[i] for i in ids)

data = encoder(text)
train_data = data[:train_size]
val_data = data[train_size+1:]

In [12]:
# Byte Pair Encoding (BPE) Implementation

In [13]:
# Convert text to bytes for BPE
text_bytes = text.encode("utf-8")
ids = list(text_bytes)

from collections import defaultdict

def get_stats(ids):
    counts = {}
    for i, j in zip(ids, ids[1:]):
        counts[(i, j)] = counts.get((i, j), 0) + 1
    return counts

def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i + 1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

In [14]:
# Train BPE tokenizer
vocab_size_bpe = 276  # 256 bytes + 20 merges
num_merges = vocab_size_bpe - 256
merges = {}

for i in range(num_merges):
    stats = get_stats(ids)
    if not stats:
        break
    # Find the most frequent pair
    pair = max(stats, key=stats.get)
    idx = 256 + i  # new token index
    print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

print(f"Compression ratio: {len(text_bytes) / len(ids):.2f}")
print(f"Final vocab size: {vocab_size_bpe}")
print(f"Number of merges: {len(merges)}")

merging (101, 32) into a new token 256
merging (116, 104) into a new token 257
merging (116, 104) into a new token 257
merging (116, 32) into a new token 258
merging (116, 32) into a new token 258
merging (115, 32) into a new token 259
merging (115, 32) into a new token 259
merging (100, 32) into a new token 260
merging (100, 32) into a new token 260
merging (44, 32) into a new token 261
merging (44, 32) into a new token 261
merging (111, 117) into a new token 262
merging (111, 117) into a new token 262
merging (101, 114) into a new token 263
merging (101, 114) into a new token 263
merging (105, 110) into a new token 264
merging (105, 110) into a new token 264
merging (121, 32) into a new token 265
merging (121, 32) into a new token 265
merging (97, 110) into a new token 266
merging (97, 110) into a new token 266
merging (58, 10) into a new token 267
merging (58, 10) into a new token 267
merging (111, 114) into a new token 268
merging (111, 114) into a new token 268
merging (111, 32) i

In [15]:
# Create vocabulary mapping for BPE
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode_bpe(ids):
    # Convert tokens back to bytes, then to string
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8", errors="replace")
    return text

def encode_bpe(text):
    # Convert text to bytes, then apply merges
    tokens = list(text.encode("utf-8"))
    while len(tokens) >= 2:
        stats = get_stats(tokens)
        pair = min(stats, key=lambda pair: merges.get(pair, float("inf")))
        if pair not in merges:
            break
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens

# Test the BPE tokenizer
test_text = "hello world! this is a test."
encoded = encode_bpe(test_text)
decoded = decode_bpe(encoded)
print(f"Original: {test_text}")
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")
print(f"Round-trip successful: {test_text == decoded}")

# Encode the full text with BPE
bpe_data = encode_bpe(text)
train_size_bpe = int(0.9 * len(bpe_data))
train_data_bpe = bpe_data[:train_size_bpe]
val_data_bpe = bpe_data[train_size_bpe:]

print(f"Original text length: {len(text)}")
print(f"BPE encoded length: {len(bpe_data)}")
print(f"Compression ratio: {len(text) / len(bpe_data):.2f}")

Original: hello world! this is a test.
Encoded: [104, 101, 275, 269, 119, 268, 108, 100, 33, 273, 105, 259, 105, 259, 97, 32, 116, 101, 115, 116, 46]
Decoded: hello world! this is a test.
Round-trip successful: True
Original text length: 1115394
BPE encoded length: 882737
Compression ratio: 1.26
Original text length: 1115394
BPE encoded length: 882737
Compression ratio: 1.26


In [16]:
block_size = 256
batch_size = 64

import torch
torch.manual_seed(1443)

def get_data(split='train', device='cuda', use_bpe=True):
    if use_bpe:
        data = train_data_bpe if split == 'train' else val_data_bpe
    else:
        data = train_data if split == 'train' else val_data
    
    ix = torch.randint(len(data)-block_size, (batch_size,))
    x = torch.stack([torch.tensor(data[i:i+block_size]) for i in ix])
    y = torch.stack([torch.tensor(data[i+1:i+block_size+1]) for i in ix])
    if device == 'cuda':
        return x.cuda(), y.cuda()
    return x, y

train_x, train_y = get_data('train', use_bpe=True)

In [17]:
from dataclasses import dataclass

@dataclass
class GPTConfig:
    sequence_len: int = block_size
    vocab_size: int = vocab_size_bpe  # Use BPE vocab size
    n_layer: int = 12
    n_head: int = 6
    n_kv_head: int = 6
    n_embd: int = 768

In [18]:

model = GPT(GPTConfig)
model = model.cuda()
optim = torch.optim.AdamW(model.parameters(),lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim, T_0=20)

In [19]:
n_batches = len(bpe_data)//batch_size
n_epochs = 10
seq_len = 100

losses = {}

for epoch in range(n_epochs):
    for batch in range(100):
        optim.zero_grad(set_to_none=True)
        train_x, train_y = get_data('train', device='cuda', use_bpe=True)
        loss, logits = model.forward(train_x, train_y)
        loss.backward()
        optim.step()
        scheduler.step()
        print(f'Epoch: {epoch}, Batch: {batch}, Loss: {loss}')
        losses[epoch] = losses.get(epoch, []) + [loss.item()]
    
    # Generate text using BPE decoder
    generated_ids = model.generate(idx=torch.zeros((1, 1), dtype=torch.long).cuda(), seq_len=100)[0].tolist()
    generated_text = decode_bpe(generated_ids)
    print(f"Generated text: {generated_text[:200]}...")  # Show first 200 chars

Epoch: 0, Batch: 0, Loss: 5.8414177894592285
Epoch: 0, Batch: 1, Loss: 4.737900733947754
Epoch: 0, Batch: 1, Loss: 4.737900733947754
Epoch: 0, Batch: 2, Loss: 4.426560401916504
Epoch: 0, Batch: 2, Loss: 4.426560401916504
Epoch: 0, Batch: 3, Loss: 4.261688232421875
Epoch: 0, Batch: 3, Loss: 4.261688232421875
Epoch: 0, Batch: 4, Loss: 4.116363048553467
Epoch: 0, Batch: 4, Loss: 4.116363048553467
Epoch: 0, Batch: 5, Loss: 4.066916465759277
Epoch: 0, Batch: 5, Loss: 4.066916465759277
Epoch: 0, Batch: 6, Loss: 4.019004821777344
Epoch: 0, Batch: 6, Loss: 4.019004821777344
Epoch: 0, Batch: 7, Loss: 3.9871513843536377
Epoch: 0, Batch: 7, Loss: 3.9871513843536377
Epoch: 0, Batch: 8, Loss: 3.9429831504821777
Epoch: 0, Batch: 8, Loss: 3.9429831504821777
Epoch: 0, Batch: 9, Loss: 3.856572151184082
Epoch: 0, Batch: 9, Loss: 3.856572151184082
Epoch: 0, Batch: 10, Loss: 3.8263816833496094
Epoch: 0, Batch: 10, Loss: 3.8263816833496094
Epoch: 0, Batch: 11, Loss: 3.805037021636963
Epoch: 0, Batch: 11, L

In [20]:
# Generate longer text with BPE
generated_ids = model.generate(idx=torch.zeros((1, 1), dtype=torch.long).cuda(), seq_len=500)[0].tolist()
generated_text = decode_bpe(generated_ids)
print(generated_text)

 OPEPSONERLE:
Ay, what she we will the would Warwick;
Maday, capter'd to shall from be aunded how,
Fail drun no black to addvicts,
And now one born Catiol dispation,
Would she Is done thy like agbod
Mearten opinacly gapers.

DUKE VINCENTIO:
Call with is breasing to usurp, inderise.

JULIET:
You shall I long.

DUKE VINCENTIO:
A riclander; as him good, and the selff;
Were must one ve: proced is triquest.
He must to be dearth me dream traibleer
His came to admend, and for what with dok-morrow
Do the love and bear-destravam stabs.

DUKE VINCENTIO:
On, my lord, my lord from now; brother, in the soul any mole
Yet would 
