In [36]:
import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        # RNN layer
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        # Fully connected layer
        self.fc = nn.Linear(hidden_size, output_size)
        self.embedding = nn.Embedding(output_size, input_size)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        # x [batch_size, seq_len]
        x = self.embedding(x) # batch seq_len 
        # Initialize hidden state with zeros
        h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)

        # RNN layer
        out, _ = self.rnn(x, h0)

        # Only take the output from the last time step
#         out = out[:, -1, :]

        # Fully connected layer
        out = self.fc(out)
        out = self.sigmoid(out)

        return out

In [37]:
from data_process import *

In [38]:
xx,yy=get_data(10)

In [39]:
xx

tensor([[2632,  790, 1424,  232, 1420,  790, 1075, 2783, 1383, 2356,  242, 2260,
         1902, 2111, 1130, 2674, 2603, 2603, 2519, 2638, 1420, 1465,  273,  468,
         2632, 1130, 1099, 2480,  329, 2991],
        [1424, 1420, 1216,  706,  587, 1954, 1199, 1420,  440, 1875, 1902, 1210,
         1210, 1047, 2795,  440,  389, 1420, 1216,  706,  587, 1007, 2827, 1420,
         2790, 1143, 2775,  526, 1420, 1045],
        [ 816,  217, 2114, 1259,  975, 1424, 1130, 2592, 1383, 1420, 1167,  389,
         2864, 2034, 1420, 2500, 2836,  406, 2668,  706, 2988,  222, 2478, 2914,
         1130, 2674, 2603, 2603, 1250, 1430],
        [ 216, 1420,   89,  200, 1420, 1902,  507, 1423, 2593, 1915, 1424, 1697,
         1369, 1473,   76, 2775, 1424, 1665,  587, 2718,  596, 1891,  570,  688,
         2332,  706, 2914, 2653, 1966, 2718],
        [1075, 2250, 1420, 1574, 1188, 1182,  206, 2928, 2683, 2436, 1420, 1188,
         1182,  816, 2025, 2088, 1420, 1683,  184,  273, 2025, 1188, 1182, 2802,
      

In [40]:
xx.shape

torch.Size([10, 30])

In [41]:
embedding_size = 60

In [42]:
rnn = SimpleRNN(embedding_size,50,vocab_size)

In [43]:
probs = rnn(xx)

In [44]:
probs.shape

torch.Size([10, 30, 3005])

In [45]:
yy.shape

torch.Size([10, 30])

In [46]:
criterion = nn.CrossEntropyLoss()

In [47]:
criterion(probs.view(-1,vocab_size),yy.view(-1))

tensor(8.0161, grad_fn=<NllLossBackward0>)

In [48]:
batch_n = 30

In [49]:
op = optim.Adam(rnn.parameters(),lr=0.0001)

In [None]:
for i in range(100000):
    xx,yy = get_data(batch_n)
    probs = rnn(xx)
    loss =  criterion(probs.view(-1,vocab_size),yy.view(-1))
    if i % 1000 == 0:
        print(loss)
    op.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=1.0)
    op.step()

tensor(7.2162, grad_fn=<NllLossBackward0>)
tensor(7.2286, grad_fn=<NllLossBackward0>)
tensor(7.2348, grad_fn=<NllLossBackward0>)
tensor(7.2371, grad_fn=<NllLossBackward0>)
tensor(7.2271, grad_fn=<NllLossBackward0>)
tensor(7.2195, grad_fn=<NllLossBackward0>)
tensor(7.2111, grad_fn=<NllLossBackward0>)
tensor(7.2078, grad_fn=<NllLossBackward0>)
tensor(7.2147, grad_fn=<NllLossBackward0>)
tensor(7.2100, grad_fn=<NllLossBackward0>)
tensor(7.2134, grad_fn=<NllLossBackward0>)
tensor(7.2046, grad_fn=<NllLossBackward0>)
tensor(7.2101, grad_fn=<NllLossBackward0>)
tensor(7.2088, grad_fn=<NllLossBackward0>)
tensor(7.2141, grad_fn=<NllLossBackward0>)
tensor(7.1964, grad_fn=<NllLossBackward0>)
tensor(7.1876, grad_fn=<NllLossBackward0>)
tensor(7.2134, grad_fn=<NllLossBackward0>)


In [60]:
start = '你们是谁？'

xx = [char2index[x] for x in start]
for _ in range(100):
    input_x = torch.tensor(xx)
    probs = rnn(input_x.view(1,-1))[:,-1,:].view(-1) # vocabsize
    choice = torch.multinomial(probs,1)
    xx = xx + [choice.item()]

In [61]:
print(''.join([index2char[index] for index in xx]))

你们是谁？成实跑也应树一串苍深突作市尾”面长…花西般寓金假后中本运爱香！月往水科门脸起小路妈给为王里心本猫也头朝好完可不适饭吧去国越出小草开一我同老多大个朝时爱用看传我只已好想吃“人像。什。刚翔成机算们向呢心也
