In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [50]:
class Network(nn.Module):
    def __init__(self, input_size, output_size, hidden_dim, n_layers):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        
        # First dimenstion of input and output vector will be the batch_size
        self.rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True)
        
        self.fc = nn.Linear(hidden_dim, output_size)
        
    def forward(self, x, hidden):
        batch_size = x.size(0)
        
        r_out, hidden = self.rnn(x, hidden)
        r_out_changed = r_out.contiguous().view(-1, self.hidden_dim)
        
        output = self.fc(r_out_changed)
        
        return r_out, hidden, output

In [51]:
# argument -> (input_size, output_size, hidden_dim, n_layers)
model = Network(10,5,3,2)
model

Network(
  (rnn): RNN(10, 3, num_layers=2, batch_first=True)
  (fc): Linear(in_features=3, out_features=5, bias=True)
)

In [52]:
# input -> (batch, seq, feature)
# hidden_in -> (num_layers * num_directions, batch, hidden_dim)
input_tensor = torch.randn((5, 4, 10))
hidden = torch.randn((2, 5, 3))

In [55]:
# output -> (batch, seq_len, hidden_dim)
# hidden_out -> (num_layers * num_directions, batch, hidden_dim)
out, h_out, fc_out = model(input_tensor, hidden)[0], model(input_tensor, hidden)[1],model(input_tensor, hidden)[2]
print("Shape of rnn_output:", out.shape)
print("Shape of hidden state:", h_out.shape)
print("Shape of final layer output:", fc_out.shape)

Shape of output: torch.Size([5, 4, 3])
Shape of hidden state: torch.Size([2, 5, 3])
Shape of final layer output: torch.Size([20, 5])


In [56]:
fc_out

tensor([[-0.2847, -0.0406,  0.3030,  0.0081,  0.5255],
        [-0.4608, -0.1324,  0.0768,  0.3836, -0.0933],
        [ 0.0510,  0.1256,  0.4968, -0.2799,  0.5112],
        [-0.4781,  0.2612,  0.5003,  0.7011, -0.2417],
        [-0.6796,  0.2401,  0.4115,  1.0423, -0.5627],
        [-0.0991,  0.3143,  0.5867,  0.2506, -0.1416],
        [-0.1302,  0.2787,  0.6195,  0.1241,  0.2477],
        [-0.3642,  0.2454,  0.4932,  0.5395, -0.1933],
        [-0.1784,  0.4543,  0.7891,  0.3489,  0.0842],
        [-0.4610,  0.1843,  0.3604,  0.7321, -0.5137],
        [ 0.0405,  0.1482,  0.4878, -0.1872,  0.3283],
        [-0.1404,  0.0516,  0.4139, -0.1099,  0.5389],
        [-0.0539,  0.0978,  0.4538, -0.1555,  0.4621],
        [-0.1854,  0.2511,  0.5683,  0.2075,  0.1503],
        [-0.2954,  0.0885,  0.3332,  0.3252, -0.0936],
        [-0.0208,  0.3336,  0.6964,  0.0060,  0.3110],
        [-0.3243, -0.0073,  0.2000,  0.3435, -0.2134],
        [ 0.0105,  0.1691,  0.5372, -0.1883,  0.4645],
        [-