In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset,TensorDataset,DataLoader
dtype=torch.FloatTensor

In [2]:
sentences=['i like cat','i love coffee','i hate milk']
sentences_list=" ".join(sentences).split()
vocab=list(set(sentences_list))
word2idx={w:i for i,w in enumerate(vocab)}
idx2word={i:w for i,w in enumerate(vocab)}
V=len(vocab)

In [4]:
def make_data(sentences):
    X=[]
    y=[]
    for sen in sentences:
        sen=sen.split()
        X_tmp=[word2idx[w] for w in sen[:-1]]
        y_tmp=word2idx[sen[-1]]

        X.append(X_tmp)
        y.append(y_tmp)
    return X,y

In [5]:
X,y=make_data(sentences)
X,y=torch.LongTensor(X),torch.LongTensor(y)
dataset=TensorDataset(X,y)
dataloader=DataLoader(dataset,2,True)
m=2
n_step=2
n_hidden=10

In [20]:
class NNLM(nn.Module):
    def __init__(self):
        super(NNLM,self).__init__()
        self.C=nn.Embedding(V,m)
        self.H=nn.Parameter(torch.randn(n_step*m,n_hidden).type(dtype))
        self.d=nn.Parameter(torch.randn(n_hidden).type(dtype))
        
        self.U=nn.Parameter(torch.randn(n_hidden,V).type(dtype))
        self.b=nn.Parameter(torch.randn(V).type(dtype))
        self.W=nn.Parameter(torch.randn(n_step*m,V).type(dtype))
    def forward(self,X):
        X=self.C(X)
        X=X.view(-1,n_step*m)
        hidden_out=torch.tanh(self.d+torch.mm(X,self.H))
        y=self.b+torch.mm(X,self.W)+torch.mm(hidden_out,self.U)
        return y
model=NNLM()
optim=torch.optim.Adam(model.parameters(),lr=1e-3)
criterion=nn.CrossEntropyLoss()

In [22]:
for epoch in range(500):
    for batch_x,batch_y in dataloader:
        pred=model(batch_x)
        loss=criterion(pred,batch_y)

        if (epoch+1) %100 ==0:
            print(epoch+1,loss.item())
        optim.zero_grad()
        loss.backward()
        optim.step()

100 0.4558974802494049
100 2.7638697624206543
200 0.37769126892089844
200 0.17931468784809113
300 0.07017400860786438
300 0.1931810975074768
400 0.06138978898525238
400 0.05610883608460426
500 0.030168922618031502
500 0.054158832877874374


In [40]:
pred=model(X)
indices=pred.max(1,keepdim=True)[1]
print(indices)
print(indices.squeeze())
print([index.item() for index in indices.squeeze()])
print([idx2word[index.item()] for index in indices.squeeze()])

tensor([[4],
        [2],
        [6]])
tensor([4, 2, 6])
[4, 2, 6]
['cat', 'coffee', 'milk']


In [43]:
pred=model(X).max(1)[1]
pred

tensor([4, 2, 6])