# Rotary Position Embedding (RoPE)

## Introduction

Self-attention is permutation-invariatnt. Each query token attends to all its context tokens simultaneously, meaning, unlike RNN/CNNs, Transformers have no built-in notion of word order. 

Like previous positional encoding techniques, **Rotary Position Embedding (RoPE)** is developed to enable the model to learn about token position information. It was introduced by Jianlin Su et al. in their 2021 paper *[RoFormer: Enhanced Transformer with Rotatry Position Embedding](https://arxiv.org/abs/2104.09864)* and was widely adopted in famous OSS LLMs (sometimes with modifications like YaRN, NTK-aware scaling or Dynamic NTK) such as GPT-NeoX (EleutherAI), LLaMA/LLaMA-2/LLaMA-3, and Qwen series. 

RoPE is the de-facto technique for LLM positional encoding now. One key advantage of RoPE, compared to Sinusoidal Positional Encodeing, is that RoPE preserves the norm and geometry of Q/K while injecting position. Additive sinusoidal can distort the magnitude and direction of the embeding, thus interfere with attention scores. 

RoPE is especially beneficial when scaling to billions of parameters:
- It tends to enable more stable training
- It makes sequence length extrapolation (context extension) straightforward
- It produces cleaner long-range generalization 

## The Math

Similar to Sinusoidal Positional Encoding, RoPE maps token positions to multiple sine-cosine frequency signals, each frequency $\phi$ is tied to a specific consecutive pair of model or head dimensions ($2m$ for even channels, $2m+1$ for odd channels). With position $p$ and dimension pair index $m$, the frequency $\phi_m(p)$ is defined as:

$$
\phi_m(p) = p \cdot B^{-\frac{2m}{d}}
$$

- $\phi$: sine/cosine frequency
- $p$: token position
- $B$: the base constant, 10000 by default
- $m$: dimension pair index ($2m$ for even and $2m+1$ for odd)
- $d$: dimension size of the model or one attention head

In Sinusoidal Position Encoding, this frequency $\phi$ is used as the angle to calculate the sine-cosine values which then get **added** to embeddings. In RoPE, this is the angle used in the **rotation matrix** that multiplies each Q/K pair $(q_{2m}, q_{2m+1})$ or $(k_{2m}, k_{2m+1})$.

$$
\begin{bmatrix}
q'_{2m}\\ q'_{2m+1}
\end{bmatrix}
=
\begin{bmatrix}
\cos\phi_m(p) & -\sin\phi_m(p) \\
\sin\phi_m(p) & \cos\phi_m(p)
\end{bmatrix} \cdot
\begin{bmatrix}
q_{2m}\\ q_{2m+1}
\end{bmatrix}
$$
$$
\begin{bmatrix}
k'_{2m}\\ k'_{2m+1}
\end{bmatrix}
=
\begin{bmatrix}
\cos\phi_m(p) & -\sin\phi_m(p) \\
\sin\phi_m(p) & \cos\phi_m(p)
\end{bmatrix} \cdot
\begin{bmatrix}
k_{2m}\\ k_{2m+1}
\end{bmatrix}
$$

### Math recap

A recap of rotation matrix and trig identities formula (in case that high-school math got rusty).

#### Angle-sum and -difference Trigonometric Identities

$$
\sin(\alpha + \beta) = \sin\alpha\cos\beta + \cos\alpha\sin\beta \\
\sin(\alpha - \beta) = \sin\alpha\cos\beta - \cos\alpha\sin\beta \\
\cos(\alpha + \beta) = \cos\alpha\cos\beta - \sin\alpha\sin\beta \\
\cos(\alpha - \beta) = \cos\alpha\cos\beta + \sin\alpha\sin\beta \\
\tan(\alpha + \beta) = \frac{\tan\alpha + \tan\beta}{1 - \tan\alpha\tan\beta} \\
\tan(\alpha - \beta) = \frac{\tan\alpha - \tan\beta}{1 + \tan\alpha\tan\beta}
$$

#### Rotation matrix

Then, let's use polar coordinates so that the rotation becomes obvious. The following is the **rotation matrix**. To rotate a vector counter-clockwise by the angle $\phi$, simply compute the dot product of this rotate matrix and the vector.

$$
\begin{bmatrix}
\cos(\phi) & -\sin(\phi)\\
\sin(\phi) & \cos(\phi)
\end{bmatrix}
$$

A point in Cartesian coordinates A(x, y) can also be represented as a vector in polar coordinates $A(r\cos\alpha, r\sin\alpha)$, where $r$ is the scale and $\alpha$ is the angle counter-clockwise from 0. We can apply this to the rotation matrix dot-product:

$$ 
A = 
\begin{bmatrix}
x\\ y
\end{bmatrix}
= 
\begin{bmatrix}
r\cos\alpha\\ r\sin\alpha
\end{bmatrix}
$$
$$ 
A' = 
\begin{bmatrix}
x'\\ y'
\end{bmatrix}
=
\begin{bmatrix}
\cos\phi & -\sin\phi\\
\sin\phi & \cos\phi
\end{bmatrix} \cdot
\begin{bmatrix}
x\\ y
\end{bmatrix}
$$
$$
=
\begin{bmatrix}
\cos\phi & -\sin\phi\\
\sin\phi & \cos\phi
\end{bmatrix} \cdot
\begin{bmatrix}
r\cos\alpha\\ r\sin\alpha
\end{bmatrix}
$$
$$
=
\begin{bmatrix}
r(\cos\phi\cos\alpha-\sin\phi\sin\alpha)\\
r(\sin\phi\cos\alpha+\cos\phi\sin\alpha)
\end{bmatrix}
$$

Applying the trig identities formula of $cos(\alpha+\beta)$ and $sin(\alpha+\beta)$, we know that the above equals to

$$
A' = 
\begin{bmatrix}
r\cos(\alpha + \phi) \\
r\sin(\alpha + \phi)
\end{bmatrix}
$$

Which is obviously $A(r\cos\alpha, r\sin\alpha)$ rotated counter-clockwise by angle $\phi$.

#### Complex number representation

The original RoFormer paper defined RoPE in the **complex number** domain ("We view each two consecutive channels of a query/key vector as the real and imaginary parts of a complex number and multiply it by a unit complex number $e^{i\theta}$".) 

In the complex plane, multiplying a coplex number $z = x + iy$ by $e^{i\theta}$ rotates it counter-clockwise by $\theta$ radians without changing its length:

$$
z' = z \cdot e^{i\theta}
$$

According to **Euler's identity formula**:

$$
e^{i\theta} = \cos\theta + i\sin\theta
$$

So Multiplying $z$ and $e^{i\theta}$ gives us:

$$
z' = z \cdot e^{i\theta} \\
= (x + iy)(\cos\theta + i\sin\theta) \\
= x\cos\theta + ix\sin\theta + iy\cos\theta + i^2y\sin\theta \\
= x\cos\theta + ix\sin\theta + iy\cos\theta + (-1)y\sin\theta \\
= x\cos\theta - y\sin\theta + i(x\sin\theta + y\cos\theta)
$$

We can interpret this as a matrix, so that the real part is the new $x'$, the imaginary part is the new $y'$:

$$
z' = 
\begin{bmatrix}
x' \\ y'
\end{bmatrix} = 
\begin{bmatrix}
\cos\theta & -\sin\theta \\
\sin\theta & \cos\theta
\end{bmatrix} \cdot
\begin{bmatrix}
x \\ y
\end{bmatrix}
$$

This tells us that multiplying $e^{i\theta}$ with vector $z=[x, y]^T$ in the complex plane is exactly the same as left-multiplying the 2D rotation matrix $R_{\theta}$ with the 2D vector $[x, y]^T$.

The math of the original RoFormer paper is: 
$$ z' = z \cdot e^{ip\theta_m} $$

Where $p$ is sequence position and $m$ indexes 2D pairs of model dimension channels.

### Real World Implementation Math

Despite the original paper using complex notation, most OSS LLM implementation (LLaMA, Qwen, etc.) implement RoPE in real space using the explicit 2x2 rotation formula

$$
q'_{2m} = q_{2m}\cos\phi_m - q_{2m+1}\sin\phi_m \\
q'_{2m+1} = q_{2m}\sin\phi_m + q_{2m+1}\cos\phi_m
$$

The reason could be
- Many frameworks (e.g. PyTorch) didn't have complex number dtype support
- Explicit real number operations can be faster on GPU as there's no need for type conversions
- It is easier to integrate into existing fused attention kernels like FlashAttention

## Implementation

Here is a minimalistic and straightforward implementation of RoPE following the formula:

$$
\theta_m = B^{\frac{-2m}{d_{\text{head}}}} \\
\phi_m(p) = p\cdot\theta_m \\
\begin{bmatrix}
x'_{2m} \\
x'_{2m+1}
\end{bmatrix}
=
\begin{bmatrix}
\cos\phi & -\sin\phi \\
\sin\phi & \cos\phi
\end{bmatrix}
\cdot
\begin{bmatrix}
x_{2m} \\
x_{2m+1}
\end{bmatrix}
$$

In [None]:
import torch

def build_rope_cache(max_seq_len: int, head_dim: int, base: float = 10000.0, 
                     pos_scale: float = 1.0, # set this >1.0 to extend the context (simple NTK-style)
                     device = None,
                     dtype = torch.float32):
    """
    Implements theta_m = base^(-2m/d_head), phi_m(p) = p/pos_scale * theta_m
    Build RoPE cache for cos(phi_m) and sin(phi_m) values
    
    Returns:
        cos, sin: [max_seq_len, head_dim//2], lookup tables of cosine and sine values ready to slice/gather and broadcast
    """
    assert head_dim % 2 == 0, "head_dim must be even for RoPE"

    # Create dimension indices m=0, 1, ..., head_dim//2-1 for each frequency
    m = torch.arange((head_dim // 2), device=device, dtype=dtype) 
    
    # Calculate all dimension frequencies theta_m = base^(-2m/d_head)
    theta_m = base ** (-2.0 * m / head_dim) # [head_dim//2]

    # Create sequence position indices with scaling
    positions = torch.arange(max_seq_len, device=device, dtype=dtype) / pos_scale # [max_seq_len]

    # Calculate angles for each position-dimension combination pair phi_m(pos) = pos * theta_m
    angles = positions[:, None] * theta_m[None, :] # [max_seq_len, head_dim//2]

    # Calculate cos, sin values and fill the lookup cache table
    cos = torch.cos(angles) # [max_seq_len, head_dim//2]
    sin = torch.sin(angles) # [max_seq_len, head_dim//2]

    return cos, sin

def apply_rope(x, cos, sin):
    """
    Apply RoPE rotation to input tensor using precomputed cos/sin values

    Inputs:
        x: input tensor - [batch_size, seq_len, num_heads, head_dim]
        cos: precomputed cosine values - [max_seq_len, head_dim]
        sin: precoputed sine values - [max_seq_len, head_dim]

    Returns:
        rotated tensor - [batch_size, seq_len, num_heads, head_dim]
    """
    # Slice cos/sin to match input seq_len
    seq_len = x.size(1)
    cos = cos[:seq_len]
    sin = sin[:seq_len]

    # Split input to odd/even pairs
    x_even = x[..., 0::2]   # [batch_size, seq_len, num_heads, head_dim//2]
    x_odd = x[..., 1::2]    # [batch_size, seq_len, num_heads, head_dim//2]    

    # Apply rotation matrix multiplication:
    # [x_rotated_even; x_rotated_odd] = [cos -sin; sin cos] * [x_even; x_odd]
    x_rotated_even = cos * x_even - sin * x_odd
    x_rotated_odd = sin * x_even + cos * x_odd

    # Interleave back to original input size
    x_rotated = torch.empty_like(x)
    x_rotated[..., 0::2] = x_rotated_even
    x_rotated[..., 1::2] = x_rotated_odd

    return x_rotated


A sample MHA module using this implementation: 

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RopeMHA(nn.Module):
    def __init__(self, d_model, num_heads, max_length=2048, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # initialize Q, K, V, output projections
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

        # Build RoPE cache
        self.register_buffer("rope_cos", None)
        self.register_buffer("rope_sin", None)
        self._build_rope_cache(max_length, self.head_dim)
    
    def _build_rope_cache(self, max_length: int, head_dim: int, 
                          base: float = 10000.0,
                          pos_scale: float = 1.0,
                          device: torch.device | None = None,
                          dtype = torch.float32):
        assert head_dim % 2 == 0, "head_dim must be even"
        # Calculate dimension half indices vector
        m = torch.arange(head_dim // 2, device=device, dtype=dtype)
        theta_m = base ** (- 2 * m / head_dim)
        # Build positions vector
        pos = torch.arange(max_length, device=device, dtype=dtype) / pos_scale
        # Compute angles of all dim-pos combinations
        angles = pos[:, None] * theta_m[None, :]
        # Compute cos/sin values and cache them
        cos = torch.cos(angles)
        sin = torch.sin(angles)

        self.register_buffer("rope_cos", cos)
        self.register_buffer("rope_sin", sin)
    
    def _apply_rope(self, x):
        # Get cos and sin from cache
        seq_len = x.size(1)
        cos, sin = self.rope_cos[:seq_len], self.rope_sin[:seq_len]
        # Split x to even and odd pairs
        x_even = x[..., 0::2]
        x_odd = x[..., 1::2]
        # Rotate
        x_rotated_even = cos * x_even - sin * x_odd
        x_rotated_odd = sin * x_even + cos * x_odd
        # Build the full rotated tensor
        x_rotated = torch.empty_like(x)
        x_rotated[..., 0::2] = x_rotated_even
        x_rotated[..., 1::2] = x_rotated_odd

        return x_rotated
        

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.size()

        # Project to get Q, K, V, and reshape for multi-head
        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

        # Apply RoPE to Q, K
        Q = self._apply_rope(Q)
        K = self._apply_rope(K)

        # Transpose for parallel multi-head attention
        Q = Q.transpose(1,2)
        K = K.transpose(1,2)
        V = V.transpose(1,2)

        attn_scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.head_dim)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(attn_scores)
        attn_values = torch.matmul(attn_weights, V)

        output = self.out_proj(attn_values.transpose(1,2).reshape(batch_size, seq_len, d_model))

        return output