# Positional Encoding: Teaching Transformers About Position

Transformers process all positions in parallel, which creates a problem: **how does the model know about word order?** 

Without positional information, "cat sat on mat" and "mat on sat cat" would look identical!

## What You'll Learn

1. **The Position Problem** - Why transformers need positional information
2. **Sinusoidal Solution** - The elegant mathematical approach
3. **Addition vs Concatenation** - Why we add instead of concatenate
4. **Implementation** - Building positional encoding from scratch

Let's solve the position puzzle!

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional
import math

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Environment setup complete!")

import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
from typing import Tuple, Optional

plt.style.use('default')
sns.set_palette("husl")
torch.manual_seed(42)
np.random.seed(42)
print("Environment setup complete!")

## 1. The Position Problem

Let's see why transformers need positional encoding:

sentence1 = ["cat", "sat", "on", "mat"]
sentence2 = ["mat", "on", "sat", "cat"]

word_embeddings = {
    "cat": [1, 0, 0], "sat": [0, 1, 0], 
    "on": [0, 0, 1], "mat": [1, 1, 0]
}

words1 = [word_embeddings[word] for word in sentence1]
words2 = [word_embeddings[word] for word in sentence2]

sum1 = [sum(x) for x in zip(*words1)]
sum2 = [sum(x) for x in zip(*words2)]

print(f"Sentence 1: {' '.join(sentence1)}")
print(f"Sentence 2: {' '.join(sentence2)}")
print(f"Without position encoding:")
print(f"Representation 1: {sum1}")
print(f"Representation 2: {sum2}")
print(f"Identical? {sum1 == sum2}")
print("❌ Problem: Can't distinguish word order!")

In [None]:
## Sinusoidal Positional Encoding

**The Solution**: Use sine and cosine functions to create unique position signatures.

**Requirements**:
- Unique pattern for each position
- Bounded values (don't explode)  
- Smooth transitions between nearby positions
- Works for any sequence length

**Formula**: For position `pos` and dimension `i`:
- Even dims: `PE(pos, 2i) = sin(pos / 10000^(2i/d_model))`
- Odd dims: `PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))`

def create_sinusoidal_encoding(max_len: int, d_model: int) -> torch.Tensor:
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                        (-math.log(10000.0) / d_model))
    
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    
    return pe

max_len, d_model = 10, 8
pos_encoding = create_sinusoidal_encoding(max_len, d_model)

print(f"Positional encoding shape: {pos_encoding.shape}")
print(f"Value range: [{pos_encoding.min():.3f}, {pos_encoding.max():.3f}]")

for i in range(3):
    print(f"Position {i}: {[round(x, 3) for x in pos_encoding[i].tolist()]}")

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
sns.heatmap(pos_encoding.T, cmap='RdBu_r', center=0)
plt.title('Positional Encoding Pattern')
plt.xlabel('Position')
plt.ylabel('Dimension')

plt.subplot(1, 2, 2)
for dim in [0, 1, 6, 7]:
    plt.plot(pos_encoding[:, dim], label=f'Dim {dim}')
plt.title('Values by Position')
plt.xlabel('Position')
plt.ylabel('Value')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("✅ Each position gets a unique bounded pattern!")

In [None]:
## Why Addition Instead of Concatenation?

**Critical design choice**: We ADD position encoding to word embeddings rather than concatenating them.

**Addition benefits**:
- Same dimensionality (efficient)
- Creates blended word-position representations
- Attention sees unified "word-at-position" features

**Concatenation problems**:
- Doubles dimensions (expensive)
- Separates word and position information
- Requires learning to combine them

word_emb = torch.tensor([1.0, 0.5, -0.3, 0.8])
pos_emb = pos_encoding[1][:4]

print(f"Word embedding:     {word_emb.tolist()}")
print(f"Position embedding: {[round(x, 3) for x in pos_emb.tolist()]}")

added = word_emb + pos_emb
concatenated = torch.cat([word_emb, pos_emb])

print(f"\nADDITION (what transformers use):")
print(f"Result: {[round(x, 3) for x in added.tolist()]} (shape: {added.shape})")
print("✅ Same dimensionality, blended representation")

print(f"\nCONCATENATION (alternative):")
print(f"Result: {[round(x, 3) for x in concatenated.tolist()]} (shape: {concatenated.shape})")
print("❌ Double dimensions, separated information")

# Solve original problem with positional encoding
print(f"\n🎉 SOLVING THE POSITION PROBLEM:")
emb1 = torch.tensor([[1,0,0,0], [0,1,0,0], [0,0,1,0], [1,1,0,0]]).float()  # cat sat on mat
emb2 = torch.tensor([[1,1,0,0], [0,0,1,0], [0,1,0,0], [1,0,0,0]]).float()  # mat on sat cat

pos_enc_4 = create_sinusoidal_encoding(4, 4)
combined1 = emb1 + pos_enc_4
combined2 = emb2 + pos_enc_4

sum1, sum2 = combined1.sum(dim=0), combined2.sum(dim=0)
are_different = not torch.allclose(sum1, sum2, atol=1e-6)

print(f"After adding positional encoding:")
print(f"Different representations? {are_different}")
print("✅ Position encoding solved the word order problem!")

In [None]:
## Complete Positional Embedding Layer

Combine word embeddings with positional encoding in a complete neural network layer.

class PositionalEmbedding(nn.Module):
    def __init__(self, vocab_size: int, max_len: int, d_model: int):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        pos_encoding = create_sinusoidal_encoding(max_len, d_model)
        self.register_buffer('pos_encoding', pos_encoding)
    
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = token_ids.shape
        
        word_emb = self.token_embedding(token_ids)
        pos_emb = self.pos_encoding[:seq_len].unsqueeze(0)
        pos_emb = pos_emb.expand(batch_size, -1, -1)
        
        return word_emb + pos_emb

vocab_size, max_len, d_model = 100, 20, 8
pos_emb_layer = PositionalEmbedding(vocab_size, max_len, d_model)

batch_size, seq_len = 2, 5
token_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
embeddings = pos_emb_layer(token_ids)

print(f"Input shape: {token_ids.shape}")
print(f"Output shape: {embeddings.shape}")

# Show same token at different positions gets different embeddings
token_id = 42
print(f"\nSame token (ID={token_id}) at different positions:")
for pos in range(4):
    test_input = torch.full((1, pos+1), token_id)
    test_output = pos_emb_layer(test_input)
    final_embedding = test_output[0, pos]
    print(f"Position {pos}: {[round(x, 3) for x in final_embedding[:3].tolist()]}...")

print("\n✅ Same token gets different embeddings at different positions!")

## Summary

You've mastered positional encoding - the key to teaching transformers about word order!

### Key Concepts:
1. **The Problem**: Transformers process all positions in parallel and need explicit position information
2. **Sinusoidal Solution**: Sine and cosine functions create unique, bounded position signatures  
3. **Addition > Concatenation**: Adding creates richer word-position interactions efficiently
4. **Implementation**: Simple but powerful - transforms how models understand sequences

### What's Next?
Now you understand all the core transformer components:
- **Tokenization** (notebook 0) - Text → numbers
- **Attention** (notebook 1) - How to focus on relevant information
- **Transformer blocks** (notebook 2) - Complete processing units
- **Position encoding** (notebook 3) - Understanding word order

Ready to see it all working together in a complete transformer! 🚀

In [None]:
from src.model.embeddings import GPTEmbedding

class PositionalEmbedding(nn.Module):
    """Complete positional embedding layer with multiple options."""
    
    def __init__(self, vocab_size: int, max_len: int, d_model: int, 
                 pos_type: str = "learned", dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.pos_type = pos_type
        
        # Token embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional embeddings
        if pos_type == "learned":
            self.pos_embedding = nn.Embedding(max_len, d_model)
        elif pos_type == "sinusoidal":
            # Register as buffer (not a parameter)
            pos_encoding = create_sinusoidal_encoding(max_len, d_model)
            self.register_buffer('pos_encoding', pos_encoding)
        else:
            raise ValueError(f"Unknown pos_type: {pos_type}")
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize token embeddings
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        if pos_type == "learned":
            nn.init.normal_(self.pos_embedding.weight, std=0.02)
    
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: combine token and positional embeddings.
        
        Args:
            token_ids: Token IDs [batch_size, seq_len]
            
        Returns:
            Combined embeddings [batch_size, seq_len, d_model]
        """
        batch_size, seq_len = token_ids.shape
        
        # Token embeddings
        token_emb = self.token_embedding(token_ids)  # [batch_size, seq_len, d_model]
        
        # Positional embeddings
        if self.pos_type == "learned":
            positions = torch.arange(seq_len, device=token_ids.device)
            pos_emb = self.pos_embedding(positions)  # [seq_len, d_model]
            pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1)  # [batch_size, seq_len, d_model]
        else:  # sinusoidal
            pos_emb = self.pos_encoding[:seq_len].unsqueeze(0)  # [1, seq_len, d_model]
            pos_emb = pos_emb.expand(batch_size, -1, -1)  # [batch_size, seq_len, d_model]
        
        # Combine embeddings
        embeddings = token_emb + pos_emb
        
        # Apply dropout
        embeddings = self.dropout(embeddings)
        
        return embeddings

# Test both types of positional embedding
vocab_size, max_len, d_model = 100, 20, 8
batch_size, seq_len = 2, 10

# Create test input
token_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

# Test learned embeddings
learned_emb = PositionalEmbedding(vocab_size, max_len, d_model, pos_type="learned")
output_learned = learned_emb(token_ids)

# Test sinusoidal embeddings
sinusoidal_emb = PositionalEmbedding(vocab_size, max_len, d_model, pos_type="sinusoidal")
output_sinusoidal = sinusoidal_emb(token_ids)

print(f"Input token IDs shape: {token_ids.shape}")
print(f"Output embeddings shape: {output_learned.shape}")
print()

# Compare parameter counts
learned_params = sum(p.numel() for p in learned_emb.parameters())
sinusoidal_params = sum(p.numel() for p in sinusoidal_emb.parameters())

print(f"Parameter comparison:")
print(f"Learned embeddings:    {learned_params:,} parameters")
print(f"Sinusoidal embeddings: {sinusoidal_params:,} parameters")
print(f"Difference: {learned_params - sinusoidal_params:,} (positional embedding table)")

# Visualize the embeddings for first batch
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Learned embeddings
sns.heatmap(output_learned[0].detach().T, cmap='viridis', ax=ax1,
            xticklabels=range(seq_len), yticklabels=range(d_model))
ax1.set_title('Learned Positional Embeddings')
ax1.set_xlabel('Position')
ax1.set_ylabel('Dimension')

# Sinusoidal embeddings
sns.heatmap(output_sinusoidal[0].detach().T, cmap='viridis', ax=ax2,
            xticklabels=range(seq_len), yticklabels=range(d_model))
ax2.set_title('Sinusoidal Positional Embeddings')
ax2.set_xlabel('Position')
ax2.set_ylabel('Dimension')

plt.tight_layout()
plt.show()

print("\n✅ Both embedding types work and produce the same output shape!")