# Extending context window of large language models via position interpolation

```{note}
We present Position Interpolation (PI) that extends the context window sizes of
RoPE-based pretrained LLMs such as LLaMA models to up to 32768 with minimal fine-tuning (within 1000 steps), while
demonstrating strong empirical results on various tasks that require long context. 
Meanwhile, the extended model by Position Interpolation
preserve quality relatively well on tasks within its original context window.
```

## Background: Rotary Position Embedding (RoPE)

Transformer models require explicit positional information to be injected, typically in the form of
positional encodings, to represent the order of inputs. We consider Rotary Position Embedding, which is the position encoding used in the LLaMA model.

Given a position index $m\in[0, c)$ and an embedding vector $\mathbf{x} := [x_0, x_1, . . . , x_{d−1}]^{\intercal}$, where
$d$ is the dimension of the attention head, RoPE defines a vector-valued complex function $f(\mathbf{x}, m)$ as
follows

$$f(\mathbf{x}, m) = \left[(x_{0} + ix_{1})e^{im\theta_{0}}, (x_{2} + ix_{3})e^{im\theta_{1}},\dots,(x_{d-2} + ix_{d-1})e^{im\theta_{d/2-1}}\right]^{\intercal}$$

where $i:=\sqrt{-1}$ is the imaginary unit and $\theta_{j}=10000^{-2j/d}$. Using RoPE, the self-attention score

$$
\begin{aligned}
a(m,n) &= \text{Re}\left \langle f(\mathbf{q}, m), f(\mathbf{k}, n)  \right \rangle \\
&= \text{Re}\left[\sum_{j=0}^{d/2-1}(q_{2j} + iq_{2j+1})(k_{2j} - ik_{2j+1})e^{i(m-n)\theta_{j}}\right]\\
&= \sum_{j=0}^{d/2-1}(q_{2j}k_{2j} + q_{2j+1}k_{2j+1})\cos((m-n)\theta_{j}) + (q_{2j}k_{2j+1} - q_{2j+1}k_{2j})\sin((m-n)\theta_{j})\\
&=: a(m-n)
\end{aligned}
$$

is only dependent on relative position $m− n$ through trigonometric functions. Here $\mathbf{q}$ and $\mathbf{k}$ are the
query and key vector for a specific attention head. At each layer, RoPE is applied on both query and
key embeddings for computing attention scores.

## Introduction

Large language models (LLMs) typically come with a pre-defined context window size. For example,
inputs to LLaMA models must be fewer than 2048 tokens. This pre-set
context window limit is frequently exceeded in applications such as conducting long conversations,
summarizing long documents, or executing long-term planning. For these applications, LLMs with
longer context windows are preferred. However, training an LLM from scratch with long context
windows requires significant investments. This naturally leads to a question: Can we extend the
context window of an existing pre-trained LLM?

One straightforward approach is to fine-tune an existing pre-trained Transformer with a longer context
window. However, empirically, we found that models trained this way adapt to long context
windows very slowly.

While certain techniques such as ALiBi and LeX enable length
extrapolation of Transformers, i.e. train on short context windows and inference on longer ones,
many existing pre-trained LLMs, including LLaMA, use positional encodings
that have weak extrapolation properties.

In this work, we introduce Position Interpolation to enable context window extensions for certain
existing pre-trained LLMs, including LLaMA. The key idea is, instead of extrapolation, we directly
down-scale the position indices so that the maximum position index matches the previous context
window limit in the pre-training stage.

![](../images/extending.png)

## Direct extrapolation

Ideally, we want to see the model trained on a context window of size $L = 2048$ to still work
reasonably well on longer context window, but may not have the capability to leverage information
that appears beyond L. For example, to answer a question located at 3000, the model trained on
maximal window size of $L = 2048$ cannot leverage evidences provided at location 0, but still
can leverage the evidences provided at location 2900. In contrast, in reality we see catastrophic
behaviors, i.e., question at location 3000 cannot be answered correctly, even if the evidences are
located at location 2900.

What is the reason behind? How could this happen if the attention score $a(m−n)$ decays as the relative
distance $|m − n|$ increases, according to the RoPE paper? It turns out that the upper bound derived in the RoPE paper
may be too loose. In fact, if we treat all trigonometric functions $e^{is\theta_{j}}$ as basis functions, 
and think about the self-attention score as basis expansion as the following:

$$a(s) = \text{Re}\left[\sum_{j=0}^{d/2-1}h_{j}e^{is\theta_{j}}\right]$$

Now the the issue becomes clear: $a(s)$ can be small in magnitude in the range of $[0, 2048]$, but gives huge values out of the
region. The underlying reason is that the trigonometric family $e^{is\theta_{j}}$ (with sufficiently large $d$) is
a universal approximator and can fit any arbitrary functions. Therefore, for $a(s)$, there always exist
coefficients ${h_j}$ (i.e. key and query) that corresponds to small function values in $[0, 2048]$ but
much larger in regions beyond.

## Positional Interpolation

Instead of extrapolate the attention score to $s>L$, we replace RoPE $f$ by $f'$ defined as follows

$$f'(\mathbf{x}, m) = f(\mathbf{x}, \frac{mL}{L'})$$

where $L'$ is the longer context window. We call this transformation on the position encoding Position Interpolation. In this step, we reduce
position indices from $[0,L')$ to $[0,L)$ to match the original range of indices before computing RoPE. Since we align the ranges of position indices and relative distances before
and after extension, we mitigate the effect on attention score computation due to context window
extensions, which can allow the model easier to adapt.

(**Interpolation bound**). For attention score $a(s) = \text{Re}\left[\sum_{j=0}^{d/2-1}h_{j}e^{is\theta_{j}}\right]$, where $\theta_{j} = c^{-2j/d}$, its interpolation value $a(s)$ for $s\in[s_{1},s_{2}]$ is bounded as follows:

$$|a(s) - a_{\text{linear}}(s)|\le d(\max_{j}|h_{j}|)\frac{(s-s_{1})(s_{2}-s)}{8\ln{c}}$$

where $a_{\text{linear}}(s)$ is the linear interpolation of two grid point $a(s_1)$ and $a(s_2)$ that are known to
behave well, enforced by LLM pre-training:

$$a_{\text{linear}}(s):=(1-\lambda(s))a(s_{1}) + \lambda(s)a(s_{2}),\quad \lambda(s):=\frac{s-s_{1}}{s_{2}-s_{1}}$$

Notably, our method of rescaling of position indices does not introduce extra weight, or modify
the model architecture in any way. This makes it attractive in practical applications, since most
infrastructure and optimization for the original model can be reused after the extension.

**Fine-tuning.** We can further fine-tune the interpolated model using the next token prediction task
with interpolated position encodings on the extended context window size using a pre-training corpus
such as the Pile. In the next section, we show that our fine-tuning process
only needs tens to hundreds thousands of examples.

## Experiments

## Llama implementation

In [1]:
import torch
from torch import nn
from typing import Tuple


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        end (int): End index for precomputing frequencies.
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.

    Returns:
        torch.Tensor: Precomputed frequency tensor with complex exponentials.

    
        

    """
    # (theta_0, theta_1, ..., theta_{d/2-1})
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    # torch.polar(abs, angle)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

The $m$-th row of `freqs_cis` is:

$$\left[e^{im\theta_{0}}, e^{im\theta_{1}}, \dots, e^{im\theta_{d/2-1}}\right]$$

where $\theta_{j}=10000^{-2j/d}$.

In [2]:
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    """
    Reshape frequency tensor for broadcasting it with another tensor.

    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
    for the purpose of broadcasting the frequency tensor during element-wise operations.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.

    Returns:
        torch.Tensor: Reshaped frequency tensor.

    Raises:
        AssertionError: If the frequency tensor doesn't match the expected shape.
        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings.
        xk (torch.Tensor): Key tensor to apply rotary embeddings.
        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.

        

    """
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

one row of `xq_`:

$$\left[(x_{0} + ix_{1}), (x_{2} + ix_{3}),\dots,(x_{d-2} + ix_{d-1})\right]$$

In [3]:
xq = torch.arange(8).reshape(1, 2, 4)
xq

tensor([[[0, 1, 2, 3],
         [4, 5, 6, 7]]])

In [4]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xq_

tensor([[[0.+1.j, 2.+3.j],
         [4.+5.j, 6.+7.j]]])

In [5]:
freqs_cis = precompute_freqs_cis(4, 2)
freqs_cis

tensor([[1.0000+0.0000j, 1.0000+0.0000j],
        [0.5403+0.8415j, 0.9999+0.0100j]])

In [6]:
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
freqs_cis

tensor([[[1.0000+0.0000j, 1.0000+0.0000j],
         [0.5403+0.8415j, 0.9999+0.0100j]]])

In [7]:
xq_ * freqs_cis

tensor([[[ 0.0000+1.0000j,  2.0000+3.0000j],
         [-2.0461+6.0674j,  5.9297+7.0596j]]])

In [8]:
torch.view_as_real(xq_ * freqs_cis)

tensor([[[[ 0.0000,  1.0000],
          [ 2.0000,  3.0000]],

         [[-2.0461,  6.0674],
          [ 5.9297,  7.0596]]]])