In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from d2l import torch as d2l

In [2]:
# RNN定义模型
vocab_size = 128 # 词典大小vocab_size = len(vocab)
num_steps, batch_size = 35, 32
num_hiddens = 256
rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens)
state = torch.zeros((1, batch_size, num_hiddens))

X = torch.rand(size=(num_steps, batch_size, vocab_size))
Y, state_new = rnn_layer(X, state)

In [3]:
# RNN模型
class RNNModel(nn.Module):
    def __init__(self, rnn_layer, vocab_size):
        super(RNNModel, self).__init__()
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果不是双向的，num_directions=1，否则为2
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)
        self.state = None

    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        Y, self.state = self.rnn(X, state)
        # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens)
        # 它的输出形状为(num_steps * batch_size, vocab_size)
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, self.state

    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            #nn.GRU
            return torch.zeros((self.num_directions * self.rnn.num_layers, batch_size, self.num_hiddens), device=device)
        else:
            #nn.LSTM
            return (torch.zeros((self.num_directions * self.rnn.num_layers, batch_size, self.num_hiddens), device=device),
                    torch.zeros((self.num_directions * self.rnn.num_layers, batch_size, self.num_hiddens), device=device))