In [1]:
import torch
import torch.nn as nn

In [2]:
class RNN(nn.Module):

    # you can also accept arguments in your model constructor
    def __init__(self, data_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size
        input_size = data_size + hidden_size

        self.i2h = nn.Linear(input_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)

    def forward(self, data, last_hidden):
        input = torch.cat((data, last_hidden), 1)
        hidden = self.i2h(input)
        output = self.h2o(hidden)
        return hidden, output


rnn = RNN(50, 20, 10)
print(rnn)

RNN(
  (i2h): Linear(in_features=70, out_features=20, bias=True)
  (h2o): Linear(in_features=20, out_features=10, bias=True)
)


In [4]:
loss_fn = nn.MSELoss()

batch_size = 10
TIMESTEPS = 5

# Create some fake data
batch = torch.randn(batch_size, 50)
hidden = torch.randn(batch_size, 20)
target = torch.randn(batch_size, 10)

loss = 0
for i in range(TIMESTEPS):
    hidden, output = rnn(batch, hidden)
    loss += loss_fn(output, target)
    
loss.backward()

In [9]:
param = list(rnn.parameters())

for s in range(len(param)):
    print(param[s].size())

torch.Size([20, 70])
torch.Size([20])
torch.Size([10, 20])
torch.Size([10])
