In [1]:
import torch
import torch.nn as nn
import torch.nn.init

In [12]:
class RNNCell(nn.Module):
    """
    Single cell of RNN model
    """

    def __init__(
        self, input_size: int, hidden_size: int, bias: bool, activation: str
    ) -> None:
        """
        Initialize recurrent neural network
        Parameters
        --------
            input_size: int
                Number of feature in the input x
            output_size: int
                Number of feature in the output y
            hidden_size: int
                Number of feature in the hidden state h
            bias: bool
                Whether to include a bias term in the linear transformations
            activation: str
                Activation function to apply to the hidden state, there are 2
                options: tanh and relu
        Returns
        -------
        nothing
        """
        super(RNNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias

        if activation not in ["tanh", "relu"]:
            raise ValueError("Invalid activation function")
        self.activation = activation
        self.x2h = nn.Linear(input_size, hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.init_parameters()

    def forward(self, input: torch.Tensor, hs_pre: torch.Tensor = None):
        """
        Computes the forward propagation of the RNN cell
        Parameters
        --------
            input: torch.Tensor
                Input tensor of shape (batch_size, input_size).
            hs_pre: torch.Tensor
                Previous hidden state tensor of shape (batch_size, hidden_size)
                 Default is None, the initial hidden state is set to zeros.
        Returns
        -------
            hs: torch.Tensor
                Output hidden state tensor of shape (batch_size, hidden_size).
        """
        if hs_pre is None:
            hs_pre = torch.zeros(input.size(0), self.hidden_size)
        hs = self.x2h(input) + self.h2h(hs_pre)
        if self.activation == "tanh":
            hs = torch.tanh(hs)
        else:
            hs = torch.relu(hs)
        return hs

    def init_parameters(self) -> None:
        """
        Initialize the weights and biases of the RNN cell
        followed by Xavier normalization
        Parameters
        --------
            None
        Returns
        -------
            None
        """
        for name, param in self.named_parameters():
            if "weight" in name:
                torch.nn.init.xavier_uniform_(param)
            if "bias" in name:
                param = param.view(1, param.size(0))
                torch.nn.init.xavier_uniform_(param)

In [13]:
class RNN(nn.Module):
    """
    Implement RNN model
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        num_layers: int,
        bias: bool,
        activation="str",
    ) -> None:
        """
        Recurrent Neural Network (RNN) model.
        Parameters
        --------
            input_size: int
                Number of features in the input x
            hidden_size: int
                Number of features in the hidden state h
            output_size: int
                Number of features in the output y
            num_layers: int
                Number of RNN cell layers
            bias: bool
                Whether to include a bias term in the linear transformations
            activation: str
                Activation function to apply to the hidden state,
                there are 2 options: tanh and relu
        Returns
        --------
        None
        """
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.output_size = output_size
        if activation not in ["tanh", "relu"]:
            raise ValueError("Invalid activation function")
        self.fc = nn.Linear(hidden_size, output_size, bias = self.bias)
        self.init_layer(activation)

    def forward(self, input, hs_pre=None) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the forward propagation of the RNN model
        Parameters
        --------
            input: torch.Tensor
                Input tensor of shape (batch_size, sequence_length, input_size)
            hs_pre: torch.Tensor
                Previous hidden state tensor of shape
                (num_layers, batch_size, hidden_size).
                Default is None, the initial hidden state is set to zeros

        Returns: tuple[torch.Tensor, torch.Tensor]
            out: torch.Tensor
                Output tensor of shape (batch_size, output_size)
            hs_final: torch.Tensor
                Final hidden state of shape (num_layers, batch_size, hidden_size)
        """
        if hs_pre is None:
            hs_pre = torch.zeros(self.num_layers, input.size(0), self.hidden_size)
        output = []
        hidden_layers = list(hs_pre)
        for t in range(input.size(1)):
            for layer in range(self.num_layers):
                if layer == 0:
                    hidden = self.rnn_cell_list[layer].forward(
                        input[:, t, :], hidden_layers[layer]
                    )
                else:
                    hidden = self.rnn_cell_list[layer](
                        hidden_layers[layer - 1], hidden_layers[layer]
                    )
                hidden_layers[layer] = hidden
            output.append(hidden)
        out = output[-1].squeeze()
        out = self.fc(out)
        hs_final = torch.stack(hidden_layers, dim=0)
        return out, hs_final

    def init_layer(self, activation: str):
        """
        Initialize the RNN cell list.
        Parameters
        --------
            None
        Returns
        -------
            activation: str
                Activation function to apply to the hidden state,
                there are 2 options: tanh and relu

        """
        self.rnn_cell_list = nn.ModuleList()
        if activation == "tanh":
            self.rnn_cell_list.append(
                RNNCell(self.input_size, self.hidden_size, self.bias, "tanh")
            )
            for _ in range(1, self.num_layers):
                self.rnn_cell_list.append(
                    RNNCell(self.hidden_size, self.hidden_size, self.bias, "tanh")
                )
        elif activation == "relu":
            self.rnn_cell_list.append(
                RNNCell(self.input_size, self.hidden_size, self.bias, "relu")
            )
            for _ in range(1, self.num_layers):
                self.rnn_cell_list.append(
                    RNNCell(self.hidden_size, self.hidden_size, self.bias, "relu")
                )

In [14]:
input_dim = 28  # input dimension
hidden_dim = 100  # hidden layer dimension
layer_dim = 1  # number of hidden layers
output_dim = 10
batch_size = 100
seq_len = 28
test = RNN(28, 100, 10, 1, True, "tanh")
x = torch.rand( seq_len, batch_size, input_dim)

In [15]:
print(test.forward(x)[0].size(), test.forward(x)[0])

torch.Size([28, 10]) tensor([[ 0.0888,  0.3144,  0.3136, -0.0854, -0.0362,  0.0709,  0.8744,  0.5221,
          0.0130, -0.0313],
        [ 0.0687,  0.1406,  0.0350, -0.0355, -0.3912, -0.2070,  0.6919,  0.5058,
          0.0596, -0.2201],
        [ 0.0422,  0.2142,  0.1666, -0.2248, -0.0453, -0.0993,  0.2942,  0.2005,
         -0.1305, -0.1414],
        [-0.0882, -0.0853,  0.4190, -0.2123, -0.0616, -0.1202,  0.4652,  0.2199,
         -0.2855, -0.0063],
        [-0.1265,  0.2059,  0.3833, -0.2442, -0.3663, -0.2000,  0.6116,  0.5147,
         -0.0566, -0.0489],
        [-0.1466,  0.1200,  0.0643, -0.3003,  0.1424, -0.0600,  0.7734,  0.5787,
         -0.0382, -0.0642],
        [-0.2135,  0.0499,  0.1985, -0.3844, -0.1370, -0.0377,  0.7577,  0.4427,
         -0.0693, -0.2159],
        [-0.2944, -0.0195,  0.2573, -0.2044, -0.1323,  0.0147,  0.7252,  0.4054,
         -0.3397, -0.1801],
        [-0.1267,  0.1813,  0.3451, -0.0504, -0.3300, -0.0812,  0.6150,  0.6598,
          0.2830, -0.1583]

In [16]:
print(test.forward(x)[1].size(), test.forward(x)[1])

torch.Size([1, 28, 100]) tensor([[[ 0.0841, -0.4109, -0.1288,  ..., -0.3316, -0.0921, -0.1590],
         [ 0.3775, -0.0048, -0.4009,  ...,  0.2302,  0.5344, -0.6408],
         [ 0.0831, -0.6643, -0.4147,  ..., -0.1601,  0.0439, -0.4168],
         ...,
         [ 0.3204, -0.1186, -0.0134,  ..., -0.4212,  0.6957, -0.5847],
         [ 0.3830, -0.7463, -0.1409,  ..., -0.0421,  0.3670, -0.8187],
         [ 0.1241, -0.5083,  0.0676,  ...,  0.2457,  0.6423, -0.6868]]],
       grad_fn=<StackBackward0>)
