In [74]:
import torch 
import torch.nn as nn
import time
device = torch.device('cuda' if torch.cuda.is_available()else 'cpu')
import zipfile
import numpy as np
import random
import math

In [75]:
def load_data_lyrics():
    with zipfile.ZipFile('/home/data/jaychou_lyrics.txt.zip') as zin:
        with zin.open('jaychou_lyrics.txt')as f:
            corpus_chars = f.read().decode('utf-8')
    corpus_chars = corpus_chars.replace('\n',' ').replace('\r',' ')
    corpus_chars = corpus_chars[:10000]
    idx_to_char = list(set(corpus_chars))
    char_to_idx = dict([(char,i)for i,char in enumerate(idx_to_char)])
    vocab_size = len(idx_to_char)
    corpus_indices = [char_to_idx[char] for char in corpus_chars]
    return idx_to_char,char_to_idx,vocab_size,corpus_indices

In [76]:
idx_to_char,char_to_idx,vocab_size,corpus_indices = load_data_lyrics()
vocab_size

1027

In [77]:
def data_iter_random(corpus_indices,batch_size,num_steps,device=None):
    if device==None:
        device = torch.device('cuda' if torch.cuda.is_available()else 'cpu')  
    num_examples = (len(corpus_indices)-1)//num_steps
    epoch_examples = num_examples//batch_size
    example_indices = list(range(num_examples))
    random.shuffle(example_indices)
    def _data(pos):
        return corpus_indices[pos:pos+num_steps]
    for i in range(epoch_examples):
        i = i*batch_size
        batch_indices = example_indices[i:i+batch_size]
        x = [_data(j*num_steps)for j in batch_indices]
        y = [_data(j*num_steps+1)for j in batch_indices]
        yield torch.tensor(x,dtype=torch.float32,device=device),torch.tensor(y,dtype=torch.float32,device=device)

In [78]:
def data_iter_consecutive(corpus_indices,batch_size,num_steps,device=None):
    if device==None:
        device = torch.device('cuda' if torch.cuda.is_available()else 'cpu')  
    corpus_indices = torch.tensor(corpus_indices,dtype=torch.float32,device=device)
    data_len = len(corpus_indices)
    batch_len = data_len//batch_size
    indices = corpus_indices[0:batch_size*batch_len].view(batch_size,batch_len)
    epoch_size = (batch_len-1)//num_steps
    for i in range(epoch_size):
        i = i*num_steps
        X = indices[:,i:i+num_steps]
        Y = indices[:,i+1:i+num_steps+1]
        yield X,Y

In [79]:
idx_to_char,char_to_idx,vocab_size,corpus_indices = load_data_lyrics()
def one_hot(x,n_class,dtype = torch.float32):
    x = x.long()
    res = torch.zeros(x.shape[0],n_class,dtype=dtype,device=x.device)
    res.scatter_(1,x.view(-1,1),1)
    return res

def to_onehot(x,n_class):
    return [one_hot(x[:,i],n_class)for i in range(x.shape[1])]

In [80]:
num_inputs,num_hiddens,num_outputs = vocab_size,256,vocab_size
def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0,0.01,size=shape),dtype=torch.float32,device=device)
        return nn.Parameter(ts,requires_grad=True)
    w_xh = _one((num_inputs,num_hiddens))
    w_hh = _one((num_hiddens,num_hiddens))
    b_h = nn.Parameter(torch.zeros(num_hiddens,requires_grad=True,device=device,dtype=torch.float32))
    w_qh = _one((num_hiddens,num_outputs))
    b_q = nn.Parameter(torch.zeros(num_outputs,requires_grad=True,device=device,dtype=torch.float32))
    return nn.ParameterList([w_xh,w_hh,b_h,w_qh,b_q])

In [81]:
def init_rnn_states(batch_size,num_hiddens,device):
    return (torch.zeros((batch_size,num_hiddens),device=device))

In [82]:
def rnn(inputs,state,params):
    w_xh,w_hh,b_h,w_qh,b_q = params
    H = state
    outputs =[]
    for X in inputs:
        H = torch.matmul(X,w_xh) + torch.matmul(H,w_hh) + b_h
        H = torch.tanh(H)
        y_hat = torch.matmul(H,w_qh) + b_q
        outputs.append(y_hat)
    return outputs,H 

In [83]:
x = torch.arange(10).view(2,5)
state = init_rnn_states(x.shape[0],num_hiddens,device)
inputs = to_onehot(x.to(device),vocab_size)
params = get_params()
outputs,state_new = rnn(inputs,state,params)
# print(len(outputs),outputs[0].shape,state_new[0].shape)
print(outputs)
print(outputs[0])

[tensor([[-0.0027,  0.0022, -0.0005,  ...,  0.0016, -0.0003,  0.0020],
        [-0.0008, -0.0002, -0.0007,  ..., -0.0017,  0.0004, -0.0007]],
       device='cuda:0', grad_fn=<AddBackward0>), tensor([[-7.4720e-04, -1.2215e-03, -3.5665e-04,  ..., -6.8670e-04,
         -7.5191e-05,  7.1814e-04],
        [ 6.2710e-04,  7.5085e-04,  7.6765e-04,  ..., -4.9642e-04,
         -2.5796e-03,  1.0057e-03]], device='cuda:0', grad_fn=<AddBackward0>), tensor([[ 1.1174e-03, -9.3348e-05,  8.4706e-04,  ..., -5.8103e-04,
          4.9575e-04,  1.4321e-03],
        [ 3.5850e-04,  3.9168e-04, -2.2048e-04,  ...,  1.7791e-03,
         -1.2169e-03,  2.2792e-03]], device='cuda:0', grad_fn=<AddBackward0>), tensor([[-2.2310e-03, -3.9051e-04, -1.0407e-04,  ...,  3.1057e-03,
         -7.4045e-04, -4.3612e-05],
        [-4.1038e-03, -9.8064e-04,  1.7955e-04,  ...,  1.1476e-03,
         -3.8410e-03,  4.6311e-04]], device='cuda:0', grad_fn=<AddBackward0>), tensor([[-9.7862e-05,  2.4274e-03,  2.7486e-04,  ..., -2.2465e

In [84]:
def pred_rnn(prefix,num_chars,rnn,params,init_rnn_states,num_hiddens,vocab_size,device,idx_to_char,char_to_idx):
    state = init_rnn_states(1,num_hiddens,device)
    outputs = [char_to_idx[prefix[0]]]
    for t in range(num_chars+len(prefix)+1):
        X = to_onehot(torch.tensor([[outputs[-1]]],device=device),vocab_size)
        Y,state = rnn(X,state,params)
        if t <len(prefix)-1:
            outputs.append(char_to_idx[prefix[t+1]])
        else:
            outputs.append(int(Y[0].argmax(dim=1).item()))
    return ''.join([idx_to_char[i]for i in outputs])

In [85]:
prefix = '爱情'
pred_rnn('爱情',10,rnn,params,init_rnn_states,num_hiddens,vocab_size,device,idx_to_char,char_to_idx)

'爱情被心婆毫闷原胸翰秒杵羞偷'

In [86]:
def clip_gradient(params,theta,device):
    norm = torch.tensor([0.0],device=device)
    for param in params:
        norm += (param.grad.data**2).sum()
    norm = norm.sqrt().item()
    if norm>theta:
        for param in params:
            param.grad.data *= (theta/norm)

In [87]:
def sgd(params,lr,batch_size):
    for param in params:
        param.data -= (lr*param.grad)/batch_size

In [88]:
def train_and_pred_rnn(rnn,get_params,init_rnn_states,num_hiddens,vocab_size,device,corpus_indices,
                        char_to_idx,idx_to_char,is_random_iter,num_epochs,num_steps,lr,clipping_theta,batch_size,
                        pred_period,pred_len,prefixes):
    if is_random_iter :
        data_iter_fn = data_iter_random
    else:
        data_iter_fn = data_iter_consecutive
    params = get_params()
    loss = nn.CrossEntropyLoss()
    start = time.time()
    for epoch in range(num_epochs):
        if not is_random_iter:
            state = init_rnn_states(batch_size,num_hiddens,device)
        l_sum = 0.0
        n = 0
        data_iter = data_iter_fn(corpus_indices,batch_size,num_steps)
        for X,y in data_iter:
            if is_random_iter:
                state = init_rnn_states(batch_size,num_hiddens,device)
            else:
                for s in state:
                    s.detach()
            inputs = to_onehot(X,vocab_size)
            (y_hat,state) = rnn(inputs,state,params)
            y_hat = torch.cat(y_hat,dim=0)
            y = torch.transpose(y,0,1).contiguous().view(-1)
            l = loss(y_hat,y.long())
            if params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            l.backward()
            clip_gradient(params,clipping_theta,device)
            sgd(params,lr,1)
            l_sum += l.item()*y.shape[0]
            n += y.shape[0]
        if (epoch+1)%pred_period == 0:
            print('epochs:%d,perplexity:%f,time:%.f sec'%(epoch+1,math.exp(l_sum/n),time.time()-start))
            start = time.time()
            for prefix in prefixes:
                print('-',pred_rnn(prefix,pred_len,rnn,params,init_rnn_states,num_hiddens,vocab_size,device,idx_to_char,char_to_idx))


In [89]:
batch_size = 35
num_steps = 32
lr = 1e2
clipping_theta = 1e-2
num_epochs =250
pred_period = 50
pred_len = 50
prefixes = ['爱情','我爱你']
train_and_pred_rnn(rnn,get_params,init_rnn_states,num_hiddens,vocab_size,device,corpus_indices,char_to_idx,idx_to_char,True,num_epochs,num_steps,lr,clipping_theta,batch_size,pred_period,pred_len,prefixes)





epochs:50,perplexity:68.786968,time:6 sec
- 爱情 我想要你想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要
- 我爱你的可爱女人 坏坏的让我疯狂的可爱女人 坏坏的让我疯狂的可爱女人 坏坏的让我疯狂的可爱女人 坏坏的让我疯狂
epochs:100,perplexity:10.783561,time:6 sec
- 爱情 我不要 你怎么 太九就颗三步四 干什么 干什么 一九四步三步四步望著天 看著星 一颗两颗三步四步 连成
- 我爱你 你爱我 开我怎么我想我 别发抖 干什么 一步四步三步四步望著天 看著星 一颗两颗三步四步 连成线背著背
epochs:150,perplexity:3.206898,time:5 sec
- 爱情 我想 我不要 想情我的太快就像 透过兵器最喜欢 双截棍柔中带刚 想要去童南嵩事 学要 有话段考倒着你打
- 我爱你的可爱 我想 你想要再不要 静涯尽透 你已一定热粥 我爱好好生活 不知不觉 你已经离开我 不知不觉 我跟
epochs:200,perplexity:1.668156,time:6 sec
- 爱情人妈 怎么我著天怪到 看来像没了你的 我说上这样牵着著 这是我都做得到 但那个人已经不是去 想著和你融化
- 我爱你 你爱我 开不了口 周杰伦 才离开没是久就开始 担心今宇的你去也起  相袋你说别单 所有 你想再考了我 
epochs:250,perplexity:1.375738,time:6 sec
- 爱情人定发过 一步承 干什么 呼吸吐这系自袋 干什么 干什么 气沉吐纳心自在 干什么 干什么 气沉病纳系自袋
- 我爱你太错 我想好这 牵小心外的溪边 默默等待 娘子 娘子 娘 再亮了 别给我抬起头 一话去对医药箱说 别怪我
