- CausalLM decoder的流程是 input-> MHA -> FFN. 其中[MHA，FFN]是一个block，会有很多个blcok
- FFN包括5个操作：升维度(h->4h)，激活函数，降维度(4h->h)，dropout和norm。
- 激活函数上，其中 LLaMA 对于 GPT 的改进还有把 GeLU 变成了 SwishGLU，多了一个矩阵。所以一般升维会从 4h -> 4h * 2 / 3
- normalization上，原版的 transformers 用 post-norm, 后面 gpt2, llama 系列用的是 pre-norm。
- normalization上，llama 系列一般用 RMSNorm 代替 GPT and transformers decoder 中的 LayerNorm。

In [5]:
import torch
import torch.nn as nn
import math

In [24]:
class DecoderBlock(nn.Module):
    def __init__(self, hidden_dim:int, num_heads: int, dropout_p: float=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # MHA block
        # MHA block有output projection，dropout和layer normalization，但是没有激活函数
        assert hidden_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)
        self.drop_att = nn.Dropout(dropout_p)
        self.norm_att = nn.LayerNorm(hidden_dim, eps=1e-5)

        # FFN block
        # FFN block的流程是升维，激活，降维，droupout，然后是normaliation
        self.up_proj = nn.Linear(hidden_dim, hidden_dim * 4)
        self.down_proj = nn.Linear(hidden_dim * 4, hidden_dim)
        self.act_ffn = nn.ReLU()
        self.drop_ffn = nn.Dropout(dropout_p)
        self.norm_ffn = nn.LayerNorm(hidden_dim, eps=1e-5)

    def att_block(self, X, att_mask):
        batch_size, seq_len, _ = X.shape
        QKV = self.qkv_proj(X)
        Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)

        # (b, s, hidden_dim) -> (b, s, head_dim * num_heads) -> (b, s, num_heads, head_dim) -> (b, num_heads, s, head_dim)
        q_state = torch.einsum('bsnh -> bnsh', Q.view(batch_size, seq_len, self.num_heads, self.head_dim))
        k_state = torch.einsum('bsnh -> bnsh', K.view(batch_size, seq_len, self.num_heads, self.head_dim))
        v_state = torch.einsum('bsnh -> bnsh', V.view(batch_size, seq_len, self.num_heads, self.head_dim))

        # (b, num_heads, s, s)
        att_value = (q_state @ k_state.transpose(-1, -2)) / math.sqrt(self.head_dim)

        if att_mask is not None:
            # 变成下三角矩阵
            att_mask = att_mask.tril()
        else:
            att_mask = torch.ones_like(att_value).tril()

        att_value.masked_fill_(att_mask == 0, float('-inf'))

        att_weight = self.drop_att(torch.softmax(att_value, dim=-1)) # (b, num_heads, s, s)
        o_state = att_weight @ v_state # (b, num_heads, s, head_dim)
        O = o_state.transpose(1,2).contiguous().view(batch_size, seq_len, -1)
        output = self.o_proj(O)

        return self.norm_att(X + output)

    def ffn_block(self, X):
        up = self.act_ffn(self.up_proj(X))
        down = self.down_proj(up)
        down = self.drop_ffn(down)
        return self.norm_ffn(X + down)

    def forward(self, X: torch.tensor, att_mask=None):
        # X'shape: (batch, seq, hidden_dum)
        # att_mask's shape: (batch, num_heads, seq, seq)
        att_output = self.att_block(X, att_mask)
        ffn_output = self.ffn_block(att_output)
        return ffn_output



In [25]:
x = torch.rand(3, 4, 64) # batch=3, seq=4, hidden_dim=64
net = DecoderBlock(64, 8) # num_heads=8
mask = (
    torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 1, 1, 0]])
    .unsqueeze(1)
    .unsqueeze(2)
    .repeat(1, 8, 4, 1)
)
net(x, mask).shape