<a href="https://colab.research.google.com/github/karankulshrestha/ai-notebooks/blob/main/RoPE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch

In [3]:
torch.arange(0, 8, 2)

tensor([0, 2, 4, 6])

In [4]:
base = 10000
dim = 4
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq

tensor([1.0000, 0.0100])

In [5]:
t = torch.arange(3).type_as(inv_freq)
t

tensor([0., 1., 2.])

In [6]:
f = torch.einsum("i,j -> ij", t, inv_freq)
f

tensor([[0.0000, 0.0000],
        [1.0000, 0.0100],
        [2.0000, 0.0200]])

In [7]:
emb = torch.cat((f, f), dim=1)
emb

tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0100, 1.0000, 0.0100],
        [2.0000, 0.0200, 2.0000, 0.0200]])

In [8]:
emb.cos()
# 3 → tokens
# 4 → embedding dimensions

tensor([[ 1.0000,  1.0000,  1.0000,  1.0000],
        [ 0.5403,  0.9999,  0.5403,  0.9999],
        [-0.4161,  0.9998, -0.4161,  0.9998]])

In [9]:
emb.cos()[:, None, None, :]

tensor([[[[ 1.0000,  1.0000,  1.0000,  1.0000]]],


        [[[ 0.5403,  0.9999,  0.5403,  0.9999]]],


        [[[-0.4161,  0.9998, -0.4161,  0.9998]]]])

In [10]:
class Rotary(torch.nn.Module):
  def __init__(self, dim, base=10000):
    super().__init__()
    inv_freq = 1.0 / (base ** torch.arange(0, dim, 2) / dim)
    self.register_buffer("inv_freq", inv_freq)
    self.seq_len_cached = None
    self.cos_cached = None
    self.sin_cached = None

  def forward(self, x, seq_dim=1):
    seq_len = x.shape[seq_dim]
    if seq_len != self.seq_len_cached:
      self.seq_len_cached = seq_len
      t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) # indexes of the word or token in sentence
      freqs = torch.einsum("i,j -> ij", t, self.inv_freq) # multiply each word or token index with the freq of rotation for each pair of embeddings [x, y]
      emb = torch.cat((freqs, freqs), dim=-1).to(x.device) # concatenate the freqs for individual x and y in pair [x, y] of embedding dims
      self.cos_cached = emb.cos()[:, None, None, :]
      self.sin_cached = emb.sin()[:, None, None, :]
    return self.cos_cached, self.sin_cached


In [11]:
def rotate_half(x):
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
  return torch.cat(
      (-x2, x1), dim=x1.ndim - 1
  )

In [12]:
@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
  return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)