In [1]:
import torch
import torch.optim as optim
import numpy as np

In [2]:
# Random seed to make results deterministic and reproducible
torch.manual_seed(0)

<torch._C.Generator at 0x20a2dd547b0>

In [3]:
# sample
sample = "hihello"

In [4]:
# make dictionary
char_set = list(set(sample))
char_dic = {c: i for i, c in enumerate(char_set)}

In [5]:
# hyper paremeters
input_size = len(char_set)
hidden_size = len(char_set)
learning_rate = 0.2

In [6]:
# data setting
sample_idx = [char_dic[c] for c in sample]
x_data = [sample_idx[:-1]]
x_one_hot = [np.eye(input_size)[x] for x in x_data]
y_data = [sample_idx[1:]]

In [7]:
# transform as torch tensor variable
X = torch.FloatTensor(x_one_hot)
Y = torch.LongTensor(y_data)

In [8]:
# declare RNN
rnn = torch.nn.RNN(input_size, hidden_size, batch_first=True)  
# batch_first guarantees the order of output = (B, S, F)

In [9]:
# loss & optimizer setting
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn.parameters(), learning_rate)

In [10]:
# start training
for i in range(1000):
    optimizer.zero_grad()
    outputs, _status = rnn(X)
    loss = criterion(outputs.view(-1, input_size), Y.view(-1))
    loss.backward()
    optimizer.step()

    result = outputs.data.numpy().argmax(axis=2)
    result_str = ''.join([char_set[c] for c in np.squeeze(result)])
    print(i, "loss: ", loss.item(), 
          "prediction: ", result, 
          "true Y: ", y_data, 
          "prediction str: ", result_str)

0 loss:  1.633445382118225 prediction:  [[1 1 1 1 1 1]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  oooooo
1 loss:  1.3027483224868774 prediction:  [[4 3 4 3 1 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihihol
2 loss:  1.14564049243927 prediction:  [[4 3 1 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  iholll
3 loss:  0.9507662653923035 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
4 loss:  0.8943259716033936 prediction:  [[4 3 2 3 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihehll
5 loss:  0.797957718372345 prediction:  [[4 3 2 4 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  iheill
6 loss:  0.7660830020904541 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
7 loss:  0.7567178606987 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
8 loss:  0.737369954586029 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
9 loss:  

73 loss:  0.52195143699646 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
74 loss:  0.5218602418899536 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
75 loss:  0.5217610001564026 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
76 loss:  0.5216742157936096 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
77 loss:  0.5215799808502197 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
78 loss:  0.5214971899986267 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
79 loss:  0.5214075446128845 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
80 loss:  0.5213276743888855 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
81 loss:  0.5212426781654358 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  i

145 loss:  0.5179557204246521 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
146 loss:  0.5179216265678406 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
147 loss:  0.5178877711296082 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
148 loss:  0.5178543329238892 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
149 loss:  0.517821192741394 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
150 loss:  0.5177883505821228 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
151 loss:  0.5177558064460754 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
152 loss:  0.517723560333252 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
153 loss:  0.5176916122436523 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] predictio

217 loss:  0.5161068439483643 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
218 loss:  0.5160874724388123 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
219 loss:  0.5160682201385498 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
220 loss:  0.5160490870475769 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
221 loss:  0.5160301327705383 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
222 loss:  0.5160112380981445 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
223 loss:  0.5159924626350403 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
224 loss:  0.5159738063812256 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
225 loss:  0.5159552693367004 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] predict

289 loss:  0.5149679780006409 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
290 loss:  0.5149549841880798 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
291 loss:  0.5149421691894531 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
292 loss:  0.5149293541908264 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
293 loss:  0.514916718006134 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
294 loss:  0.5149040222167969 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
295 loss:  0.5148913860321045 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
296 loss:  0.5148788690567017 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
297 loss:  0.5148663520812988 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] predicti

361 loss:  0.5141738653182983 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
362 loss:  0.5141644477844238 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
363 loss:  0.5141550898551941 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
364 loss:  0.5141457915306091 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
365 loss:  0.5141364336013794 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
366 loss:  0.5141271948814392 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
367 loss:  0.5141180157661438 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
368 loss:  0.5141087174415588 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
369 loss:  0.514099657535553 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] predicti

433 loss:  0.5135779976844788 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
434 loss:  0.5135707259178162 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
435 loss:  0.5135634541511536 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
436 loss:  0.5135562419891357 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
437 loss:  0.5135490298271179 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
438 loss:  0.5135418772697449 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
439 loss:  0.5135347247123718 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
440 loss:  0.5135275721549988 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
441 loss:  0.5135204792022705 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] predict

505 loss:  0.5131082534790039 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
506 loss:  0.5131023526191711 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
507 loss:  0.5130965709686279 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
508 loss:  0.5130906701087952 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
509 loss:  0.5130849480628967 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
510 loss:  0.5130791664123535 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
511 loss:  0.5130734443664551 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
512 loss:  0.5130676627159119 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
513 loss:  0.5130619406700134 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] predict

577 loss:  0.5127515196800232 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
578 loss:  0.5127618312835693 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
579 loss:  0.5127542614936829 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
580 loss:  0.5127816796302795 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
581 loss:  0.5127649307250977 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
582 loss:  0.5128089785575867 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
583 loss:  0.512765109539032 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
584 loss:  0.5127964615821838 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
585 loss:  0.5127443671226501 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] predicti

649 loss:  0.5124563574790955 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
650 loss:  0.5124748349189758 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
651 loss:  0.5124472975730896 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
652 loss:  0.5124655961990356 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
653 loss:  0.5124390721321106 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
654 loss:  0.5124581456184387 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
655 loss:  0.5124318599700928 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
656 loss:  0.5124518871307373 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
657 loss:  0.5124247670173645 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] predict

721 loss:  0.5121913552284241 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
722 loss:  0.5122120380401611 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
723 loss:  0.5121846199035645 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
724 loss:  0.512205183506012 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
725 loss:  0.5121778845787048 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
726 loss:  0.5121984481811523 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
727 loss:  0.51217120885849 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
728 loss:  0.5121919512748718 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
729 loss:  0.5121645927429199 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction

793 loss:  0.5119662880897522 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
794 loss:  0.5119883418083191 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
795 loss:  0.511960506439209 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
796 loss:  0.5119826197624207 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
797 loss:  0.5119547843933105 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
798 loss:  0.511976957321167 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
799 loss:  0.5119490623474121 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
800 loss:  0.5119711756706238 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] prediction str:  ihelll
801 loss:  0.5119432806968689 prediction:  [[4 3 2 0 0 0]] true Y:  [[4, 3, 2, 0, 0, 1]] predictio