In [1]:
%config ZMQInteractiveShell.ast_node_interactivity = "all"
%pprint

Pretty printing has been turned OFF


In [2]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [3]:
import sys
import torch
import torch.nn as nn
sys.path.append("../d2l_func/")
from data_prepare import load_data_jay_song, data_iter_random, data_iter_consecutive, to_onehot
from model_train import train_rnn, train_rnn_pytorch
from set_seed import set_seed

## GRU

当时间步过小或者过大时，RNN往往会出现梯度爆炸或者梯度消失的问题
- 梯度爆炸：可以通过梯度剪裁来解决
- 梯度消失：可以通过GRU/LSTM等来进行缓解

GRU包含重置门和更新门
- 重置门用于捕捉短期依赖关系
- 更新门用于捕捉长期依赖关系
- 重置门：$r_t = \delta(x_tw_{xr} + h_{t-1}w_{hr} + b_r)$
- 更新门：$z_t = \delta(x_tw_{xz} + h_{t-1}w_{hz} + b_z)$
- 候选隐藏层状态：$\widetilde h_t = tanh(x_tw_{xh} + r_t*(h_{t-1}w_hh) + b_h)$
- 当前时间步的隐藏层：$h_t = z*h_{t-1} + (1-z)\widetilde h_t$
- 输出层：$y_t = h_tw_{hy} + b_y$

### 自定义实现

#### 定义网络参数

In [4]:
def get_params(input_num, hidden_num, output_num, device):
    def _ones(shape):
        weight = nn.Parameter(torch.normal(0, 0.01, size=shape, device=device), requires_grad=True)
        return weight
    
    def _zeros(shape):
        bias = nn.Parameter(torch.zeros(shape, device=device), requires_grad=True)
        return bias
    
    def _three():
        return (
            _ones((input_num, hidden_num)),
            _ones((hidden_num, hidden_num)),
            _zeros(hidden_num),
        )
    
    w_xr, w_hr, b_r = _three()
    w_xz, w_hz, b_z = _three()
    w_xh, w_hh, b_h = _three()
    w_hy = _ones((hidden_num, output_num))
    b_y = _zeros((output_num))
    return nn.ParameterList([w_xr, w_hr, b_r, w_xz, w_hz, b_z, w_xh, w_hh, b_h, w_hy, b_y])

#### 定义gru层结构

In [5]:
from functools import reduce

def init_hidden_state(batch_size, hidden_num, device):
    return torch.zeros(batch_size, hidden_num, device=device)


def gru(inputs, h_state, params):
    w_xr, w_hr, b_r, w_xz, w_hz, b_z, w_xh, w_hh, b_h, w_hy, b_y = params
    outputs = []
    
    # inputs.shape is (num_step, batch_size, vocab_size)
    for x in inputs:
        rt = torch.sigmoid(torch.mm(x, w_xr) + torch.mm(h_state, w_hr) + b_r)
        zt = torch.sigmoid(torch.mm(x, w_xz) + torch.mm(h_state, w_hz) + b_z)
        h_candidate = torch.tanh(torch.mm(x, w_xh) + rt*torch.mm(h_state, w_hh) + b_h)
        h_state = zt*h_state + (1-zt)*h_candidate
        y = torch.mm(h_state, w_hy) + b_y
        outputs.append(y.unsqueeze(0))
        
    return reduce(lambda x, y: torch.cat((x, y)), outputs), h_state

In [6]:
# 验证
vocab_size, hidden_num = 15, 10
x = torch.arange(10).view(2, 5)
inputs = to_onehot(x, vocab_size, "cuda")
h_state = init_hidden_state(x.shape[0], hidden_num, "cuda")
params = get_params(vocab_size, hidden_num, vocab_size, "cuda")
outputs, h_state = gru(inputs, h_state, params)
outputs.shape, h_state.shape

(torch.Size([5, 2, 15]), torch.Size([2, 10]))

#### 字符级别预测

In [6]:
def predict_rnn(prefix, pred_num, model, init_hidden_state, hidden_num, 
                params, char_to_idx, vocab_set, vocab_size, device):
    outputs = [char_to_idx[prefix[0]]]
    h_state = init_hidden_state(1, hidden_num, device)
    
    for i in range(len(prefix) + pred_num - 1):
        # inputs.shape is (1, 1, vocab_size)
        inputs = to_onehot(torch.tensor(outputs[-1]).view(-1, 1), vocab_size, device)
        # y.shape is (1, 1, vocab_size), h_state.shape is (1, hidden_num)
        y, h_state = model(inputs, h_state, params)
        
        if i + 1 < len(prefix):
            outputs.append(char_to_idx[prefix[i+1]])
        else:
            outputs.append(y.argmax(dim=2).item())
            
    return "".join(vocab_set[i] for i in outputs)

In [8]:
# 验证
# load data
corpus_index, char_to_idx, vocab_set, vocab_size = load_data_jay_song()
# params
hidden_num, device = 256, "cuda"
params = get_params(vocab_size, hidden_num, vocab_size, device)
predict_rnn("分开", 10, gru, init_hidden_state, hidden_num, params, char_to_idx, vocab_set, vocab_size, device)

'分开鐘鐘真脚脚典盖起现现'

#### 梯度剪裁

In [7]:
def grad_clipping(params, clipping_theta, device):
    # l2 norm
    norm = torch.zeros(1, device=device)
    # cumsum all grad data
    for param in params:
        norm += (param.grad.data ** 2).sum()
        
    norm = norm.sqrt()
    
    # grad explode
    if norm > clipping_theta:
        for param in params:
            param.grad.data *= (clipping_theta / norm) 

#### 训练

In [8]:
# training
import numpy as np
from sqdm import sqdm
from optim import sgd

def train_rnn(epoch_num, batch_num, model, loss, get_params, init_hidden_state, hidden_num, batch_size,
              lr, data_iter, prefixs, pred_num, corpus_index, char_to_idx, vocab_set, vocab_size, 
              num_step, predict_rnn, clipping_theta=1e-2, random_sample=True, device="cuda"):
    
    # init(use in calculate perplexity)
    l_sum, n_class = 0., 0.
    # training bar
    process_bar = sqdm()
    # init params
    params = get_params(vocab_size, hidden_num, vocab_size, device)
    
    for epoch in range(epoch_num):
        print(f"Epoch [{epoch+1}/{epoch_num}]")
        # sample in consecutive
        if not random_sample:
            h_state = init_hidden_state(batch_size, hidden_num, device)
        for x, y in data_iter(corpus_index, batch_size, vocab_size, device):
            # 原始x的shape为(batch_size, num_step)，onehot后的shape为(num_step, batch_size, vocab_size)
            x = to_onehot(x, vocab_size, device)
            if random_sample:
                h_state = init_hidden_state(batch_size, hidden_num, device)
            else:
                # 脱离计算图，使得上一时刻的隐藏状态变成叶子节点，防止在销毁计算图后（隐藏节点还存在），因反向传播到更早的
                # 隐藏层时刻（不在当前计算图内）而出错
                h_state.detach_()
                
            # model
            # outputs.shape is (num_step, batch_size, vocab_size), h_state.shape is (batch_size, hidden_num)
            outputs, h_state = model(x, h_state, params)
            # change output.shape --> (num_step, batch_size, vocab_size), 主要是为了方便计算loss
            outputs = outputs.view(-1, outputs.shape[-1])
            # 原始y的shape为(batch_size, num_step), ---> (num_step, batch_size) ---> 1维向量
            # 转置后变成内存不连续，使用contiguous变成连续的向量
            y = y.transpose(0, 1).contiguous().view(-1)
            
            # 计算loss, 标签需要是int
            l = loss(outputs, y.long())
            # grad clear
            if params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            # grad backward
            l.backward()
            # grad clipping
            grad_clipping(params, clipping_theta, device)
            # update grad
            sgd(params, lr)
            
            # calculate l_sum
            l_sum += l.item() * y.shape[0]
            n_class += y.shape[0]
            
            # calculate perplexity
            try:
                perplexity = np.exp(l_sum / n_class)
            except OverflowError:
                perplexity = float("inf")
                
            # training bar
            process_bar.show_process(batch_num, 1, perplexity)
        print("\n")
        # predict
        for prefix in prefixs:
            print(f"prefix-{prefix}: ", predict_rnn(prefix, pred_num, model, init_hidden_state, hidden_num, 
                                                    params, char_to_idx, vocab_set, vocab_size, device))
        print("\n")

In [9]:
# load data
corpus_index, char_to_idx, vocab_set, vocab_size = load_data_jay_song()

super_params = {
        "epoch_num": 5,
        "model": gru,
        "loss": nn.CrossEntropyLoss(),
        "init_hidden_state": init_hidden_state,
        "hidden_num": 256,
        "get_params": get_params,
        "batch_size": 2,
        "num_step": 32,
        "corpus_index": corpus_index,
        "data_iter": data_iter_random,
        "lr": 100,
        "char_to_idx": char_to_idx,
        "vocab_set": vocab_set,
        "vocab_size": vocab_size,
        "predict_rnn": predict_rnn,
        "pred_num": 50,
        "prefixs": ["分开", "不分开"],
        #     "random_sample": False
    }

super_params["batch_num"] = len(list(data_iter_random(corpus_index, super_params["batch_size"],
                                                      super_params["num_step"], "cpu")))

train_rnn(**super_params)

Epoch [1/5]
12/989 [------------------------------] - train_loss: 1244.3197, train_score: -, test_loss: -, test_score: -

prefix-分开:  分开      的的 的的 的的 的的 的的 的的 的的 的的 的的 的的 的的 的的 的的 的的 的的
prefix-不分开:  不分开     的的 的的 的的 的的 的的 的的 的的 的的 的的 的的 的的 的的 的的 的的 的的 


Epoch [2/5]
21/989 [------------------------------] - train_loss: 998.7580, train_score: -, test_loss: -, test_score: --

KeyboardInterrupt: 

### 简单实现

#### 网络定义

In [4]:
class RNNModel(nn.Module):
    def __init__(self, rnn_layer, vocab_size):
        super(RNNModel, self).__init__()
        self.rnn = rnn_layer
        self.hidden_num = self.rnn.hidden_size * (2 if self.rnn.bidirectional else 1)
        self.vocab_size = vocab_size
        self.fc = nn.Linear(hidden_num, vocab_size)
        self.h_state = None
        
    def forward(self, x, h_state):
        # x.shape is (num_step, batch_size, vocab_size)
        y, self.h_state = self.rnn(x, h_state)
        return self.fc(y), self.h_state

#### 预测

In [19]:
def train_rnn_pytorch(prefix, pred_num, model, char_to_idx, vocab_size, vocab_set, device):
    outputs = [char_to_idx[prefix[0]]]
    h_state = None
    
    for i in range(len(prefix) + pred_num - 1):
        inputs = to_onehot(torch.tensor(outputs[-1]).view(-1, 1), vocab_size, device)
        if h_state is not None:
            if isinstance(h_state, tuple): # lstm , (h,c)
                h_state = (h_state[0].to(device), h_state[1].to(device))
            else:
                h_state = h_state.to(device)
                
        y, h_state = model(inputs, h_state)
        if i + 1 < len(prefix):
            outputs.append(char_to_idx[prefix[i+1]])
        else:
            outputs.append(y.argmax(dim=2).item())
            
    return "".join(vocab_set[i] for i in outputs)

In [20]:
# 验证
# load data
hidden_num = 256
corpus_index, char_to_idx, vocab_set, vocab_size = load_data_jay_song()
rnn_layer = nn.GRU(vocab_size, hidden_num)
model = RNNModel(rnn_layer, vocab_size)
model = model.cuda()
train_rnn_pytorch("分开", 10, model, char_to_idx, vocab_size, vocab_set, "cuda")

'分开八斤斤野宝斤宝苛赞赞'

### 训练

In [21]:
# training
import numpy as np
from sqdm import sqdm
from optim import sgd

def train_rnn_pytorch(epoch_num, batch_num, model, loss, optimizer, batch_size, lr, data_iter, prefixs, 
                      pred_num, corpus_index, char_to_idx, vocab_set, vocab_size, num_step, predict_rnn_pytorch, 
                      clipping_theta=1e-2, random_sample=True, device="cuda"):
    
    # init(use in calculate perplexity)
    l_sum, n_class = 0., 0.
    # training bar
    process_bar = sqdm()
    
    for epoch in range(epoch_num):
        print(f"Epoch [{epoch+1}/{epoch_num}]")
        # sample in consecutive
        if not random_sample:
            h_state = None
        for x, y in data_iter(corpus_index, batch_size, vocab_size, device):
            # 原始x的shape为(batch_size, num_step)，onehot后的shape为(num_step, batch_size, vocab_size)
            x = to_onehot(x, vocab_size, device)
            if random_sample:
                h_state = None
            else:
                # 脱离计算图，使得上一时刻的隐藏状态变成叶子节点，防止在销毁计算图后（隐藏节点还存在），因反向传播到更早的
                # 隐藏层时刻（不在当前计算图内）而出错
                if h_state is not None:
                    if isinstance(h_state, tuple): # lstm, state: (h, c)
                        h_state = (h_state[0].deatch(), h_state[1].deatch())
                    else:
                        h_state.detach_()
                
            # model
            # outputs.shape is (num_step, batch_size, vocab_size), h_state.shape is (batch_size, hidden_num)
            outputs, h_state = model(x, h_state)
            # change output.shape --> (num_step, batch_size, vocab_size), 主要是为了方便计算loss
            outputs = outputs.view(-1, outputs.shape[-1])
            # 原始y的shape为(batch_size, num_step), ---> (num_step, batch_size) ---> 1维向量
            # 转置后变成内存不连续，使用contiguous变成连续的向量
            y = y.transpose(0, 1).contiguous().view(-1)
            
            # 计算loss, 标签需要是int
            l = loss(outputs, y.long())
            # grad clear
            optimizer.zero_grad()
            # grad backward
            l.backward()
            # grad clipping
            grad_clipping(params, clipping_theta, device)
            # update grad
            optimizer.step()
            
            # calculate l_sum
            l_sum += l.item() * y.shape[0]
            n_class += y.shape[0]
            
            # calculate perplexity
            try:
                perplexity = np.exp(l_sum / n_class)
            except OverflowError:
                perplexity = float("inf")
                
            # training bar
            process_bar.show_process(batch_num, 1, perplexity)
        print("\n")
        # predict
        for prefix in prefixs:
            print(f"prefix-{prefix}: ", predict_rnn(prefix, pred_num, model, init_hidden_state, hidden_num, 
                                                    params, char_to_idx, vocab_set, vocab_size, device))
        print("\n")

In [None]:
hidden_num = 256
rnn_layer = nn.GRU(vocab_size, hidden_num)
model = RNNModel(rnn_layer, vocab_size)
model = model.cuda()
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

params = {
    "epoch_num": 250,
    "model": model,
    "loss": loss,
    "optimizer": optimizer,
    "batch_size": 64,
    "num_step": 32,
    "corpus_index": corpus_index,
    "data_iter": data_iter_consecutive,
    "char_to_idx": char_to_idx,
    "vocab_set": vocab_set,
    "vocab_size": vocab_size,
    "predict_rnn_pytorch": predict_rnn_pytorch,
    "pred_num": 50,
    "prefixs": ["分开", "不分开"],
    "random_sample": False
}

params["batch_num"] = len(list(data_iter_consecutive(corpus_index, params["batch_size"],
                                                     params["num_step"], "cpu")))

train_rnn_pytorch(**params)