In [1]:
# Word Embedding示例
import torch.nn as nn


class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()

        self.emb = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.emb(x) * math.sqrt(self.d_model)

In [4]:
# 三角函数位置编码
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(SinusoidalPositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)  # max_len代表句子中最多有几个词
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))  # d_model即公式中的d
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]  # 原向量加上计算出的位置信息才是最终的embedding
        return self.dropout(x)



In [6]:
# RoPE实现
import torch


def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0):
    '''
    计算cos和sin的值，cos值在实部，sin值在虚部，类似于 cosx+j*sinx
    :param dim: q,k,v的最后一维，一般为emb_dim/head_num
    :param end: 句长length
    :param constant： 这里指10000
    :return:
    复数计算 torch.polar(a, t)输出， a*(cos(t)+j*sin(t))
    '''
    # freqs: 计算 1/(10000^(2i/d) )，将结果作为参数theta
    # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)]
    freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  # [d/2]

    # 计算m
    t = torch.arange(end, device=freqs.device)  # [length]
    # 计算m*theta
    freqs = torch.outer(t, freqs).float()  # [length, d/2]
    # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1

    # 计算cos(m*theta)+j*sin(m*theta)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0),  cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))]
    # 其中j为虚数单位， m=0,1,...,length-1
    return freqs_cis  # [length, d/2]


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  # (1, length, 1, d/2)
    return freqs_cis.view(*shape)  # [1, length, 1, d/2]


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ):
    # 先将xq维度变为[bs, length, head,  d/2, 2], 利用torch.view_as_complex转变为复数
    # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)]
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  # [bs, length, head, d/2]
    # 同样的，xk_:[k0+j*k1, k2+j*k3, ..., k(d-2)+j*k(d-1)]
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  # [1, length, 1, d/2]
    # 下式xq_ * freqs_cis形式化输出，以第一个为例, 如下
    # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0))
    # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0)，虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0)
    # 然后通过torch.view_as_real函数，取出实部和虚部，维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2]，最后一维放实部与虚部
    # 最后经flatten函数将维度拉平，即[bs, length, head, d]
    # 此时xq_out形式化为 [实部0，虚部0，实部1，虚部1，..., 实部(d/2-1), 虚部(d/2-1)]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  # [bs, length, head, d]
    # 即为新生成的q

    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


# (bs, length, head, d)
q = torch.randn((2, 10, 12, 32))  # q=[q0, q1, .., qd-1]
k = torch.randn((2, 10, 12, 32))
v = torch.randn((2, 10, 12, 32))
freqs_cis = precompute_freqs_cis(dim=32, end=10, constant=10000.0)
# print(freqs_cis.detach().numpy())

q_new, k_new = apply_rotary_emb(xq=q, xk=k, freqs_cis=freqs_cis)
print(q_new)


tensor([[[[ 1.3704e+00,  1.3158e+00, -7.1316e-01,  ..., -3.5290e-01,
           -3.3729e-01, -4.0218e-01],
          [-3.4972e+00, -8.7002e-02, -5.3975e-01,  ...,  1.6898e+00,
            5.8790e-01,  3.7120e-01],
          [ 2.8556e-01, -4.7846e-01, -1.0758e+00,  ..., -6.4680e-01,
            1.4346e+00, -2.7216e-01],
          ...,
          [ 9.8175e-01,  6.4523e-01,  1.1556e+00,  ...,  1.1800e-01,
            1.8284e-01,  7.7066e-02],
          [ 5.1603e-01, -4.1746e-02,  1.3845e+00,  ..., -6.6929e-01,
           -1.8631e+00,  1.0389e+00],
          [-1.2169e+00,  1.3014e+00, -1.0875e+00,  ...,  1.7334e+00,
            4.7693e-01,  1.1021e+00]],

         [[-8.0751e-01,  1.0754e+00, -2.4096e+00,  ..., -1.2121e-02,
            1.7142e+00,  1.3487e+00],
          [-5.9276e-01, -6.8661e-01,  4.1122e-01,  ..., -1.5627e+00,
            3.7396e-01,  8.9711e-02],
          [-9.6383e-01,  8.8765e-01,  8.3595e-01,  ...,  8.4422e-01,
           -1.1636e-02,  4.5636e-01],
          ...,
     

In [8]:
# ALiBi实现

import math

import torch
from torch import nn


def get_relative_positions(seq_len: int) -> torch.tensor:
    x = torch.arange(seq_len)[None, :]
    y = torch.arange(seq_len)[:, None]
    return x - y


def get_alibi_slope(num_heads):
    x = (2 ** 8) ** (1 / num_heads)
    return (
        torch.tensor([1 / x ** (i + 1) for i in range(num_heads)])
        .unsqueeze(-1)
        .unsqueeze(-1)
    )


# (bs, length, head, d)
bs, length, head, d = 2, 10, 12, 32
q = torch.randn((bs, length, head, d))  # q=[q0, q1, .., qd-1]
k = torch.randn((bs, length, head, d))
v = torch.randn((bs, length, head, d))
m = get_alibi_slope(head)
scale = math.sqrt(d)

# 转换成(batch_size, num_heads, d_head, seq_len)
key = k.permute(0, 2, 3, 1)
query = q.transpose(1, 2)

bias = (m * get_relative_positions(length)).unsqueeze(0)
score = torch.matmul(query, key) / scale + bias

print(score)

tensor([[[[-1.8129,  0.4675,  1.6018,  ...,  5.5109,  5.5519,  7.4765],
          [-0.5826, -0.2571,  0.2350,  ...,  4.7676,  6.2256,  4.1429],
          [-2.2080, -1.1375,  0.1413,  ...,  2.5882,  3.4505,  2.4373],
          ...,
          [-5.7270, -3.2969, -3.4568,  ...,  0.5973,  0.9581,  0.6717],
          [-5.8621, -4.4651, -2.2802,  ..., -1.6880,  0.7928,  2.5104],
          [-4.0265, -6.6177, -3.1242,  ..., -1.3786, -1.4675,  0.1974]],

         [[ 0.6194,  0.7367,  1.2648,  ...,  3.5302,  3.2856,  5.8354],
          [-0.7855, -0.1711,  1.1610,  ...,  4.6536,  5.0511,  1.6273],
          [ 0.5123, -0.7731, -1.6922,  ...,  1.7563,  2.6126,  3.0371],
          ...,
          [-4.4562, -2.0013, -1.3560,  ...,  0.0248,  0.8473,  0.6862],
          [-2.8158, -2.2977, -2.0008,  ...,  0.0652, -0.9086, -1.0932],
          [-3.4691, -2.5648, -3.9151,  ...,  0.9297, -0.7748,  1.2525]],

         [[ 0.2255,  0.2522,  1.2905,  ...,  2.8877,  1.0149,  2.9579],
          [-1.3499, -0.1326,  