In [4]:
import collections
import torch
import math
from torch import nn
from d2l.torch import d2l


In [5]:
class seq2seqEncoder(d2l.Encoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super(seq2seqEncoder, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)

    def forward(self, X, *args):
        X = self.embedding(X)
        X = X.permute(1, 0, 2)
        output, state = self.rnn(X)
        return output, state

In [23]:
encoder = seq2seqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()

X = torch.zeros((4, 7), dtype=torch.long)
output, state = encoder(X)
output.shape, state.shape

(torch.Size([7, 4, 16]), torch.Size([2, 4, 16]))

In [33]:
class Seq2SeqDecoder(d2l.Decoder):
    """用于序列到序列学习的循环神经网络解码器"""

    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqDecoder, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers,
                          dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, *args):
        return enc_outputs[1]

    def forward(self, X, state):
        # 输出'X'的形状：(batch_size,num_steps,embed_size)
        X = self.embedding(X).permute(1, 0, 2)
        # 广播context，使其具有与X相同的num_steps
        print('X ebd [7, 4, 8]', X.shape)
        print('state[-1] (4, 16)', state[-1].shape)
        context = state[-1].repeat(X.shape[0], 1, 1)
        print('contxt (7, 4, 16)', context.shape)

        X_and_context = torch.cat((X, context), 2)
        print('X_and_context (7, 4, 24)', X_and_context.shape)
        output, state = self.rnn(X_and_context, state)
        print('output (7, 4, 16)', output.shape)
        output = self.dense(output)
        print('fc output (7, 4, 10)', output.shape)
        output = output.permute(1, 0, 2)
        print('permute output (4, 7, 10)', output.shape)
        # output的形状:(batch_size,num_steps,vocab_size)
        # state[0]的形状:(num_layers,batch_size,num_hiddens)
        return output, state


print('X (4, 7) ', X.shape)
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16,
                         num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
#
# print(state.shape)
output, state = decoder(X, state)
output.shape, state.shape


X (4, 7)  torch.Size([4, 7])
X ebd [7, 4, 8] torch.Size([7, 4, 8])
state[-1] (4, 16) torch.Size([4, 16])
contxt (7, 4, 16) torch.Size([7, 4, 16])
X_and_context (7, 4, 24) torch.Size([7, 4, 24])
output (7, 4, 16) torch.Size([7, 4, 16])
fc output (7, 4, 10) torch.Size([7, 4, 10])
permute output (4, 7, 10) torch.Size([4, 7, 10])


(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))