In [30]:
import torch
from torch import nn
from collections import defaultdict

class Char_lstm(nn.Module):
    def __init__(self,n_char,char_dim,char_hidden):
        super(Char_lstm,self).__init__()
        self.emb=nn.Embedding(n_char,char_dim)
        self.lstm=nn.LSTM(char_dim,char_hidden)
    def forward(self,x):
        x=self.emb(x)
        out,_=self.lstm(x)
        return out[-1,:,:]

def make_vector(x,w2i):
    idx=[w2i[i.lower()] for i in x]
    idx=torch.LongTensor(idx)
    return idx
class Lstm_tagger(nn.Module):
    def __init__(self,n_word,n_char,char_dim,word_dim,char_hidden,word_hidden,n_tag):
        super(Lstm_tagger,self).__init__()
        self.char_lstm=Char_lstm(n_char,char_dim,char_hidden)
        self.word_embed=nn.Embedding(n_word,word_dim)
        self.word_lstm=nn.LSTM(word_dim+char_hidden,word_hidden)
        self.classify=nn.Linear(word_hidden,n_tag)
    def forward(self,word_list,words):
        #word_list和words一致的
        char=[]
        for word in words:#words是每一个单词
            char_list=make_vector(word,char_to_idx)
            char_list=char_list.unsqueeze(1)#batch=1
            char_infor=self.char_lstm(torch.tensor(char_list))
            char.append(char_infor)
        
        char=torch.stack(char,dim=0)
        x=self.word_embed(word_list).unsqueeze(1)
        x=torch.cat((x,char),dim=2)
        x,_=self.word_lstm(x)
        s,b,h=x.shape
        print(b)
        x=x.view(-1,h)
        out=self.classify(x)
        return out

training_data = [("The monkey ate the banana".split(),
                  ["DET", "NN", "V", "DET","NN"]),
                 ("The dog ate the bones".split(), 
                  ["DET", "NN", "V", "DET", "NN"])]
w2i=defaultdict(lambda :len(w2i))
t2i=defaultdict(lambda :len(t2i))
for context,tag in training_data:
    for words in context:
        index=w2i[words.lower()]
    for label in tag:
        index=t2i[label.lower()]
        
        
z = 'abcdefghijklmnopqrstuvwxyz'
char_to_idx={}
for i in range(len(z)):
    char_to_idx[z[i]]=i
#参数分别为：
#单词字典的长度
#字符字典的长度
#字符的emb_size
#单词的emb_size
#字符的lstm的hidden
#单词的lstm的hidden
#分类的类别
net=Lstm_tagger(len(w2i),len(char_to_idx),10, 100, 50, 128, len(t2i))
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(net.parameters(),lr=0.01)
for epoch in range(3):
    train_loss=0.0
    for words,tag in training_data:#word,tag是文本和文本标签
        word_list=make_vector(words,w2i)
        tag=make_vector(tag,t2i)
        out=net(word_list,words)
        loss=criterion(out,tag)
        train_loss+=loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("epoch=%r,loss=%.4f"%(epoch,train_loss/len(training_data)))


net = net.eval()
test_sent = 'The dog ate the banana'
test = make_vector(test_sent.split(), w2i)

out = net(test, test_sent.split())
print(out.max(1)[1].data)



1
1
epoch=0,loss=1.1114
1
1
epoch=1,loss=1.1036
1
1
epoch=2,loss=1.0961
1
tensor([1, 2, 2, 1, 1])
