In [24]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset,TensorDataset,DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
dtype=torch.FloatTensor

In [25]:
sentences=['i like cat','i love coffee','i hate milk','i love you']
sentences_list=" ".join(sentences).split()
vocab=list(set(sentences_list))
word2idx={w:i for i,w in enumerate(vocab)}
idx2word={i:w for i,w in enumerate(vocab)}
V=len(vocab)

In [26]:
def make_data(sentences):
    X=[]
    y=[]
    for sen in sentences:
        sen=sen.split()
        X_tmp=[word2idx[w] for w in sen[:-1]]
        y_tmp=word2idx[sen[-1]]

        X.append(X_tmp)
        y.append(y_tmp)
    return X,y

In [31]:
X,y=make_data(sentences)
X,y=torch.LongTensor(X),torch.LongTensor(y)
print(X,y)
dataset=TensorDataset(X,y)
dataloader=DataLoader(dataset,2,True)
m=2
n_step=2
n_hidden=10

tensor([[0, 1],
        [0, 2],
        [0, 7],
        [0, 2]]) tensor([5, 3, 6, 4])


In [29]:
class NNLM(nn.Module):
    def __init__(self):
        super(NNLM,self).__init__()
        self.C=nn.Embedding(V,m)
        self.H=nn.Parameter(torch.randn(n_step*m,n_hidden).type(dtype))
        self.d=nn.Parameter(torch.randn(n_hidden).type(dtype))
        
        self.U=nn.Parameter(torch.randn(n_hidden,V).type(dtype))
        self.b=nn.Parameter(torch.randn(V).type(dtype))
        self.W=nn.Parameter(torch.randn(n_step*m,V).type(dtype))
    def forward(self,X):
        X=self.C(X)
        X=X.view(-1,n_step*m)
        hidden_out=torch.tanh(self.d+torch.mm(X,self.H))
        y=self.b+torch.mm(X,self.W)+torch.mm(hidden_out,self.U)
        return y
model=NNLM()
optim=torch.optim.Adam(model.parameters(),lr=1e-3)
criterion=nn.CrossEntropyLoss()

In [30]:
step=0
writer=SummaryWriter()
for epoch in range(500):
    loss_record=[]
    for batch_x,batch_y in tqdm(dataloader,position=0):
        pred=model(batch_x)
        loss=criterion(pred,batch_y)

        # if (epoch+1) %100 ==0:
        #     print(epoch+1,loss.item())
        optim.zero_grad()
        loss.backward()
        optim.step()
        loss_record.append(loss.item())
        step+=1
    mean_train_loss=sum(loss_record)/len(loss_record)
    writer.add_scalar('TrainLoss',mean_train_loss,step)

100%|██████████| 2/2 [00:00<00:00, 52.54it/s]
100%|██████████| 2/2 [00:00<00:00, 250.28it/s]
100%|██████████| 2/2 [00:00<00:00, 250.23it/s]
100%|██████████| 2/2 [00:00<00:00, 500.78it/s]
100%|██████████| 2/2 [00:00<00:00, 499.02it/s]
100%|██████████| 2/2 [00:00<00:00, 166.90it/s]
100%|██████████| 2/2 [00:00<00:00, 249.18it/s]
100%|██████████| 2/2 [00:00<00:00, 166.75it/s]
100%|██████████| 2/2 [00:00<00:00, 249.97it/s]
100%|██████████| 2/2 [00:00<00:00, 502.25it/s]
100%|██████████| 2/2 [00:00<00:00, 250.36it/s]
100%|██████████| 2/2 [00:00<00:00, 250.29it/s]
100%|██████████| 2/2 [00:00<00:00, 250.65it/s]
100%|██████████| 2/2 [00:00<00:00, 250.35it/s]
100%|██████████| 2/2 [00:00<00:00, 250.74it/s]
100%|██████████| 2/2 [00:00<00:00, 250.30it/s]
100%|██████████| 2/2 [00:00<00:00, 166.87it/s]
100%|██████████| 2/2 [00:00<00:00, 166.90it/s]
100%|██████████| 2/2 [00:00<00:00, 249.92it/s]
100%|██████████| 2/2 [00:00<00:00, 250.27it/s]
100%|██████████| 2/2 [00:00<00:00, 193.80it/s]
100%|█████████

In [40]:
pred=model(X)
indices=pred.max(1,keepdim=True)[1]
print(indices)
print(indices.squeeze())
print([index.item() for index in indices.squeeze()])
print([idx2word[index.item()] for index in indices.squeeze()])

tensor([[4],
        [2],
        [6]])
tensor([4, 2, 6])
[4, 2, 6]
['cat', 'coffee', 'milk']


In [43]:
pred=model(X).max(1)[1]
pred

tensor([4, 2, 6])