In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns

In [31]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_seq_len=2048, base=10000):
        super().__init__()
        # Calculate inverse frequency for RoPE
        inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
        print("Inv_Freq : ", inv_freq)
        
        t = torch.arange(max_seq_len).float()
        print("T : ", t)
        # Outer product of time steps and frequencies
        freqs = torch.outer(t, inv_freq)
        print("Freq : ", freqs)
        
        # Duplicate frequencies for sin and cos
        emb = torch.cat((freqs, freqs), dim=-1)
        print("sin : ", emb.sin())
        print("cos : ", emb.cos())
        
        # Register buffers to avoid repeated calculations
        self.register_buffer("cos_cached", emb.cos())
        self.register_buffer("sin_cached", emb.sin())

    def forward(self, x):
        # x shape: (Batch, Seq_Len, Heads, Head_Dim)
        seq_len = x.shape[1]
        print('seq len :', seq_len)
        
        # Slice cached values based on current sequence length
        cos = self.cos_cached[:seq_len, :].unsqueeze(0).unsqueeze(2)
        sin = self.sin_cached[:seq_len, :].unsqueeze(0).unsqueeze(2)
        print('sim under forward :', sin)
        print('cos under forward :', cos)
        
        # Apply rotation
        result = (x * cos) + (self._rotate_half(x) * sin)
        print('Result : ', result)
        return result

    def _rotate_half(self, x):
        # Split vector into two parts and rotate (-x2, x1)
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

In [27]:
d_model = 8
seq_len = 4
heads = 4
    
# Dummy Input (All ones to isolate positional effect)
q = torch.ones(1, seq_len, heads, d_model)
k = torch.ones(1, seq_len, heads, d_model)

In [32]:
# Apply RoPE
rope = RotaryPositionalEmbedding(d_model, max_seq_len=seq_len)
q_r, k_r = rope(q), rope(k)

# Calculate Attention Score (For Head 0 only)
score_rope = torch.matmul(q_r[0, :, 0, :], k_r[0, :, 0, :].T)

Inv_Freq :  tensor([1.0000, 0.1000, 0.0100, 0.0010])
T :  tensor([0., 1., 2., 3.])
Freq :  tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03],
        [2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03],
        [3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03]])
sin :  tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8415, 0.0998, 0.0100, 0.0010, 0.8415, 0.0998, 0.0100, 0.0010],
        [0.9093, 0.1987, 0.0200, 0.0020, 0.9093, 0.1987, 0.0200, 0.0020],
        [0.1411, 0.2955, 0.0300, 0.0030, 0.1411, 0.2955, 0.0300, 0.0030]])
cos :  tensor([[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000],
        [ 0.5403,  0.9950,  0.9999,  1.0000,  0.5403,  0.9950,  0.9999,  1.0000],
        [-0.4161,  0.9801,  0.9998,  1.0000, -0.4161,  0.9801,  0.9998,  1.0000],
        [-0.9900,  0.9553,  0.9996,  1.0000, -0.9900,  0.9553,  0.9996,  1.0000]])
seq len : 4
sim under forward : tens

In [33]:
score_rope

tensor([[8.0000, 7.0705, 5.1274, 3.9298],
        [7.0705, 8.0000, 7.0705, 5.1274],
        [5.1274, 7.0705, 8.0000, 7.0705],
        [3.9298, 5.1274, 7.0705, 8.0000]])

In [37]:
a = torch.tensor([1.0000, 0.1000, 0.0100, 0.0010])
b = torch.tensor([0., 1., 2., 3.])

print(torch.outer(b, a))

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03],
        [2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03],
        [3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03]])
