In [1]:
import torch
import torch.nn as nn

Long Short-Term Memory (LSTM) networks are a type of recurrent neural network (RNN) specifically designed to capture long-term dependencies in sequential data. Unlike standard RNNs, which struggle with vanishing or exploding gradients when learning long-range dependencies, LSTMs use a set of gates—input, forget, and output gates—to control the flow of information. These gates regulate what information is retained, forgotten, or passed to the next step, allowing LSTMs to effectively remember information over long sequences. Of course, this comes with a price. These gates are dynamical, learnt during the training from data. LSTMs have 3 times more parameters than RNNs.

![something](https://d2l.ai/_images/lstm-0.svg)

In [2]:
model = nn.LSTM(1,3,bias=False)
A = list(model.parameters())
for i in A:
    print(i.shape)

torch.Size([12, 1])
torch.Size([12, 3])


In [3]:
X = torch.randn(100,1)
H = torch.zeros(1,3)
C = torch.zeros(1,3)
alpha, (beta,gamma) = model(X, (H,C))
alpha.shape, beta.shape, gamma.shape

(torch.Size([100, 3]), torch.Size([1, 3]), torch.Size([1, 3]))

In [4]:
def LSTM_scratch(X, H, C, W_1, W_2,hid=3):
    W_1 = [W_1[hid*i:hid*(i+1)].transpose(0,1) for i in range(4)]
    W_2 = [W_2[hid*i:hid*(i+1)].transpose(0,1) for i in range(4)]
        
    out_scratch = []
    for i in range(0,len(X)):
        I = torch.sigmoid(torch.matmul(X[i],W_1[0])+torch.matmul(H,W_2[0]))
        F = torch.sigmoid(torch.matmul(X[i],W_1[1])+torch.matmul(H,W_2[1]))
        C_tilde = torch.tanh(torch.matmul(X[i],W_1[2])+torch.matmul(H,W_2[2]))
        O = torch.sigmoid(torch.matmul(X[i],W_1[3])+torch.matmul(H,W_2[3]))
        C = F*C+I*C_tilde
        H = O*torch.tanh(C)
        out_scratch.append(H)
    out_scratch = torch.stack(out_scratch).reshape(len(X),hid)
    return out_scratch, (H,C)

In [5]:
a, (b,c) = LSTM_scratch(X,H,C,A[0],A[1],3)
a.shape, b.shape, c.shape

(torch.Size([100, 3]), torch.Size([1, 3]), torch.Size([1, 3]))

In [6]:
torch.allclose(a,alpha), torch.allclose(b, beta), torch.allclose(c,gamma)

(True, True, True)

**Remark:**
$H$ is $(1,hid)$-dimensional matrix and $W_2$ is $(hid\times hid)$ dimensional square matrix. Contrary to the [documantation of LSTM](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html), we cannot multiply $W_2$ by H directly. The multiplication should be either $W_2\cdot H^T$ or $H\cdot W_2^T$. [A pull-request](https://github.com/pytorch/pytorch/pull/138191) has been submitted to address this issue in the documentation.