In [1]:
import torch

In [2]:
batch_size = 2
seq_len = 3
embed_dim = 4
n_heads = 2

In [3]:
x = torch.stack([torch.rand(seq_len, embed_dim) for _ in range(batch_size)])
x.shape

torch.Size([2, 3, 4])

In [4]:
from torch import nn

In [5]:
w_q = nn.Linear(embed_dim, n_heads*embed_dim)
w_k = nn.Linear(embed_dim, n_heads*embed_dim)
w_v = nn.Linear(embed_dim, n_heads*embed_dim)

In [15]:
from math import pi

angles = [10000**(2 * i / embed_dim) for i in range(embed_dim//2)]
rotations = torch.exp(1j * torch.matmul(torch.arange(seq_len).float().reshape(-1, 1), torch.tensor(angles).reshape(1, -1)))

In [16]:
rotations.shape

torch.Size([3, 2])

In [17]:
rotations = rotations.reshape(1, seq_len, 1, embed_dim // 2)

In [26]:
def rotary_positional_embedding(q, k, rotations):
    q_shape, k_shape = q.shape, k.shape
    q = torch.view_as_complex(q.view(*q.shape[:-1], -1, 2))
    k = torch.view_as_complex(k.view(*k.shape[:-1], -1, 2))
    q = torch.view_as_real(q * rotations)
    k = torch.view_as_real(k * rotations)
    return q.view(q_shape), k.view(k_shape)

In [27]:
q = w_q(x).view(2, seq_len, n_heads, embed_dim)
k = w_k(x).view(batch_size, seq_len, n_heads, embed_dim)

In [28]:
q, k = rotary_positional_embedding(q, k, rotations)

In [39]:
q.shape, k.shape

(torch.Size([2, 3, 2, 4]), torch.Size([2, 3, 2, 4]))

In [43]:
q_ = q.transpose(1, 2)
k_ = k.transpose(1, 2)

In [44]:
q_.shape, k_.shape

(torch.Size([2, 2, 3, 4]), torch.Size([2, 2, 3, 4]))

In [45]:
att = (q_ @ k_.transpose(-2, -1)) / (embed_dim ** 0.5)

In [31]:
att2 = torch.einsum('bqhd,bkhd->bhqk', q, k) / (embed_dim ** 0.5)

In [46]:
att.shape

torch.Size([2, 2, 3, 3])

In [47]:
att2.shape

torch.Size([2, 2, 3, 3])

In [52]:
mask =  torch.tril(torch.ones(seq_len, seq_len)).reshape(1, 1, seq_len, seq_len)

In [53]:
mask.shape

torch.Size([1, 1, 3, 3])

In [54]:
att.masked_fill(mask == 0, float('-inf'))

tensor([[[[-0.2417,    -inf,    -inf],
          [-0.0382, -0.1512,    -inf],
          [ 0.1328, -0.0423, -0.1108]],

         [[ 0.2528,    -inf,    -inf],
          [ 0.2026,  0.1229,    -inf],
          [ 0.2069,  0.1221,  0.0035]]],


        [[[-0.1954,    -inf,    -inf],
          [ 0.0657, -0.0714,    -inf],
          [ 0.1476,  0.1208, -0.0688]],

         [[ 0.1839,    -inf,    -inf],
          [ 0.2650,  0.1993,    -inf],
          [ 0.1955,  0.2624, -0.0083]]]], grad_fn=<MaskedFillBackward0>)

In [55]:
att = nn.functional.softmax(att.masked_fill(mask == 0, float('-inf')), dim=-1)

In [56]:
att

tensor([[[[1.0000, 0.0000, 0.0000],
          [0.5282, 0.4718, 0.0000],
          [0.3812, 0.3200, 0.2988]],

         [[1.0000, 0.0000, 0.0000],
          [0.5199, 0.4801, 0.0000],
          [0.3657, 0.3360, 0.2984]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.5342, 0.4658, 0.0000],
          [0.3599, 0.3503, 0.2898]],

         [[1.0000, 0.0000, 0.0000],
          [0.5164, 0.4836, 0.0000],
          [0.3467, 0.3706, 0.2827]]]], grad_fn=<SoftmaxBackward0>)

In [64]:
v = w_v(x).view(batch_size, seq_len, n_heads, embed_dim)
v.shape

torch.Size([2, 3, 2, 4])

In [69]:
y2 = att @ v.transpose(1, 2)

In [75]:
y2 = y2.transpose(1, 2)

In [72]:
y = torch.einsum('bhqk,bkhd->bqhd', att, v)
y

tensor([[[[-0.3518, -0.0209,  0.1209,  0.8380],
          [-0.2491, -0.6931, -0.0226, -0.6212]],

         [[-0.2574,  0.0482,  0.1386,  0.6419],
          [-0.3487, -0.5973,  0.1157, -0.5271]],

         [[-0.3084, -0.0112,  0.1889,  0.5804],
          [-0.3979, -0.4648,  0.1492, -0.3965]]],


        [[[-0.1169,  0.1278,  0.0732,  0.5716],
          [-0.3541, -0.6073,  0.1750, -0.5830]],

         [[-0.1740, -0.0159,  0.0567,  0.5575],
          [-0.3966, -0.5161,  0.2535, -0.5247]],

         [[-0.1855, -0.0683,  0.0909,  0.4691],
          [-0.4448, -0.3442,  0.2976, -0.3843]]]], grad_fn=<ViewBackward0>)

In [77]:
y.shape

torch.Size([2, 3, 2, 4])

In [79]:
w_o = nn.Linear(n_heads*embed_dim, embed_dim)

In [81]:
w_o(y.reshape(batch_size, seq_len, -1))

tensor([[[-0.8762, -0.2059,  0.5416,  0.0838],
         [-0.7809, -0.3210,  0.5421,  0.0079],
         [-0.7104, -0.3361,  0.5693, -0.0568]],

        [[-0.7624, -0.3870,  0.4831, -0.0102],
         [-0.7504, -0.4115,  0.4714, -0.0752],
         [-0.6530, -0.4433,  0.4762, -0.1635]]], grad_fn=<ViewBackward0>)