# Word RNN (using Embedding)

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
sentence = "Repeat is the best medicine for memory".split()

In [4]:
vocab = list(set(sentence))
print(vocab)

['for', 'the', 'best', 'Repeat', 'is', 'medicine', 'memory']


In [5]:
word2index = {tkn : i for i,tkn in enumerate(vocab, 1)}
word2index['<unk>'] = 0

In [6]:
print(word2index)

{'for': 1, 'the': 2, 'best': 3, 'Repeat': 4, 'is': 5, 'medicine': 6, 'memory': 7, '<unk>': 0}


In [7]:
index2word = {v:k for k,v in word2index.items()}
print(index2word)

{1: 'for', 2: 'the', 3: 'best', 4: 'Repeat', 5: 'is', 6: 'medicine', 7: 'memory', 0: '<unk>'}


In [11]:
def build_data(sentence, word2index):
    encoded = [word2index[token] for token in sentence]
    input_seq, label_seq = encoded[:-1], encoded[1:]
    input_seq = torch.LongTensor(input_seq).unsqueeze(0)
    label_seq = torch.LongTensor(label_seq).unsqueeze(0)

    return input_seq, label_seq

In [12]:
X, Y = build_data(sentence, word2index)
X, Y

(tensor([[4, 5, 2, 3, 6, 1]]), tensor([[5, 2, 3, 6, 1, 7]]))

In [13]:
print(X.shape)
print(Y.shape)

torch.Size([1, 6])
torch.Size([1, 6])


In [14]:
class Net(nn.Module):
    def __init__(self, vocab_size, input_size, hidden_size, batch_first=True):
        super(Net, self).__init__()
        self.embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=input_size)
        self.rnn_layer = nn.RNN(input_size, hidden_size, batch_first=batch_first)
        self.linear = nn.Linear(hidden_size, vocab_size)
    def forward(self, x):
        output = self.embedding_layer(x)
        output, hidden = self.rnn_layer(output)
        output = self.linear(output)

        return output.view(-1, output.size(2))

In [15]:
vocab_size = len(word2index)

input_size = 5
hidden_size = 20

In [16]:
model = Net(vocab_size=vocab_size, input_size=input_size, hidden_size=hidden_size, batch_first=True)
loss_function = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters())

In [17]:
output = model(X)
print(output.shape)

torch.Size([6, 8])


In [20]:
decode = lambda y: [index2word.get(x) for x in y]

In [21]:
for step in range(201):
    optimizer.zero_grad()
    output = model(X)
    loss = loss_function(output, Y.view(-1))
    loss.backward()
    optimizer.step()

    if step % 40 == 0:
        pred = output.softmax(-1).argmax(-1).tolist()
        print(["Repeat"]+decode(pred))

['Repeat', 'medicine', 'medicine', 'medicine', 'is', 'medicine', 'medicine']
['Repeat', 'is', 'the', 'best', 'medicine', 'for', 'memory']
['Repeat', 'is', 'the', 'best', 'medicine', 'for', 'memory']
['Repeat', 'is', 'the', 'best', 'medicine', 'for', 'memory']
['Repeat', 'is', 'the', 'best', 'medicine', 'for', 'memory']
['Repeat', 'is', 'the', 'best', 'medicine', 'for', 'memory']
