# RNN - charseq

In [1]:
import torch
import torch.optim as optim
import numpy as np

In [2]:
# Random seed to make results deterministic and reproducible
torch.manual_seed(0)

<torch._C.Generator at 0x1e2114450b0>

In [3]:
sample = " if you want you"

In [4]:
# make dictionary
char_set = list(set(sample)) # set : 중복제거
char_dic = {c: i for i, c in enumerate(char_set)} # enumerate로 index와 character 같이 가져옴 & char:index로 매핑해서 dictionary로 만들어줌
print(char_dic)

{'t': 0, 'n': 1, 'y': 2, 'f': 3, 'a': 4, ' ': 5, 'u': 6, 'i': 7, 'w': 8, 'o': 9}


In [5]:
# hyper parameters
dic_size = len(char_dic)
hidden_size = len(char_dic)
learning_rate = 0.1

In [6]:
# data setting
sample_idx = [char_dic[c] for c in sample] # index 구하기 # 샘플에서 하나의 character 들을 가져오고 character를 index로 변환
x_data = [sample_idx[:-1]] # 마지막 제거
x_one_hot = [np.eye(dic_size)[x] for x in x_data]
y_data = [sample_idx[1:]] # 처음 제거

In [7]:
# transform as torch tensor variable
# numpy와 list로 구성된 data를 Pytorch의 Tensor로 변환
X = torch.FloatTensor(x_one_hot)
Y = torch.LongTensor(y_data)

In [8]:
# declare RNN
rnn = torch.nn.RNN(dic_size, hidden_size, batch_first=True) # batch_first=True 는 batch_dimension이 가장 앞으로

In [9]:
# loss & optimizer setting
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn.parameters(), learning_rate) # 최적화

In [10]:
# start training
for i in range(50):
    optimizer.zero_grad()
    outputs, _status = rnn(X) # _status : 다음 input이 있으면 그다음 input은 RNN 안에서 계산할 때 쓰이게 될 hidden state -> 여기서는 주어진 모든 input을 다 처리하고 나오는 hidden state라서 따로 쓰이지는 않음
    loss = criterion(outputs.view(-1, dic_size), Y.view(-1))
    loss.backward()
    optimizer.step()
    
    # 실제로 model이 어떻게 예측했는지에 대해 알아보는 코드
    result = outputs.data.numpy().argmax(axis=2) # argmax : 가장 큰 숫자가 있는 index 가져옴
    result_str = ''.join([char_set[c] for c in np.squeeze(result)]) # 가장 큰 index의 character 가져온 후 join으로 하나의 string
    print(i, "loss: ", loss.item(), "prediction: ", result, "true Y: ", y_data, "prediction str: ", result_str)

0 loss:  2.1659820079803467 prediction:  [[7 0 5 7 8 6 2 7 7 7 8 5 7 7 6]] true Y:  [[7, 3, 5, 2, 9, 6, 5, 8, 4, 1, 0, 5, 2, 9, 6]] prediction str:  it iwuyiiiw iiu
1 loss:  1.8518487215042114 prediction:  [[7 4 5 7 9 6 5 8 4 5 5 5 6 5 6]] true Y:  [[7, 3, 5, 2, 9, 6, 5, 8, 4, 1, 0, 5, 2, 9, 6]] prediction str:  ia iou wa   u u
2 loss:  1.5781149864196777 prediction:  [[1 4 5 0 9 6 5 8 4 5 0 5 0 9 6]] true Y:  [[7, 3, 5, 2, 9, 6, 5, 8, 4, 1, 0, 5, 2, 9, 6]] prediction str:  na tou wa t tou
3 loss:  1.4471590518951416 prediction:  [[1 4 5 2 9 6 5 8 4 1 0 5 2 9 6]] true Y:  [[7, 3, 5, 2, 9, 6, 5, 8, 4, 1, 0, 5, 2, 9, 6]] prediction str:  na you want you
4 loss:  1.353110671043396 prediction:  [[2 3 5 2 9 6 5 8 4 1 0 5 2 9 6]] true Y:  [[7, 3, 5, 2, 9, 6, 5, 8, 4, 1, 0, 5, 2, 9, 6]] prediction str:  yf you want you
5 loss:  1.2719032764434814 prediction:  [[7 3 5 2 9 6 5 2 9 5 2 5 2 9 6]] true Y:  [[7, 3, 5, 2, 9, 6, 5, 8, 4, 1, 0, 5, 2, 9, 6]] prediction str:  if you yo y you
6 loss:  1.