In [1]:
import _collections
import math
import torch
from torch import nn
from d2l import torch as d2l

# 实现rnn编码器

In [2]:
class Seq2SeqEncoder(d2l.Encoder):
    """用于序列到序列学习的 rnn encoder"""
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super(Seq2SeqEncoder, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        #　在自然语言处理（NLP）任务中，可以将每个单词表示成一个向量（ont-hot）
        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)
        # 没有输出层， encoder不需要输出层

    def forward(self, X, *args):
        X = self.embedding(X)
        X = X.permute(1, 0, 2)
        output, state = self.rnn(X)
        return output, state

# 上述编码器的实现

In [8]:
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)
# (batchsize, 句子长度)
output, state = encoder(X)
output.shape

torch.Size([7, 4, 16])

In [9]:
state.shape

torch.Size([2, 4, 16])

# 实现rnn解码器

In [15]:
class Seq2SeqDecoder(d2l.Decoder):
    # 用于序列到序列学习的 rnn decoder
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super(Seq2SeqDecoder, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, *args):# enc_outputs = [output, state]
        return enc_outputs[1]

    def forward(self, X, state):
        # 输出'X'的形状：(batch_size,num_steps,embed_size)
        X = self.embedding(X).permute(1, 0, 2)
        # 广播context，使其具有与X相同的num_steps
        content = state[-1].repeat(X.shape[0], 1, 1)
        X_and_content = torch.cat((X, content), 2)
        output, state = self.rnn(X_and_content, state)
        output = self.dense(output).permute(1, 0, 2)
        # output的形状:(batch_size,num_steps,vocab_size)
        # state的形状:(num_layers,batch_size,num_hiddens)
        return output, state

# 实例化解码器

In [16]:
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
output, state = decoder(X, state)
output.shape, state.shape
# output (batch_size, 时间步数， 词表大小) 对每个样本， 每个时刻都有一个输出

(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))

# 通过零值化屏蔽不相关的项

In [26]:
def sequence_mask(X, valid_len, value=0):
    """在序列中屏蔽不相关的项"""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

X = torch.tensor([[1, 2, 3], [4, 5, 6]])
sequence_mask(X, torch.tensor([1, 2]))

tensor([[1, 0, 0],
        [4, 5, 0]])

# 通过扩展softmax交叉熵损失函数来遮蔽不相关的预测

In [26]:
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    """带屏蔽的softmax交叉熵损失函数"""

    def forward(self, pred, label, valid_len):
        weights = torch.ones_like(label)