In [1]:
import torch
import zipfile
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.__version__)

cuda
1.8.0


读取zipfile文件

In [2]:
with zipfile.ZipFile("../../data/jaychou_lyrics.txt.zip")as zin:
    with zin.open("jaychou_lyrics.txt") as f:
        lyrics=f.read().decode("utf-8")
lyrics=lyrics.replace("\n"," ").replace("\r"," ")
print(len(lyrics))
lyrics=lyrics[0:10000]

63282


In [3]:
char_set=list(set(lyrics))
vocab_size=len(char_set)
char_indices=dict([(char,i)for i,char in enumerate(char_set)])
vocab_indices=[char_indices[char] for char in lyrics]


采集数据

In [4]:
import numpy as np
import random

def Dataset_iter(num_steps,batch_size,data,mode=None):
    def get_data(pos,num_steps):
        return data[pos:pos+num_steps]
    num_data=len(data)
    num_samples=(num_data-1)//num_steps
    num_batch=num_samples//batch_size
    example_list=list(range(num_samples))
    if mode=="nearest":#相邻采样
        indices=np.array(example_list[0:num_batch*batch_size]).reshape(batch_size,num_batch)
        for i in range(num_batch):
            x=[get_data(idx*num_steps,num_steps)for idx in indices[:,i]]
            y=[get_data(idx*num_steps+1,num_steps)for idx in indices[:,i]]
            yield torch.from_numpy(np.array(x)),torch.from_numpy(np.array(y))        
    if mode=="random":#随机采样
        random.shuffle(example_list) 
        for epoch in range(num_batch):
            batch_pos=batch_size*epoch
            batch_indices=example_list[batch_pos:batch_pos+batch_size]
            x=[get_data(idx*num_steps,num_steps)for idx in batch_indices]
            y=[get_data(idx*num_steps+1,num_steps)for idx in batch_indices]
            yield torch.from_numpy(np.array(x)),torch.from_numpy(np.array(y)) 
           


In [5]:
#将采样得到的数据表示成one-hot形式
#(num_batch,num_steps)
#变成：(num_steps,(num_batch,one-hot))

In [6]:
def one_hot(x,length,dtype=torch.float32):
    x=x.long()
    zeros=torch.zeros((1,length),dtype=torch.float32,device=device)
    zeros[0,x]=1
    return zeros
def to_onehot(x,length,dtype=torch.float32):
    batch_size,num_steps=x.shape
    x=x.T
    res=[]
    for row in x:
        for t in row:
            res.append(one_hot(x,length,dtype=dtype))
    return torch.cat(res).reshape(num_steps,batch_size,length)

In [7]:
#初始化模型参数
num_inputs,num_hiddens,num_outputs=vocab_size,256,vocab_size
print(vocab_size)
import torch.nn as nn
def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
        return torch.nn.Parameter(ts, requires_grad=True)

    # 隐藏层参数
    W_xh = _one((num_inputs, num_hiddens))
    W_hh = _one((num_hiddens, num_hiddens))
    b_h = torch.nn.Parameter(torch.zeros(num_hiddens, device=device, requires_grad=True))
    # 输出层参数
    W_hq = _one((num_hiddens, num_outputs))
    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, requires_grad=True))
    return nn.ParameterList([W_xh, W_hh, b_h, W_hq, b_q])


1027


In [8]:
def init_rnn_state(batch_size, num_hiddens, device):
    return torch.zeros((batch_size, num_hiddens), device=device)
def rnn(inputs, state, params):
    # inputs和outputs皆为num_steps个形状为(batch_size, vocab_size)的矩阵
    W_xh, W_hh, b_h, W_hq, b_q = params
    H= state
    outputs = []
    for X in inputs:
        H = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(H, W_hh) + b_h)
        Y = torch.matmul(H, W_hq) + b_q
        outputs.append(Y)
    return outputs, H

In [9]:
# num_steps=5
# batch_size=2

# for step ,data in enumerate (Dataset_iter(num_steps,batch_size,vocab_indices,"random")):
#     inputs=to_onehot(data[0],vocab_size).to(device)
#     label=to_onehot(data[1],vocab_size).to(device)
#     state=init_rnn_state(batch_size, num_hiddens, device)
#     params=get_params()
#     output,h,=rnn(inputs,state,params)
# # print(len(output),output[0].shape,h.shape)

In [10]:
#定义预测函数
def predict_rnn(prefix,num_chars,params,state,rnn,num_hiddens,vocab_size,char_indices,char_set):
    output=[char_indices[prefix[0]]]
    params=get_params()
    for t in range(num_chars+len(prefix)):
        params=get_params()
        x=one_hot(torch.tensor([output[0]]),vocab_size)
        Y,state=rnn(x,state,params)
        if t<len(prefix)-1:
            output.append(char_indices[prefix[t+1]])
        else:
            output.append(torch.argmax(Y[0]).item())
    return ''.join([char_set[i] for i in output])
# params=get_params()                   
# res=predict_rnn("分开",10,params,init_rnn_state,rnn,num_hiddens,vocab_size,char_indices,char_set)
# print(res)

In [11]:
#梯度剪裁
def grad_clipping(params, theta, device):
    norm = torch.tensor([0.0], device=device)
    for param in params:
        norm += (param.grad.data ** 2).sum()
    norm = norm.sqrt().item()#.item从具有一个元素的张量中取出元素值。
    if norm > theta:
        for param in params:
            param.grad.data *= (theta / norm)

In [12]:
def sgd(params,batch_size,lr,epoch):
    with torch.no_grad():
        if epoch>100:
            lr=0.1*lr
        for param in params:
            param-=lr*param.grad.data/batch_size
    return params
    

In [13]:
#定义training function

In [22]:
import math
def train_and_predict_rnn(epoches,get_params,init_rnn_state,rnn,batch_size,num_steps,num_hiddens,mode,theta,
                          device,vocab_indices,length,num_chars,vocab_size,char_indices,char_set,prefixes):
    loss=nn.CrossEntropyLoss()
    params=get_params()
    optimizer=torch.optim.Adam(params,lr=lr)
    for epoch in range(1,epoches+1):
        if mode!="random":
            state=init_rnn_state(batch_size, num_hiddens, device)
        data_iter=Dataset_iter(num_steps,batch_size,vocab_indices,mode)
        sum_loss=0.0
        n=0
        for x,y in data_iter:
            if mode=="random":
                state=init_rnn_state(batch_size, num_hiddens, device)
            else:
                for s in state:
                    s.detach()
            optimizer.zero_grad()
            x=to_onehot(x,length,torch.float32)
            output,state=rnn(x,state,params)
            output=torch.cat(output,0)
            y=torch.transpose(y,0,1).contiguous().view(-1).to(device)
            Loss=loss(output,y.long())#交叉熵损失得到的loss是平均的loss
            if params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()#zero_表示对值进行赋零值
            Loss.backward()#赋予参数各自的梯度
            grad_clipping(params, theta, device)
            optimizer.step()
            sum_loss+=Loss.item()*y.shape[0]
            n+=y.shape[0]
        if (epoch ) % pred_period == 0:
            print('epoch %d, perplexity %f' % (
                epoch , math.exp(sum_loss/ n)))
            for prefix in prefixes:
                print(' -',predict_rnn(prefix,num_chars,params,state,rnn,num_hiddens,vocab_size,char_indices,char_set))

            
        print(f"{epoch:03d}/{epoches}  loss:{math.exp(sum_loss/n):.5f} ")
            
    

In [23]:
#super parameters

num_hiddens=600
mode="random"
epoches, num_steps, batch_size, lr, theta,length = 250, 30, 32, 1e-3,1e-2,vocab_size
pred_period=30
num_chars=20
prefixes=["分开","不分开"]
train_and_predict_rnn(epoches,get_params,init_rnn_state,rnn,batch_size,num_steps,num_hiddens,mode,theta,
                          device,vocab_indices,length,num_chars,vocab_size,char_indices,char_set,prefixes)

001/250  loss:441.97344 
002/250  loss:310.58600 
003/250  loss:307.20501 
004/250  loss:304.40485 
005/250  loss:301.87610 
006/250  loss:302.05257 
007/250  loss:297.30244 
008/250  loss:294.11009 
009/250  loss:298.31595 
010/250  loss:296.63717 
011/250  loss:297.24218 
012/250  loss:292.55373 
013/250  loss:293.25268 
014/250  loss:296.66945 
015/250  loss:293.98364 
016/250  loss:290.89277 
017/250  loss:289.82135 
018/250  loss:292.16861 
019/250  loss:290.22875 
020/250  loss:290.11592 
021/250  loss:285.70855 
022/250  loss:289.30604 
023/250  loss:290.78913 
024/250  loss:287.22101 
025/250  loss:290.07637 
026/250  loss:286.25785 
027/250  loss:288.45658 
028/250  loss:287.94189 
029/250  loss:287.26532 
epoch 30, perplexity 292.466579
 - 分开苍我甜腿强掉我窝宇抄奏晶问词抱典硬城币乡承
 - 不分开怀唱凝都千灌蜡拳始向者狗寻准只形攻较藏蒙这
030/250  loss:292.46658 
031/250  loss:283.85787 
032/250  loss:287.86840 
033/250  loss:285.96110 
034/250  loss:287.09496 
035/250  loss:285.52128 
036/250  loss:282.87850 
037/250  loss

KeyboardInterrupt: 