# Module 1.2: Transformer Architecture Deep Dive

**Goal**: Understand position encoding, RoPE, and core transformer components

**Time**: 50 minutes

**Concepts Covered**:
- Sinusoidal position encoding
- Rotary Position Embedding (RoPE)
- 2D/3D position embedding visualization
- RoPE vs absolute encoding comparison

## Setup
Install required packages

In [None]:
!pip install torch numpy matplotlib seaborn -q

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

torch.manual_seed(42)
np.random.seed(42)
plt.style.use('seaborn-v0_8-darkgrid')

## Lesson 1: Sinusoidal Position Encoding (15 mins)

The original Transformer uses fixed sinusoidal position encodings to inject positional information.

In [None]:
def sinusoidal_position_encoding(seq_len, d_model):
    """
    Generate sinusoidal position encodings
    
    Args:
        seq_len: Sequence length
        d_model: Model dimension
    
    Returns:
        pos_encoding: (seq_len, d_model) tensor
    """
    pos_encoding = np.zeros((seq_len, d_model))
    
    for pos in range(seq_len):
        for i in range(0, d_model, 2):
            # Even indices: sin
            pos_encoding[pos, i] = np.sin(pos / (10000 ** ((2 * i) / d_model)))
            # Odd indices: cos
            if i + 1 < d_model:
                pos_encoding[pos, i + 1] = np.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))
    
    return torch.tensor(pos_encoding, dtype=torch.float32)

# Generate position encodings
seq_len = 50
d_model = 128
pos_enc = sinusoidal_position_encoding(seq_len, d_model)

print(f"Position encoding shape: {pos_enc.shape}")
print(f"First position: {pos_enc[0, :5]}")
print(f"Last position: {pos_enc[-1, :5]}")

In [None]:
# Visualize position encoding heatmap
plt.figure(figsize=(12, 8))
sns.heatmap(
    pos_enc.numpy()[:, :64].T,  # Show first 64 dimensions
    cmap='RdYlBu',
    cbar_kws={'label': 'Encoding Value'},
    xticklabels=10,
    yticklabels=10
)
plt.xlabel('Position in Sequence')
plt.ylabel('Dimension')
plt.title('Sinusoidal Position Encoding Heatmap', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Each row is a dimension, each column is a position")
print("- Patterns repeat at different frequencies")
print("- Lower dimensions change slowly, higher dimensions change rapidly")

## Lesson 2: Rotary Position Embedding (RoPE) (20 mins)

RoPE rotates query and key vectors by their position, enabling relative position awareness.

In [None]:
def apply_rope(x, freqs_cis):
    """
    Apply Rotary Position Embedding (RoPE) to input tensor
    
    Args:
        x: Input tensor (..., seq_len, d_model)
        freqs_cis: Precomputed frequency cis values (seq_len, d_model // 2, 2)
    
    Returns:
        Rotated tensor
    """
    # Reshape to separate real and imaginary parts
    x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    
    # Convert to complex numbers
    x_complex = torch.view_as_complex(x_reshaped)
    
    # Apply rotation (complex multiplication)
    freqs_cis_complex = torch.view_as_complex(freqs_cis)
    x_rotated = x_complex * freqs_cis_complex
    
    # Convert back to real
    x_out = torch.view_as_real(x_rotated)
    return x_out.flatten(-2)

def precompute_freqs_cis(dim, end, theta=10000.0):
    """
    Precompute frequency cis values for RoPE
    
    Args:
        dim: Model dimension (must be even)
        end: Maximum sequence length
        theta: Base frequency
    
    Returns:
        freqs_cis: (end, dim // 2, 2) tensor
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs)
    
    # Convert to complex exponential form: e^(i*theta) = cos(theta) + i*sin(theta)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    
    return torch.view_as_real(freqs_cis)

# Precompute RoPE frequencies
d_model = 64
seq_len = 32
freqs_cis = precompute_freqs_cis(d_model, seq_len)

print(f"RoPE frequencies shape: {freqs_cis.shape}")
print(f"First position frequencies (first 4 dims): {freqs_cis[0, :2, :]}")
print(f"Last position frequencies (first 4 dims): {freqs_cis[-1, :2, :]}")

In [None]:
# Visualize RoPE rotation in 2D
def visualize_rope_2d():
    """Visualize how RoPE rotates vectors in 2D space"""
    # Create a simple 2D vector
    vec = torch.tensor([1.0, 0.0])
    
    # Create 2D RoPE frequencies
    dim = 2
    positions = torch.arange(0, 8)
    freqs = 1.0 / (10000.0 ** (torch.arange(0, dim, 2).float() / dim))
    
    fig, ax = plt.subplots(figsize=(10, 10))
    
    for pos in positions:
        angle = pos * freqs[0]
        rotated_x = vec[0] * np.cos(angle) - vec[1] * np.sin(angle)
        rotated_y = vec[0] * np.sin(angle) + vec[1] * np.cos(angle)
        
        ax.arrow(0, 0, rotated_x, rotated_y, head_width=0.1, head_length=0.1, 
                fc='blue', ec='blue', alpha=0.6, length_includes_head=True)
        ax.text(rotated_x * 1.2, rotated_y * 1.2, f'P{pos.item()}', fontsize=10)
    
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)
    ax.set_xlabel('Dimension 0')
    ax.set_ylabel('Dimension 1')
    ax.set_title('RoPE Rotation Visualization\n(Vector rotated by position)', 
                fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

visualize_rope_2d()

In [None]:
# 3D visualization of position embeddings
def visualize_3d_embeddings():
    """Visualize position embeddings in 3D space"""
    # Use PCA to reduce to 3D for visualization
    from sklearn.decomposition import PCA
    
    # Generate sinusoidal embeddings
    pos_enc = sinusoidal_position_encoding(seq_len=20, d_model=64)
    
    # Reduce to 3D
    pca = PCA(n_components=3)
    pos_3d = pca.fit_transform(pos_enc.numpy())
    
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    # Plot positions
    scatter = ax.scatter(pos_3d[:, 0], pos_3d[:, 1], pos_3d[:, 2], 
                        c=range(len(pos_3d)), cmap='viridis', s=100)
    
    # Connect sequential positions
    for i in range(len(pos_3d) - 1):
        ax.plot([pos_3d[i, 0], pos_3d[i+1, 0]], 
                [pos_3d[i, 1], pos_3d[i+1, 1]], 
                [pos_3d[i, 2], pos_3d[i+1, 2]], 
                'gray', alpha=0.3, linewidth=1)
    
    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_zlabel('PC3')
    ax.set_title('3D Position Embedding Space\n(First 20 positions)', 
                fontsize=14, fontweight='bold')
    plt.colorbar(scatter, ax=ax, label='Position')
    plt.tight_layout()
    plt.show()

try:
    !pip install scikit-learn -q
    visualize_3d_embeddings()
except:
    print("Skipping 3D visualization (scikit-learn not available)")

## Lesson 3: RoPE vs Absolute Encoding Benchmark (15 mins)

Compare RoPE and absolute position encoding on a simple task.

In [None]:
class AbsolutePositionEncoding(nn.Module):
    """Absolute position encoding (additive)"""
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = sinusoidal_position_encoding(max_len, d_model)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

class RoPEPositionEncoding(nn.Module):
    """Rotary Position Embedding"""
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        freqs_cis = precompute_freqs_cis(d_model, max_len)
        self.register_buffer('freqs_cis', freqs_cis)
    
    def forward(self, x):
        # Apply RoPE to queries and keys (simplified: apply to x)
        return apply_rope(x, self.freqs_cis[:x.size(1)])

# Simple attention with position encoding
def attention_with_pos(x, pos_encoder, use_rope=False):
    """Compute attention with position encoding"""
    if use_rope:
        # For RoPE, we apply to Q and K separately
        q = pos_encoder(x)
        k = pos_encoder(x)
    else:
        # For absolute, add to embeddings
        x_pos = pos_encoder(x)
        q = k = x_pos
    
    # Compute attention
    scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(x.size(-1))
    attn = torch.softmax(scores, dim=-1)
    return torch.matmul(attn, x), attn

# Test on sequence
d_model = 32
seq_len = 16
batch_size = 1
x = torch.randn(batch_size, seq_len, d_model)

# Absolute encoding
abs_encoder = AbsolutePositionEncoding(d_model, seq_len)
out_abs, attn_abs = attention_with_pos(x, abs_encoder, use_rope=False)

# RoPE encoding
rope_encoder = RoPEPositionEncoding(d_model, seq_len)
out_rope, attn_rope = attention_with_pos(x, rope_encoder, use_rope=True)

print(f"Input shape: {x.shape}")
print(f"Absolute encoding output shape: {out_abs.shape}")
print(f"RoPE output shape: {out_rope.shape}")
print(f"\nAbsolute attention shape: {attn_abs.shape}")
print(f"RoPE attention shape: {attn_rope.shape}")

In [None]:
# Compare attention patterns
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

sns.heatmap(attn_abs[0].detach().numpy(), ax=axes[0], cmap='YlOrRd', cbar=True)
axes[0].set_title('Absolute Position Encoding\nAttention Pattern', fontweight='bold')
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')

sns.heatmap(attn_rope[0].detach().numpy(), ax=axes[1], cmap='YlOrRd', cbar=True)
axes[1].set_title('RoPE Position Encoding\nAttention Pattern', fontweight='bold')
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Query Position')

plt.tight_layout()
plt.show()

print("\nKey Differences:")
print("- Absolute: Fixed patterns based on absolute positions")
print("- RoPE: Relative position awareness, better for variable-length sequences")
print("- RoPE: More efficient (no addition, just rotation)")

## Key Takeaways

✅ **Position Encoding**: Injects positional information into token embeddings

✅ **Sinusoidal Encoding**: Fixed, additive encoding used in original Transformer

✅ **RoPE**: Rotates query/key vectors by position, enabling relative position awareness

✅ **RoPE Advantages**: Better extrapolation, more efficient, relative position awareness

## Next Steps

Continue to **Module 1.3: Feed-Forward Networks & Normalization** to learn about:
- SwiGLU activation functions
- LayerNorm vs RMSNorm
- Memory profiling