In [None]:
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_

import sys

sys.path.append("../..")

In [2]:
from d2l import d2l_cn as d2l

#### Gated Recurrent Unit(GRU)

In [3]:
import numpy as np
import torch
from torch import nn as nn
from  torch import optim as optim
import torch.nn.functional as F

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()

In [6]:
num_inouts = vocab_size
num_hiddens = 256
num_outputs = vocab_size #vocab_size是词典大小，即one_hot向量长度

In [7]:
def get_params():
    def _one(shape):
        ts =torch.tensor(np.random.normal(0,0.01, size = shape),device = device, dtype=torch.float32)
        return torch.nn.Parameter(ts, requires_grad=True)
    
    def _three():
        return (_one((num_inouts, num_hiddens)),
                _one((num_hiddens, num_hiddens)),
                nn.Parameter(torch.zeros(num_hiddens,device=device,dtype=torch.float32),requires_grad = True)
               )

    W_xz, W_hz, b_z = _three()#更新门
    W_xr, W_hr, b_r = _three()#重置门
    W_xh, W_hh, b_h = _three() # 候选隐藏状态的计算参数
    
    W_hq = _one((num_hiddens,num_outputs))#输出层
    b_q = torch.nn.Parameter(torch.zeros(num_outputs,device=device,dtype=torch.float32), requires_grad = True)
    
    return nn.ParameterList([W_xz,W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q])
    

#### 定义GRU模型

In [8]:
def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens),device=device),)

In [9]:
def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    
    for X in inputs:
        Z = torch.sigmoid(torch.matmul(X,W_xz) + torch.matmul(H,W_hz) + b_z)
        R = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H,W_hr) + b_r)
        H_tilde = torch.tanh(torch.matmul(X,W_xh) + torch.matmul(R*H,W_hh) + b_h)#候选隐藏状态
        H = Z*H + (1-Z) * H_tilde
        Y = torch.matmul(H,W_hq) + b_q
        outputs.append(Y)
    return outputs, (H,)

In [10]:
num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2 #这里学习率100有原因的
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']

In [11]:
d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens,
                          vocab_size, device, corpus_indices, idx_to_char,
                          char_to_idx, False, num_epochs, num_steps, lr,
                          clipping_theta, batch_size, pred_period, pred_len,
                          prefixes)

0
1
2
3
4
5
6
7
8
9
10
11
12
13


KeyboardInterrupt: 

* 门控循环神经网络可以更好地捕捉时间序列中时间步距离较大的依赖关系。
* 门控循环单元引入了门的概念，从而修改了循环神经网络中隐藏状态的计算方式。它包括重置门、更新门、候选隐藏状态和隐藏状态。
* 重置门有助于捕捉时间序列里短期的依赖关系。
* 更新门有助于捕捉时间序列里长期的依赖关系。
