In [1]:
import torch
import torch.nn as nn
import torch.nn.init

In [2]:
class GRUCell(nn.Module):
    """
    Single cell of GRU
    """

    def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
        """
        Initialize gated recurrent unit cell
        Parameters
        --------
          input_size: int
            The number of expected features in the input x
          hidden_size: int
            The number of features in the hidden state h
          bias: bool
            Optional, if False,the layer doesn't use bias weights b_ih and b_hh
            Default: True
        Returns
        -------
            None

        """
        super(GRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias

        # reset gate (r)
        self.reset_i2r = nn.Linear(input_size, hidden_size, bias=bias)
        self.reset_h2r = nn.Linear(hidden_size, hidden_size, bias=bias)

        # update gate (z)
        self.update_i2z = nn.Linear(input_size, hidden_size, bias=bias)
        self.update_h2z = nn.Linear(hidden_size, hidden_size, bias=bias)

        # almost output (n)
        self.output_i2n = nn.Linear(input_size, hidden_size, bias=bias)
        self.output_h2n = nn.Linear(hidden_size, hidden_size, bias=bias)

        self.init_parameters()

    def reset_gate(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        """
        x: size (batch_size, input_size)
        h: size (batch_size, hidden_size)
        r: size (batch_size, hidden_size)
        """
        x_t = self.reset_i2r(x)
        hs_pre = self.reset_h2r(h)
        acti = nn.Sigmoid()
        r = acti(x_t + hs_pre)
        return r

    def update_gate(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        """
        x: size (batch_size, input_size)
        h: size (batch_size, hidden_size)
        z: size (batch_size, hidden_size)
        """
        x_t = self.update_i2z(x)
        hs_pre = self.update_h2z(h)
        acti = nn.Sigmoid()
        z = acti(x_t + hs_pre)
        return z

    def almost_output(
        self, x: torch.Tensor, h: torch.Tensor, r: torch.Tensor
    ) -> torch.Tensor:
        """
        x: size (batch_size, input_size)
        h: size (batch_size, hidden_size)
        r: size (batch_size, hidden_size)
        n: size (batch_size, hidden_size)
        """
        x_t = self.output_i2n(x)
        hs_pre = self.output_h2n(h)
        acti = nn.Tanh()
        n = acti(x_t + hs_pre)
        return n

    def forward(self, x: torch.Tensor, h: torch.Tensor = None) -> torch.Tensor:
        """
        Computes the forward propagation of the GRU cell
        Parameters
        --------
            input: torch.Tensor
                Input tensor of shape (batch_size, input_size).
            hs_pre: torch.Tensor
                Previous hidden state tensor of shape (batch_size, hidden_size)
                 Default is None, the initial hidden state is set to zeros.
        Returns
        -------
            hs: torch.Tensor
                Output hidden state tensor of shape (batch_size, hidden_size).
        """
        if h is None:
            h = torch.zeros(x.size(0), self.hidden_size)
        r = self.reset_gate(x, h)
        z = self.update_gate(x, h)
        n = self.almost_output(x, h, r)
        hs = (1 - z) * n + z * h
        return hs

    def init_parameters(self) -> None:
        """
        Initialize the weights and biases of the RNN cell
        followed by Xavier normalization
        Parameters
        --------
            None
        Returns
        -------
            None
        """
        for name, param in self.named_parameters():
            if "weight" in name:
                torch.nn.init.xavier_uniform_(param)
            if "bias" in name:
                param = param.view(1, param.size(0))
                torch.nn.init.xavier_uniform_(param)

In [12]:
class GRU(nn.Module):
    """
    Implements a multi-layer GRU model.
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        num_layers: int,
        bias: bool = True,
    ) -> None:
        """
        Initialize gated recurrent unit
        Parameters
        --------
          input_size: int
            The number of expected features in the input x
          hidden_size: int
            The number of features in the hidden state h
          num_layers: int
            The number of layers
          output_size: int
            The number of output features
          bias: bool
            Optional, if False, then the layer does not use bias weights
            b_ih and b_hh. Default: True
        Returns
        -------
            None

        """
        super(GRU, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.bias = bias
        self.fc = nn.Linear(self.hidden_size, self.output_size)
        self.init_cell_list()

    def forward(self, input: torch.Tensor, hs_pre: torch.Tensor = None) ->tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the GRU model.
        Parameters
        --------
          input: torch.Tensor
            The input tensor of shape (batch_size, sequence_length, input_size)
          hs_pre: torch.Tensor
            Previous hidden state tensor of shape (batch_size, hidden_size)
            Default is None, the initial hidden state is set to zeros.
        Returns: tuple[torch.Tensor, torch.Tensor]
            out: torch.Tensor
                Output tensor of shape (batch_size, output_size)
            hs_final: torch.Tensor
                Final hidden state of shape (num_layers, batch_size, hidden_size)
        """
        if hs_pre is None:
            hs_pre = torch.zeros(self.num_layers, input.size(0), self.hidden_size)
        output = []
        hidden_layers = list(hs_pre)
        for t in range(input.size(1)):
            for layer in range(self.num_layers):
                if layer == 0:
                    hidden = self.gru_cell_list[layer].forward(
                        input[:, t, :], hidden_layers[layer]
                    )
                else:
                    hidden = self.gru_cell_list[layer].forward(
                        hidden_layers[layer - 1], hidden_layers[layer]
                    )
                hidden_layers[layer] = hidden
            output.append(hidden)
        out = output[-1].squeeze()
        hs_final = torch.stack(hidden_layers, dim=0)
        return out, hs_final 

    def init_cell_list(self):
        """
        Initializes the GRU cell list based on the number of layers.
        """
        self.gru_cell_list = nn.ModuleList()
        self.gru_cell_list.append(GRUCell(self.input_size, self.hidden_size, self.bias))
        for _ in range(1, self.num_layers):
            self.gru_cell_list.append(
                GRUCell(self.hidden_size, self.hidden_size, self.bias)
            )

In [13]:
input_dim = 28  # input dimension
hidden_dim = 100  # hidden layer dimension
layer_dim = 1  # number of hidden layers
output_dim = 10
batch_size = 100
seq_len = 28
test = GRU(28, 100, 10, 1, True)
x = torch.rand( seq_len, batch_size, input_dim)

In [14]:
print(test.forward(x)[0].size(), test.forward(x)[0])

torch.Size([28, 100]) tensor([[-0.2668, -0.5939, -0.0206,  ..., -0.5574, -0.2525,  0.4538],
        [-0.1837, -0.5600,  0.0968,  ..., -0.5274, -0.4458,  0.3581],
        [ 0.0568, -0.4267,  0.1676,  ..., -0.5104, -0.4412,  0.5746],
        ...,
        [-0.2874, -0.3609,  0.2265,  ..., -0.4313, -0.4350,  0.4398],
        [ 0.2242, -0.4099,  0.2173,  ..., -0.5283, -0.4649,  0.4902],
        [-0.1249, -0.3885,  0.0440,  ..., -0.1153, -0.2531,  0.6546]],
       grad_fn=<SqueezeBackward0>)


In [15]:
print(test.forward(x)[1].size(), test.forward(x)[1])

torch.Size([1, 28, 100]) tensor([[[-0.2668, -0.5939, -0.0206,  ..., -0.5574, -0.2525,  0.4538],
         [-0.1837, -0.5600,  0.0968,  ..., -0.5274, -0.4458,  0.3581],
         [ 0.0568, -0.4267,  0.1676,  ..., -0.5104, -0.4412,  0.5746],
         ...,
         [-0.2874, -0.3609,  0.2265,  ..., -0.4313, -0.4350,  0.4398],
         [ 0.2242, -0.4099,  0.2173,  ..., -0.5283, -0.4649,  0.4902],
         [-0.1249, -0.3885,  0.0440,  ..., -0.1153, -0.2531,  0.6546]]],
       grad_fn=<StackBackward0>)
