<a href="https://colab.research.google.com/github/aju22/RoPE-PyTorch/blob/main/RoPE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://zhuanlan.zhihu.com/p/702114274

## **Rotary Positional Embeddings (RoPE)**

*Rotary Position Embedding (RoPE) is a technique used in transformer-based models to incorporate positional information into token representations. Unlike traditional positional encodings that rely on sine and cosine functions, RoPE utilizes rotation matrices to encode both absolute and relative positional information. This method was proposed as a way to enhance the effectiveness of positional embeddings in transformers.*

## **The Math**

*Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the d features as d/2
pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.*

$$\text{Let } x_m^{(1)} \text{ and } x_m^{(2)} \text {be two features of the key or}\\
 \text{query of any head at position m.}\\
 \text{For simplicity assume x has only two features. Then the transformation is:}$$


$$\text{RoPE}(x_m^{(1)}, x_m^{(2)}, m) =
\begin{bmatrix}
\cos(m\theta) & -\sin(m\theta) \\
\sin(m\theta) & \cos(m\theta)
\end{bmatrix}
\begin{bmatrix}
x_m^{(1)} \\
x_m^{(2)}
\end{bmatrix}
= \begin{bmatrix}
x_m^{(1)}\cos(m\theta) - x_m^{(2)}\sin(m\theta) \\
x_m^{(2)}\cos(m\theta) + x_m^{(1)}\sin(m\theta)
\end{bmatrix}$$


$$ \Theta = \theta_i = 10,000^{-\frac{2(i-1)}{d}} , where ( i \in [1, 2, ..., 2d] ) \text{ for the ( 2d ) pairs of features.}$$

## **The Intuition**

We would like to find a positional encoding function ***f(x,l)*** for ***x*** and its position ***l*** such that, for two items ***q*** and ***k*** and at positions ***m*** and ***n***, the innner product between ***f(q,m)*** and ***f(k,n)*** is sensitive only to the values ***q*** and ***k*** and their relative position ***m - n***.

 A key piece of information is the geometric definition of the dot product between Euclidean vectors:

 $$q \cdot k = |q| |k| \cos \theta$$

 The RoPE embedding achieves this:

 \begin{align}
\mathrm{RoPE}(x, m) &= xe^{mi\theta} \\
\langle \mathrm{RoPE}(q_j, m), \mathrm{RoPE}(k_j, n)\rangle &= \langle q_j e^{mi\theta}, k_j e^{ni\theta} \rangle \\
&= q_j k_j e^{mi\theta} \overline{e^{ni\theta}} \\
&= q_j k_j e^{(m - n)i\theta} \\
&= \mathrm{RoPE}(q_j k_j, m - n)
\end{align}



In [None]:
def rotate_half(x):
    x1 = x[..., : x.shape[-1]//2]
    x2 = x[..., x.shape[-1]//2:]
    return torch.cat([-x2,x1], dim=-1)

def apply_rope(q, k, seq_len, dim, base = 10000, device = None):
    self.dim = dim
    self.base = base
    inv_freq = 1.0 / (self.base ** (torch.arrage(0, self.dim, 2).float().to(device)/ self.dim))
    t = torch.arrage(self.seq_len, device = device, dtype=inv_freq.dtype)
    freqs = torch.outer(t, inv_freq)
    emb = torch.cat([freqs, freqs], dim = -1)
    cos = emb.cos()[:seq_len].to(x.dtype)
    sin = emb.sin()[seq_len:].to(x.dtype)
    q_em = (q * cos) + (rotate_half(q) * sin)
    k_em = (k * cos) + (rotate_half(q) * sin)
    return q_em, k_em

In [None]:
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base

        """ 这里即是在创建θ_i=10000^(−2i/d) """
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        # # 生成一个从0到最大序列长度-1的张量，用于后续计算
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        """ 这里即是m*θ_i """
        freqs = torch.outer(t, self.inv_freq)  # 此处torch.outer等价于torch.matmul(t.unsqueeze(-1), self.inv_freq.unsqueeze(0))
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        """ 这里不是按照原文中说的[mθ_0, mθ_0, mθ_1, mθ_1, mθ_2, mθ_2, ...], 而是[mθ_0, mθ_1, mθ_2, ..., mθ_0, mθ_1, mθ_2...] """
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    """
    对输出进行一次“旋转”, 即式(31)中所示的向量[-q_1, q_0, -q_3, ....]
    不过由于cos_sin_cache进行了修改, 这里也修改为了[-q_d//2, -q_d//2+1, ...., q_0, q_1, ...],
    这种改变并不会影响式(30)的成立, 可以把Rm^T * Rn计算出来, 其实并没有变化
    """
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    # 式(31)中所示
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

if __name__ == "__main__":
    position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)  # 从past_key_values_length开始计算是因为KV Cache已经包含了位置编码
    rotary_emb = Qwen2RotaryEmbedding(
        self.head_dim,
        max_position_embeddings=self.max_position_embeddings,
        base=self.rope_theta,
    )
    cos, sin = rotary_emb(value_states, seq_len=seq_length + past_key_values_length)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

    # 后面就是正常计算attention weight了
    # repeat k/v heads if n_kv_heads < n_heads
    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

In [None]:
class RotaryPositionalEmbeddings(nn.Module):

  def __init__(self, d: int, base: int = 10_000):

    super().__init__()
    self.base = base
    self.d = d
    self.cos_cached = None
    self.sin_cached = None

  def _build_cache(self, x: torch.Tensor):

    if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
      return

    seq_len = x.shape[0]

    theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) # THETA = 10,000^(-2*i/d) or 1/10,000^(2i/d)

    seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) #Position Index -> [0,1,2...seq-1]

    idx_theta = torch.einsum('n,d->nd', seq_idx, theta)  #Calculates m*(THETA) = [ [0, 0...], [THETA_1, THETA_2...THETA_d/2], ... [seq-1*(THETA_1), seq-1*(THETA_2)...] ]

    idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) # [THETA_1, THETA_2...THETA_d/2] -> [THETA_1, THETA_2...THETA_d]

    self.cos_cached = idx_theta2.cos()[:, None, None, :] #Cache [cosTHETA_1, cosTHETA_2...cosTHETA_d]
    self.sin_cached = idx_theta2.sin()[:, None, None, :] #cache [sinTHETA_1, sinTHETA_2...sinTHETA_d]

  def _neg_half(self, x: torch.Tensor):

    d_2 = self.d // 2 #

    return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) # [x_1, x_2,...x_d] -> [-x_d/2, ... -x_d, x_1, ... x_d/2]


  def forward(self, x: torch.Tensor):

    self._build_cache(x)

    neg_half_x = self._neg_half(x)

    x_rope = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) # [x_1*cosTHETA_1 - x_d/2*sinTHETA_d/2, ....]

    return x_rope

*We pair up the positive cosines and negative half sines to get the final embeddings as follows*:

$$
\begin{bmatrix}
x_m^{(i)} \\
x_m^{(i+d/2)}
\end{bmatrix}
= \begin{bmatrix}
x_m^{(i)}\cos(m\theta_i) - x_m^{(i+d/2)}\sin(m\theta_i) \\
x_m^{(i+d/2)}\cos(m\theta_i) + x_m^{(i)}\sin(m\theta_i)
\end{bmatrix}$$

#Test

In [None]:
x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
x = x[:, None, None, :]

In [None]:
RotaryPositionalEmbeddings(4)(x)

tensor([[[[  1.0000,   2.0000,   3.0000,   4.0000]]],


        [[[ -2.8876,   4.9298,   6.6077,   7.0496]]],


        [[[-11.0967,   7.7984,   2.6198,  10.1580]]]])