In [33]:
import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        # RNN layer
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        # Fully connected layer
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, output_size)
            )
        self.embedding = nn.Embedding(output_size, input_size)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        # x [batch_size, seq_len]
        x = self.embedding(x) # batch seq_len 
        # Initialize hidden state with zeros
        h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)

        # RNN layer
        out, _ = self.rnn(x, h0)

        # Only take the output from the last time step
#         out = out[:, -1, :]

        # Fully connected layer
        out = self.fc(out)
        out = self.sigmoid(out)

        return out

In [34]:
from data_process import *

In [35]:
xx,yy=get_data(10)

In [36]:
xx

tensor([[ 66,  29,  27,  98,  59,  19,   6,  12,  46, 100],
        [ 74,   7,  15,  16,  45,  20,  40,  71, 118,  77],
        [ 10,  55,  68,  14,  70,  85,  36,  16,  59,  23],
        [  1,  59,  53,  31, 110,  81,  16, 106,  88,  90],
        [121,  16,  64,  39,  99, 105,  74, 113, 120,   2],
        [ 95,   7,  15,  75,   1,  59,  53,  31, 110,  81],
        [ 37,  59,  41,   9,  48,  22, 104,  43,   5,  83],
        [ 76, 120,   1, 102,  16,  79,  37,  59,  41,   9],
        [113, 131,  74,   7,  15,  16,  45,  20,  40,  71],
        [ 44, 118,  62, 116,  59,  13, 121,  14,  93, 127]])

In [37]:
xx.shape

torch.Size([10, 10])

In [38]:
embedding_size = 200

In [40]:
rnn = SimpleRNN(embedding_size,200,vocab_size)

In [41]:
probs = rnn(xx)

In [42]:
probs.shape

torch.Size([10, 10, 132])

In [43]:
yy.shape

torch.Size([10, 10])

In [44]:
criterion = nn.CrossEntropyLoss()

In [45]:
criterion(probs.view(-1,vocab_size),yy.view(-1))

tensor(4.8941, grad_fn=<NllLossBackward0>)

In [46]:
batch_n = 30

In [47]:
op = optim.SGD(rnn.parameters(),lr=0.001)

In [52]:
for i in range(100000):
    xx,yy = get_data(batch_n)
    probs = rnn(xx)
    loss =  criterion(probs.view(-1,vocab_size),yy.view(-1))
    if i % 1000 == 0:
        print(loss)
    op.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=1.0)
    op.step()

tensor(4.1314, grad_fn=<NllLossBackward0>)
tensor(4.1371, grad_fn=<NllLossBackward0>)
tensor(4.1357, grad_fn=<NllLossBackward0>)
tensor(4.1259, grad_fn=<NllLossBackward0>)
tensor(4.1302, grad_fn=<NllLossBackward0>)
tensor(4.1282, grad_fn=<NllLossBackward0>)
tensor(4.1258, grad_fn=<NllLossBackward0>)
tensor(4.1191, grad_fn=<NllLossBackward0>)
tensor(4.1222, grad_fn=<NllLossBackward0>)
tensor(4.1086, grad_fn=<NllLossBackward0>)
tensor(4.1161, grad_fn=<NllLossBackward0>)
tensor(4.1125, grad_fn=<NllLossBackward0>)
tensor(4.1039, grad_fn=<NllLossBackward0>)
tensor(4.1052, grad_fn=<NllLossBackward0>)
tensor(4.1075, grad_fn=<NllLossBackward0>)
tensor(4.0997, grad_fn=<NllLossBackward0>)
tensor(4.0973, grad_fn=<NllLossBackward0>)
tensor(4.1002, grad_fn=<NllLossBackward0>)
tensor(4.0962, grad_fn=<NllLossBackward0>)
tensor(4.0975, grad_fn=<NllLossBackward0>)
tensor(4.0900, grad_fn=<NllLossBackward0>)
tensor(4.0928, grad_fn=<NllLossBackward0>)
tensor(4.0919, grad_fn=<NllLossBackward0>)
tensor(4.08

In [53]:
start = '你'

xx = [char2index[x] for x in start]
for _ in range(100):
    input_x = torch.tensor(xx)
    probs = rnn(input_x.view(1,-1))[:,-1,:].view(-1) # vocabsize
    choice = torch.multinomial(probs,1)
    xx = xx + [choice.item()]

In [54]:
print(''.join([index2char[index] for index in xx]))

你，尽想于很关当引要的时一，想为。理。合理引线阶细正后，因因的，成知环角，入挫章。来困得是很容升成向学入关站在篇该事困难诉那阶导诉细的文在么难环旋在让快有硕你篇你时习门识及握一馈螺旋正些环人会让就解而必
