#### Simple RoPE Implementation:

- Resources:
    - [Llama explained video by Umar Jamil](https://www.youtube.com/watch?v=Mn_9W1nCFLo): Refer to RoPE section starting at 24:30 timestamp 
    - [Efficient NLP's RoPE explanation](https://www.youtube.com/watch?v=o29P0Kpobz0)
    - [RoPE paper](https://arxiv.org/abs/2104.09864)
    - [Transformers RoPE implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L73)
    - [ROPE Notes](https://github.com/garg-aayush/building-from-scratch/blob/main/gpt-2/notes/RoPE.md) 

> Taken from: https://github.com/garg-aayush/building-from-scratch/blob/main/gpt-2/play-nbs/rope.ipynb



In [None]:
import torch
import torch.nn as nn

In [None]:
# RoPE module that rotates query/key feature pairs by position-dependent angles
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048):
        super().__init__()
        # dim: per-head size; max_seq_len: how many positions to precompute
        assert dim % 2 == 0, f"RotaryEmbedding requires even head_dim, got {dim}"
        self.dim = dim
        self.max_seq_len_cached = max_seq_len
        # inverse frequencies for each pair
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)  # keep on device, not trainable
        # positions [0..max_seq_len-1]
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device).type_as(self.inv_freq)
        # outer product → angles per (position, pair)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # duplicate to match two halves of the head dimension
        emb = torch.cat((freqs, freqs), dim=-1)
        # cache sin/cos for all positions/pairs, shaped [1,1,T,D]
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :])

    def forward(self, x):
        # x: [bs, num_heads, seq_len, head_dim];
        seq_len = x.shape[2]
        cos = self.cos_cached[:, :, :seq_len, :].to(dtype=x.dtype, device=x.device)
        sin = self.sin_cached[:, :, :seq_len, :].to(dtype=x.dtype, device=x.device)

        def rotate_half(x_):
            # rotate by swapping halves and negating the second
            x1 = x_[..., : self.dim // 2]
            x2 = x_[..., self.dim // 2 :]
            # [-x2, x1]: a 90-degree rotation in each 2D feature pair
            return torch.cat((-x2, x1), dim=-1)

        # standard RoPE: apply rotation with per-position cos/sin
        return x * cos + rotate_half(x) * sin


In [None]:

# --- Device Setup ---
device = "cpu"
if torch.backends.mps.is_available():
    device = "mps"
print(f"--- Using device: {device} ---")

# --- Hyperparameters for testing ---
batch_size = 2
num_heads = 4
seq_len = 16
head_dim = 32

# --- Test Initialization ---
print("--- Initializing Test ---")
# Instantiate the RoPE module and move it to the selected device
rope = RotaryEmbedding(dim=head_dim, max_seq_len=seq_len).to(device)
# Create a random input tensor on the selected device
x = torch.randn(batch_size, num_heads, seq_len, head_dim).to(device)
print("Test setup complete.")
print("-" * 25)

# --- Apply RoPE ---
y = rope(x)

# --- Test 1: Shape check ---
print("--- Test 1: Shape Check ---")
print(f"Input shape:  {x.shape}")
print(f"Output shape: {y.shape}")
assert x.shape == y.shape, "Shape mismatch after RoPE application"
print("✅ Shape check passed.")
print("-" * 25)

# --- Test 2: Norm preservation check ---
print("--- Test 2: Norm Preservation Check ---")
# RoPE is a rotation, so it should preserve the L2 norm of the vectors.
norm_x = torch.linalg.norm(x, dim=-1)
norm_y = torch.linalg.norm(y, dim=-1)
# Check if norms are close with a small tolerance
assert torch.allclose(norm_x, norm_y, atol=1e-6), "Norm preservation failed"
print("✅ Norm preservation check passed.")
print("-" * 25)

# --- Test 3: Relative position property ---
print("--- Test 3: Relative Position Property Check ---")
# The dot product between two rotated vectors should only depend on their relative position.
# Let's check if dot product of q at pos m and k at pos n is the same as
# that of q at pos m+d and k at pos n+d for some distance d.

# Pick two positions m, n and a shift d
m, n = 2, 8
d = 4
assert m + d < seq_len and n + d < seq_len, "Test positions are out of bounds"

# Get the rotated vectors from the output y at original positions
q_m_rot = y[:, :, m, :]
k_n_rot = y[:, :, n, :]
# Compute their dot product
dot_product1 = torch.sum(q_m_rot * k_n_rot, dim=-1)

# To test the relative property, we rotate the *same* original vectors
# but as if they were at positions m+d and n+d.
# We create a new input tensor for rope where only these positions are non-zero.
x2 = torch.zeros_like(x).to(device)
x2[:, :, m+d, :] = x[:, :, m, :] # original q is now at m+d
x2[:, :, n+d, :] = x[:, :, n, :] # original k is now at n+d

# Apply RoPE to this new sparse tensor
y2 = rope(x2)
q_md_rot = y2[:, :, m+d, :]
k_nd_rot = y2[:, :, n+d, :]
# Compute the dot product of the newly rotated vectors
dot_product2 = torch.sum(q_md_rot * k_nd_rot, dim=-1)

# The dot products should be equal because the relative distance (n-m) is the same as ((n+d)-(m+d)).
assert torch.allclose(dot_product1, dot_product2, atol=1e-6), "Relative position property failed"
print("✅ Relative position property check passed.")
print("-" * 25)

print("\n🎉 All tests passed successfully! 🎉")
