In [1]:
import time
import math
import numpy as np
import torch
from torch import nn,optim
import torch.nn.functional as F
import random
import zipfile
device='cpu'

In [2]:
def load_data_jay_lyrics():
    with zipfile.ZipFile(r'F:\study\ml\ebooks3\6\jaychou_lyrics.txt.zip') as zin:
        with zin.open('jaychou_lyrics.txt') as f:
            corpus_chars=f.read().decode('utf-8')
    corpus_chars=corpus_chars.replace('\n',' ').replace('\r',' ')
    corpus_chars=corpus_chars[0:10000]
    idx_to_char=list(set(corpus_chars))
    char_to_idx=dict([( char,i ) for i , char in enumerate(idx_to_char)])
    vocab_size=len(char_to_idx)
    corpus_indices=[char_to_idx[char] for char in corpus_chars]
    return corpus_indices, char_to_idx, idx_to_char, vocab_size

In [3]:
corpus_indices, char_to_idx, idx_to_char, vocab_size=load_data_jay_lyrics()

In [5]:
num_hiddens=256
rnn_layer=nn.RNN(input_size=vocab_size,hidden_size=num_hiddens)
rnn_layer

RNN(1027, 256)

In [9]:
num_steps=35
batch_size=2
state=None
X=torch.rand(num_steps,batch_size,vocab_size)
Y,state_new=rnn_layer(X,state)
print('X shape: ',X.shape)
print('Y.shape: ',Y.shape,len(state_new),state_new[0].shape)

X shape:  torch.Size([35, 2, 1027])
Y.shape:  torch.Size([35, 2, 256]) 1 torch.Size([2, 256])


In [10]:
num_hiddens=256
rnn_layer2=nn.RNN(input_size=vocab_size,hidden_size=num_hiddens,num_layers=2)
rnn_layer2

RNN(1027, 256, num_layers=2)

In [11]:
num_steps=35
batch_size=2
state=None
X=torch.rand(num_steps,batch_size,vocab_size)
Y,state_new=rnn_layer2(X,state)
print('X shape: ',X.shape)
print('Y.shape: ',Y.shape,len(state_new),state_new[0].shape)

X shape:  torch.Size([35, 2, 1027])
Y.shape:  torch.Size([35, 2, 256]) 2 torch.Size([2, 256])


In [25]:
def one_hot(x,n_class):
    x=x.long()
    res=torch.zeros(x.shape[0],n_class)
    res.scatter_(1,x.view(-1,1),1)
    return res

In [26]:
def to_onehot(x,n_class):
    return [one_hot(x[:,i],n_class) for i in range(x.shape[1])]

In [27]:
x=torch.arange(10).view(2,5)
to_onehot(x,10)

[tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]]),
 tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]]),
 tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]]),
 tensor([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]]),
 tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])]

In [48]:
class RNNModel(nn.Module):
    def __init__(self,rnn_layer,vocab_size):
        super().__init__()
        self.rnn=rnn_layer
        self.hidden_size=rnn_layer.hidden_size*(2 if rnn_layer.bidirectional else 1)
        self.vocab_size=vocab_size
        self.dense=nn.Linear(self.hidden_size,vocab_size)
        self.state=None
        
    def forward(self,inputs,state):
        X=to_onehot(inputs,self.vocab_size)
        Y,self.state=self.rnn(torch.stack(X),state)
        print('Y shape : ',Y.shape)
        print('Y view shape : ',Y.view(-1,Y.shape[-1]).shape)
        output=self.dense(Y.view(-1,Y.shape[-1]))
        print('output shape : ',output.shape)
        return output,self.state

In [49]:
def predict_rnn_pytorch(prefix,num_chars,model,vocab_size,device,idx_to_char,char_to_idx):
    state=None
    output=[char_to_idx[prefix[0]]]
    for t in range(num_chars+len(prefix)-1):
        X=torch.Tensor([output[-1]]).view(1,1)
        if state is not None:
            if isinstance(state ,tuple):
                state=(state[0].to(device),state[1].to(device))
            else:
                state=state.to(device)
        (Y,state)=model(X,state)
        if t < len(prefix)-1:
            output.append(char_to_idx[prefix[t+1]])
        else:
            output.append(int(Y.argmax(dim=1).item()))
    return ''.join([idx_to_char[i] for i in output])

In [50]:
model=RNNModel(rnn_layer2,vocab_size)
predict_rnn_pytorch('分开',10,model,vocab_size,device,idx_to_char,char_to_idx)

Y shape :  torch.Size([1, 1, 256])
Y view shape :  torch.Size([1, 256])
output shape :  torch.Size([1, 1027])
Y shape :  torch.Size([1, 1, 256])
Y view shape :  torch.Size([1, 256])
output shape :  torch.Size([1, 1027])
Y shape :  torch.Size([1, 1, 256])
Y view shape :  torch.Size([1, 256])
output shape :  torch.Size([1, 1027])
Y shape :  torch.Size([1, 1, 256])
Y view shape :  torch.Size([1, 256])
output shape :  torch.Size([1, 1027])
Y shape :  torch.Size([1, 1, 256])
Y view shape :  torch.Size([1, 256])
output shape :  torch.Size([1, 1027])
Y shape :  torch.Size([1, 1, 256])
Y view shape :  torch.Size([1, 256])
output shape :  torch.Size([1, 1027])
Y shape :  torch.Size([1, 1, 256])
Y view shape :  torch.Size([1, 256])
output shape :  torch.Size([1, 1027])
Y shape :  torch.Size([1, 1, 256])
Y view shape :  torch.Size([1, 256])
output shape :  torch.Size([1, 1027])
Y shape :  torch.Size([1, 1, 256])
Y view shape :  torch.Size([1, 256])
output shape :  torch.Size([1, 1027])
Y shape : 

'分开蔓联自联联自联联自联'