In [0]:
'''
  code by Minho Ryu @bzantium
'''
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

char_arr = list('abcdefghijklmnopqrstuvwxyz')
word_dict = {n: i for i, n in enumerate(char_arr)}
number_dict = {i: w for i, w in enumerate(char_arr)}
n_class = len(word_dict) # number of class(=number of vocab)

seq_data = ['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash', 'star']

# TextLSTM Parameters
n_step = 3
n_hidden = 128

def make_batch(seq_data):
    input_batch, target_batch = [], []

    for seq in seq_data:
        input = [word_dict[n] for n in seq[:-1]] # 'm', 'a' , 'k' is input
        target = word_dict[seq[-1]] # 'e' is target
        input_batch.append(np.eye(n_class)[input])
        target_batch.append(target)

    return torch.Tensor(input_batch), torch.LongTensor(target_batch)

class TextLSTM(nn.Module):
    def __init__(self, n_hidden, n_class):
        super(TextLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden)
        self.linear = nn.Linear(n_hidden, n_class)

    def forward(self, X):
        input = X.transpose(0, 1)  # X : [n_step, batch_size, n_class]

        hidden_state = torch.zeros(1, len(X), n_hidden)   # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        cell_state = torch.zeros(1, len(X), n_hidden)     # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]

        outputs, (hidden, _) = self.lstm(input, (hidden_state, cell_state))
        hidden = hidden.squeeze(0)
        model = self.linear(hidden)
        return model

input_batch, target_batch = make_batch(seq_data)

model = TextLSTM(n_hidden, n_class)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

output = model(input_batch)

# Training
for epoch in range(1000):
    optimizer.zero_grad()

    output = model(input_batch)
    loss = criterion(output, target_batch)
    if (epoch + 1) % 100 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

    loss.backward()
    optimizer.step()

inputs = [sen[:3] for sen in seq_data]

predict = model(input_batch).data.max(1, keepdim=True)[1].squeeze().numpy()
for i, input in enumerate(inputs):
    print(input, '->', number_dict[predict[i]])

Epoch: 0100 cost = 0.545306
Epoch: 0200 cost = 0.035130
Epoch: 0300 cost = 0.009809
Epoch: 0400 cost = 0.004519
Epoch: 0500 cost = 0.002628
Epoch: 0600 cost = 0.001731
Epoch: 0700 cost = 0.001232
Epoch: 0800 cost = 0.000924
Epoch: 0900 cost = 0.000720
Epoch: 1000 cost = 0.000577
mak -> e
nee -> d
coa -> l
wor -> d
lov -> e
hat -> e
liv -> e
hom -> e
has -> h
sta -> r
