In [34]:
import torch
import torch.nn as nn
torch.manual_seed(42)

<torch._C.Generator at 0x109dfcdd0>

In [35]:
d = 512
max_pos = 4096
base = 10000

In [36]:
# total angles 
angles = torch.arange(0, 4, 1) 
m = torch.arange(0, 3, 1) 
m_angles = torch.outer(m, angles)
print(m_angles)

# repeat angles trick
cache = torch.zeros(3, 8)
print(m_angles[:,0::2])
print(m_angles[:,1::2])
cache[:,0::2] = m_angles
cache[:,1::2] = m_angles
print(cache)

tensor([[0, 0, 0, 0],
        [0, 1, 2, 3],
        [0, 2, 4, 6]])
tensor([[0, 0],
        [0, 2],
        [0, 4]])
tensor([[0, 0],
        [1, 3],
        [2, 6]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 1., 2., 2., 3., 3.],
        [0., 0., 2., 2., 4., 4., 6., 6.]])


In [37]:
# broadcast batch dimention

# X = bs, seq_len, d
# pos = seq_len, d
X = torch.zeros( 2, 3, 4)
pos = torch.randn(3, 4) 
result = X + pos
print(result[0,:,:])
print(result[1,:,:])

tensor([[ 0.3367,  0.1288,  0.2345,  0.2303],
        [-1.1229, -0.1863,  2.2082, -0.6380],
        [ 0.4617,  0.2674,  0.5349,  0.8094]])
tensor([[ 0.3367,  0.1288,  0.2345,  0.2303],
        [-1.1229, -0.1863,  2.2082, -0.6380],
        [ 0.4617,  0.2674,  0.5349,  0.8094]])


In [42]:
class RoPE(nn.Module):
    def __init__(self, dim = 512, max_pos = 4096, base = 10000.0):
        super().__init__()
        self.dim = dim
        self.base = base
        self.max_pos = max_pos
        
        m = torch.arange(0, self.max_pos, 1)
        i = torch.arange(0, self.dim//2, 1) 
        theta = self.base ** (-2 * i / self.dim)
        m_theta = torch.outer(m, theta)

        self.cos = self.sin = torch.zeros(self.max_pos, self.dim) 
        self.cos[:, 0::2] = self.cos[:, 1::2] = torch.cos(m_theta) # cos(theta1), cos
        self.sin[:, 0::2] = self.sin[:, 1::2] = torch.sin(m_theta) # sin, sin
        
    def apply_rope(self, X):
        '''
            input: X[bs, n_heads, seq_len, head_dim]
        '''
        bs, n_heads, seq_len, d = X.shape

        X_shift = torch.zeros_like(X)
        X_shift[..., 0::2] = -X[..., 1::2]
        X_shift[..., 1::2] = X[..., 0::2]

        Y = self.cos[None, None, :seq_len, :] * X + \
            self.sin[None, None, :seq_len, :] * X_shift

        return Y

rope = RoPE()

In [43]:
seq_len = 100
bs = 3

Q = torch.randn(bs, seq_len, d).unsqueeze(dim=1) # 单头 Q
rope_q = rope.apply_rope(Q)

In [44]:
n_heads = 8
Q = torch.randn(bs, n_heads, seq_len, d) # 多头 Q
rope_q = rope.apply_rope(Q)