<a href="https://colab.research.google.com/github/ohmreborn/rnn-lstm-gru-pytorch-from-scratch/blob/main/explain_rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import numpy as np

In [None]:
class RNNCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"):
        super(RNNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.nonlinearity = nonlinearity
        if self.nonlinearity not in ["tanh", "relu"]:
            raise ValueError("Invalid nonlinearity selected for RNN.")

        self.x2h = nn.Linear(input_size, hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, hidden_size, bias=bias)

        self.reset_parameters()


    def reset_parameters(self):
        std = 1.0 / np.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)


    def forward(self, input, hx=None):

        # Inputs:
        #       input: of shape (batch_size, input_size)
        #       hx: of shape (batch_size, hidden_size)
        # Output:
        #       hy: of shape (batch_size, hidden_size)

        if hx is None:
            hx = input.new_zeros(input.size(0), self.hidden_size)

        hy = (self.x2h(input) + self.h2h(hx))

        if self.nonlinearity == "tanh":
            hy = torch.tanh(hy)
        else:
            hy = torch.relu(hy)

        return hy

In [None]:
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, bias, output_size, activation='tanh'):
        super(SimpleRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.output_size = output_size

        self.rnn_cell_list = nn.ModuleList()

        if activation == 'tanh':
            self.rnn_cell_list.append(RNNCell(self.input_size,
                                                   self.hidden_size,
                                                   self.bias,
                                                   "tanh"))
            for l in range(1, self.num_layers):
                self.rnn_cell_list.append(RNNCell(self.hidden_size,
                                                       self.hidden_size,
                                                       self.bias,
                                                       "tanh"))

        elif activation == 'relu':
            self.rnn_cell_list.append(RNNCell(self.input_size,
                                                   self.hidden_size,
                                                   self.bias,
                                                   "relu"))
            for l in range(1, self.num_layers):
                self.rnn_cell_list.append(RNNCell(self.hidden_size,
                                                   self.hidden_size,
                                                   self.bias,
                                                   "relu"))
        else:
            raise ValueError("Invalid activation.")

        # self.fc = nn.Linear(self.hidden_size, self.output_size)


    def forward(self, input, hx=None):

        # Input of shape (batch_size, seqence length, input_size)
        #
        # Output of shape (batch_size, output_size)

        if hx is None:
            if torch.cuda.is_available():
                h0 = torch.zeros(self.num_layers, input.size(0), self.hidden_size).cuda()
            else:
                h0 = torch.zeros(self.num_layers, input.size(0), self.hidden_size)

        else:
             h0 = hx

        outs = []

        hidden = list()
        for layer in range(self.num_layers):
            hidden.append(h0[layer, :, :])

        for t in range(input.size(1)):

            for layer in range(self.num_layers):

                if layer == 0:
                    hidden_l = self.rnn_cell_list[layer](input[:, t, :], hidden[layer])
                else:
                    hidden_l = self.rnn_cell_list[layer](hidden[layer - 1],hidden[layer])
                hidden[layer] = hidden_l
            outs.append(hidden_l)

        # Take only last time step. Modify for seq to seq
        # out = outs[-1].squeeze()

        # out = self.fc(out)


        return torch.stack(outs),hidden

In [None]:
model = SimpleRNN(10, 20, 2,True,20)

In [None]:
weight = {'weight_ih_l0':'rnn_cell_list.0.x2h.weight',
 'weight_hh_l0':'rnn_cell_list.0.h2h.weight',
 'bias_ih_l0':'rnn_cell_list.0.x2h.bias',
 'bias_hh_l0':'rnn_cell_list.0.h2h.bias',
 'weight_ih_l1':'rnn_cell_list.1.x2h.weight',
 'weight_hh_l1':'rnn_cell_list.1.h2h.weight',
 'bias_ih_l1':'rnn_cell_list.1.x2h.bias',
 'bias_hh_l1':'rnn_cell_list.1.h2h.bias'
 }

In [None]:
state = model.state_dict()
from collections import OrderedDict

d = OrderedDict()
d['weight_ih_l0'] = state['rnn_cell_list.0.x2h.weight']
d['weight_hh_l0'] = state['rnn_cell_list.0.h2h.weight']
d['bias_ih_l0'] = state['rnn_cell_list.0.x2h.bias']
d['bias_hh_l0'] = state['rnn_cell_list.0.h2h.bias']
d['weight_ih_l1'] = state['rnn_cell_list.1.x2h.weight']
d['weight_hh_l1'] = state['rnn_cell_list.1.h2h.weight']
d['bias_ih_l1'] = state['rnn_cell_list.1.x2h.bias']
d['bias_hh_l1'] = state['rnn_cell_list.1.h2h.bias']

d

In [None]:
rnn = nn.RNN(10, 20, 2,batch_first=True)
rnn.load_state_dict(d)
rnn

In [None]:
input = torch.randn(3, 5, 10)
h0 = torch.randn(2, 3, 20)

In [None]:
output, hn = rnn(input, h0)
o,h = model(input,h0)


In [None]:
hn

In [None]:
torch.stack(h)