# Word RNN (using Embedding)

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim

In [20]:
sentence = "Repeat is the best medicine for memory".split()

In [21]:
vocab = list(set(sentence))
print(vocab)

['best', 'medicine', 'Repeat', 'for', 'memory', 'is', 'the']


In [22]:
word2index = {tkn : i for i,tkn in enumerate(vocab, 1)}
word2index['<unk>'] = 0

In [23]:
print(word2index)

{'best': 1, 'medicine': 2, 'Repeat': 3, 'for': 4, 'memory': 5, 'is': 6, 'the': 7, '<unk>': 0}


In [24]:
index2word = {v:k for k,v in word2index.items()}
print(index2word)

{1: 'best', 2: 'medicine', 3: 'Repeat', 4: 'for', 5: 'memory', 6: 'is', 7: 'the', 0: '<unk>'}


In [25]:
def build_data(sentence, word2index):
    encoded = [word2index[token] for token in sentence]
    input_seq, label_seq = encoded[:-1], encoded[1:]
    input_seq = torch.LongTensor(input_seq).unsqueeze(0)
    label_seq = torch.LongTensor(label_seq).unsqueeze(0)

    return input_seq, label_seq

In [26]:
X, Y = build_data(sentence, word2index)

In [27]:
print(X.shape)
print(Y.shape)

torch.Size([1, 6])
torch.Size([1, 6])


In [28]:
class Net(nn.Module):
    def __init__(self, vocab_size, input_size, hidden_size, batch_first=True):
        super(Net, self).__init__()
        self.embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=input_size)
        self.rnn_layer = nn.RNN(input_size, hidden_size, batch_first=batch_first)
        self.linear = nn.Linear(hidden_size, vocab_size)
    def forward(self, x):
        output = self.embedding_layer(x)
        output, hidden = self.rnn_layer(output)
        output = self.linear(output)

        return output.view(-1, output.size(2))

In [29]:
vocab_size = len(word2index)

input_size = 5
hidden_size = 20

In [30]:
model = Net(vocab_size=vocab_size, input_size=input_size, hidden_size=hidden_size, batch_first=True)
loss_function = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters())

In [31]:
output = model(X)
print(output.shape)

torch.Size([6, 8])


In [32]:
decode = lambda y: [index2word.get(x) for x in y]

In [33]:
for step in range(201):
    optimizer.zero_grad()
    output = model(X)
    loss = loss_function(output, Y.view(-1))
    loss.backward()
    optimizer.step()

    if step % 40 == 0:
        pred = output.softmax(-1).argmax(-1).tolist()
        print(["Repeat"]+decode(pred))

['Repeat', 'Repeat', 'Repeat', '<unk>', '<unk>', 'medicine', 'for']
['Repeat', 'is', 'memory', 'best', 'medicine', 'for', 'memory']
['Repeat', 'is', 'the', 'best', 'medicine', 'for', 'memory']
['Repeat', 'is', 'the', 'best', 'medicine', 'for', 'memory']
['Repeat', 'is', 'the', 'best', 'medicine', 'for', 'memory']
['Repeat', 'is', 'the', 'best', 'medicine', 'for', 'memory']
