In [21]:
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.autograd import Variable

In [25]:
training_data = [("The dog ate the apple".split(),
                  ["DET", "NN", "V", "DET", "NN"]),
                 ("Everybody read that book".split(), ["NN", "V", "DET",
                                                       "NN"])]

word_to_idx = {}
tag_to_idx = {}
for context, tag in training_data:
    for word in context:
        if word not in word_to_idx:
            word_to_idx[word] = len(word_to_idx)
    for label in tag:
        if label not in tag_to_idx:
            tag_to_idx[label] = len(tag_to_idx)
alphabet = 'abcdefghijklmnopqrstuvwxyz'
character_to_idx = {}
for i in range(len(alphabet)):
    character_to_idx[alphabet[i]] = i

In [41]:
class CharLSTM(nn.Module):
    """
    对每个词的字符进行嵌入并学习得到一个结果
    """
    def __init__(self, n_char, char_dim, char_hidden):
        super(CharLSTM, self).__init__()
        self.char_embedding = nn.Embedding(n_char, char_dim)
        self.char_lstm = nn.LSTM(char_dim, char_hidden, batch_first=True)
        
    def forward(self, x):
        x = self.char_embedding(x)
        _, h = self.char_lstm(x)
        
        return h[0]  # 返回一个词的字符输出数据

In [30]:
class LSTMTagger(nn.Module):
    def __init__(self, n_word, n_char, char_dim, n_dim, char_hidden, n_hidden, n_tag):
        super(LSTMTagger, self).__init__()
        self.word_embedding = nn.Embedding(n_word, n_dim)
        self.char_lstm = CharLSTM(n_char, char_dim, char_hidden)
        self.lstm = nn.LSTM(n_dim + char_hidden, n_hidden, batch_first = True)
        self.linear1 = nn.Linear(n_hidden, n_tag)
        
    def forward(self, x, word):
        char = torch.FloatTensor()
        for each in word:
            char_list = []
            for letter in each:
                char_list.append(character_to_idx[letter.lower()])
            char_list = torch.LongTensor(char_list)
            char_list = char_list.unsqueeze(0)
            if torch.cuda.is_available():
                tempchar = self.char_lstm(Variable(char_list).cuda())
            else:
                tempchar = self.char_lstm(Variable(char_list))
            tempchar = tempchar.squeeze(0)
            char = torch.cat((char, tempchar.cpu().data), 0)
        if torch.cuda.is_available():
            char = char.cuda()
        char = Variable(char)
        x = self.word_embedding(x)
        x = torch.cat((x, char), 1)
        x = x.unsqueeze(0)
        x, _ = self.lstm(x)
        x = x.squeeze(0)
        x = self.linear1(x)
        y = F.log_softmax(x)
        return y

In [42]:
model = LSTMTagger(
    len(word_to_idx), len(character_to_idx), 10, 100, 50, 128, len(tag_to_idx))
if torch.cuda.is_available():
    model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)


def make_sequence(x, dic):
    idx = [dic[i] for i in x]
    idx = Variable(torch.LongTensor(idx))
    return idx


for epoch in range(300):
    print('*' * 10)
    print('epoch {}'.format(epoch + 1))
    running_loss = 0
    for data in training_data:
        word, tag = data
        word_list = make_sequence(word, word_to_idx)
        tag = make_sequence(tag, tag_to_idx)
        if torch.cuda.is_available():
            word_list = word_list.cuda()
            tag = tag.cuda()
        # forward
        out = model(word_list, word)
        loss = criterion(out, tag)
        running_loss += loss.data[0]
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('Loss: {}'.format(running_loss / len(data)))
print()
input = make_sequence("Everybody ate the apple".split(), word_to_idx)
if torch.cuda.is_available():
    input = input.cuda()

out = model(input, "Everybody ate the apple".split())
print(out)

**********
epoch 1
Loss: 1.0942506790161133
**********
epoch 2
Loss: 1.0895805358886719
**********
epoch 3
Loss: 1.0849435329437256
**********
epoch 4
Loss: 1.0803377628326416
**********
epoch 5
Loss: 1.0757620334625244
**********
epoch 6
Loss: 1.0712149143218994
**********
epoch 7
Loss: 1.066694736480713
**********
epoch 8
Loss: 1.0622005462646484
**********
epoch 9
Loss: 1.0577306747436523
**********
epoch 10
Loss: 1.0532841682434082
**********
epoch 11
Loss: 1.0488595962524414
**********
epoch 12
Loss: 1.0444557666778564
**********
epoch 13
Loss: 1.040071725845337
**********
epoch 14
Loss: 1.0357060432434082
**********
epoch 15
Loss: 1.0313578844070435
**********
epoch 16
Loss: 1.0270261764526367
**********
epoch 17
Loss: 1.0227097272872925
**********
epoch 18
Loss: 1.0184077024459839
**********
epoch 19
Loss: 1.014119029045105
**********
epoch 20




Loss: 1.009842872619629
**********
epoch 21
Loss: 1.0055782794952393
**********
epoch 22
Loss: 1.0013244152069092
**********
epoch 23
Loss: 0.9970804452896118
**********
epoch 24
Loss: 0.9928455948829651
**********
epoch 25
Loss: 0.9886189699172974
**********
epoch 26
Loss: 0.9843999147415161
**********
epoch 27
Loss: 0.9801875352859497
**********
epoch 28
Loss: 0.9759813547134399
**********
epoch 29
Loss: 0.9717804789543152
**********
epoch 30
Loss: 0.9675843119621277
**********
epoch 31
Loss: 0.9633922576904297
**********
epoch 32
Loss: 0.959203839302063
**********
epoch 33
Loss: 0.955018162727356
**********
epoch 34
Loss: 0.9508348107337952
**********
epoch 35
Loss: 0.9466532468795776
**********
epoch 36
Loss: 0.9424730539321899
**********
epoch 37
Loss: 0.9382935762405396
**********
epoch 38
Loss: 0.934114396572113
**********
epoch 39
Loss: 0.9299349784851074
**********
epoch 40
Loss: 0.9257550239562988
**********
epoch 41
Loss: 0.921574056148529
**********
epoch 42
Loss: 0.9173916

Loss: 0.28140169382095337
**********
epoch 211
Loss: 0.27910757064819336
**********
epoch 212
Loss: 0.2768353223800659
**********
epoch 213
Loss: 0.27458474040031433
**********
epoch 214
Loss: 0.27235573530197144
**********
epoch 215
Loss: 0.2701480984687805
**********
epoch 216
Loss: 0.267961710691452
**********
epoch 217
Loss: 0.2657964527606964
**********
epoch 218
Loss: 0.26365214586257935
**********
epoch 219
Loss: 0.2615286707878113
**********
epoch 220
Loss: 0.2594257593154907
**********
epoch 221
Loss: 0.2573433816432953
**********
epoch 222
Loss: 0.25528138875961304
**********
epoch 223
Loss: 0.2532394826412201
**********
epoch 224
Loss: 0.25121766328811646
**********
epoch 225
Loss: 0.24921566247940063
**********
epoch 226
Loss: 0.24723342061042786
**********
epoch 227
Loss: 0.24527069926261902
**********
epoch 228
Loss: 0.24332736432552338
**********
epoch 229
Loss: 0.2414032220840454
**********
epoch 230
Loss: 0.23949816823005676
**********
epoch 231
Loss: 0.237611979246139