# ðŸ“š BERT: Bidirectional Encoder Representations

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/transformer/transformer_architectures/02_bert/demo.ipynb)

![Architecture](architecture.png)

### Key Innovations
- **Bidirectional**: See left AND right context
- **MLM**: Masked Language Modeling pretraining
- **NSP**: Next Sentence Prediction

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## BERT Architecture

In [None]:
class BertEmbeddings(nn.Module):
    """BERT has 3 embeddings: Token + Position + Segment"""
    def __init__(self, vocab_size, d_model, max_len=512, n_segments=2, dropout=0.1):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.position_embed = nn.Embedding(max_len, d_model)
        self.segment_embed = nn.Embedding(n_segments, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, tokens, segments=None):
        seq_len = tokens.size(1)
        positions = torch.arange(seq_len, device=tokens.device).unsqueeze(0)
        
        if segments is None:
            segments = torch.zeros_like(tokens)
        
        x = self.token_embed(tokens) + self.position_embed(positions) + self.segment_embed(segments)
        return self.dropout(self.norm(x))

class BertAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_qkv = nn.Linear(d_model, 3 * d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        B, L, D = x.shape
        qkv = self.W_qkv(x).reshape(B, L, 3, self.n_heads, self.d_k).permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = self.dropout(F.softmax(scores, dim=-1))
        out = torch.matmul(attn, V).transpose(1, 2).reshape(B, L, D)
        return self.W_o(out), attn

class BertBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = BertAttention(d_model, n_heads, dropout)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        attn_out, _ = self.attn(x, mask)
        x = self.norm1(x + self.dropout(attn_out))
        x = self.norm2(x + self.ffn(x))
        return x

class BERT(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=4, n_layers=4, d_ff=512, dropout=0.1):
        super().__init__()
        self.embed = BertEmbeddings(vocab_size, d_model, dropout=dropout)
        self.blocks = nn.ModuleList([BertBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.mlm_head = nn.Linear(d_model, vocab_size)
        self.cls_head = nn.Linear(d_model, 2)  # For NSP
    
    def forward(self, tokens, segments=None, mask=None):
        x = self.embed(tokens, segments)
        for block in self.blocks:
            x = block(x, mask)
        return x
    
    def mlm_forward(self, tokens, segments=None, mask=None):
        x = self.forward(tokens, segments, mask)
        return self.mlm_head(x)
    
    def cls_forward(self, tokens, segments=None, mask=None):
        x = self.forward(tokens, segments, mask)
        return self.cls_head(x[:, 0])  # [CLS] token

model = BERT(vocab_size=5000, d_model=128, n_heads=4, n_layers=3)
print(f'BERT Parameters: {sum(p.numel() for p in model.parameters()):,}')

## Masked Language Modeling (MLM)

In [None]:
def create_mlm_data(tokens, vocab_size, mask_prob=0.15, mask_token=4):
    """
    Create MLM training data:
    - 80%: Replace with [MASK]
    - 10%: Replace with random token
    - 10%: Keep original
    """
    labels = tokens.clone()
    masked_tokens = tokens.clone()
    
    # Create mask (don't mask special tokens 0-4)
    mask_candidates = (tokens > 4).float()
    mask_probs = torch.rand_like(tokens.float()) * mask_candidates
    mask_positions = mask_probs < mask_prob
    
    # Set labels for non-masked positions to -100 (ignore)
    labels[~mask_positions] = -100
    
    # Apply masking strategy
    rand = torch.rand_like(tokens.float())
    
    # 80% -> [MASK]
    mask_token_positions = mask_positions & (rand < 0.8)
    masked_tokens[mask_token_positions] = mask_token
    
    # 10% -> random token
    random_positions = mask_positions & (rand >= 0.8) & (rand < 0.9)
    masked_tokens[random_positions] = torch.randint(5, vocab_size, (random_positions.sum(),))
    
    # 10% -> keep original (do nothing)
    
    return masked_tokens, labels

# Test MLM
tokens = torch.randint(5, 100, (2, 10))
masked, labels = create_mlm_data(tokens, vocab_size=100)
print('Original:', tokens[0].tolist())
print('Masked:  ', masked[0].tolist())
print('Labels:  ', labels[0].tolist())

## Training on Tiny Dataset

In [None]:
# Create tiny text dataset
class TinyTextDataset:
    def __init__(self, vocab_size=1000, size=2000, max_len=32):
        self.vocab_size = vocab_size
        self.max_len = max_len
        # Generate random sequences (simulating tokenized text)
        self.data = [torch.randint(5, vocab_size, (np.random.randint(10, max_len),)) for _ in range(size)]
    
    def get_batch(self, batch_size):
        indices = np.random.choice(len(self.data), batch_size)
        batch = [self.data[i] for i in indices]
        
        # Pad
        max_len = max(len(x) for x in batch)
        padded = torch.zeros(batch_size, max_len, dtype=torch.long)
        attention_mask = torch.zeros(batch_size, max_len)
        
        for i, seq in enumerate(batch):
            # Add [CLS] at start
            padded[i, 0] = 1  # [CLS]
            padded[i, 1:len(seq)+1] = seq[:max_len-1]
            attention_mask[i, :len(seq)+1] = 1
        
        return padded, attention_mask

dataset = TinyTextDataset(vocab_size=1000, size=3000)
print(f'Dataset size: {len(dataset.data)} sequences')

In [None]:
# Training
model = BERT(vocab_size=1000, d_model=128, n_heads=4, n_layers=3, d_ff=256).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

losses = []
n_epochs = 100
batch_size = 32

print('Training BERT with MLM...')
for epoch in range(n_epochs):
    model.train()
    
    tokens, attn_mask = dataset.get_batch(batch_size)
    tokens = tokens.to(device)
    
    # Create MLM data
    masked_tokens, mlm_labels = create_mlm_data(tokens, vocab_size=1000)
    masked_tokens = masked_tokens.to(device)
    mlm_labels = mlm_labels.to(device)
    
    # Forward
    optimizer.zero_grad()
    logits = model.mlm_forward(masked_tokens)
    
    loss = criterion(logits.view(-1, 1000), mlm_labels.view(-1))
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    if (epoch + 1) % 20 == 0:
        print(f'Epoch {epoch+1}: Loss = {loss.item():.4f}')

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('MLM Loss')
plt.title('BERT MLM Training')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Test: Fill in the mask
model.eval()

test_tokens = torch.randint(5, 100, (1, 15)).to(device)
test_tokens[0, 5] = 4  # [MASK] at position 5

with torch.no_grad():
    logits = model.mlm_forward(test_tokens)
    predicted = logits[0, 5].argmax().item()

print(f'Input (with MASK at pos 5): {test_tokens[0].tolist()}')
print(f'Predicted token for MASK: {predicted}')

print('\nðŸŽ¯ Key Takeaways:')
print('1. BERT is encoder-only (no autoregressive generation)')
print('2. Bidirectional: Can see full context in both directions')
print('3. MLM: Learn by predicting masked tokens')
print('4. Great for classification, NER, QA tasks')