In [1]:
import numpy as np
import torch
import torch.nn as nn

## nn.RNN from scratch

### 1) Definitions

In [2]:
input_sequence_length = 10
hidden_units = 1
input_dimension = 1

In [3]:
rnn_input = torch.randn(1, input_sequence_length, input_dimension)
print(rnn_input.shape)
rnn_input

torch.Size([1, 10, 1])


tensor([[[ 2.2424],
         [ 0.1419],
         [ 0.3966],
         [ 1.4313],
         [ 0.3641],
         [-0.6464],
         [-0.6286],
         [-0.5761],
         [ 0.7838],
         [ 0.4501]]])

### 2) nn.Parameter implementation

In [22]:
# W_ih = nn.Parameter(torch.randn(hidden_units, input_dimension))
# W_hh = nn.Parameter(torch.randn(hidden_units, hidden_units))

# b_ih = nn.Parameter(torch.randn(hidden_units))
# b_hh = nn.Parameter(torch.randn(hidden_units))

W_ih = nn.Parameter(torch.ones(hidden_units, input_dimension))
W_hh = nn.Parameter(torch.ones(hidden_units, hidden_units))

b_ih = nn.Parameter(torch.ones(hidden_units))
b_hh = nn.Parameter(torch.ones(hidden_units))

relu = nn.ReLU()

In [23]:
def get_h(h_t_prev, rnn_input_t):
    return relu(W_ih.T * rnn_input_t + b_ih + h_t_prev * W_hh.T + b_hh).detach()

In [31]:
h_array = [0]
for i in range(input_sequence_length):
    h_array.append(get_h(h_array[-1], rnn_input[:,i]))
h_array.pop(0)
output1 = np.array([x.detach()[0] for x in h_array])
output1

array([[ 4.242429],
       [ 6.38429 ],
       [ 8.780897],
       [12.212207],
       [14.576307],
       [15.929882],
       [17.301323],
       [18.72527 ],
       [21.509113],
       [23.959234]], dtype=float32)

### 3) nn.Linear implementation

In [7]:
class RNN2(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(RNN2, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        
        self.i2h.weight.data = torch.ones(self.i2h.weight.shape)
        self.i2h.bias.data = torch.ones(self.i2h.bias.shape)
        
        self.h2h.weight.data = torch.ones(self.h2h.weight.shape)
        self.h2h.bias.data = torch.ones(self.h2h.bias.shape)
        
        # self.h2o = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, input, hidden):
        # combined = torch.cat((input, hidden), 1)
        i2h = self.i2h(input)
        h2h = self.h2h(hidden)
        # output = self.h2o(hidden)
        output = self.relu(i2h + h2h)
        return output

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

n_hidden = 1
rnn2 = RNN2(input_dimension, hidden_units)

In [32]:
hidden = rnn2.initHidden()

rnn2.zero_grad()

output = []

for i in range(input_sequence_length):
    hidden = rnn2(rnn_input[:,i], hidden)
    output.append(hidden.detach()[0])
output2 = np.array(output)
output2

array([[ 4.242429],
       [ 6.38429 ],
       [ 8.780897],
       [12.212207],
       [14.576307],
       [15.929882],
       [17.301323],
       [18.72527 ],
       [21.509113],
       [23.959234]], dtype=float32)

### 4) nn.RNN implementation

In [9]:
rnn = nn.RNN(input_size=input_dimension, hidden_size=hidden_units, num_layers=1, batch_first=True, nonlinearity='relu')

rnn.weight_ih_l0.data = torch.ones(rnn.weight_ih_l0.shape)
rnn.bias_ih_l0.data = torch.ones(rnn.bias_ih_l0.shape)
rnn.weight_hh_l0.data = torch.ones(rnn.weight_hh_l0.shape)
rnn.bias_hh_l0.data = torch.ones(rnn.bias_hh_l0.shape)

In [10]:
# h_0 = torch.tensor([[[h_array[0]]]], dtype=torch.float32)
h_0 = torch.zeros((1, 1, hidden_units), dtype=torch.float32)
h_0.shape

torch.Size([1, 1, 1])

In [33]:
output = rnn(rnn_input, h_0)
output3 = output[0].detach().numpy()
output3

array([[[ 4.242429],
        [ 6.38429 ],
        [ 8.780897],
        [12.212207],
        [14.576307],
        [15.929882],
        [17.301323],
        [18.72527 ],
        [21.509113],
        [23.959234]]], dtype=float32)

In [38]:
output1 == output2

array([[ True],
       [ True],
       [ True],
       [ True],
       [ True],
       [ True],
       [ True],
       [ True],
       [ True],
       [ True]])

In [39]:
output1 == output3

array([[[ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True]]])