In [1]:
import torch
import math

In [54]:

def get_sinusoidal_embeddings( n_positions, dim):
    """Generate sinusoidal positional embeddings."""
    position = torch.arange(n_positions, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
    sinusoidal_emb = torch.zeros((n_positions, dim))
    sinusoidal_emb[:, 0::2] = torch.sin(position * div_term)
    sinusoidal_emb[:, 1::2] = torch.cos(position * div_term)
    return sinusoidal_emb

def apply_rotary_position_embeddings(sinusoidal_pos, q, k):
    # Split the sinusoidal_pos into sin and cos parts
    sin, cos = sinusoidal_pos.chunk(2, dim=-1)
    # Apply the rotary embeddings to the query and key
    q_rot = torch.stack((-q[..., 1::2], q[..., ::2]), dim=-1)
    k_rot = torch.stack((-k[..., 1::2], k[..., ::2]), dim=-1)
    q_rot = torch.reshape(q_rot, q.shape[:-1] + (q.shape[-1]//2, 2)) * torch.stack((cos, sin), dim=-1)
    k_rot = torch.reshape(k_rot, k.shape[:-1] + (k.shape[-1]//2, 2)) * torch.stack((cos, sin), dim=-1)
    q_rot = torch.reshape(q_rot, q.shape)
    k_rot = torch.reshape(k_rot, k.shape)
    return q_rot, k_rot


In [46]:
emb = get_sinusoidal_embeddings(1024, 768)
emb.shape

torch.Size([1024, 768])

In [47]:
emb

tensor([[ 0.0000,  1.0000,  0.0000,  ...,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.8543,  ...,  0.9997, -0.8791,  0.4767],
        [ 0.9093, -0.4161,  0.8880,  ...,  0.9986, -0.8381, -0.5455],
        ...,
        [ 0.0176, -0.9998,  0.3572,  ...,  0.2463,  0.0208, -0.9998],
        [-0.8318, -0.5550, -0.6122,  ..., -0.1584,  0.8870, -0.4618],
        [-0.9165,  0.4001, -0.9937,  ...,  0.4197,  0.8310,  0.5563]])

In [48]:
emb[:, 0::2].shape

torch.Size([1024, 384])

In [49]:
emb[:, 1::2].shape

torch.Size([1024, 384])

In [50]:
a = torch.tensor([1,2,3,4,5,6,7,8,9,10])
a[1::2]


tensor([ 2,  4,  6,  8, 10])

In [51]:
a.chunk(2, dim=-1)

(tensor([1, 2, 3, 4, 5]), tensor([ 6,  7,  8,  9, 10]))

In [52]:
result = get_sinusoidal_embeddings(1024, 768)
result[1,:]


tensor([ 8.4147e-01,  5.4030e-01,  8.5434e-01,  5.1972e-01,  8.6699e-01,
         4.9832e-01,  8.7940e-01,  4.7608e-01,  8.9152e-01,  4.5298e-01,
         9.0331e-01,  4.2900e-01,  9.1471e-01,  4.0412e-01,  9.2567e-01,
         3.7832e-01,  9.3615e-01,  3.5159e-01,  9.4609e-01,  3.2391e-01,
         9.5541e-01,  2.9527e-01,  9.6407e-01,  2.6565e-01,  9.7198e-01,
         2.3505e-01,  9.7908e-01,  2.0347e-01,  9.8529e-01,  1.7090e-01,
         9.9052e-01,  1.3735e-01,  9.9470e-01,  1.0282e-01,  9.9773e-01,
         6.7315e-02,  9.9952e-01,  3.0865e-02,  9.9998e-01, -6.5122e-03,
         9.9900e-01, -4.4787e-02,  9.9647e-01, -8.3922e-02,  9.9230e-01,
        -1.2387e-01,  9.8636e-01, -1.6459e-01,  9.7855e-01, -2.0600e-01,
         9.6875e-01, -2.4803e-01,  9.5684e-01, -2.9061e-01,  9.4270e-01,
        -3.3363e-01,  9.2622e-01, -3.7699e-01,  9.0727e-01, -4.2056e-01,
         8.8573e-01, -4.6420e-01,  8.6150e-01, -5.0776e-01,  8.3446e-01,
        -5.5106e-01,  8.0452e-01, -5.9393e-01,  7.7

In [53]:
result[:,0]

tensor([ 0.0000,  0.8415,  0.9093,  ...,  0.0176, -0.8318, -0.9165])

In [55]:
dim=768
torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))

tensor([1.0000e+00, 9.7630e-01, 9.5316e-01, 9.3057e-01, 9.0852e-01, 8.8699e-01,
        8.6596e-01, 8.4544e-01, 8.2540e-01, 8.0584e-01, 7.8674e-01, 7.6810e-01,
        7.4989e-01, 7.3212e-01, 7.1477e-01, 6.9783e-01, 6.8129e-01, 6.6515e-01,
        6.4938e-01, 6.3399e-01, 6.1897e-01, 6.0430e-01, 5.8997e-01, 5.7599e-01,
        5.6234e-01, 5.4901e-01, 5.3600e-01, 5.2330e-01, 5.1090e-01, 4.9879e-01,
        4.8697e-01, 4.7543e-01, 4.6416e-01, 4.5316e-01, 4.4242e-01, 4.3193e-01,
        4.2170e-01, 4.1170e-01, 4.0195e-01, 3.9242e-01, 3.8312e-01, 3.7404e-01,
        3.6517e-01, 3.5652e-01, 3.4807e-01, 3.3982e-01, 3.3177e-01, 3.2390e-01,
        3.1623e-01, 3.0873e-01, 3.0142e-01, 2.9427e-01, 2.8730e-01, 2.8049e-01,
        2.7384e-01, 2.6735e-01, 2.6102e-01, 2.5483e-01, 2.4879e-01, 2.4289e-01,
        2.3714e-01, 2.3152e-01, 2.2603e-01, 2.2067e-01, 2.1544e-01, 2.1034e-01,
        2.0535e-01, 2.0049e-01, 1.9573e-01, 1.9110e-01, 1.8657e-01, 1.8214e-01,
        1.7783e-01, 1.7361e-01, 1.6950e-

In [67]:
math.log(10000.0)

9.210340371976184