In this notebook, we will implement the forward method of an LSTM from scratch and then doublecheck against the implementation provided by PyTorch to verify our code. As a starting point, here are the standard formulas for an LSTM cell,.

$$
f_t = \sigma(W_{if} x_t  + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t  + b_{ig} +  W_{hg} h_{t-1} + b_{hg})  \\
i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi})    \\
o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho})   \\
c_t = f_t \\odot c_{t-1} + i_t \odot g_t  \\
h_t = o_t \odot \tanh(c_t)
$$

Thus to determine the values of the various gates, we multiply the input $x_t$ with the matrices $W_{ii}, W_{if}, W_{ig}$ and $W_{io}$. Instead of doing this in four separate operations, we can also combine the four matrices into one large matrix of dimension $4 H \times E$, where $H$ is the dimension of the hidden layer and the cells and $E$ is the dimension of the input, and then carry out one large multiplication. The same holds for the bias and the matrices operating on the hidden state. Let us combine all matrices operating on the input into one matrix $W_{ih}$ and all matrices operating on the hidden state into one matrix $W_{hh}$.

Which part of this matrix corresponds to which weight matrix is convention, we will use the approach that PyTorch uses as well under the hood (see  [here](https://github.com/pytorch/pytorch/blob/4130e4f2848ac83baac38dc89d3b95630f39ce7f/torch/nn/modules/rnn.py#L664)). 

The output that we return will again be the full hidden layer values of shape (L, H) as well as the last value of the hidden layer and the last value of the memory cell, combined into one tuple. We will also allow to pass in existing values for hidden layer and memory cell.

In [1]:
import torch

In [2]:
def forward(x, previous_state = None):    
    L = x.shape[0]
    if previous_state is None:
      hidden = torch.zeros(H)
      cells = torch.zeros(H)
    else:
      hidden, cells = previous_state
      hidden = hidden.squeeze(dim = 0)
      cells = cells.squeeze(dim = 0)
    _hidden = []
    _cells = []
    for i in range(L):
        _x = x[i]
        #
        # multiply w_ih and w_hh by x and h and add biases
        # 
        A = w_ih @ _x 
        A = A + w_hh @ hidden 
        A = A + b_ih + b_hh
        #
        # The value of the forget gate is obtained by taking the second set of H rows of the result
        # and applying the sigmoid function
        #
        ft = torch.sigmoid(A[H:2*H])
        #
        # Similary the input gate is the first block, the candidate cell the third block and the output gate
        # the last block
        #
        it = torch.sigmoid(A[0:H])
        gt = torch.tanh(A[2*H:3*H])
        ot = torch.sigmoid(A[3*H:4*H])
        #
        # New value of cell --> apply forget gate and add input gate times candidate cell
        #
        cells = ft * cells + it * gt
        #
        # new value of hidden layer is output gate times cell value
        #
        hidden = ot * torch.tanh(cells)
        _cells.append(cells)
        _hidden.append(hidden)
    return torch.stack(_hidden), (hidden.unsqueeze(dim = 0), cells.unsqueeze(dim = 0))

Let us now do a test drive. For that purpose, we will create a PyTorch LSTM, extract the weights from there (I have chosen the way how the weights are modelled in our forward function in alignment with the PyTorch conventions so that this step is easy), run our forward function and the PyTorch network and check that the results are the same.

In [3]:
E= 5
H = 3
L = 3
#
# Create PyTorch LSTM and extract weights
#
torchLSTM = torch.nn.LSTM(input_size = E, hidden_size = H)
w_ih = torchLSTM.weight_ih_l0
w_hh = torchLSTM.weight_hh_l0
b_ih = torchLSTM.bias_ih_l0
b_hh = torchLSTM.bias_hh_l0
assert w_ih.shape == (4*H, E), "Shape of w_ih not as expected"
assert w_hh.shape == (4*H, H), "Shape of w_hh not as expected"
#
# Create random input of dimensions L x E and feed it into
# both networks
#
x = torch.rand(L, E)
_out, (_h, _c) = torchLSTM(x)
out, (h, c) = forward(x)
#
# Output will be of shape (L, H)
#
assert out.shape == (L, H), "Shape of output not correct"
assert h.shape == (1, H), "Shape of h not correct"
assert c.shape == (1, H), "Shape of memory cell not correct"
#
# Make sure that outputs match
#
assert torch.allclose(_out, out), "Outputs do not match"
assert torch.allclose(_h, h), "Hidden layers do not match"
assert torch.allclose(_c, c), "Cells do not match"

Next let us try out the same with a previously obtained hidden and cell state.

In [4]:
x = torch.randn(1, E)
_out, (_h, _c) = torchLSTM(x, (_h, _c))
out, (h, c) = forward(x, (h, c))
#
# Output will be of shape (L, H)
#
assert out.shape == (1, H), "Shape of output not correct"
assert h.shape == (1, H), "Shape of h not correct"
assert c.shape == (1, H), "Shape of memory cell not correct"
#
# Make sure that outputs match
#
assert torch.allclose(_out, out), "Outputs do not match"
assert torch.allclose(_h, h), "Hidden layers do not match"
assert torch.allclose(_c, c), "Cells do not match"