In [2]:
import torch
from torch import nn
import torch.nn.functional as F

### RMSNorm

In [3]:
class RMSNorm(torch.nn.Module):
    """docstring for RMSNorm."""
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(hidden_size))
    
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        input_dtype = x.dtype
        x = x.to(torch.float32)
        x = self._norm(x)
        return self.weight * x.to(input_dtype)
    
    # 实现代码2
    # def forward1(self, x):
    #     output = self._norm(x.float()).type_as(x)
    #     return self.weight * output
    
# RMSNorm test
x = torch.rand(2, 3)
print(x)
r = RMSNorm(x.size(-1))
print(r(x))

tensor([[0.7842, 0.1649, 0.0308],
        [0.5024, 0.9393, 0.4886]])
tensor([[1.6937, 0.3562, 0.0665],
        [0.7425, 1.3882, 0.7222]], grad_fn=<MulBackward0>)


### RoPE

In [10]:
class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False) #persistent=False将不会作为state_dict

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )
        
    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        #超过预设的max_position_embeddings则重新计算更大的Rope缓存，否则直接在缓存上切片
        if seq_len > self.max_seq_len_cached: 
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

In [12]:
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    
    #此处逻辑与原始的ROPE有所差异，原始逻辑如下
    #x1 = x[..., 0::2] 
    #x2 = x[..., 1::2]
    #res = torch.cat((x1, x2), dim=-1)
    #res[...,0::2]=-x2
    #res[...,1::2]=x1
    #return res
    
    x1 = x[..., : x.shape[-1] // 2] 
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [11]:
x = torch.randn(1,8,4,2)
rope = LlamaRotaryEmbedding(dim=8)
cos,sin = rope.forward(x,seq_len=4)
print(cos.shape) 
print(cos)

torch.Size([1, 1, 4, 8])
tensor([[[[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
            1.0000],
          [ 0.5403,  0.9950,  0.9999,  1.0000,  0.5403,  0.9950,  0.9999,
            1.0000],
          [-0.4161,  0.9801,  0.9998,  1.0000, -0.4161,  0.9801,  0.9998,
            1.0000],
          [-0.9900,  0.9553,  0.9996,  1.0000, -0.9900,  0.9553,  0.9996,
            1.0000]]]])


In [43]:
dim = 128
torch.arange(0, dim, 2)[:(dim //2)].float() / dim

tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250,
        0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656,
        0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062,
        0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469,
        0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875,
        0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281,
        0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688,
        0.9844])

In [47]:
import numpy as np
abs = torch.tensor([1, 2], dtype=torch.float64)
angle = torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64)
z = torch.polar(abs, angle)
z

tensor([ 6.1232e-17+1.0000j, -1.4142e+00-1.4142j], dtype=torch.complex128)

In [6]:
x = torch.rand(32, 128)
x.size()

torch.Size([32, 128])

In [9]:
x1 = x[..., : x.shape[-1] // 2] 
print(x1.size())
x2 = x[..., x.shape[-1] // 2 :]
torch.cat((-x2, x1), dim=-1).size()

torch.Size([32, 64])


torch.Size([32, 128])

In [21]:
x = torch.randn(1,8,4,2)
x[None, None, :]

tensor([[[[[[-0.9850, -2.2869],
            [ 1.4804, -0.7685],
            [ 0.5947, -0.8731],
            [-0.9780,  0.7070]],

           [[ 0.4824,  0.6247],
            [ 0.3330, -0.8440],
            [-2.0476,  1.5151],
            [ 0.2602,  0.1515]],

           [[ 0.3006, -0.2799],
            [ 1.6850,  1.4726],
            [-0.2695,  0.4490],
            [ 0.8820, -2.0848]],

           [[-0.0554, -1.0951],
            [-0.1584,  0.7699],
            [ 0.5442, -0.0554],
            [ 0.7802, -0.5489]],

           [[-1.0475,  1.6059],
            [-1.1929, -0.5152],
            [-0.0339,  1.4537],
            [ 1.2854,  0.8239]],

           [[-0.4184, -1.2918],
            [ 0.7324, -0.5918],
            [-0.7403,  0.6899],
            [ 0.7254,  0.4882]],

           [[-0.4869,  0.2572],
            [-0.2051,  0.6930],
            [-0.9310,  0.6159],
            [ 0.1253, -0.3576]],

           [[ 1.4549,  0.6704],
            [-1.4409,  0.6654],
            [-0.2931, -0.6