In [1]:
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.autograd import Variable

In [3]:
training_data = [("The dog ate the apple".split(),
                  ["DET", "NN", "V", "DET", "NN"]),
                 ("Everybody read that book".split(), ["NN", "V", "DET",
                                                       "NN"])]
print(training_data)

[(['The', 'dog', 'ate', 'the', 'apple'], ['DET', 'NN', 'V', 'DET', 'NN']), (['Everybody', 'read', 'that', 'book'], ['NN', 'V', 'DET', 'NN'])]


In [5]:
word_to_idx = {}
tag_to_idx = {}
for context, tag in training_data:
    for word in context:
        if word not in word_to_idx:
            word_to_idx[word] = len(word_to_idx)
    for label in tag:
        if label not in tag_to_idx:
            tag_to_idx[label] = len(tag_to_idx)
alphabet = 'abcdefghijklmnopqrstuvwxyz'
character_to_idx = {}
for i in range(len(alphabet)):
    character_to_idx[alphabet[i]] = i
print(word_to_idx)
print(tag_to_idx)
print(character_to_idx)

{'The': 0, 'dog': 1, 'Everybody': 5, 'apple': 4, 'the': 3, 'book': 8, 'that': 7, 'ate': 2, 'read': 6}
{'DET': 0, 'NN': 1, 'V': 2}
{'o': 14, 'v': 21, 'i': 8, 'g': 6, 'h': 7, 'l': 11, 'd': 3, 'f': 5, 'a': 0, 'w': 22, 'e': 4, 'x': 23, 'u': 20, 'y': 24, 't': 19, 'm': 12, 's': 18, 'z': 25, 'r': 17, 'b': 1, 'k': 10, 'j': 9, 'n': 13, 'q': 16, 'c': 2, 'p': 15}


In [6]:
class CharLSTM(nn.Module):
    def __init__(self, n_char, char_dim, char_hidden):
        super(CharLSTM, self).__init__()
        self.char_embedding = nn.Embedding(n_char, char_dim)
        self.char_lstm = nn.LSTM(char_dim, char_hidden, batch_first=True)

    def forward(self, x):
        x = self.char_embedding(x)
        _, h = self.char_lstm(x)
        return h[0]

In [7]:
class LSTMTagger(nn.Module):
    def __init__(self, n_word, n_char, char_dim, n_dim, char_hidden, n_hidden,
                 n_tag):
        super(LSTMTagger, self).__init__()
        self.word_embedding = nn.Embedding(n_word, n_dim)
        self.char_lstm = CharLSTM(n_char, char_dim, char_hidden)
        self.lstm = nn.LSTM(n_dim + char_hidden, n_hidden, batch_first=True)
        self.linear1 = nn.Linear(n_hidden, n_tag)

    def forward(self, x, word):
        char = torch.FloatTensor()
        for each in word:
            char_list = []
            for letter in each:
                char_list.append(character_to_idx[letter.lower()])
            char_list = torch.LongTensor(char_list)
            char_list = char_list.unsqueeze(0)
            if torch.cuda.is_available():
                tempchar = self.char_lstm(Variable(char_list).cuda())
            else:
                tempchar = self.char_lstm(Variable(char_list))
            tempchar = tempchar.squeeze(0)
            char = torch.cat((char, tempchar.cpu().data), 0)
        if torch.cuda.is_available():
            char = char.cuda()
        char = Variable(char)
        x = self.word_embedding(x)
        x = torch.cat((x, char), 1)
        x = x.unsqueeze(0)
        x, _ = self.lstm(x)
        x = x.squeeze(0)
        x = self.linear1(x)
        y = F.log_softmax(x)
        return y

In [8]:
model = LSTMTagger(
    len(word_to_idx), len(character_to_idx), 10, 100, 50, 128, len(tag_to_idx))
if torch.cuda.is_available():
    model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)

In [9]:
def make_sequence(x, dic):
    idx = [dic[i] for i in x]
    idx = Variable(torch.LongTensor(idx))
    return idx

In [11]:
for epoch in range(300):
    print('*' * 10)
    print('epoch {}'.format(epoch + 1))
    running_loss = 0
    for data in training_data:
        word, tag = data
        word_list = make_sequence(word, word_to_idx)
        tag = make_sequence(tag, tag_to_idx)
        if torch.cuda.is_available():
            word_list = word_list.cuda()
            tag = tag.cuda()
        # forward
        out = model(word_list, word)
        loss = criterion(out, tag)
        running_loss += loss.data.item()
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('Loss: {}'.format(running_loss / len(data)))
print()
input = make_sequence("Everybody ate the apple".split(), word_to_idx)
if torch.cuda.is_available():
    input = input.cuda()

out = model(input, "Everybody ate the apple".split())
print(out)

**********
epoch 1
Loss: 1.0878206491470337
**********
epoch 2
Loss: 1.0831924080848694
**********
epoch 3
Loss: 1.0785991549491882
**********
epoch 4
Loss: 1.0740392804145813
**********
epoch 5
Loss: 1.0695111751556396
**********
epoch 6
Loss: 1.0650132894515991
**********
epoch 7
Loss: 1.0605440735816956
**********
epoch 8
Loss: 1.0561020374298096
**********
epoch 9
Loss: 1.0516858100891113
**********
epoch 10
Loss: 1.0472939610481262
**********
epoch 11
Loss: 1.0429250597953796
**********
epoch 12
Loss: 1.038577914237976
**********
epoch 13
Loss: 1.03425133228302
**********
epoch 14
Loss: 1.0299437642097473
**********
epoch 15
Loss: 1.0256543159484863
**********
epoch 16
Loss: 1.021381676197052
**********
epoch 17
Loss: 1.0171247720718384
**********
epoch 18
Loss: 1.01288241147995
**********
epoch 19
Loss: 1.008653700351715
**********
epoch 20




Loss: 1.0044374465942383
**********
epoch 21
Loss: 1.0002326965332031
**********
epoch 22
Loss: 0.9960386157035828
**********
epoch 23
Loss: 0.9918541312217712
**********
epoch 24
Loss: 0.9876784980297089
**********
epoch 25
Loss: 0.9835106730461121
**********
epoch 26
Loss: 0.9793499410152435
**********
epoch 27
Loss: 0.9751953184604645
**********
epoch 28
Loss: 0.9710462391376495
**********
epoch 29
Loss: 0.9669018983840942
**********
epoch 30
Loss: 0.9627615809440613
**********
epoch 31
Loss: 0.9586243331432343
**********
epoch 32
Loss: 0.9544897973537445
**********
epoch 33
Loss: 0.9503572881221771
**********
epoch 34
Loss: 0.9462260007858276
**********
epoch 35
Loss: 0.942095547914505
**********
epoch 36
Loss: 0.9379652142524719
**********
epoch 37
Loss: 0.9338345527648926
**********
epoch 38
Loss: 0.9297029972076416
**********
epoch 39
Loss: 0.9255701303482056
**********
epoch 40
Loss: 0.9214353263378143
**********
epoch 41
Loss: 0.9172983467578888
**********
epoch 42
Loss: 0.913

Loss: 0.2608746513724327
**********
epoch 220
Loss: 0.2587933838367462
**********
epoch 221
Loss: 0.25673212110996246
**********
epoch 222
Loss: 0.25469083338975906
**********
epoch 223
Loss: 0.2526693046092987
**********
epoch 224
Loss: 0.25066740065813065
**********
epoch 225
Loss: 0.24868492782115936
**********
epoch 226
Loss: 0.2467217817902565
**********
epoch 227
Loss: 0.24477776885032654
**********
epoch 228
Loss: 0.24285274744033813
**********
epoch 229
Loss: 0.24094657599925995
**********
epoch 230
Loss: 0.23905903100967407
**********
epoch 231
Loss: 0.23719006776809692
**********
epoch 232
Loss: 0.2353394255042076
**********
epoch 233
Loss: 0.2335069626569748
**********
epoch 234
Loss: 0.23169253766536713
**********
epoch 235
Loss: 0.22989597916603088
**********
epoch 236
Loss: 0.22811711579561234
**********
epoch 237
Loss: 0.2263558730483055
**********
epoch 238
Loss: 0.22461198270320892
**********
epoch 239
Loss: 0.2228853553533554
**********
epoch 240
Loss: 0.2211758047342