In [68]:
import torch
from torchcrf import CRF
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x20510c08b50>

In [69]:
class BiLSTM_CRF(nn.Module):

    def __init__(self,
                 vocab_size,  # 单词表的单词数目
                 embedding_dim,  # 输出词向量的维度大小
                 hidden_dim,  # 隐含变量的维度大小的2倍(权重矩阵W_{ih}、W_{hh}中h的大小的2倍)
                 num_tags):  # 实体标签种类个数
        super(BiLSTM_CRF, self).__init__()
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True, batch_first=True)
        self.hidden2tag = nn.Linear(hidden_dim, num_tags)  # lstm后接的全连接层
        self.crf = CRF(num_tags=num_tags, batch_first=True)

    def _get_lstm_features(self, sentence):
        # sentence.shape=[batch_size, sen len]
        # embeds.shape=[batch_size, sen len, embedding_dim]
        embeds = self.word_embeds(sentence)
        # lstm_out.shape=[batch_size, sen len, hidden_dim]
        lstm_out, _ = self.lstm(embeds)
        # lstm_feats.shape=[batch_size, sen len, num_tags]
        emissions = self.hidden2tag(lstm_out)  # 句子中每个词属于不同实体类别标签的概率(即发射分数)
        return emissions

    def forward(self, sentence, tags):
        # sentence.shape=[batch_size, sen len]
        # tags.shape=[batch_size, sen len]
        # emissions.shape=[batch_size, sen len, num_tags]
        emissions = self._get_lstm_features(sentence)
        log_likelihood = self.crf(emissions, tags)  # 对数似然
        loss = -log_likelihood  # 损失值
        return loss

    def decode(self, sentence):
        # sentence.shape=[batch_size, sen len]
        # emissions.shape=[batch_size, sen len, num_tags]
        emissions = self._get_lstm_features(sentence)
        best_path = self.crf.decode(emissions)
        return best_path  # 最优路径

In [70]:
training_data = [("the wall street journal reported today that apple corporation made money".split(),
                  "B I I I O O O B I O O".split()),
                 ("georgia tech is a university in georgia".split(),
                  "B I O O O O B".split())]

word_to_ix = {}
for sentence, tags in training_data:
    for word in sentence:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
word_to_ix

{'the': 0,
 'wall': 1,
 'street': 2,
 'journal': 3,
 'reported': 4,
 'today': 5,
 'that': 6,
 'apple': 7,
 'corporation': 8,
 'made': 9,
 'money': 10,
 'georgia': 11,
 'tech': 12,
 'is': 13,
 'a': 14,
 'university': 15,
 'in': 16}

In [71]:
tag_to_ix = {"B": 0, "I": 1, "O": 2, "<START>": 3, "<STOP>": 4}
tag_to_ix

{'B': 0, 'I': 1, 'O': 2, '<START>': 3, '<STOP>': 4}

In [72]:
model = BiLSTM_CRF(vocab_size=len(word_to_ix), embedding_dim=5, hidden_dim=4, num_tags=len(tag_to_ix))
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

In [73]:
precheck_sent = torch.tensor([[word_to_ix[w] for w in training_data[0][0]]], dtype=torch.long)
precheck_tags = torch.tensor([[tag_to_ix[t] for t in training_data[0][1]]], dtype=torch.long)
precheck_tags  # 真实标签

tensor([[0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2]])

In [74]:
# Check predictions before training
with torch.no_grad():
    print(model(precheck_sent, precheck_tags))  # 损失值较大
    print(model.decode(precheck_sent))  # 最优路径错误

tensor(18.8844)
[[3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3]]


In [75]:
for epoch in range(300):
    for sentence, tags in training_data:
        model.zero_grad()
        sentence_in = torch.tensor([[word_to_ix[w] for w in sentence]], dtype=torch.long)
        targets = torch.tensor([[tag_to_ix[t] for t in tags]], dtype=torch.long)
        loss = model(sentence_in, targets)  # 损失值
        loss.backward()
        optimizer.step()

# Check predictions before training
with torch.no_grad():
    print(model(precheck_sent, precheck_tags))  # 损失值较小
    print(model.decode(precheck_sent))  # 最优路径正确

tensor(0.9706)
[[0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2]]
