# Basic RNN

In [10]:
import torch
import torch.optim as optim
import numpy as np

In [11]:
# Random seed to make results deterministic and reproducible
torch.manual_seed(0)

<torch._C.Generator at 0x7947b280fe70>

In [28]:
# 학습시킬 문장 준비
text_data = 'Fall in love with artificial intelligence'

# 사전 만들기
char_set = list(set(text_data))
char_dict = {c: i for i, c in enumerate(char_set)}

# hyper parameters
dict_size = len(char_dict)
input_size = len(char_set)
hidden_size = len(char_set)
learning_rate = 0.05

# 데이터셋 준비
text_idx = [char_dict[c] for c in text_data]
x_data = [text_idx[:-1]]
x_one_hot = [np.eye(dict_size)[x] for x in x_data]
y_data = [text_idx[1:]]

# 학습을 위한 데이터형식 변환
X = torch.FloatTensor(x_one_hot)
Y = torch.LongTensor(y_data)

In [29]:
# declare RNN
rnn = torch.nn.RNN(input_size, hidden_size, batch_first=True)
# batch_first guarantees the order of output = (B, S, F)

# loss & optimizer setting
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn.parameters(), learning_rate)

In [30]:
# training
for i in range(100):
    optimizer.zero_grad()
    outputs, _status = rnn(X)
    loss = criterion(outputs.view(-1, input_size), Y.view(-1))
    loss.backward()
    optimizer.step()

    result = outputs.data.numpy().argmax(axis=2)
    result_str = ''.join([char_set[c] for c in np.squeeze(result)])
    print(i, "loss: ", loss.item(), "prediction: ", result, "true Y: ",
          y_data, "prediction str: ", result_str)

0 loss:  2.7659363746643066 prediction:  [[ 2  2 12 12  1  1  2  2 12 15 10 12  2  7  1 12  1  2  1 10 12  1 15  1
   6  1 15 12  2  1  2 12 12 15 12  1 12 12  1  6]] true Y:  [[4, 0, 0, 14, 2, 12, 14, 0, 6, 8, 5, 14, 1, 2, 15, 13, 14, 4, 11, 15, 2, 9, 2, 7, 2, 4, 0, 14, 2, 12, 15, 5, 0, 0, 2, 10, 5, 12, 7, 5]] prediction str:  iinnwwiintgnicwnwiwgnwtwowtniwinntnwnnwo
1 loss:  2.549696683883667 prediction:  [[ 2  0  2  2  2 12  0  2  2 15  2  2  2  0  2  2  2  2  0 15  2  2  0  0
   2 12  0  2  2 12  2  2  2  2  2  2  2  2  0  2]] true Y:  [[4, 0, 0, 14, 2, 12, 14, 0, 6, 8, 5, 14, 1, 2, 15, 13, 14, 4, 11, 15, 2, 9, 2, 7, 2, 4, 0, 14, 2, 12, 15, 5, 0, 0, 2, 10, 5, 12, 7, 5]] prediction str:  iliiinliitiiiliiiiltiillinliiniiiiiiiili
2 loss:  2.397393226623535 prediction:  [[2 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
  2 2 2 2]] true Y:  [[4, 0, 0, 14, 2, 12, 14, 0, 6, 8, 5, 14, 1, 2, 15, 13, 14, 4, 11, 15, 2, 9, 2, 7, 2, 4, 0, 14, 2, 12, 15, 5, 0, 0, 2, 10, 5