# YaRN (Yet another RoPE extensioN)

## Introduction

[YaRN(Yet another RoPE extensioN method)](https://arxiv.org/abs/2309.00071) is a compute-efficient way to extend the context window of RoPE-based models. 

### Background

RoPE injects position information through rotation, which lets the model learn position without addtional parameters while preserving vector norms. However, RoPE has a fundamental limitation: models can only handle sequences no longer than they have seen during training. 

RoPE encodes position through frequency($\phi_m(p)$)-dependent rotations, where:
$$
\phi_m(p) = p \cdot \theta_m \quad \text{where} \quad \theta_m = B^{\frac{-2m}{d_{head}}}
$$ 
When position $p$ exceeds the longest sequence in training, model encounters position patterns that it hasn't learned and this could hurt inference performance.

RoPE's frequency design is inherently multi-scale. Across the whole embedding dimensions:
- **High frequencie**s (low dimension $m$): capture fine-grained positional relationships, allowing the model to learn *local attention patterns*
- **Low frequencies** (high dimension $m$): capture coarse-grained positional relationships, allowing the model to learn *global attention patterns*

For each token, RoPE encodes its position information with rotation frequencies of all embedding dimensions. This naturally becomes a multi-scale position signature and the model learns to use different attention heads to focus on different frequecny ranges - some become specialized in local relationships, others in global relationships. On the flip side, this makes naive position scaling to account for unseen context length very difficult for RoPE.

### Previous Scaling Attempts

#### Linear Interpolation (Position Interpolation)

Introduced by Chen et al. from Meta in their 2023 paper [Extending Context Window of Large Language Models via Position Interpolation](https://arxiv.org/abs/2306.15595), Linear Interpolation compresses position space uniformly to account for the target context length:

$$
s = \frac{L'}{L} \quad , \quad m' = \frac{m}{s}  \quad \text{(m is position index)}
$$


This paper uses a different notation from the RoPE paper which can be confusing, in RoPE paper, $m$ stands for dimension index while $p$ for token position. To keep it consistent, let's keep $p$ for position index and rewrute the full Linear Interpolation RoPE formula:

$$
s = \frac{L'}{L} \quad , \quad p' = \frac{p}{s}
$$
$$
\phi'(p,m) = p' \cdot B^{\frac{-2m}{d_{head}}} = \frac{p}{s} \cdot B^{\frac{-2m}{d_{head}}}
$$


In Python:

In [1]:
import torch

def linear_interpolation_rope(max_seq_len, head_dim, original_max_len, base=10000.0):
    scale = max_seq_len / original_max_len
    positions = torch.arange(max_seq_len) / scale # Compress positions
    # The rest of RoPE

The problem with this approach is that uniform compression hurts all frequencies equally, especially over-compressing high-frequency information hurts performance of fine-grained local attentions.

#### NTK-aware Scaling

NTK(Neural Tangent Kernel)-aware scaling emerged from the Reddit community r/LocalLLaMA in mid-2023. It was developed through community experientations and implemented in popular inference tools like [llama.cpp](https://github.com/ggml-org/llama.cpp), [text-generation-webui](https://github.com/oobabooga/text-generation-webui), and various Hugging Face model implementations.

Instead of scaling positions, this approach scales $\theta_m$ depending on dimension $m$, slowing down the angular growth at higher frequencies (lower dimensions).

The formula:

$$
\theta_m = B^{\frac{-2m}{d_{head}}} \quad \text{(RoPE)}
$$
$$
\theta_m^{\text{NTK}} = \theta_m \cdot s^{\frac{2m}{d-2}} \quad \text{where} \quad s = \frac{L'}{L}
$$
$$
\phi_m^{NTK}(p) = {p} \cdot \theta_m^{\text{NTK}}
$$

This formula ws derived empirically rather than from formal mathematical derivation. Its heuristics:
- Low dimensions (high frequency): $m$ is small so frequencies get minimal scaling, keeping the fine-grained local attention performance
- High dimensions (low frequency): $m$ is large so frequencies get agressive scaling. This works since the coarse-grained global positional relationships (represented in high dimensions) are more robust to compression

In [2]:
def ntk_scaling_rope(max_seq_len, head_dim, original_max_len, base=10000.0):
    scale = max_seq_len / original_max_len
    m = torch.arrange(head_dim // 2)
    # original theta_m
    theta_m = base ** (-2 * m / head_dim)
    # NTK-aware scaling
    theta_m_ntk = theta_m * (scale ** (2 * m / (head_dim - 2)))
    # ... Rest of RoPE

### Dynamic NTK Scaling

At inference-time, often mutiple forward-passes are performed with varing sequence lengths from 1 to max context length (e.g., autoregressive token generation). Throughout this inference cycle, we can either 
- apply the same positional embedding scaling (PI or NTK-aware Scaling), or
- update position embedding scaling for every sequence length from 1 to max_seq_length

The first method causes the model to perform sub-optimal for sequnce length smaller than the max_seq_length, also an abrupt performance degradation when the sequence length is longer than the max_seq_length. The second method allows the model to gracefully degrade. When combined with NTK-aware scaling, the second method is called **Dynamic NTK Scaling**

### YaRN

YaRN was the first formal academic paper that rigorously analyzed and improved upon community experimentations to address the fundamental issue of extending context window of RoPE models: different frequency bands require different scaling strategies. It was introduced by Peng et al. in their 2023 paper [YaRN: Efficient Context Window Extension of Large Language Models](https://arxiv.org/abs/2309.00071)

Standard RoPE fails when extending beyond the max context window that the model has seen in training. To address this, Linear Interpolation applies a uniform position compression. NTK-aware scaling employs dimension-dependent scaling. YaRN improves LI and NTK-aware scaling by:
- **NTK-by-parts interpolations** with a ramp function: NTK-by-parts partitions RoPE frequencies into different regions with differnet scaling strategies, while the ramp function smoothly transitions between these interpolation and extrapolation regions to avoid discontinuities.
  - High frequencies (local patterns): interpolation (gentle compression)
  - Low frequencies (global patterns): extrapolation (preserve as-is)
- **Attention scaling temperature**: add a temperature factor `t` to maintain proper attention entropy
   


With this design YaRN achieves strong long-context performance with minimal fine-tuning (only ~400 steps of continued pretraining/fine-tuning).

The core YaRN formula:

**NTK-by-parts Interpolation with Ramp Function**


We choose not to interpolate the high frequency dimensions at all while always interpolating the lower frequency dimensions. In particular:
- For small wavelengths $\lambda$ (much smaller than $L$, aka high frequencies/low dimension m), don't interpolate (don't compress $\theta_m$)
- For large wavelengths $\lambda$ (equal to or bigger than $L$, aka low frequencies/high dimension m), interpolate (compress) and avoid extrapolation(avoid preserving)
- For dimensions in-between, we can have a bit of both (weighted sum)

So the NTK-by-parts interpolation formula is a weighted sum of interpolation (scaling) and extrapolation (not scaling). The weight is dependent on dimension $m$:

$$
\theta_m' = (1 - \gamma (r(m)))\frac{\theta_m}{s} + \gamma(r(m))\theta_m
$$

Where:

- $s$ is the context length scale factor
- $r(m)$ is the ratio between the original context size $L$ and the wavelength at dimension $m$, $\lambda_m$:
$$
s = \frac{L'}{L}, \quad r(m) = \frac{L}{\lambda_m}
$$ 

- The wavelength $\lambda_m$ is calculated from that dimension's frequency $\theta_m$:
$$
\lambda_m = \frac{2\pi}{\theta_m} = \frac{2\pi}{B^{\frac{-2m}{d_{head}}}} = 2\pi B^{\frac{2m}{d_{head}}}
$$


- $\gamma(r)$ is the **Ramp function**:
$$
\gamma(r) = 
\begin{cases} 
0, & \text{if } r < \alpha \\ 
1, & \text{if } r > \beta \\   
\frac{r - \alpha}{\beta - \alpha}, & \text{otherwise }
\end{cases}
$$

Where $\alpha$ and $\beta$ are two tunable parameters that should be tuned on a case-by-case basis. E.g., the authors found that for Llama family models, good values are $\alpha=1$ and $\beta=32$.

**Attention Scaling Temperature**

YaRN also introduces a temprature $t$ on the logits before attention softmax. 

$$
softmax(\frac{q_m @ k_n^T}{t\sqrt{d_k}})
$$

This is to address the issue that attention score distribution changes when extending context length. When context window is extended, the attention mechanism sees more tokens and the distribution of attention scores are impacted:
- More tokens to attend to -> attention is spread thinner
- Average attention entropy increases
This can lead to problems such as diluted attention patterns, loss of focus on important tokens, and thus degrade the model performance.

For LLaMA and Llana 2 models, the authors recommended:
$$
\sqrt{\frac{1}{t}} = 0.1 ln(s) + 1
$$


## Implementation

Let's implement YaRN following the formulas.

In [15]:
import torch
import math

def build_yarn_rope_cache(
    dim: int,
    max_seq_len: int, # L'
    orig_seq_len:int,   # L
    base: float = 10000.0,
    alpha: float = 1.0,
    beta: float = 32.0,
    device = None,
    dtype = torch.float32
):
    # Scale factor s = L'/L
    s = max_seq_len / orig_seq_len

    # Original RoPE frequencies
    # theta_m = B^(-2m/d_head)
    m = torch.arange(dim // 2, device=device, dtype=dtype)
    theta_m = base ** (-2.0 * m / dim)

    # Wavelengths
    # lambda_m = 2pi / theta_m
    lambda_m = 2 * math.pi / theta_m

    # ratio r(m) = L / lambda_m
    r_m = orig_seq_len / lambda_m

    # ramp function gamma(r)
    gamma = torch.zeros_like(r_m)
    # if r < alpha: gamma = 0
    # if r > beta: gamma = 1
    # else: gamma = (r - alpha) / (beta - alpha)
    mask_middle = (r_m >= alpha) & (r_m <= beta)
    mask_high = r_m > beta
    gamma[mask_middle] = (r_m[mask_middle] - alpha) / (beta - alpha)
    gamma[mask_high] = 1.0

    # NTK-by-parts interpolation
    # theta_m' = (1 - gamma(r(m)) * theta_m / s + gamma(r(m)) * theta_m
    theta_yarn = (1 - gamma) * (theta_m / s) + gamma * theta_m

    # Attention scaling temperature
    # mscale = sqrt(1/t) = 0.1 * ln(s) + 1 => 
    mscale = 0.1 * math.log(s) + 1.0 if s > 1 else 1.0

    return theta_yarn, mscale

An example to apply YaRN in a RoPE:

In [16]:
def apply_yarn_rope(
    q: torch.Tensor,
    k: torch.Tensor,
    positions: torch.Tensor,
    max_seq_len: int,
    orig_seq_len: int = 2048,
    base: float = 10000.0,
    alpha: float = 1.0,
    beta: float = 32.0
):
    head_dim = q.size(-1)
    device = q.device

    # Get theta_yarn, mscale
    theta_m_yarn, mscale = build_yarn_rope_cache(
        head_dim, max_seq_len, orig_seq_len, base, alpha, beta, device, q.dtype
    )

    # RoPE
    # Compute rotation angles: phi_m(p) = p * theta_m'
    # shape: (positions, head_dim // 2)
    angles = positions[:, None] * theta_m_yarn[None, :]
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    ## Split q, k into even/odd pairs
    q_even = q[..., 0::2]
    q_odd = q[..., 1::2]
    k_even = k[..., 0::2]
    k_odd = k[..., 1::2]
    ## Rotation matrix
    q_rotated_even = cos * q_even - sin * q_odd
    q_rotated_odd = sin * q_even + cos * q_odd
    k_rotated_even = cos * k_even - sin * k_odd
    k_rotated_odd = sin * k_even + cos * k_odd
    ## Interleave back
    q_rotated = torch.empty_like(q)
    q_rotated[..., 0::2] = q_rotated_even
    q_rotated[..., 1::2] = q_rotated_odd
    k_rotated = torch.empty_like(k)
    k_rotated[..., 0::2] = k_rotated_even
    k_rotated[..., 1::2] = k_rotated_odd

    # Apply temperature mscale
    q_rotated = q_rotated * mscale
    k_rotated = k_rotated * mscale

    return q_rotated, k_rotated



An example to use YaRN RoPE in MHA:

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class YaRNRopeMHA(nn.Module):
    def __init__(self, d_model, num_heads, orig_max_length = 2048, 
                 dropout=0.1, alpha=1.0, beta=32.0, base=10000.0):
        super().__init__()
        self.d_model=d_model
        self.num_heads=num_heads
        self.head_dim = d_model // num_heads
        self.orig_max_length = orig_max_length
        self.alpha = alpha
        self.beta = beta
        self.base = base

        # Initialize Q, K, V, output projections
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        # Initialize dropout
        self.dropout = nn.Dropout(dropout)
    
    def _apply_yarn_rope(self, x, positions, max_seq_len):
        device = x.device

        theta_m_yarn, mscale = build_yarn_rope_cache(
            self.head_dim, 
            max_seq_len, self.orig_max_length, 
            self.base, self.alpha, self.beta, 
            device, x.dtype
        )
        
        # positions: (batch_size, seq_len) -> angles: (batch_size, seq_len, head_dim//2)
        angles = positions.unsqueeze(-1).float() * theta_m_yarn.unsqueeze(0).unsqueeze(0)
        cos = torch.cos(angles)
        sin = torch.sin(angles)

        # Add num_heads dimension: (batch_size, seq_len, head_dim//2) -> (batch_size, seq_len, 1, head_dim//2)
        cos = cos.unsqueeze(2)
        sin = sin.unsqueeze(2)
        
        x_even = x[..., 0::2]
        x_odd = x[..., 1::2]
        x_rotated_even = cos * x_even - sin * x_odd
        x_rotated_odd = sin * x_even + cos * x_odd
        x_rotated = torch.empty_like(x)
        x_rotated[..., 0::2] = x_rotated_even
        x_rotated[..., 1::2] = x_rotated_odd
        x_rotated = x_rotated * mscale

        return x_rotated
    
    def forward(self, x, positions, max_seq_len=None, mask=None):
        batch_size, seq_len, d_model = x.size()
        
        if max_seq_len is None:
            max_seq_len = self.orig_max_length
        
        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

        # Apply YaRN RoPE to Q, K
        Q = self._apply_yarn_rope(Q, positions, max_seq_len)
        K = self._apply_yarn_rope(K, positions, max_seq_len)

        # Parallel multi-head attention
        Q = Q.transpose(1,2)
        K = K.transpose(1,2)
        V = V.transpose(1,2)

        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        attn_values = torch.matmul(attn_weights, V)

        # Reshape
        output = attn_values.transpose(1,2).reshape(batch_size, seq_len, d_model)
        output = self.out_proj(output)

        return output

Test extending context length:

In [20]:
# model config
d_model = 512
num_heads = 8
orig_seq_len = 2048

mha = YaRNRopeMHA(d_model, num_heads, orig_seq_len)

batch_size = 2
extended_seq_len = 8192 # 4x

# Input
x = torch.randn(batch_size, extended_seq_len, d_model)
positions = torch.arange(extended_seq_len).unsqueeze(0).expand(batch_size, -1)

# Forward pass with extended context
output = mha(x, positions, max_seq_len=extended_seq_len)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Context extension: {extended_seq_len / orig_seq_len}x")


Input shape: torch.Size([2, 8192, 512])
Output shape: torch.Size([2, 8192, 512])
Context extension: 4.0x


## References

- EleutherAI blog: [Extending the RoPE](https://blog.eleuther.ai/yarn/) 2023/11/13