# 手撕 RoPE

本质就是将 Q、K矩阵 乘以一个旋转矩阵

![](imgs\RoPE.png)

In [None]:
import torch

### 计算 $cos(mf_i)$ 和 $sin(mf_i)$
 - $m$ 是当前 token 在其 sequence 中的位置，如 "我爱学习" 中，"学"对应的 $m$ 为 2
 - $i$ 是当前隐藏维度 hidden_size 中的位置，取值范围是 [0, $\frac{d}{2} - 1$]

In [None]:
def compute_default_rope_parameters(dim=896):
    base = 1000000.0  # Qwen2.5: rope_theta

    # Compute the inverse frequencies
    inv_freq = 1.0 / (
        base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
    )

    return inv_freq


def rotary_emb(hidden_size, position_ids):
    inv_freq = compute_default_rope_parameters(dim=hidden_size)
    inv_freq_expanded = (
        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
    )
    position_ids_expanded = position_ids[:, None, :].float()

    freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
    emb = torch.cat((freqs, freqs), dim=-1)
    cos = emb.cos()
    sin = emb.sin()

    return cos, sin

### 对 Q/K 乘以由 $cos(mf_i)$ 和 $sin(mf_i)$ 构成的旋转矩阵 (即进行旋转位置编码)

In [None]:
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [None]:
BATCH_SIZE = 2
SEQ_LEN = 4
HIDDEN_SIZE = 32
NUM_HEADS = 4


position_ids = torch.tensor(
    [
        [0, 1, 2, 3],
        [0, 1, 2, 3],
    ]
)
cos, sin = rotary_emb(hidden_size=32, position_ids=position_ids)
print(f"cos/sin shape: {cos.shape}")

query_states = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HIDDEN_SIZE)
key_states = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HIDDEN_SIZE)

print(f"query_states shape is: {query_states.shape}")
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
print(f"query_states(after rotary_pos_emb) shape is: {query_states.shape}")