In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [5]:
n_hidden = 35
lr = 0.01
epochs = 1000

string = "hello pytorch. how long can a rnn cell remember"
chars = "abcdeghijklmnopqrstuvwxyz ?!.,:;01"
char_list = [i for i in chars]
n_letters = len(char_list)

In [51]:
def string_to_onehot(string):
  start = np.zeros(shape=(len(char_list)), dtype=int)
  end = np.zeros(shape=(len(char_list)), dtype=int)
  start[-2] = 1
  end[-1] = 1
  for i in string:
    idx = char_list.index(i)
    zero = np.zeros(shape=len(char_list), dtype=int)
    zero[idx] = 1
    start = np.vstack([start, zero])
  output = np.vstack([start, end])
  return output

In [38]:
def onehot_to_word(onehot_1):
  onehot = torch.Tensor.numpy(onehot_1)
  return char_list[onehot.argmax()]

In [45]:
class RNN(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(RNN, self).__init__()

    self.input_size = input_size
    self.hidden_size = hidden_size
    self.output_size = output_size

    self.i2h = nn.Linear(input_size, hidden_size)
    self.h2h = nn.Linear(hidden_size, hidden_size)
    self.i2o = nn.Linear(hidden_size, output_size)
    self.act_fn = nn.Tanh()
  
  def forward(self, input, hidden):
    hidden = self.act_fn(self.i2h(input)+self.h2h(hidden))
    output = self.i2o(hidden)
    return output, hidden
  
  def init_hidden(self):
    return torch.zeros(1, self.hidden_size)

In [46]:
rnn = RNN(n_letters, n_hidden, n_letters)

In [47]:
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)

In [55]:
one_hot = torch.from_numpy(string_to_onehot(string)).type_as(torch.FloatTensor())
one_hot.shape

torch.Size([49, 34])

In [72]:
for i in range(epochs):
  rnn.zero_grad()
  total_loss = 0
  hidden = rnn.init_hidden()

  for j in range(one_hot.size()[0]-1):# 49-1=48
    input = one_hot[j:j+1,:]
    target = one_hot[j+1]

    output, hidden = rnn.forward(input, hidden)
    loss = loss_func(output.view(-1), target.view(-1))
    total_loss += loss
    input = output
  
  total_loss.backward()
  optimizer.step()

  if i%10==0:
    print(total_loss)


tensor(2.4474, grad_fn=<AddBackward0>)
tensor(0.9752, grad_fn=<AddBackward0>)
tensor(0.6407, grad_fn=<AddBackward0>)
tensor(0.4472, grad_fn=<AddBackward0>)
tensor(0.3134, grad_fn=<AddBackward0>)
tensor(0.2323, grad_fn=<AddBackward0>)
tensor(0.1753, grad_fn=<AddBackward0>)
tensor(0.1392, grad_fn=<AddBackward0>)
tensor(0.1169, grad_fn=<AddBackward0>)
tensor(0.0988, grad_fn=<AddBackward0>)
tensor(0.0899, grad_fn=<AddBackward0>)
tensor(0.0762, grad_fn=<AddBackward0>)
tensor(0.0650, grad_fn=<AddBackward0>)
tensor(0.0579, grad_fn=<AddBackward0>)
tensor(0.0608, grad_fn=<AddBackward0>)
tensor(0.0491, grad_fn=<AddBackward0>)
tensor(0.0437, grad_fn=<AddBackward0>)
tensor(0.0388, grad_fn=<AddBackward0>)
tensor(0.0458, grad_fn=<AddBackward0>)
tensor(0.0369, grad_fn=<AddBackward0>)
tensor(0.0325, grad_fn=<AddBackward0>)
tensor(0.0290, grad_fn=<AddBackward0>)
tensor(0.0261, grad_fn=<AddBackward0>)
tensor(0.0236, grad_fn=<AddBackward0>)
tensor(0.0216, grad_fn=<AddBackward0>)
tensor(0.0351, grad_fn=<A

In [77]:
start = torch.zeros(1, len(char_list))
start[:, -2] = 1

with torch.no_grad():
  hidden = rnn.init_hidden()
  input = start
  output_string = ""

  for i in range(len(string)):
    output, hidden = rnn.forward(input, hidden)
    output_string += onehot_to_word(output.data)
    input = output

print(output_string)

hello pytorcemeem r eelcaeremeeeaemaaeaemeapaem
