In [1]:
import torch
import torch.nn as nn


# model
class RNN(nn.Module):
    def __init__(self, hidden_dim=2, input_dim=2, output_dim=1):
        super(RNN, self).__init__()

        # Defining some parameters
        self.hidden_dim = hidden_dim
        self.Wh = nn.Linear(hidden_dim, hidden_dim)
        self.Wx = nn.Linear(input_dim, hidden_dim)
        self.Wy = nn.Linear(hidden_dim, output_dim)
        # activation function
        self.act = nn.Sigmoid()
        #self.act = lambda x: x
    
    def forward(self, x):
        # input shape = (batch_size, sequence length, input dimension)
        batch_size = x.size(0)
        seq_length = x.size(1)

        hidden = torch.nn.Parameter(torch.zeros(batch_size, self.hidden_dim))
        outs = []

        for i in range(seq_length):
            hidden = self.act(self.Wh(hidden) + self.Wx(x[:,i,:]))
            out = self.Wy(hidden)
            outs.append(out)
        
        # output shape = (batch_size, sequence length, output dimension)
        return torch.stack(outs).permute(1,0,2)
    

model = RNN(hidden_dim=2, input_dim=2, output_dim=1)
print(model)

### feed data to RNN ###
batch_size = 128
seq_len = 10
input_dim = 2
x = torch.normal(0,1,size=(batch_size, seq_len, input_dim))
model(x).shape

RNN(
  (Wh): Linear(in_features=2, out_features=2, bias=True)
  (Wx): Linear(in_features=2, out_features=2, bias=True)
  (Wy): Linear(in_features=2, out_features=1, bias=True)
  (act): Sigmoid()
)


torch.Size([128, 10, 1])