# Understanding Rotary Position Embedding (RoPE)

This notebook explains **RoPE (Rotary Position Embedding)** step-by-step, the positional encoding mechanism used in LLaMA, GPT-NeoX, and other modern LLMs.

## Why RoPE?

In transformers, attention is **permutation invariant** — without positional information, "dog bites man" and "man bites dog" look identical. We need to encode position.

Traditional approaches:
- **Absolute positional embeddings** (original Transformer): add learned position vectors to token embeddings
- **Sinusoidal embeddings**: fixed sin/cos patterns, good for extrapolation

**RoPE** does something smarter: it encodes **relative** position directly into the attention mechanism by rotating query and key vectors in a way that makes their dot product depend on distance.

## Key Insight

Instead of adding positional info to embeddings, RoPE **rotates** Q and K vectors by angles proportional to their positions. The dot product between rotated vectors naturally captures relative distance.

Mathematical property:
```
Q_m · K_n = (Rotate(Q, m)) · (Rotate(K, n)) = function(m - n)
```

The attention score depends only on the *difference* between positions `m` and `n`, not their absolute values.


In [1]:
import torch
import torch.nn as nn
import numpy as np

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Step 1: 2D Rotation Basics

RoPE is built on **2D rotations**. Let's start with the simplest case.

A 2D vector `(x, y)` rotated by angle `θ` becomes:
```
x' = x * cos(θ) - y * sin(θ)
y' = x * sin(θ) + y * cos(θ)
```

In matrix form:
```
[x']   [cos(θ)  -sin(θ)] [x]
[y'] = [sin(θ)   cos(θ)] [y]
```


In [2]:
def rotate_2d(x, y, theta):
    """Rotate a 2D point (x, y) by angle theta."""
    x_rot = x * np.cos(theta) - y * np.sin(theta)
    y_rot = x * np.sin(theta) + y * np.cos(theta)
    return x_rot, y_rot

# Example
x, y = 1.0, 0.5
print(f"Original: ({x}, {y})")
print(f"After 90° rotation: {rotate_2d(x, y, np.pi/2)}")
print(f"After 180° rotation: {rotate_2d(x, y, np.pi)}")

Original: (1.0, 0.5)
After 90° rotation: (-0.49999999999999994, 1.0)
After 180° rotation: (-1.0, -0.4999999999999999)


## Step 2: Extending to Higher Dimensions

Attention heads have dimension `d_head` (e.g., 64, 128). RoPE applies 2D rotation to **pairs of dimensions**:

- Dimensions 0 and 1 form a pair → rotate by `θ_0 * pos`
- Dimensions 2 and 3 form a pair → rotate by `θ_1 * pos`
- Dimensions 4 and 5 form a pair → rotate by `θ_2 * pos`
- ...

Each pair has a **different base frequency** `θ_i`, computed as:
```
θ_i = 1 / (base^(2i / d))
```

Default `base = 10000` (like sinusoidal embeddings). Lower frequencies for early dimensions (global patterns), higher frequencies for later dimensions (fine-grained patterns).


In [3]:
def compute_freqs(d_head, base=10000.0):
    """
    Compute rotation frequencies for RoPE.
    
    Args:
        d_head: Dimension of each attention head (must be even)
        base: Base for frequency computation (default 10000)
    
    Returns:
        freqs: Array of shape (d_head // 2,) with rotation frequencies
    """
    # Pair indices: 0, 1, 2, ... d_head//2 - 1
    i = np.arange(0, d_head, 2)  # [0, 2, 4, ...]
    
    # Compute θ_i = 1 / (base^(2i / d))
    freqs = 1.0 / (base ** (i / d_head))
    
    return freqs

# Example: 8-dimensional head
d_head = 8
freqs = compute_freqs(d_head)

print(f"Head dimension: {d_head}")
print(f"Number of rotation pairs: {len(freqs)}")
print(f"Frequencies: {freqs}")
print(f"\nFrequency ratio (first/last): {freqs[0] / freqs[-1]:.2f}")
print("Lower frequencies (early dims) capture global patterns.")
print("Higher frequencies (late dims) capture fine-grained patterns.")

Head dimension: 8
Number of rotation pairs: 4
Frequencies: [1.    0.1   0.01  0.001]

Frequency ratio (first/last): 1000.00
Lower frequencies (early dims) capture global patterns.
Higher frequencies (late dims) capture fine-grained patterns.


## Step 3: Applying RoPE to a Vector

Given:
- A vector `x` of dimension `d_head`
- Position `pos` in the sequence
- Frequencies `θ = [θ_0, θ_1, ..., θ_{d/2-1}]`

We split `x` into pairs and rotate each pair:
```python
for i in range(d_head // 2):
    x[2*i], x[2*i+1] = rotate_2d(x[2*i], x[2*i+1], θ_i * pos)
```

The rotation angle for each pair is `θ_i * pos` — higher positions = more rotation.


In [4]:
def apply_rope_naive(x, pos, freqs):
    """
    Apply RoPE rotation to vector x at position pos.
    
    Args:
        x: Vector of shape (d_head,)
        pos: Position in sequence (integer)
        freqs: Rotation frequencies, shape (d_head // 2,)
    
    Returns:
        Rotated vector of shape (d_head,)
    """
    d_head = len(x)
    x_rotated = x.copy()
    
    for i in range(d_head // 2):
        # Get the pair
        x0 = x[2*i]
        x1 = x[2*i + 1]
        
        # Rotation angle for this pair
        theta = freqs[i] * pos
        
        # Apply 2D rotation
        cos_theta = np.cos(theta)
        sin_theta = np.sin(theta)
        
        x_rotated[2*i]     = x0 * cos_theta - x1 * sin_theta
        x_rotated[2*i + 1] = x0 * sin_theta + x1 * cos_theta
    
    return x_rotated

# Example
d_head = 8
x = np.random.randn(d_head)
freqs = compute_freqs(d_head)

print("Original vector:")
print(x)
print("\nAfter RoPE at position 0:")
print(apply_rope_naive(x, pos=0, freqs=freqs))
print("\nAfter RoPE at position 5:")
print(apply_rope_naive(x, pos=5, freqs=freqs))
print("\nAfter RoPE at position 100:")
print(apply_rope_naive(x, pos=100, freqs=freqs))
print("\nNotice: Same vector, different positions → different rotations")

Original vector:
[ 0.49671415 -0.1382643   0.64768854  1.52302986 -0.23415337 -0.23413696
  1.57921282  0.76743473]

After RoPE at position 0:
[ 0.49671415 -0.1382643   0.64768854  1.52302986 -0.23415337 -0.23413696
  1.57921282  0.76743473]

After RoPE at position 5:
[ 0.00831403 -0.51553161 -0.16177924  1.64710287 -0.22215877 -0.24554714
  1.57535592  0.77532117]

After RoPE at position 100:
[ 0.3583137  -0.3707469   0.28510338 -1.63028723  0.07050585 -0.32353801
  1.4947077   0.92125896]

Notice: Same vector, different positions → different rotations


## Step 4: Efficient Implementation with Complex Numbers

The naive loop is slow. Instead, we use **complex number multiplication**.

Key insight: 2D rotation is complex multiplication:
```
(x + iy) * e^(iθ) = (x + iy) * (cos(θ) + i*sin(θ))
```

Steps:
1. Precompute `freqs_cis = e^(i * θ_i * pos)` for all positions
2. View vector as complex: `[x0, x1, x2, x3, ...]` → `[x0+i*x1, x2+i*x3, ...]`
3. Multiply by `freqs_cis`
4. View back as real

This is the production implementation used in LLaMA.


In [6]:
def precompute_freqs_cis(d_head, max_seq_len=2048, base=10000.0):
    """
    Precompute rotation matrices for all positions.
    
    Args:
        d_head: Head dimension (must be even)
        max_seq_len: Maximum sequence length
        base: Frequency base
    
    Returns:
        Complex tensor of shape (max_seq_len, d_head // 2)
        Each entry is e^(i * θ_j * pos)
    """
    # Compute base frequencies
    freqs = compute_freqs(d_head, base)
    
    # Create position range
    positions = np.arange(max_seq_len)
    
    # Outer product: (max_seq_len, 1) x (1, d_head//2) = (max_seq_len, d_head//2)
    angles = np.outer(positions, freqs)
    
    # Convert to complex: e^(i*θ) = cos(θ) + i*sin(θ)
    freqs_cis = np.cos(angles) + 1j * np.sin(angles)
    
    return freqs_cis

# Precompute for sequence length 100
d_head = 8
max_seq_len = 100
freqs_cis = precompute_freqs_cis(d_head, max_seq_len)

print(f"Precomputed freqs_cis shape: {freqs_cis.shape}")
print(f"Type: {freqs_cis.dtype}")
print(f"\nFor position 0: {freqs_cis[0]}")
print(f"For position 10: {freqs_cis[10]}")
print("\nThis precomputation happens once and gets reused for all sequences.")

Precomputed freqs_cis shape: (100, 4)
Type: complex128

For position 0: [1.+0.j 1.+0.j 1.+0.j 1.+0.j]
For position 10: [-0.83907153-0.54402111j  0.54030231+0.84147098j  0.99500417+0.09983342j
  0.99995   +0.00999983j]

This precomputation happens once and gets reused for all sequences.


In [7]:
def apply_rope_complex(x, freqs_cis_pos):
    """
    Apply RoPE using complex multiplication (fast version).
    
    Args:
        x: Vector of shape (d_head,)
        freqs_cis_pos: Precomputed rotation for this position, shape (d_head // 2,)
    
    Returns:
        Rotated vector of shape (d_head,)
    """
    d_head = len(x)
    
    # View as complex: [x0, x1, x2, x3] -> [x0+i*x1, x2+i*x3]
    x_complex = x.reshape(-1, 2).astype(np.complex128)
    x_complex = x_complex[:, 0] + 1j * x_complex[:, 1]
    
    # Apply rotation via complex multiplication
    x_rotated_complex = x_complex * freqs_cis_pos
    
    # Convert back to real
    x_rotated = np.stack([x_rotated_complex.real, x_rotated_complex.imag], axis=-1)
    x_rotated = x_rotated.flatten()
    
    return x_rotated

# Verify both implementations match
x = np.random.randn(8)
freqs = compute_freqs(8)
freqs_cis = precompute_freqs_cis(8, 100)

pos = 42
result_naive = apply_rope_naive(x, pos, freqs)
result_complex = apply_rope_complex(x, freqs_cis[pos])

print("Naive implementation:")
print(result_naive)
print("\nComplex implementation:")
print(result_complex)
print("\nDifference (should be ~0):")
print(np.abs(result_naive - result_complex).max())
print("\n✅ Both implementations produce identical results!")

Naive implementation:
[ 0.68505083  0.21326734 -0.17872323  0.63223269  1.00109309 -1.64833239
 -1.69978754 -0.63421692]

Complex implementation:
[ 0.68505083  0.21326734 -0.17872323  0.63223269  1.00109309 -1.64833239
 -1.69978754 -0.63421692]

Difference (should be ~0):
0.0

✅ Both implementations produce identical results!


## Step 5: The Magic — Relative Position Encoding

Here's why RoPE is special. When we compute attention:
```
score = Q_m · K_n
```

With RoPE:
```
score = Rotate(Q, m) · Rotate(K, n)
```

Due to rotation properties, this depends only on `m - n`, the **relative distance**.

Let's verify this experimentally.


In [8]:
# Create two random vectors
q = np.random.randn(8)
k = np.random.randn(8)
freqs_cis = precompute_freqs_cis(8, 200)

# Compute attention scores for different (m, n) pairs with same distance
distance = 5
pairs = [(0, 5), (10, 15), (50, 55), (100, 105)]  # All have distance = 5

print("Testing: Same distance, different absolute positions")
print("="*60)
scores = []
for m, n in pairs:
    q_rot = apply_rope_complex(q, freqs_cis[m])
    k_rot = apply_rope_complex(k, freqs_cis[n])
    score = np.dot(q_rot, k_rot)
    scores.append(score)
    print(f"Positions ({m:3d}, {n:3d}), distance={n-m}: score = {score:.6f}")

print(f"\nScore variance: {np.var(scores):.10f}")
print("✅ All scores nearly identical — depends only on distance!")

# Now try different distances
print("\n" + "="*60)
print("Testing: Different distances from same starting position")
print("="*60)
m = 10
distances = [0, 1, 5, 10, 20, 50]
for d in distances:
    n = m + d
    q_rot = apply_rope_complex(q, freqs_cis[m])
    k_rot = apply_rope_complex(k, freqs_cis[n])
    score = np.dot(q_rot, k_rot)
    print(f"Distance {d:2d}: score = {score:.6f}")

print("\n✅ Scores vary with distance, not absolute position!")

Testing: Same distance, different absolute positions
Positions (  0,   5), distance=5: score = -1.844244
Positions ( 10,  15), distance=5: score = -1.844244
Positions ( 50,  55), distance=5: score = -1.844244
Positions (100, 105), distance=5: score = -1.844244

Score variance: 0.0000000000
✅ All scores nearly identical — depends only on distance!

Testing: Different distances from same starting position
Distance  0: score = -2.393375
Distance  1: score = -2.512100
Distance  5: score = -1.844244
Distance 10: score = -1.953404
Distance 20: score = -1.591033
Distance 50: score = -4.243366

✅ Scores vary with distance, not absolute position!


## Step 6: PyTorch Implementation

Now let's implement RoPE in PyTorch, matching the production code style.


In [9]:
class RotaryEmbedding(nn.Module):
    """
    Rotary Position Embedding (RoPE).
    
    Precomputes rotation frequencies and applies them to Q/K tensors.
    """
    
    def __init__(self, head_dim: int, max_seq_len: int = 2048, base: float = 10000.0):
        super().__init__()
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.base = base
        
        # Precompute frequencies
        freqs_cis = self._precompute_freqs_cis()
        # Register as buffer (not a parameter, moves with model to device)
        self.register_buffer('freqs_cis', freqs_cis, persistent=False)
    
    def _precompute_freqs_cis(self):
        # Compute base frequencies: θ_i = 1 / (base^(2i / d))
        i = torch.arange(0, self.head_dim, 2, dtype=torch.float32)
        freqs = 1.0 / (self.base ** (i / self.head_dim))
        
        # Positions
        positions = torch.arange(self.max_seq_len, dtype=torch.float32)
        
        # Outer product: (max_seq_len, head_dim // 2)
        angles = torch.outer(positions, freqs)
        
        # Convert to complex: e^(iθ) = cos(θ) + i*sin(θ)
        freqs_cis = torch.polar(torch.ones_like(angles), angles)
        
        return freqs_cis
    
    @staticmethod
    def apply_rotary(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
        """
        Apply rotary embedding to input tensor.
        
        Args:
            x: Input tensor of shape (batch, heads, seq_len, head_dim)
            freqs_cis: Rotation frequencies of shape (seq_len, head_dim // 2)
        
        Returns:
            Rotated tensor of same shape as x
        """
        # Reshape x to pair up dimensions: (b, h, t, d) -> (b, h, t, d//2, 2)
        x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2)
        
        # Convert to complex: (b, h, t, d//2)
        x_complex = torch.view_as_complex(x_reshaped)
        
        # Broadcast freqs_cis: (t, d//2) -> (1, 1, t, d//2)
        freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(0)
        
        # Apply rotation via complex multiplication
        x_rotated = x_complex * freqs_cis
        
        # Convert back to real: (b, h, t, d//2, 2) -> (b, h, t, d)
        x_out = torch.view_as_real(x_rotated).flatten(-2)
        
        return x_out.type_as(x)
    
    def forward(self, seq_len: int, offset: int = 0) -> torch.Tensor:
        """
        Get rotation frequencies for a given sequence.
        
        Args:
            seq_len: Length of current sequence
            offset: Starting position (for KV cache)
        
        Returns:
            freqs_cis slice of shape (seq_len, head_dim // 2)
        """
        return self.freqs_cis[offset : offset + seq_len]

# Test the implementation
rope = RotaryEmbedding(head_dim=64, max_seq_len=512)
print(f"RoPE module initialized: head_dim={rope.head_dim}, max_seq_len={rope.max_seq_len}")
print(f"Precomputed freqs_cis shape: {rope.freqs_cis.shape}")

# Create dummy Q/K tensors: (batch, heads, seq_len, head_dim)
batch_size = 2
num_heads = 8
seq_len = 16
head_dim = 64

q = torch.randn(batch_size, num_heads, seq_len, head_dim)
k = torch.randn(batch_size, num_heads, seq_len, head_dim)

# Get frequencies for this sequence
freqs_cis = rope(seq_len)
print(f"\nFrequencies for seq_len={seq_len}: {freqs_cis.shape}")

# Apply RoPE
q_rotated = rope.apply_rotary(q, freqs_cis)
k_rotated = rope.apply_rotary(k, freqs_cis)

print(f"\nOriginal Q shape: {q.shape}")
print(f"Rotated Q shape: {q_rotated.shape}")
print(f"Shape preserved: {q.shape == q_rotated.shape}")

RoPE module initialized: head_dim=64, max_seq_len=512
Precomputed freqs_cis shape: torch.Size([512, 32])

Frequencies for seq_len=16: torch.Size([16, 32])

Original Q shape: torch.Size([2, 8, 16, 64])
Rotated Q shape: torch.Size([2, 8, 16, 64])
Shape preserved: True


## Step 7: Using RoPE in Attention

In practice, RoPE is applied in the attention layer:

```python
# Inside MultiHeadAttention.forward()
qkv = self.qkv_proj(x)  # (b, t, 3*d_out)
q, k, v = split_and_reshape(qkv)  # (b, h, t, d_h)

# Apply RoPE BEFORE KV-caching
if freqs_cis is not None:
    q = RotaryEmbedding.apply_rotary(q, freqs_cis)
    k = RotaryEmbedding.apply_rotary(k, freqs_cis)

# Then compute attention as usual
scores = q @ k.transpose(-2, -1) / sqrt(d_h)
attn = softmax(scores, dim=-1)
output = attn @ v
```

**Key**: RoPE is applied to Q and K but **NOT V** — only positional information goes into the attention weights, not the values.


In [10]:
def simple_attention_with_rope(q, k, v, rope_module, seq_len):
    """
    Simplified attention with RoPE.
    
    Args:
        q, k, v: Tensors of shape (batch, heads, seq_len, head_dim)
        rope_module: RotaryEmbedding instance
        seq_len: Sequence length
    
    Returns:
        Attention output of same shape as v
    """
    # Get RoPE frequencies
    freqs_cis = rope_module(seq_len)
    
    # Apply RoPE to Q and K (NOT V)
    q = rope_module.apply_rotary(q, freqs_cis)
    k = rope_module.apply_rotary(k, freqs_cis)
    
    # Scaled dot-product attention
    head_dim = q.size(-1)
    scores = (q @ k.transpose(-2, -1)) / (head_dim ** 0.5)
    attn = torch.softmax(scores, dim=-1)
    output = attn @ v
    
    return output, attn

# Demo
rope = RotaryEmbedding(head_dim=32)
seq_len = 10

q = torch.randn(1, 4, seq_len, 32)
k = torch.randn(1, 4, seq_len, 32)
v = torch.randn(1, 4, seq_len, 32)

output, attn = simple_attention_with_rope(q, k, v, rope, seq_len)

print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn.shape}")
print(f"\nAttention matrix for first head (first 5x5):")
print(attn[0, 0, :5, :5].detach().numpy())

Output shape: torch.Size([1, 4, 10, 32])
Attention weights shape: torch.Size([1, 4, 10, 10])

Attention matrix for first head (first 5x5):
[[0.03593255 0.08071355 0.23664361 0.05539827 0.04589043]
 [0.27033275 0.15126061 0.06344225 0.02910307 0.06042084]
 [0.02558617 0.06199391 0.15081383 0.08664306 0.08330887]
 [0.19717795 0.14761339 0.02668178 0.0051381  0.01747172]
 [0.17322601 0.11085868 0.04897138 0.1102547  0.03070857]]


## Summary

**RoPE (Rotary Position Embedding)** encodes position by rotating Q and K vectors:

1. **Split dimensions into pairs** — each pair gets rotated independently
2. **Different frequencies** — early dims use low freq (global), late dims use high freq (local)
3. **Position-dependent rotation** — angle = `θ_i * position`
4. **Relative encoding** — attention score `Q_m · K_n` depends only on `m - n`
5. **Efficient via complex numbers** — precompute `e^(iθ)`, multiply as complex

### Advantages over absolute embeddings:
- ✅ Naturally encodes **relative** position ("how far apart are tokens?")
- ✅ Better **extrapolation** to longer sequences than seen in training
- ✅ No learned parameters — purely geometric
- ✅ Works seamlessly with **KV-cache** (position baked into cached keys)

### Where RoPE is used:
- LLaMA (all versions)
- GPT-NeoX
- PaLM
- Many modern open-source LLMs

---

## Next Steps

1. Integrate `RotaryEmbedding` into your `MultiHeadAttention` layer
2. Precompute `freqs_cis` once in the model constructor
3. Pass `freqs_cis` slice to each attention layer during forward pass
4. Train and see if RoPE improves long-range dependencies!
