In [46]:
import torch
import torch.nn as nn
from torch import autograd
from torch.autograd import Variable
from torchvision import datasets, transforms, models
import sys

In [139]:
idx2char = ['h', 'i', 'e', 'l', 'o']

# Teach hihell -> ihello
x_data = [0, 1, 0, 2, 3, 3]   # hihell
one_hot_lookup = [[1, 0, 0, 0, 0],  # 0
                  [0, 1, 0, 0, 0],  # 1
                  [0, 0, 1, 0, 0],  # 2
                  [0, 0, 0, 1, 0],  # 3
                  [0, 0, 0, 0, 1]]  # 4

y_data = [1, 0, 2, 3, 3, 4]    # ihello
x_one_hot = [one_hot_lookup[x] for x in x_data]

# As we have one batch of samples, we will change them to variables only once
inputs = Variable(torch.Tensor(x_one_hot))
labels = Variable(torch.Tensor(y_data))

num_classes = 5
input_size = 5  # one-hot size
hidden_size = 5  # output from the RNN. 5 to directly predict one-hot
batch_size = 1   # one sentence
sequence_length = 1  # One by one
num_layers = 1  # one-layer rnn

In [145]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, batch_first=True)
    
    def forward(self, x, hidden):
        x = x.view(batch_size, sequence_length, input_size)
        out, hidden = self.rnn(x, hidden)
        out = out.view(-1, num_classes)
        return hidden, out

    def init_hidden(self):
        return Variable(torch.zeros(num_layers, batch_size, hidden_size))

In [159]:
model = Model()
print(model)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr= 0.1)


for epoch in range(100):
    optimizer.zero_grad()
    loss = 0
    hidden = model.init_hidden()
    
    for x, y in zip(inputs, labels):
        hidden,output = model(x, hidden)
        loss += criterion(output,torch.LongTensor([y]))
    
    print("epoch: %d, loss: %1.3f" % (epoch + 1, loss.data))
    loss.backward()
    optimizer.step()

print("Learning finished!")


Model(
  (rnn): RNN(5, 5, batch_first=True)
)
epoch: 1, loss: 9.615
epoch: 2, loss: 8.513
epoch: 3, loss: 7.677
epoch: 4, loss: 6.981
epoch: 5, loss: 6.433
epoch: 6, loss: 6.073
epoch: 7, loss: 5.829
epoch: 8, loss: 5.581
epoch: 9, loss: 5.401
epoch: 10, loss: 5.222
epoch: 11, loss: 5.071
epoch: 12, loss: 4.959
epoch: 13, loss: 4.853
epoch: 14, loss: 4.728
epoch: 15, loss: 4.589
epoch: 16, loss: 4.446
epoch: 17, loss: 4.302
epoch: 18, loss: 4.161
epoch: 19, loss: 4.028
epoch: 20, loss: 3.909
epoch: 21, loss: 3.808
epoch: 22, loss: 3.759
epoch: 23, loss: 3.723
epoch: 24, loss: 3.670
epoch: 25, loss: 3.634
epoch: 26, loss: 3.593
epoch: 27, loss: 3.539
epoch: 28, loss: 3.498
epoch: 29, loss: 3.480
epoch: 30, loss: 3.443
epoch: 31, loss: 3.420
epoch: 32, loss: 3.406
epoch: 33, loss: 3.381
epoch: 34, loss: 3.358
epoch: 35, loss: 3.343
epoch: 36, loss: 3.326
epoch: 37, loss: 3.307
epoch: 38, loss: 3.293
epoch: 39, loss: 3.281
epoch: 40, loss: 3.269
epoch: 41, loss: 3.256
epoch: 42, loss: 3.2