In [1]:
%config ZMQInteractiveShell.ast_node_interactivity = "all"
%pprint

Pretty printing has been turned OFF


In [2]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [4]:
import sys
import torch
import numpy as np
import torch.nn as nn
sys.path.append("../d2l_func/")
from data_prepare import load_data_jay_song, data_iter_random, data_iter_consecutive, to_onehot
from model_train import train_rnn, train_rnn_pytorch
from set_seed import set_seed

## LSTM

LSTM有输入门，遗忘门，输出门
- 输入门：$i_t = \delta(x_tw_{xi} + h_{t-1}w_{hi} + b_i)$
- 遗忘门：$f_t = \delta(x_tw_{xf} + h_{t-1}w_{hf} + b_f)$
- 候选元胞状态：$\widetilde c = tanh(x_tw_{xc} + h_{t-1}w_{hc} + b_c)$
- 元胞状态：$c_t = i_t*\widetilde c + o_t*c_{t-1}$
- 输出门：$o_t = \delta(x_tw_{xo} + h_{t-1}w_{ho} + b_o)$
- 隐藏层状态：$h_t = o_t*tanh(c_t)$

### 自定义实现

#### 定义网络参数

In [14]:
def get_params(input_num, hidden_num, output_num, device):
    def _ones(shape):
        weight = nn.Parameter(torch.normal(0, 0.01, size=shape, device=device), requires_grad=True)
        return weight
    
    def _zeros(shape):
        bias = nn.Parameter(torch.zeros(shape, device=device), requires_grad=True)
        return bias
    
    def _three():
        return (
            _ones((input_num, hidden_num)),
            _ones((hidden_num, hidden_num)),
            _zeros(hidden_num)
        )
    
    # 输入门/遗忘门/输出门
    w_xi, w_hi, b_i = _three()
    w_xf, w_hf, b_f = _three()
    w_xo, w_ho, b_o = _three()
    # 元胞状态
    w_xc, w_hc, b_c = _three()
    # 输出层
    w_hy = _ones((hidden_num, output_num))
    b_y = _zeros(output_num)
    
    return nn.ParameterList([w_xi, w_hi, b_i, w_xf, w_hf, b_f, w_xo, w_ho, b_o, w_xc, w_hc, b_c, w_hy, b_y])

#### 定义网络结构

In [23]:
from functools import reduce

def init_hidden_state(batch_size, hidden_num, device):
    return (torch.zeros(batch_size, hidden_num, device=device), 
            torch.zeros(batch_size, hidden_num, device=device))


def lstm(inputs, h_state, params):
    w_xi, w_hi, b_i, w_xf, w_hf, b_f, w_xo, w_ho, b_o, w_xc, w_hc, b_c, w_hy, b_y = params
    outputs = []
    h, c = h_state
    
    # inputs.shape is (num_step, batch_size, vocab_size)
    for x in inputs:
        it = torch.sigmoid(torch.mm(x, w_xi) + torch.mm(h, w_hi) + b_i)
        ft = torch.sigmoid(torch.mm(x, w_xf) + torch.mm(h, w_hf) + b_f)
        ot = torch.sigmoid(torch.mm(x, w_xo) + torch.mm(h, w_ho) + b_o)
        c_candidate = torch.tanh(torch.mm(x, w_xc) + torch.mm(h, w_hc) + b_c)
        c = it*c_candidate + ft*c
        h = ot*torch.tanh(c)
        y = torch.mm(h, w_hy) + b_y
        outputs.append(y.unsqueeze(0))
        
    return reduce(lambda x, y: torch.cat((x, y)), outputs), (h, c)

In [24]:
# 验证
hidden_num, vocab_size, device = 10, 15, "cuda"
x = torch.arange(10).view(2, 5)
inputs = to_onehot(x, vocab_size, device)
h_state = init_hidden_state(inputs.shape[1], hidden_num, device)
params = get_params(vocab_size, hidden_num, vocab_size, device)
outputs, h_state = lstm(inputs, h_state, params)
# 输出/隐藏状态/元胞状态
outputs.shape, h_state[0].shape, h_state[1].shape

(torch.Size([5, 2, 15]), torch.Size([2, 10]), torch.Size([2, 10]))

#### 预测

In [26]:
def predict_rnn(prefix, pred_num, model, init_hidden_state, hidden_num, 
                params, char_to_idx, vocab_set, vocab_size, device):
    outputs = [char_to_idx[prefix[0]]]
    h_state = init_hidden_state(1, hidden_num, device)
    
    for i in range(len(prefix) + pred_num - 1):
        # inputs.shape is (1, 1, vocab_size)
        inputs = to_onehot(torch.tensor(outputs[-1]).view(-1, 1), vocab_size, device)
        # y.shape is (1, 1, vocab_size), h_state.shape is (1, hidden_num)
        y, h_state = model(inputs, h_state, params)
        
        if i + 1 < len(prefix):
            outputs.append(char_to_idx[prefix[i+1]])
        else:
            outputs.append(y.argmax(dim=2).item())
            
    return "".join(vocab_set[i] for i in outputs)

In [27]:
# 验证
# load data
corpus_index, char_to_idx, vocab_set, vocab_size = load_data_jay_song()
# params
hidden_num, device = 256, "cuda"
params = get_params(vocab_size, hidden_num, vocab_size, device)
predict_rnn("分开", 10, lstm, init_hidden_state, hidden_num, params, char_to_idx, vocab_set, vocab_size, device)

'分开顽繁曲狈耿跡掏狞墙台'

#### 训练

In [29]:
# load data
corpus_index, char_to_idx, vocab_set, vocab_size = load_data_jay_song()

super_params = {
        "epoch_num": 5,
        "rnn": lstm,
        "loss": nn.CrossEntropyLoss(),
        "init_hidden_state": init_hidden_state,
        "hidden_num": 256,
        "get_params": get_params,
        "batch_size": 2,
        "num_step": 32,
        "corpus_index": corpus_index,
        "data_iter": data_iter_random,
        "lr": 100,
        "char_to_idx": char_to_idx,
        "vocab_set": vocab_set,
        "vocab_size": vocab_size,
        "predict_rnn": predict_rnn,
        "pred_num": 50,
        "prefixs": ["分开", "不分开"],
        #     "random_sample": False
    }

super_params["batch_num"] = len(list(data_iter_random(corpus_index, super_params["batch_size"],
                                                      super_params["num_step"], "cpu")))

train_rnn(**super_params)

Epoch [1/5]
989/989 [>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] - train_loss: 345.9668, train_score: -, test_loss: -, test_score: -

prefix-分开:  分开 我们的爱情 我的爱爱不出 我的爱爱不出 我的爱爱不出 我的爱爱不出 我的爱爱不出 我的爱爱不出 我
prefix-不分开:  不分开 我们的爱情 我的爱爱不出 我的爱爱不出 我的爱爱不出 我的爱爱不出 我的爱爱不出 我的爱爱不出 我


Epoch [2/5]
989/989 [>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] - train_loss: 231.9450, train_score: -, test_loss: -, test_score: -

prefix-分开:  分开 我们的感觉 我们的感觉 我们的感觉 我们的感觉 我们的感觉 我们的感觉 我们的感觉 我们的感觉 我
prefix-不分开:  不分开 我们的感觉 我们的感觉 我们的感觉 我们的感觉 我们的感觉 我们的感觉 我们的感觉 我们的感觉 我


Epoch [3/5]
989/989 [>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] - train_loss: 166.0830, train_score: -, test_loss: -, test_score: -

prefix-分开:  分开 我不能够想 我不要再想 我不能再想 我不要我 我不要我 我不要我 我不要我 我不要我 我不要我 我
prefix-不分开:  不分开 我用第一人称 在我的等待 你说我爱你 你说我爱你 你说我爱你 你说我爱你 你说我爱你 你说我爱你 


Epoch [4/5]
989/989 [>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] - train_loss: 125.3085, train_score: -, test_loss: -, test_score: -

prefix-分开:  分开 我不能再想 我不能再想 我不再再想 我不再再想 我不再再想 我不再再想 我不再再想 我不再再想 我
prefix-不分开:  不分开 我知道你不要我  我不再再想 我不是我要的天堂景象 你只

### 简洁实现

#### 网络结构

In [30]:
class RNNModel(nn.Module):
    def __init__(self, rnn_layer, vocab_size):
        super(RNNModel, self).__init__()
        self.rnn = rnn_layer
        self.hidden_num = self.rnn.hidden_size * (2 if self.rnn.bidirectional else 1)
        self.vocab_size = vocab_size
        self.fc = nn.Linear(self.hidden_numden_numden_num, vocab_size)
        self.h_state = None
        
    def forward(self, x, h_state):
        y, self.h_state = self.rnn(x, h_state)
        return self.fc(y), self.h_state

#### 预测

In [None]:
# 预测
def train_rnn_pytorch(prefix, pred_num, model, char_to_idx, vocab_size, vocab_set, device):
    outputs = [char_to_idx[prefix[0]]]
    h_state = None
    
    for i in range(len(prefix) + pred_num - 1):
        inputs = to_onehot(torch.tensor(outputs[-1]).view(-1, 1), vocab_size, device)
        if h_state is not None:
            if isinstance(h_state, tuple): # lstm , (h,c)
                h_state = (h_state[0].to(device), h_state[1].to(device))
            else:
                h_state = h_state.to(device)
                
        y, h_state = model(inputs, h_state)
        if i + 1 < len(prefix):
            outputs.append(char_to_idx[prefix[i+1]])
        else:
            outputs.append(y.argmax(dim=2).item())
            
    return "".join(vocab_set[i] for i in outputs)

In [31]:
# 验证
# load data
hidden_num = 256
corpus_index, char_to_idx, vocab_set, vocab_size = load_data_jay_song()
rnn_layer = nn.GRU(vocab_size, hidden_num)
model = RNNModel(rnn_layer, vocab_size)
model = model.cuda()
train_rnn_pytorch("分开", 10, model, char_to_idx, vocab_size, vocab_set, "cuda")

TypeError: train_rnn_pytorch() missing 8 required positional arguments: 'num_step', 'batch_size', 'char_to_idx', 'vocab_set', 'vocab_size', 'prefixs', 'pred_num', and 'predict_rnn_pytorch'

#### 训练

In [None]:
hidden_num = 256
rnn_layer = nn.LSTM(vocab_size, hidden_num)
model = RNNModel(rnn_layer, vocab_size)
model = model.cuda()
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

params = {
    "epoch_num": 250,
    "model": model,
    "loss": loss,
    "optimizer": optimizer,
    "batch_size": 64,
    "num_step": 32,
    "corpus_index": corpus_index,
    "data_iter": data_iter_consecutive,
    "char_to_idx": char_to_idx,
    "vocab_set": vocab_set,
    "vocab_size": vocab_size,
    "predict_rnn_pytorch": predict_rnn_pytorch,
    "pred_num": 50,
    "prefixs": ["分开", "不分开"],
    "random_sample": False
}

params["batch_num"] = len(list(data_iter_consecutive(corpus_index, params["batch_size"],
                                                     params["num_step"], "cpu")))

train_rnn_pytorch(**params)