# Char_RNN (using a lot of sentence as data)

In [61]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [62]:
sentence = ("if you want to build a ship, don't drum up people together to "
            "collect wood and don't assign them tasks and work, but rather "
            "teach them to long for the endless immensity of the sea.")

In [63]:
char_set = list(set(sentence))
char_dic = {c: i for i, c in enumerate(char_set)}
print(char_dic)

{'s': 0, 'g': 1, 'h': 2, 'k': 3, 'e': 4, 'w': 5, 'd': 6, '.': 7, 'f': 8, 'r': 9, 'b': 10, ' ': 11, 'p': 12, 'y': 13, 'c': 14, 'a': 15, 'u': 16, 'i': 17, 'o': 18, 'n': 19, "'": 20, 'l': 21, 't': 22, ',': 23, 'm': 24}


In [64]:
dic_size = len(char_dic)
print(dic_size)

25


In [65]:
hidden_size = dic_size
timesteps = 10
lr = 0.1

In [66]:
x_data = []
y_data = []

for i in range(0, len(sentence) - timesteps):
    x_str = sentence[i:i + timesteps]
    y_str = sentence[i+1: i + timesteps + 1]
    print(i, x_str, '=>', y_str)

    x_data.append([char_dic[c] for c in x_str])
    y_data.append([char_dic[c] for c in y_str])

0 if you wan => f you want
1 f you want =>  you want 
2  you want  => you want t
3 you want t => ou want to
4 ou want to => u want to 
5 u want to  =>  want to b
6  want to b => want to bu
7 want to bu => ant to bui
8 ant to bui => nt to buil
9 nt to buil => t to build
10 t to build =>  to build 
11  to build  => to build a
12 to build a => o build a 
13 o build a  =>  build a s
14  build a s => build a sh
15 build a sh => uild a shi
16 uild a shi => ild a ship
17 ild a ship => ld a ship,
18 ld a ship, => d a ship, 
19 d a ship,  =>  a ship, d
20  a ship, d => a ship, do
21 a ship, do =>  ship, don
22  ship, don => ship, don'
23 ship, don' => hip, don't
24 hip, don't => ip, don't 
25 ip, don't  => p, don't d
26 p, don't d => , don't dr
27 , don't dr =>  don't dru
28  don't dru => don't drum
29 don't drum => on't drum 
30 on't drum  => n't drum u
31 n't drum u => 't drum up
32 't drum up => t drum up 
33 t drum up  =>  drum up p
34  drum up p => drum up pe
35 drum up pe => rum up peo
36

In [67]:
x_one_hot = np.eye(dic_size)[x_data]
X = torch.FloatTensor(x_one_hot)
Y = torch.LongTensor(y_data)

In [68]:
print(X.shape)
print(Y.shape)

torch.Size([170, 10, 25])
torch.Size([170, 10])


In [69]:
class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, layers):
        super(Net, self).__init__()
        self.rnn = nn.RNN(input_dim, hidden_dim, num_layers=layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, hidden_dim, bias=True)

    def forward(self, x):
        x, _status = self.rnn(x)
        x = self.fc(x)
        return x

In [70]:
net = Net(dic_size, hidden_size, 2)

In [71]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr)

In [72]:
outputs = net(X)
print(outputs.view(-1, dic_size).shape)

torch.Size([1700, 25])


In [73]:
print(Y.view(-1).shape)

torch.Size([1700])


In [74]:
for i in range(100):
    optimizer.zero_grad()
    outputs = net(X)
    loss = criterion(outputs.view(-1, dic_size), Y.view(-1))
    loss.backward()
    optimizer.step()

    
    results = outputs.argmax(dim=2)
    predict_str = ""
    for j, result in enumerate(results):
        if j == 0:
            predict_str += ''.join([char_set[t] for t in result])
        else:
            predict_str += char_set[result[-1]]
    print("{} try: {}".format(i,predict_str))

0 try: mtmmuumtuumttumttmlumummtmtmtamumumtmummtuttuuutumtummuummmuumtumluamutumttuumttmumumuumtmumtummtluumtmuumttummumttammuuummmuuuummtummtlumtuumttumtlumtmumtumtttmuuutmlumummttmttum
1 try: ttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttt
2 try:        e                e e                       e                                                        e                                                                       
3 try: t..'t.y t. i..g ...i... ...g.. ....g .....c ...i........' ...i...c ..ci..'...c...c...c...'......... ...ci....g...c...g... ...............c...c ..gg..c ..' .......yc....' ........c
4 try: t t tlh dtlo tlo dtlo t dlh  tdlo dtdd idolo dtlodhtlhd tlo tlh dolo dh lh ltlo tlh  tlhd ldhodh htlo iddot tdh ttdlo tiddltdolh  odd d tlo dtloilo tdhodtdohiddl d dllodtlhoihl tt
5 try:     t h h                 h t t o              t       h  