[Long Short-Term Memory](https://dl.acm.org/doi/10.1162/neco.1997.9.8.1735)


[1997-LSTM.pdf](../papers/1997-LSTM.pdf)

![LSTM示意图](http://assets.hypervoid.top/img/2025/06/30/image-20250630170007753-a624.png)

$$
\begin{align*}
i_t &= \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1} + b_{hi}) \\
f_t &= \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf}) \\
g_t &= \tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg}) \\
o_t &= \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho}) \\
c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
h_t &= o_t \odot \tanh(c_t)
\end{align*}
$$

In [91]:
import torch
import torch.nn as nn

torch.manual_seed(20250630)
torch.cuda.manual_seed(20250630)

In [92]:
batch_size, seq_len = 2, 3
# 输入大小，隐含层大小
input_size, hidden_size = 4, 5
# 输入
input = torch.randn(batch_size, seq_len, input_size)
# 初始值
c0, h0 = torch.randn(batch_size, hidden_size), torch.randn(batch_size, hidden_size)

In [93]:
# Pytorch LSTM
torch_lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
out_torch, (h_torch, c_torch) = torch_lstm.forward(
    input,
    (h0.unsqueeze(0), c0.unsqueeze(0)),
)
print(out_torch)
print(h_torch)
print(c_torch)
print("---Parameters---")
for k, v in torch_lstm.named_parameters():
    print(f"{k}:\t{v.shape}")
# hidden_size=5, ifgo 一共四个w拼起来 就是 20
# weight_ih_l0:	torch.Size([20, 4]) # 
# weight_hh_l0:	torch.Size([20, 5]) # 
# bias_ih_l0:	torch.Size([20])    # 
# bias_hh_l0:	torch.Size([20])    # 

tensor([[[ 0.4276,  0.2803,  0.0205, -0.0904, -0.0928],
         [ 0.3246,  0.0375,  0.1131, -0.0302, -0.4382],
         [ 0.2796, -0.2374,  0.1253, -0.0093, -0.2690]],

        [[ 0.3898, -0.1509,  0.0402,  0.0404,  0.3354],
         [ 0.2717, -0.2599,  0.1875, -0.0164,  0.3097],
         [ 0.2790, -0.4573,  0.1867,  0.0285,  0.3013]]],
       grad_fn=<TransposeBackward0>)
tensor([[[ 0.2796, -0.2374,  0.1253, -0.0093, -0.2690],
         [ 0.2790, -0.4573,  0.1867,  0.0285,  0.3013]]],
       grad_fn=<StackBackward0>)
tensor([[[ 0.6250, -0.4408,  0.2818, -0.0264, -0.4600],
         [ 0.7244, -0.9398,  0.4891,  0.1022,  0.4883]]],
       grad_fn=<StackBackward0>)
---Parameters---
weight_ih_l0:	torch.Size([20, 4])
weight_hh_l0:	torch.Size([20, 5])
bias_ih_l0:	torch.Size([20])
bias_hh_l0:	torch.Size([20])


In [None]:
from typing import Tuple
from torch import Tensor


class LstmCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LstmCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.w_i = nn.Parameter(torch.randn(4 * hidden_size, input_size))
        self.w_h = nn.Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.b_i = nn.Parameter(torch.randn(4 * hidden_size))
        self.b_h = nn.Parameter(torch.randn(4 * hidden_size))
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.w_i)
        nn.init.xavier_uniform_(self.w_h)
        nn.init.zeros_(self.b_i)
        nn.init.zeros_(self.b_h)

    def forward(
        self, x_t: Tensor, state: tuple[Tensor, Tensor]
    ) -> tuple[Tensor, Tensor]:
        h_prev, c_prev = state
        gates = ((x_t @ self.w_i.T) + (h_prev @ self.w_h.T) + self.b_i + self.b_h)  # [1, 2, 20]
        gates.squeeze_(0)
        # print(gates.shape)
        input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)

        # 分别计算 输入门(i)、遗忘门(f)、cell门(f)、输出门(o)
        i_t = torch.sigmoid(input_gate)
        f_t = torch.sigmoid(forget_gate)
        g_t = torch.tanh(cell_gate)
        o_t = torch.sigmoid(output_gate)
        c_next = f_t * c_prev + i_t * g_t
        h_next = o_t * torch.tanh(c_next)
        return h_next, c_next


class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm_cell = LstmCell(input_size, hidden_size)
        self.leaky_relu = nn.LeakyReLU()

    def forward(self, x: Tensor, init: Tuple[Tensor, Tensor] | None = None):
        batch_size_, seq_len_, _ = x.shape
        h_out = torch.zeros((batch_size_, seq_len_, self.hidden_size))
        if init is None:
            h_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)
            c_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h_prev, c_prev = init

        for i in range(seq_len):
            x_t = x[:, i, :]
            h_prev, c_prev = self.lstm_cell(x_t, (h_prev, c_prev))
            h_out[:, i, :] = h_prev
        return h_out, (h_prev, c_prev)

$$
\begin{align*}
i_t &= \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1} + b_{hi}) \\
f_t &= \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf}) \\
g_t &= \tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg}) \\
o_t &= \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho}) \\
c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
h_t &= o_t \odot \tanh(c_t)
\end{align*}
$$

In [95]:

# Pytorch LSTM
my_lstm = LSTM(input_size, hidden_size)
my_lstm.lstm_cell.w_i = torch_lstm.weight_ih_l0
my_lstm.lstm_cell.w_h = torch_lstm.weight_hh_l0
my_lstm.lstm_cell.b_i = torch_lstm.bias_ih_l0
my_lstm.lstm_cell.b_h = torch_lstm.bias_hh_l0
out_, (h_, c_) = my_lstm.forward(
    input,
    (h0.unsqueeze(0), c0.unsqueeze(0)),
)



print(out_)
print(h_)
print(c_)
print("---Parameters---")
for k, v in my_lstm.named_parameters():
    print(f"{k}:\t{v.shape}")
# tensor([[[-0.0310, -0.0355,  0.2224,  0.2715, -0.2373],
#          [-0.0458, -0.0054,  0.2440,  0.0589, -0.5580],
#          [ 0.0597, -0.2669,  0.1833,  0.0118, -0.3843]],

#         [[ 0.1301, -0.0079,  0.1585, -0.2035, -0.1400],
#          [ 0.1941, -0.2171,  0.1993, -0.0618, -0.0919],
#          [ 0.2392, -0.4258,  0.1861,  0.0149,  0.1375]]],
#        grad_fn=<TransposeBackward0>)
# tensor([[[ 0.0597, -0.2669,  0.1833,  0.0118, -0.3843],
#          [ 0.2392, -0.4258,  0.1861,  0.0149,  0.1375]]],
#        grad_fn=<StackBackward0>)
# tensor([[[ 0.1169, -0.5324,  0.4434,  0.0319, -0.6892],
#          [ 0.5470, -0.8999,  0.4868,  0.0556,  0.2197]]],

tensor([[[ 0.4276,  0.2803,  0.0205, -0.0904, -0.0928],
         [ 0.3246,  0.0375,  0.1131, -0.0302, -0.4382],
         [ 0.2796, -0.2374,  0.1253, -0.0093, -0.2690]],

        [[ 0.3898, -0.1509,  0.0402,  0.0404,  0.3354],
         [ 0.2717, -0.2599,  0.1875, -0.0164,  0.3097],
         [ 0.2790, -0.4573,  0.1867,  0.0285,  0.3013]]], grad_fn=<CopySlices>)
tensor([[[ 0.2796, -0.2374,  0.1253, -0.0093, -0.2690],
         [ 0.2790, -0.4573,  0.1867,  0.0285,  0.3013]]],
       grad_fn=<MulBackward0>)
tensor([[[ 0.6250, -0.4408,  0.2818, -0.0264, -0.4600],
         [ 0.7244, -0.9398,  0.4891,  0.1022,  0.4883]]],
       grad_fn=<AddBackward0>)
---Parameters---
lstm_cell.w_i:	torch.Size([20, 4])
lstm_cell.w_h:	torch.Size([20, 5])
lstm_cell.b_i:	torch.Size([20])
lstm_cell.b_h:	torch.Size([20])


In [96]:
assert c_.allclose(c_torch)
assert h_.allclose(h_torch)
assert out_.allclose(out_torch)