# 31、PyTorch GRU的原理及其手写复现

In [1]:
import torch
import torch.nn as nn

$$
\begin{align*}
r_t &= \sigma(W_{ir}x_t + b_{ir} + W_{hr}h_{(t-1)} + b_{hr}) \\
z_t &= \sigma(W_{iz}x_t + b_{iz} + W_{hz}h_{(t-1)} + b_{hz}) \\
n_t &= \tanh(W_{in}x_t + b_{in} + r_t \odot (W_{hn}h_{(t-1)} + b_{hn})) \\
h_t &= (1 - z_t) \odot n_t + z_t \odot h_{(t-1)}
\end{align*}
$$

In [2]:
from typing import Tuple
from torch import Tensor


def gru_forward(
    input: Tensor,
    init: Tensor,
    w_ih: Tensor,
    w_hh: Tensor,
    b_ih: Tensor,
    b_hh: Tensor,
):
    prev_h = init
    bs, T, i_size = input.shape
    h_size = w_ih.shape[0] // 3

    # reset, update, new
    w_ir, w_iz, w_in = w_ih.chunk(3, 0)
    w_hr, w_hz, w_hn = w_hh.chunk(3, 0)

    b_ir, b_iz, b_in = b_ih.chunk(3, 0)
    b_hr, b_hz, b_hn = b_hh.chunk(3, 0)

    output = torch.zeros(bs, T, h_size)

    for t in range(T):
        x = input[:, t, :]
        r_t = torch.sigmoid((x @ w_ir.T + b_ir) + (prev_h @ w_hr.T + b_hr))
        z_t = torch.sigmoid((x @ w_iz.T + b_iz) + (prev_h @ w_hz.T + b_hz))
        n_t = torch.tanh((x @ w_in.T + b_in) + r_t * (prev_h @ w_hn.T + b_hn))
        h_t = (1 - z_t) * n_t + z_t * prev_h

        prev_h = h_t
        output[:, t, :] = h_t

    return output, prev_h

In [3]:
batch_size, seq_len = 2, 3
input_size, hidden_size = 3, 4
input = torch.randn((batch_size, seq_len, input_size))
h0 = torch.zeros(batch_size, hidden_size)
gru_layer = nn.GRU(input_size, hidden_size, batch_first=True)
out1, h1 = gru_layer.forward(input)
for k, v in gru_layer.named_parameters():
    print(f"{k}\t{v.shape}")
print("---------")
out2, h2 = gru_forward(
    input,
    h0,
    gru_layer.weight_ih_l0,
    gru_layer.weight_hh_l0,
    gru_layer.bias_ih_l0,
    gru_layer.bias_hh_l0
)
assert torch.allclose(out1, out2)
assert torch.allclose(h1, h2)

weight_ih_l0	torch.Size([12, 3])
weight_hh_l0	torch.Size([12, 4])
bias_ih_l0	torch.Size([12])
bias_hh_l0	torch.Size([12])
---------
