# mini-Transformer (from scratch)
$\textbf{Goal:}$ implement a tiny decoder-only Transformer and train it on character data (no external libs beyond PyTorch). You will learn tokenization (chars), attention, causal masks, training loop, sampling.

In [1]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import warnings
warnings.filterwarnings('ignore')

device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

BLOCK_SIZE = 64
BATCH_SIZE = 64
LR         = 3e-4
SEED       = 42
torch.manual_seed(SEED)

<torch._C.Generator at 0x111a0ceb0>

In [2]:
text = """
Jalāl al-Dīn Muḥammad Rūmī, or simply Rumi, was a 13th-century poet, Hanafi 
faqih, Maturidi theologian, and Sufi mystic born during the Khwarazmian Empire. 
Rumi's works are written in his mother tongue, Persian. He occasionally used the 
Arabic language and single Turkish and Greek words in his verse."""

# convert text to characters
chars = list(set(text))

# string to integer
stoi = {c: i for i, c in enumerate(chars)}

# integer to string
itos = {i: c for c, i in stoi.items()}

# encode
encode = lambda s: torch.tensor([stoi[c] for c in s], dtype=torch.long)

# decode
decode = lambda t: ''.join(itos[int(i)] for i in t)

In [3]:
data = encode(text)
def get_batch(block_size=16, batch_size=32):
    ix = torch.randint(0, len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[i: i+block_size] for i in ix])             # (batch_size, block_size)
    y = torch.stack([data[i+1: i+1+block_size] for i in ix])         # (batch_size, block_size)

    return x.to(device), y.to(device)

### Model

In [4]:
class Head(nn.Module):
    def __init__(self, n_embed, head_size, block_size, dropout):
        super(Head, self).__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size))) # lower triangle (masking)

    def forward(self, x):
        B, T, C = x.shape
        k, q, v = self.key(x), self.query(x), self.value(x)
        att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
        att = att.masked_fill(self.tril[:T, :T]==0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        out = att @ v
        
        return out

In [5]:
class MultiHead(nn.Module):
    def __init__(self, n_embed, n_head, block_size, dropout):
        super(MultiHead, self).__init__()
        head_size = n_embed // n_head
        self.heads = nn.ModuleList([Head(n_embed, head_size, block_size, dropout) for _ in range(n_head)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)   # each h: (B, T, head_size) -> concat: (B, T, n_head*head_size)
        proj = self.proj(out)                                 # n_head*head_size = embed_dim -> project them to embed_dim
        return self.dropout(proj)

In [6]:
class FeedForward(nn.Module):
    def __init__(self, n_embed, dropout):
        super(FeedForward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4*n_embed),
            nn.GELU(),
            nn.Linear(4*n_embed, n_embed),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [7]:
class Block(nn.Module):
    def __init__(self, n_embed, n_head, block_size, dropout):
        super(Block, self).__init__()
        self.mh = MultiHead(n_embed, n_head, block_size, dropout)
        self.ff = FeedForward(n_embed, dropout)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x_p = self.ln1(x)
        x = x + self.mh(x_p)
        x_p = self.ln2(x)
        x = x + self.ff(x_p)
        
        return x

In [8]:
class TinyGPT(nn.Module):
    def __init__(self, vocab_size, block_size, n_embed=128, n_head=4, n_layer=4, dropout=0.1):
        super(TinyGPT, self).__init__()
        self.token_embed = nn.Embedding(vocab_size, n_embed)
        self.pos_embed = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(
            *[Block(n_embed, n_head, block_size, dropout) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token = self.token_embed(idx)
        pos = self.pos_embed(torch.arange(T, device=idx.device))
        x = token + pos
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            
        return logits, loss

In [9]:
vocab_size = len(chars)
model = TinyGPT(vocab_size, BLOCK_SIZE).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LR)

### Train

In [10]:
num_epochs = 1000
for epoch in range(num_epochs):
    x_batch, y_batch = get_batch(block_size=BLOCK_SIZE, batch_size=BATCH_SIZE)
    _, loss = model(x_batch, y_batch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item(): .4f}")

Epoch 100/1000, Loss:  1.6484
Epoch 200/1000, Loss:  0.4358
Epoch 300/1000, Loss:  0.1481
Epoch 400/1000, Loss:  0.0896
Epoch 500/1000, Loss:  0.0778
Epoch 600/1000, Loss:  0.0650
Epoch 700/1000, Loss:  0.0572
Epoch 800/1000, Loss:  0.0520
Epoch 900/1000, Loss:  0.0503
Epoch 1000/1000, Loss:  0.0469


### Generation

In [11]:
def generate(prompt="Jalāl ", block_size=16, max_new_tokens=400, temperature=0.8, top_k=50):
    idx = encode(prompt).unsqueeze(0).to(device)
    with torch.no_grad():
        for _ in range(max_new_tokens):
            logits, _ = model(idx[:, -block_size:])
            logits = logits[:, -1, :] / temperature
            if top_k:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_id], dim=1)
            
    return decode(idx[0].cpu())
    
print(generate(block_size=BLOCK_SIZE))

Jalāl al-Dīn Muḥammad Rūmī, or simply Rumi, was a 13th-century poet, Hanafi 
faqih, Maturidi theologian, and Sufi mystic born during the Khwarazmian Empire. 
Rumi's works are written in his mother tongue, Persian. He occasionally used the 
Arabic language and single Turkish and Greek words in his verserserdsiothersiangue, Persian. He occasionally used the 
Arabic language and single Turkish and Greek wo


## Evaluation

In [12]:
def distinct_n_score(text, n=1):
    tokens = list(text)
    if len(tokens) < n:
        return 0.0
    ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
    
    return len(set(ngrams)) / len(ngrams)

In [13]:
from collections import Counter

def evaluate_model(model, data, block_size=8, batch_size=8):
    """
    Evaluates TinyGPT on a dataset and computes several metrics:
    - Cross-Entropy Loss
    - Perplexity (PPL)
    - Bits Per Character (BPC)
    - Next-character Accuracy
    - Distinct-1 / Distinct-2 (diversity of generated text)
    """

    model.eval()
    losses = []
    correct = 0
    total = 0

    # Break data into batches for evaluation
    with torch.no_grad():
        # x_batch, y_batch = get_batch(batch_size=batch_size)
        # for x, y in zip(x_batch, y_batch):
        for i in range(0, len(data) - block_size - 1, batch_size):
            x = torch.stack([data[j:j+block_size] for j in range(i, min(i+batch_size, len(data) - block_size - 1))])
            y = torch.stack([data[j+1:j+1+block_size] for j in range(i, min(i+batch_size, len(data) - block_size - 1))])

            x, y = x.to(device), y.to(device)
            logits, loss = model(x, y)
            losses.append(loss.item())

            # accuracy (next-character prediction)
            preds = torch.argmax(logits, dim=-1)
            correct += (preds == y).sum().item()
            total += y.numel()

    # average cross-entropy loss
    avg_loss = sum(losses) / len(losses)
    
    # perplexity
    perplexity = math.exp(avg_loss)
    
    # bits per character (BPC)
    bpc = avg_loss / math.log(2)
    
    # next-character prediction accuracy
    accuracy = correct / total

    # generate text for diversity metrics
    generated_text = generate(prompt="Jalāl ", max_new_tokens=400)
    distinct1 = distinct_n_score(generated_text, 1)
    distinct2 = distinct_n_score(generated_text, 2)

    model.train()

    return {
        "CrossEntropyLoss": avg_loss,
        "Perplexity": perplexity,
        "BitsPerCharacter": bpc,
        "Accuracy": accuracy,
        "Distinct-1": distinct1,
        "Distinct-2": distinct2,
        "GeneratedSample": generated_text
    }

In [14]:
eval_data = encode(text)

metrics = evaluate_model(model, eval_data, block_size=BLOCK_SIZE, batch_size=BATCH_SIZE)

print("\n=== Evaluation Results ===")
print(f"Cross-Entropy Loss: {metrics['CrossEntropyLoss']:.4f}")
print(f"Perplexity (PPL):   {metrics['Perplexity']:.2f}")
print(f"Bits Per Char:      {metrics['BitsPerCharacter']:.3f} bits")
print(f"Accuracy:           {metrics['Accuracy']*100:.2f}%")
print(f"Distinct-1:         {metrics['Distinct-1']:.3f}")
print(f"Distinct-2:         {metrics['Distinct-2']:.3f}")
print("\n=== Generated Sample ===\n")
print(metrics["GeneratedSample"])


=== Evaluation Results ===
Cross-Entropy Loss: 0.0399
Perplexity (PPL):   1.04
Bits Per Char:      0.058 bits
Accuracy:           98.48%
Distinct-1:         0.113
Distinct-2:         0.432

=== Generated Sample ===

Jalāl al-Dīn Muḥammad Rūmī, or simply Rumi, was a 13th-century poet, Hanafi 
faqih, Maturidi theologian, and Sufi mystic born during the Khwarazmian Empire. 
Rumi's works are written in his mother tongue, Persian. He occasionally used the 
Arabic language and single Turksh are waritten in his mother tongue, Persian. He occasionally used the 
Arabic language and Sufi mystic born during the Khwarazmian Em
