In [2]:
import torch
from torch import nn

In [3]:
RNN_TYPES = ['RNN', 'LSTM', 'GRU']

class RNN(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, output_dim,
                 rnn_type='RNN'):

        super().__init__()
        self.output_dim = output_dim

        assert rnn_type in RNN_TYPES, f'Use one of the following: {str(RNN_TYPES)}'
        RnnCell = getattr(nn, rnn_type)
        self.rnn = RnnCell(embedding_dim, hidden_dim, batch_first=True, bias=False)
        self.fc = nn.Linear(hidden_dim, output_dim, bias=False)

    def forward(self, X):
        rnn_out, _ = self.rnn(X)
        fc_out = self.fc(rnn_out)
        return fc_out

In [4]:
batch_size = 32
seq_len = 15
embedding_dim = 100
hidden_dim = 20
vocab_size = 5_000
rnn_type = 'RNN'

In [7]:
X = torch.randn(batch_size, seq_len, embedding_dim)
y = torch.randint(vocab_size, (batch_size * seq_len,))

rnn = RNN(embedding_dim=embedding_dim,
          hidden_dim=hidden_dim,
          output_dim=vocab_size,
          rnn_type='RNN')

cel = nn.CrossEntropyLoss()

In [8]:
out = rnn(X).squeeze().view(batch_size * seq_len, vocab_size)
cel(out, y)

tensor(8.6194, grad_fn=<NllLossBackward>)