In [46]:
import torch
from torch import nn
import collections
import math

## 编码器

In [47]:
class Encoder(nn.Module):
    """编码器-解码器架构的基本编码器接口"""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)
    def forward(self, X, *args):
        raise NotImplementedError

In [48]:
class Seq2SeqEncoder(Encoder):
    """用于序列到序列学习的循环神经网络编码器"""
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super(Seq2SeqEncoder, self).__init__(**kwargs)
        ## 嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)
    def forward(self, X, *args):
        ## 输出X的形状 (batch_size, num_steps, embed_size)
        X = self.embedding(X)
        ## 在循环神经网络模型中,第一个轴对应时间步
        X = X.permute(1, 0, 2)
        output, state = self.rnn(X)
        return (output, state)


In [49]:
"""example
X = torch.arange(24).reshape(2,3,4)
permute_X = X.permute(1, 0, 2)
transpose_X = X.transpose(1, 0)
X, permute_X, transpose_X
"""
"""output
(tensor([[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],
 
         [[12, 13, 14, 15],
          [16, 17, 18, 19],
          [20, 21, 22, 23]]]),
 tensor([[[ 0,  1,  2,  3],
          [12, 13, 14, 15]],
 
         [[ 4,  5,  6,  7],
          [16, 17, 18, 19]],
 
         [[ 8,  9, 10, 11],
          [20, 21, 22, 23]]]),
 tensor([[[ 0,  1,  2,  3],
          [12, 13, 14, 15]],
 
         [[ 4,  5,  6,  7],
          [16, 17, 18, 19]],
 
         [[ 8,  9, 10, 11],
          [20, 21, 22, 23]]]))
"""

'output\n(tensor([[[ 0,  1,  2,  3],\n          [ 4,  5,  6,  7],\n          [ 8,  9, 10, 11]],\n \n         [[12, 13, 14, 15],\n          [16, 17, 18, 19],\n          [20, 21, 22, 23]]]),\n tensor([[[ 0,  1,  2,  3],\n          [12, 13, 14, 15]],\n \n         [[ 4,  5,  6,  7],\n          [16, 17, 18, 19]],\n \n         [[ 8,  9, 10, 11],\n          [20, 21, 22, 23]]]),\n tensor([[[ 0,  1,  2,  3],\n          [12, 13, 14, 15]],\n \n         [[ 4,  5,  6,  7],\n          [16, 17, 18, 19]],\n \n         [[ 8,  9, 10, 11],\n          [20, 21, 22, 23]]]))\n'

In [50]:
encoder = Seq2SeqEncoder(
    vocab_size=10,
    embed_size=8,
    num_hiddens=16,
    num_layers=2
)
encoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)
output, state = encoder(X)
output.shape, state.shape   ## state的shape为(隐藏层数量，批量大小，隐藏单元数量)

(torch.Size([7, 4, 16]), torch.Size([2, 4, 16]))

## 解码器

In [51]:
class Decoder(nn.Module):
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)
    def init_state(self, enc_outputs, *args):
        raise NotImplementedError
    def forward(self, X, state):
        raise NotImplementedError

In [52]:
class Seq2SeqDecoder(Decoder):
    """用于序列到序列学习的循环神经网络解码器"""
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super(Seq2SeqDecoder, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)
    def init_state(self, enc_outputs, *args):
        return enc_outputs[1]
    def forward(self, X, state):
        X = self.embedding(X).permute(1, 0, 2)
        context = state[-1].repeat(X.shape[0], 1, 1)
        X_and_context = torch.cat((X, context), 2)
        output, state = self.rnn(X_and_context, state)
        output = self.dense(output).permute(1, 0, 2)
        return output, state

In [53]:
"""example
a = torch.arange(12).reshape(1, 3, 4)
a, a.repeat(2, 1, 1), a.repeat(2, 2, 1) 
"""
"""output
(tensor([[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]]]),
 tensor([[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],
 
         [[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]]]),
 tensor([[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],
 
         [[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]]]))
"""

'output\n(tensor([[[ 0,  1,  2,  3],\n          [ 4,  5,  6,  7],\n          [ 8,  9, 10, 11]]]),\n tensor([[[ 0,  1,  2,  3],\n          [ 4,  5,  6,  7],\n          [ 8,  9, 10, 11]],\n \n         [[ 0,  1,  2,  3],\n          [ 4,  5,  6,  7],\n          [ 8,  9, 10, 11]]]),\n tensor([[[ 0,  1,  2,  3],\n          [ 4,  5,  6,  7],\n          [ 8,  9, 10, 11],\n          [ 0,  1,  2,  3],\n          [ 4,  5,  6,  7],\n          [ 8,  9, 10, 11]],\n \n         [[ 0,  1,  2,  3],\n          [ 4,  5,  6,  7],\n          [ 8,  9, 10, 11],\n          [ 0,  1,  2,  3],\n          [ 4,  5,  6,  7],\n          [ 8,  9, 10, 11]]]))\n'

In [54]:
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
output, state = decoder(X, state)
output.shape, state.shape

(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))

## 损失函数

In [55]:
def sequence_mask(X, valid_len, value=0):
    """在序列中遮蔽不相关的项"""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None,:] < valid_len[:, None]
    X[~mask] = value
    return X

In [56]:
"""example
a = torch.arange(12)[None,:]
a
"""
"""output
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]])
"""

'output\ntensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]])\n'

In [57]:
"""example
a = torch.arange(12).reshape(2,6)
a[0] = -1
a
"""
"""output
tensor([[-1, -1, -1, -1, -1, -1],
        [ 6,  7,  8,  9, 10, 11]])
"""

'output\ntensor([[-1, -1, -1, -1, -1, -1],\n        [ 6,  7,  8,  9, 10, 11]])\n'

In [58]:
X = torch.tensor([[1,2,3],[4,5,6]])
sequence_mask(X, torch.tensor([1,2]))

tensor([[1, 0, 0],
        [4, 5, 0]])

In [59]:
X = torch.ones(2,3,4)
sequence_mask(X, torch.tensor([1,2]), value=-1)

tensor([[[ 1.,  1.,  1.,  1.],
         [-1., -1., -1., -1.],
         [-1., -1., -1., -1.]],

        [[ 1.,  1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.],
         [-1., -1., -1., -1.]]])

In [60]:
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    """带遮蔽的softmax交叉熵损失函数"""
    ## pred      --> (batch_size, num_steps, vocab_size)
    ## label     --> (batch_size, num_steps)
    ## valid_len --> (batch_size, )
    def forward(self, pred, label, valid_len):
        weights = torch.ones_like(label)
        weights = sequence_mask(weights, valid_len)
        self.reduction = 'none'
        unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(
            pred.permute(0, 2, 1), label
        )
        weighted_loss = (unweighted_loss * weights).mean(dim=1)
        return weighted_loss


In [61]:
loss = MaskedSoftmaxCELoss()
loss(torch.ones(3, 4, 10), torch.ones((3,4), dtype=torch.long), torch.tensor([4,2,0]))

tensor([2.3026, 1.1513, 0.0000])