- here's the [og paper](https://arxiv.org/pdf/2104.09864) on rotary positional embeddings

In [3]:
import torch
from typing import Tuple

In [4]:
# functions are from the llama 3 model implementation (https://github.com/meta-llama/llama3/blob/main/llama/model.py)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis  

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f"{freqs_cis.shape} != ({x.shape[1]}, {x.shape[-1]})"
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).reshape(*xq.shape[:-1], -1)
    xk_out = torch.view_as_real(xk_ * freqs_cis).reshape(*xk.shape[:-1], -1)
    return xq_out.type_as(xq), xk_out.type_as(xk)

dim = 2048 // 32 
max_seq_len = 1024 * 2
theta = 500000.0
batch_size = 2
seq_len = 1024
embed_dim = dim

freqs_cis = precompute_freqs_cis(dim, max_seq_len, theta)

print("freqs_cis.size()", freqs_cis.size())

xq = torch.randn(batch_size, max_seq_len, 64)
xk = torch.randn(batch_size, max_seq_len, 64)

print("xq.size()", xq.size(), "xk.size()", xk.size())

(xq_r, xk_r) = apply_rotary_emb(xq, xk, freqs_cis)
print("xq_r.size()", xq_r.size(), "xk_r.size()", xk_r.size())

# | verifying rotary embedding

m = 100 # position to check
k = 0 # first dimension pair (0, 1)

# compute frequency
freq = 1.0 / (theta ** (2 * k / dim))  # θ_k
angle = m * freq  # m * θ_k

# get input values
x_2k = xq[0, m, 2 * k]
x_2k1 = xq[0, m, 2 * k + 1]

# expected output after rotation
cos_angle = torch.cos(torch.tensor(angle))
sin_angle = torch.sin(torch.tensor(angle))
xq_r_expected_2k = x_2k * cos_angle - x_2k1 * sin_angle
xq_r_expected_2k1 = x_2k * sin_angle + x_2k1 * cos_angle

# compare with actual output
print("Expected xq_r[0, m, 2k]:", xq_r_expected_2k.item())
print("Actual xq_r[0, m, 2k]:", xq_r[0, m, 2 * k].item())
print("Expected xq_r[0, m, 2k+1]:", xq_r_expected_2k1.item())
print("Actual xq_r[0, m, 2k+1]:", xq_r[0, m, 2 * k + 1].item())

freqs_cis.size() torch.Size([2048, 32])
xq.size() torch.Size([2, 2048, 64]) xk.size() torch.Size([2, 2048, 64])
xq_r.size() torch.Size([2, 2048, 64]) xk_r.size() torch.Size([2, 2048, 64])
Expected xq_r[0, m, 2k]: -0.31764817237854004
Actual xq_r[0, m, 2k]: -0.31764817237854004
Expected xq_r[0, m, 2k+1]: -0.4906976521015167
Actual xq_r[0, m, 2k+1]: -0.4906976521015167
