In [2]:
import torch
import torch.nn as nn
import torch.nn.init

In [3]:
class RNNCell(nn.Module):
    """
    Single cell of RNN model
    """

    def __init__(
        self, input_size: int, hidden_size: int, bias: bool, activation: str
    ) -> None:
        """
        Initialize recurrent neural network
        Parameters
        --------
            input_size: int
                Number of feature in the input x
            output_size: int
                Number of feature in the output y
            hidden_size: int
                Number of feature in the hidden state h
            bias: bool
                Whether to include a bias term in the linear transformations
            activation: str
                Activation function to apply to the hidden state, there are 2
                options: tanh and relu
        Returns
        -------
        nothing
        """
        super(RNNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias

        if activation not in ["tanh", "relu"]:
            raise ValueError("Invalid activation function")
        self.activation = activation
        self.x2h = nn.Linear(input_size, hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.init_parameters()

    def forward(self, input: torch.Tensor, hs_pre: torch.Tensor = None):
        """
        Computes the forward propagation of the RNN cell
        Parameters
        --------
            input: torch.Tensor
                Input tensor of shape (batch_size, input_size).
            hs_pre: torch.Tensor
                Previous hidden state tensor of shape (batch_size, hidden_size)
                 Default is None, the initial hidden state is set to zeros.
        Returns
        -------
            hs: torch.Tensor
                Output hidden state tensor of shape (batch_size, hidden_size).
        """
        if hs_pre is None:
            hs_pre = torch.zeros(input.size(0), self.hidden_size)
        hs = self.x2h(input) + self.h2h(hs_pre)
        if self.activation == "tanh":
            hs = torch.tanh(hs)
        else:
            hs = torch.relu(hs)
        return hs

    def init_parameters(self) -> None:
        """
        Initialize the weights and biases of the RNN cell
        followed by Xavier normalization
        Parameters
        --------
            None
        Returns
        -------
            None
        """
        for name, param in self.named_parameters():
            if "weight" in name:
                torch.nn.init.xavier_uniform_(param)
            if "bias" in name:
                param = param.view(1, param.size(0))
                torch.nn.init.xavier_uniform_(param)

In [12]:
class RNN(nn.Module):
    """
    Implement RNN model
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        num_layers: int,
        bias: bool,
        activation="str",
    ) -> None:
        """
        Recurrent Neural Network (RNN) model.
        Parameters
        --------
            input_size: int
                Number of features in the input x
            hidden_size: int
                Number of features in the hidden state h
            output_size: int
                Number of features in the output y
            num_layers: int
                Number of RNN cell layers
            bias: bool
                Whether to include a bias term in the linear transformations
            activation: str
                Activation function to apply to the hidden state,
                there are 2 options: tanh and relu
        Returns
        --------
        None
        """
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.output_size = output_size
        if activation not in ["tanh", "relu"]:
            raise ValueError("Invalid activation function")
        self.fc = nn.Linear(hidden_size, output_size, bias = self.bias)
        self.init_layer(activation)

    def forward(self, input, hs_pre=None) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the forward propagation of the RNN model
        Parameters
        --------
            input: torch.Tensor
                Input tensor of shape (batch_size, sequence_length, input_size)
            hs_pre: torch.Tensor
                Previous hidden state tensor of shape
                (num_layers, batch_size, hidden_size).
                Default is None, the initial hidden state is set to zeros

        Returns: tuple[torch.Tensor, torch.Tensor]
            out: torch.Tensor
                Output tensor of shape (batch_size, output_size)
            hs_final: torch.Tensor
                Final hidden state of shape (num_layers, batch_size, hidden_size)
        """
        if hs_pre is None:
            hs_pre = torch.zeros(self.num_layers, input.size(0), self.hidden_size)
        output = []
        hidden_layers = list(hs_pre)
        for t in range(input.size(1)):
            for layer in range(self.num_layers):
                if layer == 0:
                    hidden = self.rnn_cell_list[layer].forward(
                        input[:, t, :], hidden_layers[layer]
                    )
                else:
                    hidden = self.rnn_cell_list[layer].forward(
                        hidden_layers[layer - 1], hidden_layers[layer]
                    )
                hidden_layers[layer] = hidden
            output.append(hidden)
        out = output[-1].squeeze()
        out = self.fc(out)
        hs_final = torch.stack(hidden_layers, dim=0)
        return out, hs_final

    def init_layer(self, activation: str):
        """
        Initialize the RNN cell list.
        Parameters
        --------
            None
        Returns
        -------
            activation: str
                Activation function to apply to the hidden state,
                there are 2 options: tanh and relu

        """
        self.rnn_cell_list = nn.ModuleList()
        if activation == "tanh":
            self.rnn_cell_list.append(
                RNNCell(self.input_size, self.hidden_size, self.bias, "tanh")
            )
            for _ in range(1, self.num_layers):
                self.rnn_cell_list.append(
                    RNNCell(self.hidden_size, self.hidden_size, self.bias, "tanh")
                )
        elif activation == "relu":
            self.rnn_cell_list.append(
                RNNCell(self.input_size, self.hidden_size, self.bias, "relu")
            )
            for _ in range(1, self.num_layers):
                self.rnn_cell_list.append(
                    RNNCell(self.hidden_size, self.hidden_size, self.bias, "relu")
                )

In [13]:
input_dim = 28  # input dimension
hidden_dim = 100  # hidden layer dimension
layer_dim = 1  # number of hidden layers
output_dim = 10
batch_size = 100
seq_len = 28
test = RNN(28, 100, 10, 1, True, "tanh")
x = torch.rand( seq_len, batch_size, input_dim)

In [14]:
print(test.forward(x)[0].size(), test.forward(x)[0])

torch.Size([28, 10]) tensor([[ 0.3240,  0.1318, -0.3015,  0.2144, -0.0220, -0.0358,  0.1316, -0.1483,
         -0.2284,  0.4579],
        [ 0.0450,  0.3146, -0.0773,  0.2080, -0.0566, -0.3204, -0.1115,  0.1054,
         -0.0486,  0.6179],
        [ 0.1036,  0.1987, -0.4135,  0.1749, -0.1761,  0.0322, -0.1121, -0.1917,
          0.1592,  0.5641],
        [ 0.2084,  0.2986, -0.5458,  0.1055,  0.1568, -0.2350, -0.0763, -0.0389,
         -0.1862,  0.2424],
        [ 0.2791,  0.0548, -0.1757,  0.2230,  0.0735, -0.3046, -0.0459, -0.1348,
          0.0254,  0.4173],
        [ 0.1985, -0.0448, -0.2171,  0.3067,  0.0311, -0.2138, -0.0950, -0.1290,
          0.0707,  0.6437],
        [ 0.2949,  0.1009, -0.4130,  0.1254, -0.0622, -0.4116, -0.0914,  0.0814,
          0.0252,  0.3516],
        [ 0.2482, -0.0756, -0.6072,  0.1497, -0.0643, -0.2225,  0.0069, -0.1661,
          0.0760,  0.3967],
        [ 0.2006,  0.2145, -0.4609,  0.1922, -0.2438, -0.1063, -0.0483,  0.0916,
          0.1057,  0.2624]

In [15]:
print(test.forward(x)[1].size(), test.forward(x)[1])

torch.Size([1, 28, 100]) tensor([[[-0.8112,  0.5459,  0.4067,  ..., -0.1415,  0.6768,  0.7035],
         [-0.4754,  0.6991,  0.4315,  ..., -0.6095,  0.7262,  0.7983],
         [ 0.0216,  0.6962,  0.2407,  ..., -0.4083,  0.5292,  0.8893],
         ...,
         [-0.5000,  0.3047,  0.2183,  ..., -0.5036,  0.6181,  0.7862],
         [-0.5872,  0.7028,  0.0917,  ..., -0.3595,  0.6892,  0.8751],
         [-0.5660,  0.7532,  0.4371,  ..., -0.0735,  0.5856,  0.7527]]],
       grad_fn=<StackBackward0>)
