In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

dtype = torch.FloatTensor

In [13]:
sentences = ["i like dog", "i love coffee", "i hate milk"]
words = ' '.join(sentences).split(' ')
words_dict = list(set(words))
num2word = {index:word for index,word in enumerate(words_dict)}
word2num = {word:index for index,word in enumerate(words_dict)}

n_class = len(words_dict)
n_step = 2
n_hidden = 2
embedding_size = 2

def make_batch(sentences):
    input_batch = []
    target_batch = []
    for sen in sentences:
        word = sen.split()
        input = [word2num[w] for w in word[:-1]]
        target = word2num[word[-1]]
        
        input_batch.append(input)
        target_batch.append(target)
        
    return input_batch,target_batch

class NNLM(nn.Module):
    def __init__(self):
        super(NNLM,self).__init__()
        self.C = nn.Embedding(n_class,embedding_size)
        self.H = nn.Parameter(torch.randn(n_step*embedding_size,n_hidden).type(dtype))
        self.W = nn.Parameter(torch.randn(n_step*embedding_size,n_class).type(dtype))
        self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))
        self.U = nn.Parameter(torch.randn(n_hidden,n_class).type(dtype))
        self.b = nn.Parameter(torch.randn(n_class).type(dtype))
        
    def forward(self,X):
        X = self.C(X)
        X = X.view(-1,n_step*embedding_size)
        tanh = torch.tanh(torch.mm(X,self.H)+self.d)
        return torch.mm(tanh,self.U)+torch.mm(X,self.W)+self.b
    
model = NNLM()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)

input_batch,target_batch = make_batch(sentences)
input_batch = Variable(torch.LongTensor(input_batch))
target_batch = Variable(torch.LongTensor(target_batch))

for epoch in range(5000):
    
    optimizer.zero_grad()
    predict = model(input_batch)
    loss = criterion(predict,target_batch)
    if epoch%100 == 0:
        print('epoch:%d loss:%f' % (epoch,loss))
    loss.backward()
    optimizer.step()
    
prediction = model(input_batch).data.max(1,keepdim=True)[1]
# Test
print([sen.split()[:2] for sen in sentences], '->', [num2word[n.item()] for n in prediction.squeeze()])


epoch:0 loss:3.273170
epoch:100 loss:2.161475
epoch:200 loss:1.446672
epoch:300 loss:1.048739
epoch:400 loss:0.818006
epoch:500 loss:0.660088
epoch:600 loss:0.539275
epoch:700 loss:0.439288
epoch:800 loss:0.352862
epoch:900 loss:0.279044
epoch:1000 loss:0.218977
epoch:1100 loss:0.172366
epoch:1200 loss:0.137162
epoch:1300 loss:0.110754
epoch:1400 loss:0.090818
epoch:1500 loss:0.075571
epoch:1600 loss:0.063727
epoch:1700 loss:0.054377
epoch:1800 loss:0.046883
epoch:1900 loss:0.040791
epoch:2000 loss:0.035773
epoch:2100 loss:0.031592
epoch:2200 loss:0.028071
epoch:2300 loss:0.025078
epoch:2400 loss:0.022512
epoch:2500 loss:0.020296
epoch:2600 loss:0.018367
epoch:2700 loss:0.016680
epoch:2800 loss:0.015194
epoch:2900 loss:0.013880
epoch:3000 loss:0.012712
epoch:3100 loss:0.011669
epoch:3200 loss:0.010735
epoch:3300 loss:0.009894
epoch:3400 loss:0.009136
epoch:3500 loss:0.008450
epoch:3600 loss:0.007827
epoch:3700 loss:0.007260
epoch:3800 loss:0.006744
epoch:3900 loss:0.006271
epoch:4000 l