In [130]:
import torch
import torch.nn as nn

In [131]:
torch.manual_seed(1)

<torch._C.Generator at 0x7fba0bbd0e70>

In [132]:
rnn_layer = nn.RNN(input_size=5, hidden_size=2, num_layers=1, batch_first=True)

In [133]:
rnn_layer

RNN(5, 2, batch_first=True)

In [134]:
w_xh = rnn_layer.weight_ih_l0
w_hh = rnn_layer.weight_hh_l0
b_xh = rnn_layer.bias_ih_l0
b_hh = rnn_layer.bias_hh_l0

In [135]:
print(w_xh)

Parameter containing:
tensor([[ 0.3643, -0.3121, -0.1371,  0.3319, -0.6657],
        [ 0.4241, -0.1455,  0.3597,  0.0983, -0.0866]], requires_grad=True)


In [136]:
print('W_xh shape:', w_xh.shape)
print('W_hh shape:', w_hh.shape)
print('b_xh shape:', b_xh.shape)
print('b_hh shape:', b_hh.shape)

W_xh shape: torch.Size([2, 5])
W_hh shape: torch.Size([2, 2])
b_xh shape: torch.Size([2])
b_hh shape: torch.Size([2])


In [137]:
x_seq = torch.tensor([[1.0]*5, [2.0]*5, [3.0]*5]).float()
x_seq

tensor([[1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.]])

In [138]:
torch.reshape(x_seq, (1, 3, 5))

tensor([[[1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3.]]])

In [139]:
## output of the simple RNN:
output, hn = rnn_layer(torch.reshape(x_seq, (1, 3, 5))) # x_seq sequence has reshaped as (1,3,5), it was already as (3,5) ==> 5 = input_size
print(output)
print('hn', hn) # ==> (1,1,2)

tensor([[[-0.3520,  0.5253],
         [-0.6842,  0.7607],
         [-0.8649,  0.9047]]], grad_fn=<TransposeBackward1>)
hn tensor([[[-0.8649,  0.9047]]], grad_fn=<StackBackward0>)


In [140]:
out_man = []
for t in range(3):
    xt = torch.reshape(x_seq[t], (1,5))
    print(f'Time step {t} =>')
    print('Input :', xt.numpy())

    ht = torch.matmul(xt, torch.transpose(w_xh, 0, 1)) + b_xh
    print('Hidden :', ht.detach().numpy())

    if t > 0:
        print('t > 0')
        prev_h = out_man[t-1]
    else :
        print('t = 0 time')
        prev_h = torch.zeros((ht.shape))



    ot = ht + torch.matmul(prev_h, torch.transpose(w_hh, 0, 1)) + b_hh

    ot = torch.tanh(ot)
    out_man.append(ot)
    print('Output (manual) :', ot.detach().numpy())
    print('RNN output ', output[:, t].detach().numpy())



Time step 0 =>
Input : [[1. 1. 1. 1. 1.]]
Hidden : [[-0.4701929  0.5863904]]
t = 0 time
Output (manual) : [[-0.3519801   0.52525216]]
RNN output  [[-0.3519801   0.52525216]]
Time step 1 =>
Input : [[2. 2. 2. 2. 2.]]
Hidden : [[-0.88883156  1.2364397 ]]
t > 0
Output (manual) : [[-0.68424344  0.76074266]]
RNN output  [[-0.68424344  0.76074266]]
Time step 2 =>
Input : [[3. 3. 3. 3. 3.]]
Hidden : [[-1.3074701  1.886489 ]]
t > 0
Output (manual) : [[-0.8649416   0.90466356]]
RNN output  [[-0.8649416   0.90466356]]
