In [6]:
import torch


def Computing_RoPE_params(theta: int, dim: int, seq_len: int, dtype=torch.bfloat16):
    """
    Computing the rotate angle for each embedding vector in the sequence.
    Args:
        theta: a parameter for adjusting the rotate speed.
        dim: dimension of the vectors.
        seq_len: the sequence length.
    Output: the computed params for each vector in the sequence.
    """
    freq = theta ** (-torch.arange(0, dim // 2) / (dim // 2))
    freq = freq.to(dtype)
    positions = freq[None, :] * torch.arange(seq_len)[:, None].to(dtype)
    return positions.cos(), positions.sin()

def Apply_RoPE(x: torch.Tensor, cos, sin):
    """
    Apply RoPE to queries or keys.
    Args:
        x:   [batch, head, seq_len, dim]
        cos: [seq_len, dim/2]
        sin: [seq_len, dim/2]
    Returns:
        x_rot: [batch, head, seq_len, dim]
    """
    x1, x2 = x[..., ::2], x[..., 1::2]  # [B, H, L, d_h/2]

    # match shape for broadcasting
    cos = cos[None, None, :, :]  # [1,1,L,d_h/2]
    sin = sin[None, None, :, :]

    x_rot = torch.stack([x1 * cos - x2 * sin,
                         x1 * sin + x2 * cos], dim=-1)
    return x_rot.flatten(-2)

B, H, L, d_h = 2, 4, 128, 64
x = torch.randn(B, H, L, d_h, dtype=torch.bfloat16, device="cpu")

cos, sin = Computing_RoPE_params(theta=10000, dim=d_h, seq_len=L, dtype=torch.bfloat16)
x_rope = Apply_RoPE(x, cos, sin)

print(x_rope.shape)  # [2, 4, 128, 64]

torch.Size([2, 4, 128, 64])
