# Char_RNN (using a lot of sentence as data)

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [16]:
sentence = ("if you want to build a ship, don't drum up people together to "
            "collect wood and don't assign them tasks and work, but rather "
            "teach them to long for the endless immensity of the sea.")

In [17]:
char_set = list(set(sentence))
char_dic = {c: i for i, c in enumerate(char_set)}
print(char_dic)

{'f': 0, 'y': 1, 'o': 2, 'b': 3, 'r': 4, 't': 5, "'": 6, 'm': 7, 'g': 8, 'c': 9, 'u': 10, '.': 11, ',': 12, 'e': 13, 'h': 14, 'p': 15, 'a': 16, 'i': 17, 'k': 18, 'w': 19, 's': 20, 'd': 21, 'n': 22, ' ': 23, 'l': 24}


In [18]:
dic_size = len(char_dic)
print(dic_size)

25


In [19]:
hidden_size = dic_size
timesteps = 10
lr = 0.1

In [20]:
x_data = []
y_data = []

for i in range(0, len(sentence) - timesteps):
    x_str = sentence[i:i + timesteps]
    y_str = sentence[i+1: i + timesteps + 1]
    print(i, x_str, '=>', y_str)

    x_data.append([char_dic[c] for c in x_str])
    y_data.append([char_dic[c] for c in y_str])

0 if you wan => f you want
1 f you want =>  you want 
2  you want  => you want t
3 you want t => ou want to
4 ou want to => u want to 
5 u want to  =>  want to b
6  want to b => want to bu
7 want to bu => ant to bui
8 ant to bui => nt to buil
9 nt to buil => t to build
10 t to build =>  to build 
11  to build  => to build a
12 to build a => o build a 
13 o build a  =>  build a s
14  build a s => build a sh
15 build a sh => uild a shi
16 uild a shi => ild a ship
17 ild a ship => ld a ship,
18 ld a ship, => d a ship, 
19 d a ship,  =>  a ship, d
20  a ship, d => a ship, do
21 a ship, do =>  ship, don
22  ship, don => ship, don'
23 ship, don' => hip, don't
24 hip, don't => ip, don't 
25 ip, don't  => p, don't d
26 p, don't d => , don't dr
27 , don't dr =>  don't dru
28  don't dru => don't drum
29 don't drum => on't drum 
30 on't drum  => n't drum u
31 n't drum u => 't drum up
32 't drum up => t drum up 
33 t drum up  =>  drum up p
34  drum up p => drum up pe
35 drum up pe => rum up peo
36

In [30]:
x_data[0], y_data[0]

([17, 0, 23, 1, 2, 10, 23, 19, 16, 22], [0, 23, 1, 2, 10, 23, 19, 16, 22, 5])

In [31]:
x_one_hot = np.eye(dic_size)[x_data]
X = torch.FloatTensor(x_one_hot)
Y = torch.LongTensor(y_data)

In [32]:
print(X.shape)
print(Y.shape)

torch.Size([170, 10, 25])
torch.Size([170, 10])


In [33]:
class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, layers):
        super(Net, self).__init__()
        self.rnn = nn.RNN(input_dim, hidden_dim, num_layers=layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, hidden_dim, bias=True)

    def forward(self, x):
        x, _status = self.rnn(x)
        x = self.fc(x)
        return x

In [36]:
net = Net(dic_size, hidden_size, 2)

In [37]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr)

In [38]:
outputs = net(X)
print(outputs.view(-1, dic_size).shape)

torch.Size([1700, 25])


In [39]:
print(Y.view(-1).shape)

torch.Size([1700])


In [40]:
for i in range(100):
    optimizer.zero_grad()
    outputs = net(X)
    loss = criterion(outputs.view(-1, dic_size), Y.view(-1))
    loss.backward()
    optimizer.step()

    
    results = outputs.argmax(dim=2)
    predict_str = ""
    for j, result in enumerate(results):
        if j == 0:
            predict_str += ''.join([char_set[t] for t in result])
        else:
            predict_str += char_set[result[-1]]
    print("{} try: {}".format(i,predict_str))

0 try: cdddcdoccddddddddudddcddduudcccddcdcddcdcodudcuddcddcdddddddddduddddudccuddcdccccddddkddddccdddccoddddccdccccdddccdoddddddddddcddooddccddddudcccddddddcddcddddddccdcudddccdduddcudc
1 try: p co pc ppc   poc   p coppp   ppc o pcc p c c p cc o cpo pc o p p cc o p c pp poppc o s pcpo o c p    pppopop c c c c c c pc c c c c c p po pc oppc o c pop p ppc  popp o  ppo po p
2 try:       oro   or o r ro   oro  oro   rrs rr  o s   r   s or o rr  rrs  rorr   r o rr orr   r   r    r      o  r  orr   rs   r  rs  rr o   r  rr   rro rr  rr    rorro       rror   or
3 try:   nt  nt        n    n          n                          n        n          n        n       n             n     n                           n       n            nn      n     
4 try: tyatmihtmihihihihiihihmihhiihiiihhaiiihmiahihihiihiiihmihhihaahhhhiihmihiiihahihhihiiihihmiihiihhihiihhhiihihihiiihihhiiiihhhiiihhhiihhiahihhhiihmihiiiihtiihhhihihaihhiiiihihihhih
5 try: tb dma ha d  a  dt   ta       a d   a da    t tt   t  a   