In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data 
import torch.optim as optim
dtype = torch.FloatTensor

In [2]:
sentences = ['i like cat', 'i love coffee', 'i hate milk']
sentences_list = " ".join(sentences).split() # ['i', 'like', 'cat', 'i', 'love'. 'coffee',...]
vocab = list(set(sentences_list))
word2idx = {w:i for i, w in enumerate(vocab)}
idx2word = {i:w for i, w in enumerate(vocab)}
vocab_lens = len(vocab)

In [3]:
def make_data(sentences):
    input_data = []
    target_data = []
    for sen in sentences:
        sen = sen.split()
        input_map = [word2idx[i] for i in sen[:-1]]
        tar_map = word2idx[sen[-1]]
        input_data.append(input_map)
        target_data.append(tar_map)
    return input_data,target_data

In [4]:
input_data,target_data = make_data(sentences)
input_data,target_data = torch.LongTensor(input_data), torch.LongTensor(target_data)
dataset = Data.TensorDataset(input_data,target_data)
dataloader = Data.DataLoader(dataset, batch_size=2, shuffle=True)

dataset

<torch.utils.data.dataset.TensorDataset at 0x24ddda89e08>

In [5]:
embedding_dim = 2
n_step = 2#输入数据长度


In [6]:
class NNLM(nn.Module):
    def __init__(self):
        super(NNLM,self).__init__()
        self.table_lookup = nn.Embedding(vocab_lens,embedding_dim)
        self.linear1 = nn.Linear(n_step*embedding_dim, 64)
        self.linear2 =nn.Linear(n_step*embedding_dim,vocab_lens)
        self.linear3 = nn.Linear(64,vocab_lens, bias=False)
        # 或者直接定义一个参数矩阵
        # self.U = nn.Parameter(torch.randn(n_hidden, V).type(dtype))

    def forward(self,x):
        # input: [batch_size, n_step]
        x = self.table_lookup(x) # [batch_size, n_step, embeddding_dim]
        x = x.view(-1, n_step*embedding_dim) # [batch_size, n_step*embedding_dim]
        hidden_out = F.tanh(self.linear1(x))
        out = self.linear2(x) + self.linear3(hidden_out)
        return out
'''
1000 0.001605809316970408
1000 0.001312824198976159
2000 0.00028981463401578367
2000 0.0002321927313460037
3000 8.302579226437956e-05
3000 6.878139538457617e-05
4000 2.443760422465857e-05
4000 2.825220326485578e-05
5000 8.404219443036709e-06
5000 9.65590606938349e-06'''
# class NNLM(nn.Module):
#   def __init__(self):
#     super(NNLM, self).__init__()
#     self.C = nn.Embedding(vocab_lens, embedding_dim)
#     self.H = nn.Parameter(torch.randn(n_step * embedding_dim, 64).type(dtype))
#     self.d = nn.Parameter(torch.randn(64).type(dtype))
#     self.b = nn.Parameter(torch.randn(vocab_lens).type(dtype))
#     self.W = nn.Parameter(torch.randn(n_step * embedding_dim, vocab_lens).type(dtype))
#     self.U = nn.Parameter(torch.randn(64, vocab_lens).type(dtype))

#   def forward(self, X):
#     '''
#     X : [batch_size, n_step]
#     '''
#     X = self.C(X) # [batch_size, n_step, m]
#     X = X.view(-1, n_step * embedding_dim) # [batch_szie, n_step * m]
#     hidden_out = torch.tanh(self.d + torch.mm(X, self.H)) # [batch_size, n_hidden]
#     output = self.b + torch.mm(X, self.W) + torch.mm(hidden_out, self.U)
#     return output
'''
1000 0.0013060837518423796
1000 0.0014915067004039884
2000 0.0002582333399914205
2000 0.00032884435495361686
3000 8.779368363320827e-05
3000 7.652943895664066e-05
4000 2.4616414521005936e-05
4000 3.5523738915799186e-05
5000 1.043075644702185e-05
5000 8.821448318485636e-06
'''
model = NNLM()
optim = optim.Adam(model.parameters(),lr = 1e-3)
criterion = nn.CrossEntropyLoss()


In [7]:
for epoch in range(5000):
    for datax, datay in dataloader:
        pre = model(datax)
        loss = criterion(pre, datay)
        loss.backward()
        optim.step()
        optim.zero_grad()

        if (epoch + 1) % 1000 == 0:
            print(epoch + 1, loss.item())
        
        



1000 0.001605809316970408
1000 0.001312824198976159
2000 0.00028981463401578367
2000 0.0002321927313460037
3000 8.302579226437956e-05
3000 6.878139538457617e-05
4000 2.443760422465857e-05
4000 2.825220326485578e-05
5000 8.404219443036709e-06
5000 9.65590606938349e-06


In [8]:
# Pred
pred = model(input_data).max(1, keepdim=True)[1]
print([idx2word[idx.item()] for idx in pred.squeeze()])

['cat', 'coffee', 'milk']
