In [4]:
import numpy as np
import torch
import torch.nn as nn

def word_hot_encoding(sentences):
    s_array = []
    word_list = list(set(" ".join(sentences).lower().split()))
    word_dict = {w: i for i, w in enumerate(word_list)}
    list_size = len(word_list)

    for s in sentences:
        words = s.lower().split()
        temp = [word_dict[word] for word in words]  # fixed indexing
        s_array.append(np.eye(list_size)[temp])

    return np.array(s_array)

class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.RNN(9, 5, batch_first=True)
        self.seq = nn.Sequential(
            nn.Linear(5, 9),
        )

    def forward(self, x):
        x, h = self.rnn(x)
        x = self.seq(x[:, -1, :])
        return x

F = Network()
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(F.parameters(), lr=0.5)


sentences = ["I like dog", "I love coffee", "I hate milk", "You like cat", "You love milk", "You hate coffee"]
word_list = list(set(" ".join(sentences).lower().split()))
word_dict = {w: i for i, w in enumerate(word_list)}
number_dict = {i: w for i, w in enumerate(word_list)}

print(word_list)
print(word_dict)

s_array = word_hot_encoding(sentences)

x = torch.tensor(s_array[:, :2, :], dtype=torch.float)
t = torch.tensor(s_array[:, 2, :], dtype=torch.long)

epoch = 200
for e in range(epoch):
    loss_sum = 0
    for b in range(x.shape[0]):
        y = F(x[b:b+1])

        loss = loss_function(y, t[b:b+1].argmax(dim=1))
        loss_sum += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if (e + 1) % 10 == 0:
        print("epoch {} | loss {}".format(e + 1, loss_sum))


['dog', 'cat', 'i', 'milk', 'coffee', 'love', 'like', 'you', 'hate']
{'dog': 0, 'cat': 1, 'i': 2, 'milk': 3, 'coffee': 4, 'love': 5, 'like': 6, 'you': 7, 'hate': 8}
epoch 10 | loss 6.359657287597656
epoch 20 | loss 2.753444194793701
epoch 30 | loss 0.2771744728088379
epoch 40 | loss 0.15811550617218018
epoch 50 | loss 0.1107105016708374
epoch 60 | loss 0.0847846195101738
epoch 70 | loss 0.06820619106292725
epoch 80 | loss 0.05649149417877197
epoch 90 | loss 0.047624897211790085
epoch 100 | loss 0.04064047336578369
epoch 110 | loss 0.035064566880464554
epoch 120 | loss 0.030606824904680252
epoch 130 | loss 0.027031967416405678
epoch 140 | loss 0.024139901623129845
epoch 150 | loss 0.02177196554839611
epoch 160 | loss 0.019808536395430565
epoch 170 | loss 0.01815774105489254
epoch 180 | loss 0.016754858195781708
epoch 190 | loss 0.015548719093203545
epoch 200 | loss 0.014501885510981083


In [5]:
#sentences = ["I like dog", "I love coffee", "I hate milk", "You like cat", "You love milk", "You hate coffee"]

result = F(x)
result_arg = torch.argmax(result, dim = 1)
for i in result_arg :
    print(number_dict[i.item()])


dog
coffee
milk
cat
milk
coffee
