# mini-Transformer (from scratch)
$\textbf{Goal:}$ implement a tiny encoder-only Transformer and train it on character data (no external libs beyond PyTorch). You will learn tokenization (chars), attention, causal masks, training loop, sampling.

In [1]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import warnings
warnings.filterwarnings('ignore')

DEVICE = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

BLOCK_SIZE = 8
BATCH_SIZE = 32
LR         = 3e-4
MLM_PROB   = 0.1
SEED       = 42
torch.manual_seed(SEED)

<torch._C.Generator at 0x11470cf10>

In [2]:
text = """
Rumi was born to Persian parents in Balkh modern-day Afghanistan or Wakhsh a village 
on the East bank of the Wakhsh River known as Sangtuda in present-day Tajikistan.
The area, culturally adjacent to Balkh, is where Mawlânâ's father, Bahâ' uddîn Walad, 
was a preacher and jurist. He lived and worked there until 1212, when Rumi was aged 
around five and the family moved to Samarkand. Greater Balkh was at that time a major 
centre of Persian culture and Sufism had developed there for several centuries. The 
most important influences upon Rumi, besides his father, were the Persian poets Attar 
and Sanai. Rumi expresses his appreciation: "Attar was the spirit, Sanai his eyes twain, 
And in time thereafter, Came we in their train" and mentions in another poem: "Attar has 
traversed the seven cities of Love, We are still at the turn of one street". His father 
was also connected to the spiritual lineage of Najm al-Din Kubra. Rumi lived most of his 
life under the Persianate Seljuk Sultanate of Rum, where he produced his works and died 
in 1273 AD. He was buried in Konya, and his shrine became a place of pilgrimage. Upon his 
death, his followers and his son Sultan Walad founded the Mevlevi Order, also known as 
the Order of the Whirling Dervishes, famous for the Sufi dance known as the Sama ceremony. 
He was laid to rest beside his father, and over his remains a shrine was erected. A hagiographical 
account of him is described in Shams ud-Din Ahmad Aflāki's Manāqib ul-Ārifīn (written between 1318 
and 1353). This biography needs to be treated with care as it contains both legends and facts about 
Rumi. For example, Professor Franklin Lewis of the University of Chicago, author of the most complete 
biography on Rumi, has separate sections for the hagiographical biography of Rumi and the actual 
biography about him.
"""

In [3]:
# prepare dictionary and encoding function

# remove newlines
text = text.replace('\n', ' ')

# convert text to characters
words = text.split()

# size of vocabulary
vocab = list(set(words))
vocab_size = len(vocab)

# string to integer
stoi = {c: i for i, c in enumerate(vocab)}

stoi['[MASK]'] = len(stoi)
mask_token_id = stoi['[MASK]']
vocab_size = len(stoi)

# encode
encode = lambda s: torch.tensor([stoi[c] for c in s.split()], dtype=torch.long)

In [4]:
def mask_tokens(inputs, vocabsize, mask_token_id, mlm_prob=0.15):
    labels = inputs.clone()    # create a copy of the inputs
    probability_matrix = torch.full(labels.shape, mlm_prob)      # based on the masked language model (MLM)
    mask_indices = torch.bernoulli(probability_matrix).bool()    # indices we aim to mask by mask_token, e.e, [MASK]

    labels[~mask_indices] = -100                                 # ignore non-masked tokens in loss

    # replacing (80/10/10)
    # ~80% replace with mask_token
    # ~10% replace with random token
    # ~10% unchanged

    # replace with mask_token
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & mask_indices
    inputs[indices_replaced] = mask_token_id

    # replace with random token
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & mask_indices & ~indices_replaced
    random_tokens = torch.randint(vocabsize, size=(indices_random.sum().item(),), dtype=torch.long)
    inputs[indices_random] = random_tokens

    return inputs, labels

In [5]:
data = encode(text)
def get_batch(mask_token_id, vocabsize, batch_size=32, block_size=8):
    ix = torch.randint(0, len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[i: i+block_size] for i in ix])             # (batch_size, block_size)
    y = torch.stack([data[i+1: i+1+block_size] for i in ix])         # (batch_size, block_size)
    x, y = mask_tokens(x.clone(), vocabsize, mask_token_id)

    return x.to(DEVICE), y.to(DEVICE)

### Model

In [6]:
class Head(nn.Module):
    def __init__(self, n_embed, head_size, dropout):
        super(Head, self).__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, T, C = x.shape
        k, q, v = self.key(x), self.query(x), self.value(x)
        att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        out = att @ v
        
        return out

In [7]:
class MultiHead(nn.Module):
    def __init__(self, n_embed, n_head, dropout):
        super(MultiHead, self).__init__()
        head_size = n_embed // n_head
        self.heads = nn.ModuleList([Head(n_embed, head_size, dropout) for _ in range(n_head)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)   # each h: (B, T, head_size) -> concat: (B, T, n_head*head_size)
        proj = self.proj(out)                                 # n_head*head_size = embed_dim -> project them to embed_dim
        return self.dropout(proj)

In [8]:
class FeedForward(nn.Module):
    def __init__(self, n_embed, dropout):
        super(FeedForward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4*n_embed),
            nn.GELU(),
            nn.Linear(4*n_embed, n_embed),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [9]:
class Block(nn.Module):
    def __init__(self, n_embed, n_head, dropout):
        super(Block, self).__init__()
        self.mh = MultiHead(n_embed, n_head, dropout)
        self.ff = FeedForward(n_embed, dropout)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x_p = self.ln1(x)
        x = x + self.mh(x_p)
        x_p = self.ln2(x)
        x = x + self.ff(x_p)
        
        return x

In [10]:
class TinyBERT(nn.Module):
    def __init__(self, vocab_size, block_size, n_embed=128, n_head=4, n_layer=4, dropout=0.1):
        super(TinyBERT, self).__init__()
        self.token_embed = nn.Embedding(vocab_size, n_embed)
        self.pos_embed = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(
            *[Block(n_embed, n_head, dropout) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embed)
        self.mlm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, labels=None):
        B, T = idx.shape
        token = self.token_embed(idx)
        pos = self.pos_embed(torch.arange(T, device=idx.device))
        x = token + pos
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.mlm_head(x)
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
            
        return logits, loss

In [11]:
model = TinyBERT(vocab_size, BLOCK_SIZE).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LR)

### Train

In [12]:
num_epochs = 1000
for epoch in range(num_epochs):
    x_batch, y_batch = get_batch(mask_token_id, vocab_size, batch_size=BATCH_SIZE, block_size=BLOCK_SIZE)
    logits, loss = model(x_batch, y_batch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    mask = y_batch != -100
    preds = torch.argmax(logits, dim=-1)
    correct = (preds[mask] == y_batch[mask]).sum().item()
    acc = correct / mask.sum().item()

    if (epoch + 1) % 100 == 0:
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item(): .4f}, Accuracy: {acc * 100: .2f}")

Epoch 100/1000, Loss:  4.8240, Accuracy:  5.88
Epoch 200/1000, Loss:  3.9187, Accuracy:  16.67
Epoch 300/1000, Loss:  3.5331, Accuracy:  30.43
Epoch 400/1000, Loss:  2.6440, Accuracy:  46.81
Epoch 500/1000, Loss:  2.0034, Accuracy:  70.00
Epoch 600/1000, Loss:  1.5215, Accuracy:  65.96
Epoch 700/1000, Loss:  1.3033, Accuracy:  74.00
Epoch 800/1000, Loss:  1.0863, Accuracy:  86.67
Epoch 900/1000, Loss:  0.6710, Accuracy:  91.89
Epoch 1000/1000, Loss:  0.6087, Accuracy:  95.83


## Evaluation

In [13]:
from collections import Counter
import math

def evaluate_model(model, data, orig_data=None, mask_flag=True, block_size=32, batch_size=32, mask_token_id=None, vocab_size=None):
    """
    Evaluates TinyBERT (MLM) on a dataset and computes:
    - Cross-Entropy Loss
    - Perplexity (PPL)
    - Masked Token Accuracy
    - Bits Per Character (optional)
    """

    model.eval()
    losses = []
    correct, total = 0, 0

    with torch.no_grad():
        for i in range(0, len(data) - block_size - 1, batch_size):
            x = torch.stack([
                data[j:j+block_size] 
                for j in range(i, min(i+batch_size, len(data)-block_size-1))
            ]).clone()

            # mask tokens
            if mask_flag:
                x_true = x.clone()
                inputs, labels = mask_tokens(x, vocab_size, mask_token_id)
            else:
                x_true = torch.stack([
                    orig_data[j:j+block_size] 
                    for j in range(i, min(i+batch_size, len(data)-block_size-1))
                ]).clone()
                inputs = x.clone()
                labels = x.clone()
                labels[labels != mask_token_id] = -100
                
            mask = labels != -100
            if mask.sum().item() == 0:
                continue
                
            x_true, inputs, labels = x_true.to(DEVICE), inputs.to(DEVICE), labels.to(DEVICE)

            logits, loss = model(inputs, x_true)
            losses.append(loss.item())

            # compute accuracy only on masked tokens
            preds = torch.argmax(logits, dim=-1)
            correct += (preds[mask] == x_true[mask]).sum().item()
            total += mask.sum().item()

    avg_loss = sum(losses) / len(losses)
    perplexity = math.exp(avg_loss)
    accuracy = correct / total

    model.train()

    return {
        "CrossEntropyLoss": avg_loss,
        "Perplexity": perplexity,
        "MaskedAccuracy": accuracy
    }

In [14]:
orig_text = """
Rumi was born to Persian parents in Balkh modern-day Afghanistan or Wakhsh a village 
on the East bank of the Wakhsh River known as Sangtuda in present-day Tajikistan.
"""
orig_data = encode(orig_text)

eval_text = """
Rumi was born to [MASK] parents in Balkh modern-day Afghanistan or Wakhsh a village 
on the East bank of the [MASK] River known as Sangtuda in present-day Tajikistan.
"""

encode_eval = lambda s: torch.tensor([stoi.get(w, mask_token_id) for w in s.split()], dtype=torch.long)
eval_data = encode_eval(eval_text)

metrics = evaluate_model(model,
                         eval_data,
                         orig_data,
                         mask_flag=False,
                         block_size=BLOCK_SIZE, 
                         batch_size=BATCH_SIZE,
                         mask_token_id=mask_token_id,
                         vocab_size=vocab_size)

print("\n=== Evaluation Results ===")
print(f"Cross-Entropy Loss: {metrics['CrossEntropyLoss']:.4f}")
print(f"Perplexity (PPL)  : {metrics['Perplexity']:.2f}")
print(f"Masked Accuracy   : {metrics['MaskedAccuracy']*100:.2f}%")


=== Evaluation Results ===
Cross-Entropy Loss: 0.2245
Perplexity (PPL)  : 1.25
Masked Accuracy   : 100.00%
