In [1]:
import torch
import torch.nn as nn
import torch.nn.init

In [2]:
class LSTMCell(nn.Module):
    """
    Single cell of LSTM
    """

    def __init__(self, input_size: int, hidden_size: int, bias: bool = True):
        """
        Initialize long short term memory cell
        Parameters
        --------
          input_size: int
            The number of expected features in the input x
          hidden_size: int
            The number of features in the hidden state h
          bias: bool
            Optional, if False, the layer doesn't use bias weights b_ih and b_hh
            Default: True
        Returns
        -------
        None

        """
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias

        # input gate (i)
        self.input_x2i = nn.Linear(input_size, hidden_size, bias=bias)
        self.input_h2i = nn.Linear(hidden_size, hidden_size, bias=bias)

        # forgot gate (f)
        self.forgot_x2f = nn.Linear(input_size, hidden_size, bias=bias)
        self.forgot_h2f = nn.Linear(hidden_size, hidden_size, bias=bias)

        # cell vector (c)
        self.cell_x2c = nn.Linear(input_size, hidden_size, bias=bias)
        self.cell_h2c = nn.Linear(hidden_size, hidden_size, bias=bias)

        # almost output (o)
        self.output_x2o = nn.Linear(input_size, hidden_size, bias=bias)
        self.output_h20 = nn.Linear(hidden_size, hidden_size, bias=bias)

        self.init_parameters()

    def input_gate(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        """
        x: size (batch_size, input_size)
        h: size (batch_size, hidden_size)
        i: size (batch_size, hidden_size)
        """
        x_t = self.input_x2i(x)
        hs_pre = self.input_h2i(h)
        acti = nn.Sigmoid()
        i = acti(x_t + hs_pre)
        return i

    def forgot_gate(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        """
        x: size (batch_size, input_size)
        h: size (batch_size, hidden_size)
        f: size (batch_size, hidden_size)
        """
        x_t = self.forgot_x2f(x)
        hs_pre = self.forgot_h2f(h)
        acti = nn.Sigmoid()
        f = acti(x_t + hs_pre)
        return f

    def cell_vector(
        self,
        i: torch.Tensor,
        f: torch.Tensor,
        x: torch.Tensor,
        h: torch.Tensor,
        c_pre: torch.Tensor,
    ) -> torch.Tensor:
        """
        x: size (batch_size, input_size)
        h: size (batch_size, hidden_size)
        c: size (batch_size, hidden_size)
        """
        x_t = self.cell_x2c(x)
        hs_pre = self.cell_h2c(h)
        acti = nn.Tanh()
        c = f * c_pre + i * acti(x_t + hs_pre)
        return c

    def almost_output(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        """
        x: size (batch_size, input_size)
        h: size (batch_size, hidden_size)
        out: size (batch_size, hidden_size)
        """
        x_t = self.output_x2o(x)
        hs_pre = self.output_h20(h)
        acti = nn.Sigmoid()
        out = acti(x_t + hs_pre)
        return out

    def forward(
        self, x: torch.Tensor, h_n_c: tuple[torch.Tensor, torch.Tensor] = None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the forward propagation of the LSTM cell
        Parameters
        --------
            input: torch.Tensor
                Input tensor of shape (batch_size, input_size).
            h_n_c: tuple[torch.Tensor, torch.Tensor]
                Previous hidden state tensor of shape
                ((batch_size, hidden_size), (batch_size, hidden_size))
                Default is None, the initial hidden state is set to zeros.
        Returns
        -------
            hs: tuple[torch.Tensor, torch.Tensor]
                Output hidden state tensor of shape
                ((batch_size, hidden_size), (batch_size, hidden_size)).
        """
        if h_n_c is None:
            h_n_c = torch.zeros(x.size(0), self.hidden_size)
            h_n_c = (h_n_c, h_n_c)
        (hs_pre, c_pre) = h_n_c
        i = self.input_gate(x, hs_pre)
        f = self.forgot_gate(x, hs_pre)
        c = self.cell_vector(i, f, x, hs_pre, c_pre)
        o = self.almost_output(x, hs_pre)
        acti = nn.Tanh()
        hs = o * acti(c)
        return (hs, c)

    def init_parameters(self) -> None:
        """
        Initialize the weights and biases of the LSTM cell
        followed by Xavier normalization
        Parameters
        --------
            None
        Returns
        -------
            None
        """
        for name, param in self.named_parameters():
            if "weight" in name:
                torch.nn.init.xavier_uniform_(param)
            if "bias" in name:
                param = param.view(1, param.size(0))
                torch.nn.init.xavier_uniform_(param)

In [11]:
class LSTM(nn.Module):
    """
    Implements a multi-layer LSTM model.
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        num_layers: int,
        bias: bool = True,
    ) -> None:
        """
        Initialize long short term memory
        Parameters
        --------
          input_size: int
            The number of expected features in the input x
          hidden_size: int
            The number of features in the hidden state h
          bias: bool
            Optional, if False, the layer doesn't use bias weights b_ih and b_hh.
            Default: True
        Returns
        -------
          None
        """
        super(LSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.output_size = output_size
        self.fc = nn.Linear(self.hidden_size, self.output_size)
        self.init_cell_list()

    def forward(
        self, input: torch.Tensor, hs_pre: tuple[torch.Tensor, torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Forward pass of the LSTM model.
        Parameters
        --------
          input: torch.Tensor
            The input tensor of shape (batch_size, sequence_length, input_size)
          hs_pre: tuple[torch.Tensor, torch.Tensor]
            Previous hidden state tensor of shape
            ((batch_size, hidden_size), (batch_size, hidden_size))
            Default is None, the initial hidden state is set to zeros.
        Returns 
        --------
          output: torch.Tensor
            Output hidden state tensor of shape (batch_size, hidden_size)
        """
        if hs_pre is None:
            hs_pre = torch.zeros(self.num_layers, input.size(0), self.hidden_size)
        output = []
        hidden_layers = list(hs_pre) + list(hs_pre)
        for t in range(input.size(1)):
            for layer in range(self.num_layers):
                if layer == 0:
                    hidden = self.lstm_cell_list[layer].forward(
                        input[:, t, :],
                        (hidden_layers[layer][0], hidden_layers[layer][1]),
                    )
                else:
                    hidden = self.lstm_cell_list[layer].forward(
                        hidden_layers[layer - 1][0],
                        (hidden_layers[layer][0], hidden_layers[layer][1]),
                    )
                hidden_layers[layer] = hidden
            output.append(hidden[0])
        out = output[-1].squeeze()
        out = self.fc(out)
        return out
    

    def init_cell_list(self) -> None:
        """
        Initializes the LSTM cell list based on the number of layers.
        """
        self.lstm_cell_list = nn.ModuleList()
        self.lstm_cell_list.append(
            LSTMCell(self.input_size, self.hidden_size, self.bias)
        )
        for _ in range(1, self.num_layers):
            self.lstm_cell_list.append(
                LSTMCell(self.hidden_size, self.hidden_size, self.bias)
            )

In [12]:
input_dim = 28  # input dimension
hidden_dim = 100  # hidden layer dimension
layer_dim = 1  # number of hidden layers
output_dim = 10
batch_size = 100
seq_len = 28
test = LSTM(28, 100, 10, 1, True)
x = torch.rand(seq_len, batch_size, input_dim)

In [14]:
print(test.forward(x).size(), test.forward(x))

torch.Size([28, 10]) tensor([[-0.0553, -0.1602, -0.0626,  0.1215, -0.0583, -0.0262, -0.0771,  0.0806,
          0.0807, -0.1749],
        [-0.0426, -0.1778, -0.0602,  0.1115, -0.0759, -0.0095, -0.0238,  0.0848,
         -0.0086, -0.1316],
        [-0.0715, -0.1779, -0.0900,  0.1281, -0.0933, -0.0164, -0.0571,  0.0880,
         -0.0069, -0.1538],
        [-0.0678, -0.2332, -0.1064,  0.0642, -0.0592, -0.0532, -0.0900,  0.0213,
          0.0957, -0.1392],
        [-0.0528, -0.1844, -0.0919,  0.0895, -0.0864, -0.0220, -0.0283,  0.0564,
          0.0245, -0.1547],
        [-0.0653, -0.1616, -0.0835,  0.1054, -0.0535, -0.0735, -0.0390,  0.0123,
         -0.0159, -0.0699],
        [ 0.0066, -0.1381, -0.0798,  0.0270, -0.0076, -0.0236, -0.1256,  0.0473,
          0.0929, -0.1238],
        [ 0.0110, -0.2156, -0.0829,  0.0544, -0.0336, -0.0467, -0.0835,  0.0675,
          0.0370, -0.0876],
        [-0.0459, -0.2389, -0.0461,  0.0885, -0.1012, -0.0276, -0.0891, -0.0149,
          0.0454, -0.1350]