## **RNN**

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [3]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [4]:
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

In [54]:
sentence = " ok computer"
char_set = list(set(sentence))
char_dic = {c : i for i, c in enumerate(char_set)}

In [55]:
vocab_sz = len(char_dic)
hidden_sz = len(char_dic)
input_sz = len(char_dic)

In [56]:
sen_idx = [char_dic[c] for c in sentence]
x_idx = sen_idx[:-1]
x_one_hot = [[np.eye(vocab_sz)[x] for x in x_idx]]
y_data = [sen_idx[1:]]

In [57]:
x_train = torch.FloatTensor(x_one_hot)
y_train = torch.LongTensor(y_data)

In [58]:
class RNN(nn.Module):
  def __init__(self, input_size, hidden_size, vocab_size):
    super(RNN, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.vocab_size = vocab_size

    self.rnn = nn.RNN(input_size=self.input_size, hidden_size=self.hidden_size, batch_first=True)
    self.linear = nn.Linear(self.hidden_size, self.vocab_size)

  def forward(self, x):
    outputs, _ = self.rnn(x)
    x = self.linear(outputs)

    return x

model = RNN(input_size=input_sz, hidden_size=hidden_sz, vocab_size=vocab_sz).to(device)

In [59]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [60]:
epochs = 500

for epoch in range(epochs):
  model.train()

  outputs = model(x_train.to(device)) # forward propagation
  loss = criterion(outputs.view(-1, hidden_sz), y_train.view(-1).to(device))

  optimizer.zero_grad()
  loss.backward() # backward propagation
  optimizer.step() # update parameters

  result = outputs.data.numpy().argmax(axis=2)
  result_str = ''.join([char_set[idx] for idx in np.squeeze(result)])

  if epoch % 50 == 0 or epoch == epochs-1:
    print('loss : {} prediction : {}'.format(loss, result_str))

loss : 2.3719842433929443 prediction : ttktktktkkk
loss : 2.1735692024230957 prediction : tttcotpttkk
loss : 1.813103437423706 prediction : co computec
loss : 1.3812229633331299 prediction : oo computer
loss : 1.0392802953720093 prediction : ok computer
loss : 0.7643840312957764 prediction : ok computer
loss : 0.5527215600013733 prediction : ok computer
loss : 0.40871354937553406 prediction : ok computer
loss : 0.3132968544960022 prediction : ok computer
loss : 0.2474282681941986 prediction : ok computer
loss : 0.20059262216091156 prediction : ok computer
