# A. データの前処理及び変数、関数の宣言

In [100]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

sentences = ["i like dogs", "i love coffee", "i hate milk", "you like cats", "you love milk", "you hate coffee"]
dtype = torch.float

"""
Word Processing
"""
word_list = list(set(" ".join(sentences).split()))
word_dict = {w: i for i, w in enumerate(word_list)}
number_dict = {i: w for i, w in enumerate(word_list)}
n_class = len(word_dict)


"""
TextRNN Parameter
"""
batch_size = len(sentences)
n_hidden = 5

def make_batch(sentences):
    input_batch = []
    target_batch = []

    for sen in sentences:
        word = sen.split()
        input = [word_dict[n] for n in word[:-1]]
        target = word_dict[word[-1]]

        input_batch.append(np.eye(n_class)[input])  # One-Hot Encoding
        target_batch.append(target)
  
    return input_batch, target_batch

input_batch, target_batch = make_batch(sentences)
input_batch = torch.tensor(input_batch, dtype=torch.float32, requires_grad=True)
target_batch = torch.tensor(target_batch, dtype=torch.int64)

# B. LSTMの構築

In [115]:
class TextLSTM(nn.Module):
    def __init__(self):
        super(TextLSTM, self).__init__()

        self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, dropout=0.3)
        self.W = nn.Parameter(torch.randn([n_hidden, n_class]).type(dtype))
        self.b = nn.Parameter(torch.randn([n_class]).type(dtype))
        self.Softmax = nn.Softmax(dim=1)

    def forward(self, hidden_and_cell, X):
        X = X.transpose(0, 1)
        outputs, hidden = self.lstm(X, hidden_and_cell)
        outputs = outputs[-1]
        model = torch.mm(outputs, self.W) + self.b
        return model

# C. 学習

In [118]:
model = TextLSTM()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(500):
    hidden = torch.zeros(1, batch_size, n_hidden, requires_grad=True)
    cell = torch.zeros(1, batch_size, n_hidden, requires_grad=True)
    output = model((hidden, cell), input_batch)
    loss = criterion(output, target_batch)

    if (epoch + 1) % 100 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
  
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  "num_layers={}".format(dropout, num_layers))


Epoch: 0100 cost = 0.230437
Epoch: 0200 cost = 0.031391
Epoch: 0300 cost = 0.013321
Epoch: 0400 cost = 0.007688
Epoch: 0500 cost = 0.005116


# D. 検証

In [119]:
hidden = torch.zeros(1, batch_size, n_hidden, requires_grad=True)
cell = torch.zeros(1, batch_size, n_hidden, requires_grad=True)
predict = model((hidden, cell), input_batch).data.max(1, keepdim=True)[1]
print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])

[['i', 'like'], ['i', 'love'], ['i', 'hate'], ['you', 'like'], ['you', 'love'], ['you', 'hate']] -> ['dogs', 'coffee', 'milk', 'cats', 'milk', 'coffee']


In [120]:
for parameter in model.parameters():
    print(parameter)

Parameter containing:
tensor([[-2.4433,  1.6995, -0.6711, -1.3300, -3.4730,  2.2199, -1.1504,  2.5870,
         -1.0595],
        [-0.8545, -0.8071, -0.5162,  2.5201, -1.1936, -2.3645,  0.1325,  2.4949,
         -0.8585],
        [ 1.4986, -3.7473,  0.1965, -1.1459,  1.3352,  1.8019,  2.3007,  1.5746,
          0.6256],
        [-1.3062,  2.1798,  0.8306, -2.7042, -1.2823, -2.3491, -0.8534,  1.6816,
          0.9794],
        [-0.8011,  1.5705, -2.0520,  1.4748, -1.0428, -2.8994, -0.8618, -2.3974,
         -0.7201]], requires_grad=True)
Parameter containing:
tensor([ 0.6973, -1.2723,  0.8722,  0.7686,  0.4770,  0.2070,  0.1465,  0.2058,
         0.1060], requires_grad=True)
Parameter containing:
tensor([[ 1.2907, -0.0568,  1.1087, -0.1635, -0.1776,  0.2999, -0.6543,  0.2586,
          0.4940],
        [ 0.5337,  0.0792,  0.8622, -0.3191,  0.7686,  0.1982,  0.7354,  0.1778,
          0.2376],
        [ 0.7075,  0.1099,  0.2279, -0.1378,  1.0445, -0.1929,  0.3203,  0.1997,
         -0.58

In [121]:
for parameter in model.parameters():
    print(parameter.shape)

torch.Size([5, 9])
torch.Size([9])
torch.Size([20, 9])
torch.Size([20, 5])
torch.Size([20])
torch.Size([20])
