In [1]:
import time
import math
import numpy as np
import torch
from torch import nn,optim
import torch.nn.functional as F
import sys
import zipfile

In [2]:
device='cpu'

In [3]:
def load_data_jay_lyrics(): 
    with zipfile.ZipFile(r'F:\study\ml\ebooks3\6\jaychou_lyrics.txt.zip') as zif:
        with zif.open('jaychou_lyrics.txt') as f:
            corpus_chars =f.read().decode('utf-8')
    corpus_chars=corpus_chars.replace('\n',' ').replace('\r',' ')
    corpus_chars=corpus_chars[0:10000]
    
    idx_to_char=list(set(corpus_chars))
    char_to_idx=dict([(char,i) for i,char in enumerate(idx_to_char)])
    vocab_size=len(char_to_idx)
    corpus_indices=[char_to_idx[char] for char in corpus_chars]
    
    return corpus_indices,char_to_idx,idx_to_char,vocab_size

In [4]:
corpus_indices,char_to_idx,idx_to_char,vocab_size=load_data_jay_lyrics()

In [5]:
num_inputs,num_hiddens,num_outputs=vocab_size,256,vocab_size

In [63]:
def get_params():
    def _one(shape):
        ts=torch.Tensor(np.random.normal(0,0.01,size=shape))
        return torch.nn.Parameter(ts,requires_grad=True)
    def _three():
        return(_one((num_inputs,num_hiddens)),
               _one((num_hiddens,num_hiddens)),
               torch.nn.Parameter(torch.zeros(num_hiddens,dtype=torch.float32),requires_grad=True)
        )
    
    W_xz,W_hz,b_z=_three()
    W_xr,W_hr,b_r=_three()
    W_xh,W_hh,b_h=_three()
    
    W_hq=_one((num_hiddens,num_outputs))
    b_q=torch.nn.Parameter(torch.zeros(num_outputs,dtype=torch.float32),requires_grad=True)
    
    return nn.ParameterList([W_xz,W_hz,b_z,W_xr,W_hr,b_r,W_xh,W_hh,b_h,W_hq,b_q])

In [118]:
def grad_clipping(params,theta,device):
    norm=torch.Tensor([0.0])
    for p in params:
        norm +=(p.grad.data **2).sum()
    norm = np.sqrt(norm.item())
    if norm > theta:
        for p in params:
            p.grad.data *= (theta/norm)

In [119]:
def one_hot(x,n_class,dtype=torch.float32):
    x=x.long()
    res=torch.zeros(x.shape[0],n_class,dtype=dtype,device=x.device)
    res.scatter_(1,x.view(-1,1),1)
    return res

In [120]:
def init_rnn_state(batch_size,num_hiddens,device):
    return (torch.zeros((batch_size,num_hiddens)),)

In [121]:
def to_onehot(x,n_class):
    return [one_hot(x[:,i],n_class) for i in range(x.shape[1])]

In [122]:
def predict_rnn_pytorch(prefix,num_chars,model,vocab_size,device,idx_to_char,char_to_idx):
    state=None
    output=[char_to_idx[prefix[0]]]
    for t in range(num_chars+len(prefix)-1):
        X=torch.tensor([output[-1]],device=device).view(-1,1)
        if state is not None:
            if isinstance(state,tuple):
                state=(state[0].to(device),state[1].to(device))
            else:
                state=state.to(device)
        (Y,state)=model(X,state)
        if t<len(prefix) -1:
            output.append(char_to_idx[prefix[t+1]])
        else:
            output.append(int(Y.argmax(dim=1).item()))
    return ''.join([idx_to_char[i] for i in output])

In [123]:
def gru(inputs,state,params):
    W_xz,W_hz,b_z,W_xr,W_hr,b_r,W_xh,W_hh,b_h,W_hq,b_q=params
#     print('state len :',len(state))
    H,=state
    outputs=[]
    for X in inputs:
        Z=torch.sigmoid(torch.matmul(X,W_xz)+torch.matmul(H,W_hz)+b_z)
        R=torch.sigmoid(torch.matmul(X,W_xr)+torch.matmul(H,W_hr)+b_r)
        H_tilda=torch.tanh(torch.matmul(X,W_xh)+R*torch.matmul(H,W_hh)+b_h)
        H=Z*H+(1-Z)*H_tilda
        Y=torch.matmul(H,W_hq)+b_q
        outputs.append(Y)
    return outputs,(H,)

In [124]:
def data_iter_consecutive(corpus_indices,batch_size,num_steps,device=None):
    if device is None:
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    corpus_indices=torch.tensor(corpus_indices,dtype=torch.float32,device=device)
    data_len=len(corpus_indices)
    batch_len=data_len//batch_size
    indices=corpus_indices[0:batch_size*batch_len].view(batch_size,batch_len)
    epoch_size=(batch_len-1) // num_steps
    for i in range(epoch_size):
        i=i+num_steps
        X=indices[:,i:i+num_steps]
        Y=indices[:,i+1:i+num_steps+1]
        yield X,Y

In [125]:
def predict_rnn(prefix,num_chars,rnn,params,init_rnn_state,num_hiddens,vocab_size,device,idx_to_char,char_to_idx):
    state=init_rnn_state(1,num_hiddens,device)
    output=[char_to_idx[prefix[0]]]
    for t in range(num_chars+len(prefix) -1):
        print('before X :',torch.tensor([[output[-1]]],device=device))
        X= to_onehot(torch.tensor([[output[-1]]],device=device),vocab_size)
        print('X Input shape : ',len(X))
        (Y,state)=rnn(X,state,params)
        print('Y shape :',len(Y))
        print('Y :',Y)
        if t<len(prefix) -1:
            output.append(char_to_idx[prefix[t+1]])
        else:
            output.append(int(Y[0].argmax(dim=1).item()))
    return ''.join([idx_to_char[i] for i in output])

In [129]:
def sgd(params,lr,batch_size):
    for param in params:
        param.data -= lr * param.grad / batch_size

In [136]:
def train_and_predict_rnn(rnn, get_params, init_run_state, num_hiddens,
                          vocab_size,device, corpus_indices, idx_to_char, 
                          char_to_idx,
                          is_random_iter, num_epochs,num_steps, lr, clipping_theta,
                          batch_size, pred_period, pred_len, prefixes):
    if is_random_iter:
        data_iter_fn = data_iter_random
    else:
        data_iter_fn = data_iter_consecutive

    params = get_params()
    loss = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        if not is_random_iter:
            state = init_rnn_state(batch_size, num_hiddens,device)
        l_sum, n, start = 0.0, 0, time.time()
        data_iter = data_iter_fn(corpus_indices, batch_size, num_steps)
        for X, Y in data_iter:
            if is_random_iter:
                state = init_rnn_state(batch_size, num_hiddens,device)
            else:
                for s in state:
                    s.detach()
#             print('----')
#             print('X : ',X)
#             print('state shape',state.shape)
            print('X shape : ',X.shape)
            inputs = to_onehot(X, vocab_size)
#             print('state len :',len(state))
#             print('state[0] len :',state[0].shape)
            (outputs, state) = rnn(inputs, state, params)

            #print()
            outputs = torch.cat(outputs, dim=0)
            y = torch.transpose(Y, 0, 1).contiguous().view(-1)
            l = loss(outputs, y.long())

            if params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            l.backward(retain_graph=True)
            grad_clipping(params, clipping_theta,device)
            sgd(params, lr, 1)
            l_sum += l.item() * y.shape[0]
            n += y.shape[0]

        if (epoch + 1) % pred_period == 0:
            print('epoch %d,perplexity %f,time %.2f sec' %
                  (epoch + 1, math.exp(l_sum / n), time.time() - start))
            for prefix in prefixes:
                print(
                    ' -',
                    predict_rnn(prefix, pred_len, rnn, params, init_run_state,
                                num_hiddens, vocab_size, device ,idx_to_char,
                                char_to_idx))

In [137]:
num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32,1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']

In [None]:
train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens, vocab_size,
                      device, corpus_indices, idx_to_char, char_to_idx, False,
                      num_epochs, num_steps, lr, clipping_theta, batch_size,
                      pred_period, pred_len, prefixes)

X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape 

X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape 

Y shape : 1
Y : [tensor([[ 7.5125,  0.1526, -1.1135,  ..., -1.1291,  0.7372,  1.4731]],
       grad_fn=<AddBackward0>)]
before X : tensor([[0]])
X Input shape :  1
Y shape : 1
Y : [tensor([[ 3.4144,  1.0695, -1.1182,  ..., -1.1066,  3.9451,  4.2275]],
       grad_fn=<AddBackward0>)]
before X : tensor([[480]])
X Input shape :  1
Y shape : 1
Y : [tensor([[ 1.5381,  2.1408, -0.8844,  ..., -1.1797,  6.0400,  4.0983]],
       grad_fn=<AddBackward0>)]
before X : tensor([[630]])
X Input shape :  1
Y shape : 1
Y : [tensor([[ 1.8057,  1.9882, -0.6526,  ..., -1.0898,  4.9740,  1.8881]],
       grad_fn=<AddBackward0>)]
before X : tensor([[860]])
X Input shape :  1
Y shape : 1
Y : [tensor([[ 7.5125,  0.1526, -1.1135,  ..., -1.1291,  0.7372,  1.4731]],
       grad_fn=<AddBackward0>)]
before X : tensor([[0]])
X Input shape :  1
Y shape : 1
Y : [tensor([[ 3.4144,  1.0695, -1.1182,  ..., -1.1066,  3.9451,  4.2275]],
       grad_fn=<AddBackward0>)]
before X : tensor([[480]])
X Input shape :  1
Y shape 

X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape 

Y shape : 1
Y : [tensor([[-0.7975,  1.0893, -2.0349,  ..., -1.4300,  6.7310, 11.5500]],
       grad_fn=<AddBackward0>)]
before X : tensor([[1026]])
X Input shape :  1
Y shape : 1
Y : [tensor([[-0.6759,  0.6415, -1.9276,  ..., -0.9631,  5.4665,  7.8741]],
       grad_fn=<AddBackward0>)]
before X : tensor([[671]])
X Input shape :  1
Y shape : 1
Y : [tensor([[ 1.2770,  4.3164, -1.1997,  ..., -1.4521,  2.6009,  3.4997]],
       grad_fn=<AddBackward0>)]
before X : tensor([[860]])
X Input shape :  1
Y shape : 1
Y : [tensor([[10.9208, -0.9031, -2.1402,  ..., -1.5806, -3.1350,  1.3148]],
       grad_fn=<AddBackward0>)]
before X : tensor([[72]])
X Input shape :  1
Y shape : 1
Y : [tensor([[ 4.3777,  0.1287, -1.8735,  ..., -1.9394,  4.0720,  6.8025]],
       grad_fn=<AddBackward0>)]
before X : tensor([[911]])
X Input shape :  1
Y shape : 1
Y : [tensor([[-1.1001,  2.9581, -2.3975,  ..., -1.8495, 11.5425,  7.9919]],
       grad_fn=<AddBackward0>)]
before X : tensor([[581]])
X Input shape :  1
Y sh

X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape 

X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape :  torch.Size([32, 35])
X shape 