In [2]:
import torch
from torch import nn
from d2l import torch as d2l

In [3]:
batch_size, num_steps = 32, 35
train_iter, vocab= d2l.load_data_time_machine(batch_size, num_steps)

### 初始化模型参数

In [4]:
def get_params(vocab_size: int, num_hiddens: int, device: any):
    num_inputs = num_outputs = vocab_size # 输入和输出都是词元，词元又表示成下标张量，即都为 vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01
    
    def three():
        return (
            normal((num_inputs, num_hiddens)),
            normal((num_hiddens, num_hiddens)),
            torch.zeros(num_hiddens, device=device)
        )
    
    W_xz, W_hz, b_z = three() # 更新门
    W_xr, W_hr, b_r = three() # 重置门
    W_xh, W_hh, b_h = three() # 隐状态

    W_hq = normal((num_hiddens, num_outputs)) # 输出层
    b_q = torch.zeros(num_outputs, device=device)

    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]

    # 附加梯度
    for param in params:
        param.requires_grad_(True)
    return params

### 初始隐状态

In [5]:
def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )

### 定义 forward 函数，和 rnn 相似，只是公式更多

- $R_t = \sigma(X_tW_{xr} + H_{t-1}W_{hr} + b_r)$
- $Z_t = \sigma(X_tW_{xz} + H_{t-1}W_{hz} + b_z)$
- $\tilde{H}_t = \text{tanh} (X_tW{xh} + (R_t \odot H_{t-1})W_{hh} + b_h)  $
- $H_t = Z_t \odot H_{t-1} + (1 - Z_t) \odot \tilde {H}_t$


In [6]:
def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = H @ W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

In [None]:
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
                            init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

### 简洁实现

In [None]:
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

### 总结
- GRU 的实现，实际上就是把 forward 函数改成 GRU 的，就是改公式
- 然后修改 GRU 的 Module，其他正常