In [83]:
import torch
import torch.nn as nn
import zipfile
import random
device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
import numpy as np


In [84]:
def load_data_lyrics():
    with zipfile.ZipFile('/home/data/jaychou_lyrics.txt.zip') as zin:
        with zin.open('jaychou_lyrics.txt') as f:
            chorpus_chars = f.read().decode('utf-8')
    chorpus_chars = chorpus_chars.replace('\n',' ').replace('\r',' ')
    chorpus_chars = chorpus_chars[:10000]
    char_to_idx = list(set(chorpus_chars))
    idx_to_char = dict([(char,i)for i,char in enumerate(char_to_idx)])
    vocab_size = len(char_to_idx)
    chorpus_indices = (idx_to_char[char]for char in chorpus_chars)
    return char_to_idx,idx_to_char,vocab_size,chorpus_indices
    

In [85]:
def data_iter_random(chorpus_indices,batch_sizes,num_steps,device = None):
    if device==None:
        device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
    num_exampls = (len(chorpus_indices)-1)//num_steps
    num_epochs = num_exampls // batch_sizes
    num_indices = list(range(num_exampls))
    random.shuffle(num_indices)
    print('num_indices',num_indices)
    def _data(pos):
        return chorpus_indices[pos:pos+num_steps]
    for i in range(num_epochs):
        i = i*batch_sizes
        batch_indices = num_indices[i:i+batch_sizes]
        x = [_data(j*num_steps) for j in batch_indices]
        y = [_data(j*num_steps+1)for j in batch_indices]
        yield torch.tensor(x,dtype=torch.float32,device=device),torch.tensor(y,dtype=torch.float32,device=device)

In [86]:
x = list(range(30))
print(x)
for x,y in data_iter_random(x,2,6):
    print(x,y)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
num_indices [2, 0, 1, 3]
tensor([[12., 13., 14., 15., 16., 17.],
        [ 0.,  1.,  2.,  3.,  4.,  5.]]) tensor([[13., 14., 15., 16., 17., 18.],
        [ 1.,  2.,  3.,  4.,  5.,  6.]])
tensor([[ 6.,  7.,  8.,  9., 10., 11.],
        [18., 19., 20., 21., 22., 23.]]) tensor([[ 7.,  8.,  9., 10., 11., 12.],
        [19., 20., 21., 22., 23., 24.]])


In [87]:
char_to_idx,idx_to_char,vocab_size,chorpus_indices = load_data_lyrics()
def one_hot(x,n_class,dtype=torch.float32):
    x = x.long()
    res = torch.zeros(x.shape[0],n_class,dtype=dtype,device=x.device)
    res.scatter_(1,x.view(-1,1),1)
    return res

In [88]:
def to_hot(x,n_class):
    return [one_hot(x[:,i],n_class)for i in range(x.shape[1])]

In [89]:
num_inputs,num_hiddens,num_outputs = vocab_size,256,vocab_size
print('train on:',device)
def get_param():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0,0.01,size=shape),dtype=torch.float32,device=device)
        return nn.Parameter(ts,requires_grad=True)
    w_xh = _one((num_inputs,num_hiddens))
    w_hh = _one((num_hiddens,num_hiddens))
    b_h = nn.Parameter(torch.zeros(num_hiddens,requires_grad=True,device=device,dtype=torch.float32))
    w_hq = _one((num_hiddens,num_outputs))
    b_q = nn.Parameter(torch.zeros(num_outputs,requires_grad=True,device=device,dtype=torch.float32))
    return nn.ParameterList([w_xh,w_hh,b_h,w_hq,b_q])

train on: cpu


初始化隐藏状态

In [119]:
def init_rnn_states(batch_size,num_hiddens,device):
    return (torch.zeros((batch_size,num_hiddens),device=device))

In [138]:
def rnn(inputs,state,params):
    w_xh,w_hh,b_h,w_hq,b_q=params
    H = state
    outputs = []
    for X in inputs:
        H = torch.matmul(X,w_xh)+torch.matmul(H,w_hh)+b_h
        H = torch.tanh(H)
        o = torch.matmul(H,w_hq)+b_q
        outputs.append(o)
    return outputs,(H)

In [125]:
X = torch.arange(10).view(2,5)
state = init_rnn_states(X.shape[0],num_hiddens,device)
inputs = to_hot(X.to(device),vocab_size)
params = get_param()
outputs,state_new = rnn(inputs,state,params)
print(len(outputs),outputs[0].shape,state_new[0].shape)

5 torch.Size([2, 1027]) torch.Size([2, 256])


In [141]:
def predict_rnn(prefix,num_chars,rnn,params,init_rnn_states,num_hiddens,vocab_size,device,idx_to_char,char_to_idx):
    state = init_rnn_states(1,num_hiddens,device)
    output = [idx_to_char[prefix[0]]]
    for t in range(num_chars+len(prefix)-1):
        X = to_hot(torch.tensor([[output[-1]]],device=device),vocab_size)
        (Y,state) = rnn(X,state,params)
        if t <len(prefix)-1:
            output.append(idx_to_char[prefix[t+1]])
        else:
            output.append(int(Y[0].argmax(dim=1).item()))
    return ''.join([char_to_idx[i]for i in output])
     

In [144]:
predict_rnn('爱情',10,rnn,params,init_rnn_states,num_hiddens,vocab_size,device,idx_to_char,char_to_idx)

0
491
1
2
3
4
5
6
7
8
9
10


'爱情故朵前爸步们便加u茶'