In [1]:
# 复现Transformer模型

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Literal

**TransformerEmbedding**

词嵌入层

In [None]:
class TransformerEmbedding(nn.Module):
    def __init__(self, d_model, vocab_size):
        self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
        self.d_model = d_model

    def forward(self, input_ids):
        return self.emb(input_ids) * math.sqrt(self.d_model)

**Positional Embedding**

Transformer中使用的位置编码为绝对位置编码，使用Sinusoidal函数生成：

$$
p_{i,2t} = \sin (i/10000^{2t/d})
$$

$$
p_{i,2t+1} = \cos (i/10000^{2t/d})
$$

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model=768, max_len=512):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)  # 创建位置编码矩阵
        position = torch.arange(0, max_len).unsqueeze(1).type(torch.float32)
        # 这里使用对数指数的形式计算分数，数值更稳定，硬件上对对数和指数的计算有优化
        div_term = torch.exp(torch.arange(0, d_model, 2)) * (
            -math.log(10000.0) / d_model
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)
        self.register_buffer("pe", self.pe)

    def forward(self, x):
        # x: [batch, seq_len, d_model]
        # pe: [1, max_len, d_model]
        return x + self.pe[:, : x.size(1)]

**MultiHeadAttention**

多头注意力机制，计算公式如下：
$$O_i = \text{softmax}{\frac{Q_iK_i^T}{\sqrt{d}}} V_i$$
其中 $Q_i$，$K_i$ 代表注意力头拆分后的第 $i$ 个向量，$d$ 为拆分后的特征向量的维度。
在具体实现过程中，注意是对softmax后的注意力进行dropout，如果序列中存在padding，需要对注意力进行遮盖，具体做法是在softmax前需要遮盖的注意力赋予一个非常小的数（负数），这样softmax的时候就会将其权重计算接近0。

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_model,
        num_heads,
        dropout_prob,
    ):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout_prob)
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):

        # 进行投影
        q = self.q_proj(query)  # [batch, seq_len, emb]
        k = self.k_proj(key)
        v = self.v_proj(value)
        # 进行多头拆分
        assert self.d_model % self.num_heads == 0
        b, s_q, e = q.shape
        _, s_k, _ = k.shape
        h = self.num_heads
        q = q.reshape(b, s_q, h, e // h).transpose(1, 2)  # [b,h,s,e//h]
        k = k.reshape(b, s_k, h, e // h).transpose(1, 2)
        v = v.reshape(b, s_k, h, e // h).transpose(1, 2)
        # 计算qk注意力
        att = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_model)
        if mask is not None:
            # mask: [b,s_q,s_k]
            # att: [b,h,s_q,s_k]
            att.masked_fill_(mask == 0, float("-inf"))
        att = torch.softmax(att, dim=-1)
        att = self.dropout(att)
        out = torch.matmul(att, v)  # [b, h, s_q, e//h]
        out = out.transpose(1, 2).reshape(b, s_q, e)
        out = self.out_proj(out)
        return out

**PositionWiseFeedForward**

位置感知前馈神经网络，对每个位置的向量单独的进行线性变换，共享权重参数，而不是将序列的所有特征拼接起来进行前馈传播。有两个线性层构成，第一个线性层通常会进行一个升维，然后经过dropout和激活函数，再通过第二个激活函数。

In [None]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, in_feature_dim=768, hidden_size=1024, dropout_prob=0.1):
        super(PositionalEncoding, self).__init__()
        self.linear1 = nn.Linear(in_feature_dim, hidden_size)
        self.linear2 = nn.Linear(hidden_size, in_feature_dim)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x):
        return self.linear2(torch.relu(self.dropout(self.linear1(x))))

**LayerNorm**

层归一化，在特征维度上对特征进行缩放和偏移。公式如下：

$$\text{LayerNorm} = \gamma \odot \frac{X-\mu}{\sigma+\epsilon} + \beta$$

其中 $\mu$ 是特征均值，$\sigma$ 是标准差：
$$\mu = \frac{1}{d}\sum_{i=1}^dx_i$$

$$\sigma=\sqrt{\frac{1}{d}\sum_{i=1}^d(x_i-\mu)}$$
$d$ 是特征维度。

$\gamma$ 是可学习的缩放参数，$\beta$ 是可学习的偏移参数，$\epsilon$ 确保不会除0。

In [8]:
class LayerNorm(nn.Module):
    def __init__(self, num_dim):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(num_dim))
        self.beta = nn.Parameter(torch.zeros(num_dim))
        self.epsilon = 1e-6

    def forward(self, x):
        mean = torch.mean(x, dim=-1, keepdim=True)  # [b,s,1]
        std = torch.std(x, dim=-1, keepdim=True)
        x = self.gamma * (x - mean) / (std + self.epsilon) + self.beta

**ResidualConnection**

残差连接，也就是结构图中的`Add&Norm`，在实现上，包括两种方式，一种PreNorm，再对子层的结果进行LayerNorm之后再与之前的结果加和；另一种PostNorm，将子层结果和之前结果加和之后再通过LayerNorm。PreNorm训练梯度更稳定，适用于深层的网络，PostNorm训练稳定性较差，但是性能更好。Transformer中采用的是PostNorm，在更大规模的模型中通常使用的是PreNorm。

PreNorm：

$$
\bm{X} = \bm{X} + \text{LayerNorm}(\text{SubLayer}(\bm{X}))
$$

PostNorm：

$$
\bm{X} =  \text{LayerNorm}(\bm{X} +\text{SubLayer}(\bm{X}))
$$

In [None]:
class ResidualConnection(nn.Module):
    def __init__(self, num_dim, mode: Literal["PreNorm", "PostNorm"] = "PostNorm"):
        super(ResidualConnection, self).__init__()
        self.mode = mode
        self.layernorm = LayerNorm(num_dim=num_dim)

    def forward(self, x, sublayer):
        if self.mode == "PreNorm":
            return x + self.layernorm(sublayer(x))
        return self.layernorm(x + sublayer(x))

**EncoderLayer**

Encoder的基础模块，结构如下图所示：
<div align="center">
<img src="img/TransformerEncoderLayer.png" width=250px align="center">
</div>

主要包含一个多头注意力模块和一个前馈神经网络模块，两个模块分别有两个残差链接。

In [None]:
class EncoderLayer(nn.Module):
    def __init__(
        self,
        d_model=768,
        hidden_size=1024,
        num_att_heads=6,
        dropout_prob=0.1,
        norm_mode: Literal["PreNorm", "PostNorm"] = None,
    ):
        super(EncoderLayer, self).__init__()
        self.multi_head_attention = MultiHeadAttention(
            d_model=d_model,
            num_heads=num_att_heads,
            dropout_prob=dropout_prob,
        )
        self.res_connect1 = ResidualConnection(
            num_dim=d_model,
            mode=norm_mode,
        )
        self.ffn = PositionWiseFeedForward(
            in_feature_dim=d_model,
            hidden_size=hidden_size,
            dropout_prob=dropout_prob,
        )
        self.res_connect2 = ResidualConnection(
            num_dim=d_model,
            mode=norm_mode,
        )

    def forward(self, x, mask):
        x = self.res_connect1(x, lambda x: self.multi_head_attention(x, x, x, mask))
        x = self.res_connect2(x, lambda x: self.ffn(x))
        return x

**DecoderLayer**

Decoder的基础模块，结构图如下所示：
<div align="center">
<img src="img/TransformerDecoderLayer.png" width=250px>
</div>

主要包含一个target部分的掩码自注意力机制和source（Encoder输出）部分的交叉注意力部分，以及最后的前馈神经网络部分。


In [14]:
class DecoderLayer(nn.Module):
    def __init__(
        self,
        d_model,
        hidden_size,
        num_att_heads,
        dropout_prob,
        norm_mode: Literal["PreNorm", "PostNorm"] = None,
    ):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(
            d_model=d_model,
            num_heads=num_att_heads,
            dropout_prob=dropout_prob,
        )
        self.res_connect1 = ResidualConnection(
            num_dim=d_model,
            mode=norm_mode,
        )
        self.cross_attention = MultiHeadAttention(
            d_model=d_model,
            num_heads=num_att_heads,
            dropout_prob=dropout_prob,
        )
        self.res_connect2 = ResidualConnection(
            num_dim=d_model,
            mode=norm_mode,
        )
        self.ffn = PositionWiseFeedForward(
            in_feature_dim=d_model,
            hidden_size=hidden_size,
            dropout_prob=dropout_prob,
        )
        self.res_connect3 = ResidualConnection(
            num_dim=d_model,
            mode=norm_mode,
        )

    def forward(self, src, tgt, src_mask, tgt_mask):
        # 首先计算tgt的自注意力，注意这里的有一个因果掩码 tgt_mask是一个下三角矩阵
        x = self.res_connect1(
            tgt, lambda x: self.self_attention(tgt, tgt, tgt, tgt_mask)
        )
        # 计算交叉注意力机制
        x = self.res_connect2(x, lambda x: self.cross_attention(x, src, src, src_mask))
        x = self.res_connect3(x, lambda x: self.ffn(x))
        return x

**Encoder**

由多个EncoderLayer构成

In [16]:
class Encoder(nn.Module):
    def __init__(
        self,
        num_layers,
        d_model,
        hidden_size,
        num_att_heads,
        dropout_prob,
        norm_mode: Literal["PreNorm", "PostNorm"],
    ):
        self.layers = nn.ModuleList(
            EncoderLayer(
                d_model=d_model,
                hidden_size=hidden_size,
                num_att_heads=num_att_heads,
                dropout_prob=dropout_prob,
                norm_mode=norm_mode,
            )
        )

    def forward(self, x, sen_mask):
        for layer in self.layers:
            x = layer(x, x, x, sen_mask)
        return x

**Decoder**

由多个DecoderLayer堆叠而成

In [None]:
class Decoder(nn.Module):
    def __init__(
        self,
        num_layers,
        d_model,
        hidden_size,
        num_att_heads,
        dropout_prob,
        norm_mode: Literal["PreNorm", "PostNorm"],
    ):
        self.layers = nn.ModuleList(
            DecoderLayer(
                d_model=d_model,
                hidden_size=hidden_size,
                num_att_heads=num_att_heads,
                dropout_prob=dropout_prob,
                norm_mode=norm_mode,
            )
        )

    def forward(self, src, tgt, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(tgt, src, src, src_mask, tgt_mask)
        return x

**Predictor**

最后用来预测输出的模块，由一个线性层构成


In [18]:
class Predictor(nn.Module):
    def __init__(self, d_model, vocab_size):
        super(Predictor).__init__()
        self.cls = nn.Linear(in_features=d_model, out_features=vocab_size)

    def forward(self, x):
        return torch.log_softmax(self.cls(x))

**Transformer**

最后组装Transformer

In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        d_model,
        vocab_size,
        hidden_size,
        num_attention_heads,
        num_encoder_layer,
        num_decoder_layer,
        dropout_prob,
        src_mask_idx,
        tgt_mask_idx,
        max_len,
        norm_mode,
    ):
        super(Transformer).__init__()
        # pos embedding
        self.pos_embedding = PositionalEncoding(d_model=d_model, max_len=max_len)
        # Encoder Embedding模块
        self.encoder_embedding = TransformerEmbedding(
            d_model=d_model, vocab_size=vocab_size
        )
        # Encoder
        self.Encoder = Encoder(
            num_layers=num_encoder_layer,
            d_model=d_model,
            hidden_size=hidden_size,
            num_att_heads=num_attention_heads,
            dropout_prob=dropout_prob,
            norm_mode=norm_mode,
        )
        #

        # Decoder Embedding模块
        self.decoder_embedding = TransformerEmbedding(
            d_model=d_model, vocab_size=vocab_size
        )
        # Decoder
        self.decoder = Decoder(
            num_layers=num_decoder_layer,
            d_model=d_model,
            hidden_size=hidden_size,
            num_att_heads=num_attention_heads,
            dropout_prob=dropout_prob,
            norm_mode=norm_mode,
        )
        # Predictor
        self.predictor = Predictor(d_model=d_model, vocab_size=vocab_size)
        self.src_mask_idx = src_mask_idx
        self.tgt_mask_idx = tgt_mask_idx

    def make_padding_mask(self, q_input_ids, k_input_ids, q_pad_idx, k_pad_idx):
        # input_ids: [b,s]
        # pad_idx: LongTensor
        q_mask = q_input_ids.ne(q_pad_idx) #[b,l]
        k_mask = k_input_ids.ne(k_pad_idx) #[b,s]
        mask = q_mask.unsqueeze(-1) & k_mask()

    def forward(self, src_input_ids, tgt_input_ids, src_mask):
        pass