In [26]:
import torch
from torch import nn
from d2l import d2l
from torch.functional import F

In [57]:
class Model(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers=1):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(self.vocab_size, self.hidden_size, self.num_layers)
        self.linear = nn.Linear(self.hidden_size, self.vocab_size)
        
    def forward(self, inputs, state):
        output, state = self.rnn(inputs, state)
        return self.linear(output), state

In [70]:
device = torch.device('cuda:0')
# device = torch.device('cpu')

batch_size = 128
num_steps = 10
train_iter, vocab = d2l.load_data_time_machine(batch_size=batch_size,
                                               num_steps=num_steps,
                                               token_type='char')


hidden_size = 1024
vocab_size = len(vocab)
num_layers = 4
net = Model(vocab_size=vocab_size, hidden_size=hidden_size, num_layers=num_layers)
net = net.to(device)

In [71]:
lr = 0.2
num_epochs = 1000

loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=lr)

In [None]:
for epoch in range(num_epochs):
    state = None
    for X_idx, y_idx in train_iter:
        X = F.one_hot(X_idx.T, len(vocab)).to(torch.float32)
        y = F.one_hot(y_idx.T, len(vocab)).to(torch.float32)
        if state is None:
            state = torch.zeros(num_layers, batch_size, hidden_size)
        else:
            state.detach_()

        X = X.to(device)
        y = y.to(device)
        state = state.to(device)

        y_hat, state = net(X, state)
#         print(y.shape, y_hat.shape)
    #     print(X.shape)
    #     print(y.shape)
    #     print(y_hat.shape)
    #     print(state.shape)

        l = loss(y_hat, y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
        print('epoch:', epoch, 'loss:', l.item(), end='\r')
    #     break

epoch: 585 loss: 9.7203321456909183