In [1]:
import numpy as np
import torch
from torch import nn
from torch import optim

In [8]:
def make_batch(sentence, word_dict):
    input_batch = []
    target_batch = []

    words = sentence.split()
    for i, word in enumerate(words[:-1]):
        input = [word_dict[n] for n in words[:(i + 1)]]
        input = input + [0] * (max_len - len(input))
        target = word_dict[words[i + 1]]
        input_batch.append(np.eye(n_class)[input])
        target_batch.append(target)

    return input_batch, target_batch

In [9]:
class BILSTM(nn.Module):
    def __init__(self, n_class, n_hidden):
        super(BILSTM, self).__init__()
        self.n_class = n_class
        self.n_hidden = n_hidden
        self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, bidirectional=True)
        self.W = nn.Linear(n_hidden*2, n_class, bias=False)
        self.b = nn.Parameter(torch.ones([n_class]))

    def forward(self, X):
        input = X.transpose(0, 1)

        hidden_state = torch.zeros(1*2, len(X), self.n_hidden)
        cell_state = torch.zeros(1*2, len(X), self.n_hidden)

        outputs, (_,_) = self.lstm(input, (hidden_state, cell_state))
        outputs = outputs[-1]
        model = self.W(outputs) + self.b
        return model


In [10]:
n_hidden = 5
sentence = (
        'Lorem ipsum dolor sit amet consectetur adipisicing elit '
        'sed do eiusmod tempor incididunt ut labore et dolore magna '
        'aliqua Ut enim ad minim veniam quis nostrud exercitation'
    )
word_dict = {w: i for i, w in enumerate(list(set(sentence.split())))}
number_dict = {i: w for i, w in enumerate(list(set(sentence.split())))}
n_class = len(word_dict)
max_len = len(sentence.split())

In [11]:
model = BILSTM(n_class, n_hidden)

In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

input_batch, target_batch = make_batch(sentence, word_dict)
input_batch = torch.FloatTensor(input_batch)
target_batch = torch.LongTensor(target_batch)

In [14]:
for epoch in range(10000):
    optimizer.zero_grad()
    output = model(input_batch)
    loss = criterion(output, target_batch)
    if (epoch + 1) % 1000 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

    loss.backward()
    optimizer.step()


Epoch: 1000 cost = 1.869559
Epoch: 2000 cost = 1.605071
Epoch: 3000 cost = 1.424430
Epoch: 4000 cost = 1.299399
Epoch: 5000 cost = 1.222594
Epoch: 6000 cost = 1.165366
Epoch: 7000 cost = 1.129363
Epoch: 8000 cost = 1.096126
Epoch: 9000 cost = 1.081305
Epoch: 10000 cost = 0.714586


In [15]:
predict = model(input_batch).data.max(1, keepdim=True)[1]
print(sentence)
print([number_dict[n.item()] for n in predict.squeeze()])

Lorem ipsum dolor sit amet consectetur adipisicing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua Ut enim ad minim veniam quis nostrud exercitation
['ipsum', 'ipsum', 'sit', 'amet', 'consectetur', 'adipisicing', 'elit', 'sed', 'do', 'eiusmod', 'tempor', 'ut', 'ut', 'labore', 'et', 'dolore', 'magna', 'aliqua', 'enim', 'enim', 'ad', 'minim', 'veniam', 'quis', 'nostrud', 'exercitation']
