In [79]:
import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x1cd01a97b70>

In [80]:
START_TAG = "<START>"
STOP_TAG = "<STOP>"

tag_to_ix = {"B": 0, "I": 1, "O": 2, START_TAG: 3, STOP_TAG: 4}
tag_to_ix

{'B': 0, 'I': 1, 'O': 2, '<START>': 3, '<STOP>': 4}

In [81]:
def log_sum_exp(vec):
    """Compute log sum exp in a numerically stable way for the forward algorithm"""

    # vec.shape=[1, ?]
    max_score = torch.max(vec)
    # max_score_boradcast.shaep=[1, ?]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

In [82]:
class BiLSTM_CRF(nn.Module):

    def __init__(self,
                 vocab_size,  # 单词表的单词数目
                 tag_to_ix,  # 实体标签(key)与标签编号(value)组成的字典
                 embedding_dim,  # 输出词向量的维度大小
                 hidden_dim):  # 隐含变量的维度大小的2倍(权重矩阵W_{ih}、W_{hh}中h的大小的2倍)
        super(BiLSTM_CRF, self).__init__()
        self.hidden_dim = hidden_dim
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)  # 实体标签种类个数

        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True)
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)  # lstm后接的全连接层

        # 转移得分矩阵
        # self.transitions.shape=[self.taget_size, self.taget_size]
        # self.transitions[i, j]:从j列对应的标签转移到i行对应的标签的得分
        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))
        # 其他标签转移到"START_TAG"的分数非常小(即不可能由其他标签转移到"START_TAG")
        self.transitions.data[tag_to_ix[START_TAG], :] = -10000
        # "STOP_TAG"转移到所有其他标签的分数非常小(即不可能由"STOP_TAG"转移到其他标签)
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000

    def _get_lstm_features(self, sentence):
        # sentence.shape=[sen len, ]  # 这里每次输入一个句子

        # embeds.shape=[sen len, 1, embedding_dim]
        embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
        # lstm_out.shape=[sen len, 1, hidden_dim]
        lstm_out, _ = self.lstm(embeds)
        # lstm_out.shape=[sen len, hidden_dim]
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        # lstm_feats.shape=[sen len, self.tagset_size]
        lstm_feats = self.hidden2tag(lstm_out)  # 句子中每个词属于不同实体类别标签的概率(即发射分数)
        return lstm_feats

    def _score_sentence(self, feats, tags):
        """计算给定tag序列的分数"""
        # feats.shape=[sen len, self.taget_size]

        score = torch.zeros(1)
        # 开头出处添加标签"START_TAG"的编号
        tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])
        for i, feat in enumerate(feats):
            # feat.shape=[self.tageta_size, ]

            # self.transitions[tags[i + 1], tags[i]]:转移分数(从tags[i]转移到tags[i + 1]转移得分)
            # feat[tags[i + 1]]:发射分数
            score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
        # 添加tag最后一个元素到"STOP_TAG"的转移得分
        score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
        return score

    def _forward_alg(self, feats):
        """全部路径的分数计算"""
        # feats.shape=[sen len, self.taget_size]

        init_alphas = torch.full((1, self.tagset_size), -10000.)
        init_alphas[0][self.tag_to_ix[START_TAG]] = 0.
        forward_var = init_alphas  # 上一步的得分(previous)

        for feat in feats:
            # feat.shape=[self.taget_size, ]
            alphas_t = []

            for next_tag in range(self.tagset_size):  # 通过循环每次计算一个标签
                # feat[next_tag]:当前步骤该标签的发射分数(obs[i])
                # feat[next_tag].view(-1, 1).shape=[1, 1]
                # emit_score.shape=[1, self.tag_size]  # 扩展obs[i]
                emit_score = feat[next_tag].view(1, -1).expand(1, self.tagset_size)
                # trans_score.shape=[1, self.tagset_size]   # transition得分(所有其他标签转移到该标签的分数)
                trans_score = self.transitions[next_tag].view(1, -1)
                next_tag_var = forward_var + trans_score + emit_score
                alphas_t.append(log_sum_exp(next_tag_var).view(1))

            forward_var = torch.cat(alphas_t).view(1, -1)  # 更新上一步的得分(更新previous)
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        # 总路径得分(TotalScore)
        alpha = log_sum_exp(terminal_var)
        return alpha

    def neg_log_likelihood(self, sentence, tags):
        """CRF损失函数"""

        # feats.shape=[sen len, self.taget_size]
        feats = self._get_lstm_features(sentence)

        forward_score = self._forward_alg(feats)  # 全部路径总得分
        gold_score = self._score_sentence(feats, tags)  # 最优路径得分
        return forward_score - gold_score

    def _viterbi_decode(self, feats):
        """维特比算法求解最优路径"""
        # feats.shape=[sen len, self.taget_size]

        backpointers = []

        init_vvars = torch.full((1, self.tagset_size), -10000.)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0
        # forward_var.shape=[1, self.tagset_size]
        forward_var = init_vvars

        for feat in feats:
            # feat.shape=[self.taget_size, ]
            bptrs_t = []
            viterbivars_t = []

            for next_tag in range(self.tagset_size):
                # 维特比算法记录最优路径时只考虑上一步的分数以及上一步tag转移到当前tag的转移分数(并不取决与当前tag的发射分数)
                # next_tag_var.shaep=[1, self.taget_size]
                next_tag_var = forward_var + self.transitions[next_tag]
                best_tag_id = torch.argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(torch.max(next_tag_var, dim=1).values)

            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)

        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        best_tag_id = torch.argmax(terminal_var)

        # Follow the back pointers to decode the best path.
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)

        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG]
        best_path.reverse()
        return best_path

    def forward(self, sentence):
        lstm_feats = self._get_lstm_features(sentence)
        tag_seq = self._viterbi_decode(lstm_feats)
        return tag_seq

In [83]:
training_data = [("the wall street journal reported today that apple corporation made money".split(),
                  "B I I I O O O B I O O".split()),
                 ("georgia tech is a university in georgia".split(),
                  "B I O O O O B".split())]

word_to_ix = {}
for sentence, tags in training_data:
    for word in sentence:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
word_to_ix

{'the': 0,
 'wall': 1,
 'street': 2,
 'journal': 3,
 'reported': 4,
 'today': 5,
 'that': 6,
 'apple': 7,
 'corporation': 8,
 'made': 9,
 'money': 10,
 'georgia': 11,
 'tech': 12,
 'is': 13,
 'a': 14,
 'university': 15,
 'in': 16}

In [84]:
model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, embedding_dim=5, hidden_dim=4)
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

In [85]:
precheck_sent = torch.tensor([word_to_ix[w] for w in training_data[0][0]], dtype=torch.long)
precheck_tags = torch.tensor([tag_to_ix[t] for t in training_data[0][1]], dtype=torch.long)
precheck_tags  # 真实标签

tensor([0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2])

In [86]:
# Check predictions before training
with torch.no_grad():
    print(model.neg_log_likelihood(precheck_sent, precheck_tags))  # 损失值较大
    print(model(precheck_sent))  # 最优路径错误

tensor([13.4273])
[tensor(1), tensor(2), tensor(2), tensor(2), tensor(2), tensor(2), tensor(2), tensor(2), tensor(2), tensor(2), tensor(1)]


In [87]:
for epoch in range(300):
    for sentence, tags in training_data:
        model.zero_grad()
        sentence_in = torch.tensor([word_to_ix[w] for w in sentence], dtype=torch.long)
        targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)
        loss = model.neg_log_likelihood(sentence_in, targets)  # 损失值
        loss.backward()
        optimizer.step()

# Check predictions after training
with torch.no_grad():
    print(model.neg_log_likelihood(precheck_sent, precheck_tags))  # 损失值较小
    print(model(precheck_sent))  # 最优路径正确

tensor([0.8400])
[tensor(0), tensor(1), tensor(1), tensor(1), tensor(2), tensor(2), tensor(2), tensor(0), tensor(1), tensor(2), tensor(2)]
