# Decoder

In [1]:
import torch 
from torch import nn
import torch.nn.functional as F
import math

In [2]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
    super(DecoderLayer, self).__init__()
    # 多头注意力 解码内部的自注意力机制
    self.attention1 = MultiHeadAttention(d_model, n_head)
    self.norm1 = LayerNorm(d_model)
    self.dropout1 = nn.Dropout(drop_prob)
    # 多头注意力 编码器-解码器注意力机制
    # 跨模态注意力机制:encoder编码之后输入到decoder中
    self.cross_attention = MultiHeadAttention(d_model, n_head)
    self.norm2 = LayerNorm(d_model)
    self.ffn = PositionwiseFeedForward(d_model, ffn_hidden, drop_prob)
    self.norm3 = LayerNorm(d_model)
    self.dropout3 = nn.Dropout(drop_prob)
  """
    解码器层的前向传播
    :dec_output: 解码器的输入 (batch_size, seq_len, d_model)
    :enc_output: 编码器的输出 (batch_size, seq_len, d_model)
    :t_mask: 解码器自注意力机制的掩码 (batch_size, seq_len, seq_len)
    :s_mask: 编码器-解码器注意力机制的掩码 (batch_size, seq_len, seq_len)
  """
  def forward(self, dec_output, enc_output, t_mask=None, s_mask=None):
    # 解码器自注意力机制
    _x = dec_output
    x = self.attention1(dec_output, dec_output, dec_output, t_mask)
    x = self.dropout1(x)
    x = self.norm1(x + _x)
    _x = x
    # 编码器-解码器注意力机制 跨模态的注意力(交叉注意力)
    x = self.cross_attention(x, enc_output, enc_output, s_mask)
    x = self.dropout2(x)
    x = self.norm2(x + _x)
    # 前馈神经网络
    x = self.ffn(x)
    x = self.dropout3(ffn_output)
    x = self.norm3(x + _x)
    return x

In [3]:
class Decoder(nn.Module):
  """
    解码器
    :dec_voc_size: 解码器词表大小
    :max_len: 序列最大长度  
    :d_model: 词向量维度
    :ffn_hidden: 前馈神经网络隐藏层维度
    :n_head: 多头注意力机制头数
    :n_layers: 解码器层数
    :drop_prob: dropout概率
  """
  def __init__(self, dec_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, device):
    super(Decoder, self).__init__()
    self.embedding = TransformerEmbedding(dec_voc_size, d_model, max_len, drop_prob, device)
    self.layers = nn.ModuleList([
      DecoderLayer(d_model, ffn_hidden, n_head, drop_prob)
      for _ in range(n_layers)
    ])
    self.fc = nn.Linear(d_model, dec_voc_size)
  # 目标掩码和源掩码
  def forward(self, dec, enc, t_mask=None, s_mask=None):
    # 解码器嵌入层
    dec = self.embedding(enc)
    # 多层解码器层
    for layer in self.layers:
      dec = layer(dec, enc, t_mask, s_mask)
    # 线性映射到词表大小
    dec = self.fc(dec)
    return dec  # 返回整个词汇表概率的情况