# LSTM+CRF实现序列标注

$$\begin{align}P(y|x) = \frac{\exp{(\text{Score}(x, y)})}{\sum_{y'} \exp{(\text{Score}(x, y')})}\end{align}$$

$$\begin{align}\text{Score}(x,y) = \sum_i \log \psi_i(x,y)\end{align}$$

$$\begin{align}\text{Score}(x,y) = \sum_i \log \psi_\text{EMIT}(y_i \rightarrow x_i) + \log \psi_\text{TRANS}(y_{i-1} \rightarrow y_i)\end{align}$$

$$\begin{align}= \sum_i h_i[y_i] + \textbf{P}_{y_i, y_{i-1}}\end{align}$$

In [1]:
import mindspore
import mindspore.nn as nn
import mindspore.numpy as mnp
from mindspore import Parameter
from mindspore.common.initializer import initializer, Uniform

def sequence_mask(seq_length, max_length, batch_first=False):
    """generate mask matrix by seq_length"""
    range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)
    result = range_vector < seq_length.view(seq_length.shape + (1,))
    if batch_first:
        return result.astype(mindspore.int64)
    return result.astype(mindspore.int64).swapaxes(0, 1)

class CRF(nn.Cell):
    def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'invalid reduction: {reduction}')
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.reduction = reduction
        self.start_transitions = Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')
        self.end_transitions = Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')
        self.transitions = Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')

    def construct(self, emissions, tags=None, seq_length=None):
        if tags is None:
            return self._decode(emissions, seq_length)
        return self._forward(emissions, tags, seq_length)

    def _forward(self, emissions, tags=None, seq_length=None):
        if self.batch_first:
            batch_size, max_length = tags.shape
            emissions = emissions.swapaxes(0, 1)
            tags = tags.swapaxes(0, 1)
        else:
            max_length, batch_size = tags.shape

        if seq_length is None:
            seq_length = mnp.full((batch_size,), max_length, mindspore.int64)
        
        mask = sequence_mask(seq_length, max_length)

        # shape: (batch_size,)
        numerator = self._compute_score(emissions, tags, seq_length-1, mask)
        # shape: (batch_size,)
        denominator = self._compute_normalizer(emissions, mask)
        # shape: (batch_size,)
        llh = denominator - numerator

        if self.reduction == 'none':
            return llh
        elif self.reduction == 'sum':
            return llh.sum()
        elif self.reduction == 'mean':
            return llh.mean()
        return llh.sum() / mask.astype(emissions.dtype).sum()

    def _decode(self, emissions, seq_length=None):
        if self.batch_first:
            batch_size, max_length = emissions.shape[:2]
            emissions = emissions.swapaxes(0, 1)
        else:
            batch_size, max_length = emissions.shape[:2]

        if seq_length is None:
            seq_length = mnp.full((batch_size,), max_length, mindspore.int64)
        
        mask = sequence_mask(seq_length, max_length)

        return self._viterbi_decode(emissions, mask)

    def _compute_score(self, emissions, tags, seq_ends, mask):
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)

        seq_length, batch_size = tags.shape
        mask = mask.astype(emissions.dtype)

        # Start transition score and first emission
        # shape: (batch_size,)
        score = self.start_transitions[tags[0]]
        score += emissions[0, mnp.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            # Transition score to next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]

            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += emissions[i, mnp.arange(batch_size), tags[i]] * mask[i]

        # End transition score
        # shape: (batch_size,)
        last_tags = tags[seq_ends, mnp.arange(batch_size)]
        # shape: (batch_size,)
        score += self.end_transitions[last_tags]

        return score

    def _compute_normalizer(self, emissions, mask):
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)

        seq_length = emissions.shape[0]

        # Start transition score and first emission; score has size of
        # (batch_size, num_tags) where for each batch, the j-th column stores
        # the score that the first timestep has tag j
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            # Broadcast score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.expand_dims(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emissions = emissions[i].expand_dims(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the sum of scores of all
            # possible tag sequences so far that end with transitioning from tag i to tag j
            # and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emissions

            # Sum over all possible current tags, but we're in score space, so a sum
            # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
            # all possible tag sequences so far, that end in tag i
            # shape: (batch_size, num_tags)
            next_score = mnp.log(mnp.sum(mnp.exp(next_score), axis=1))

            # Set score to the next score if this timestep is valid (mask == 1)
            # shape: (batch_size, num_tags)
            score = mnp.where(mask[i].expand_dims(1), next_score, score)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Sum (log-sum-exp) over all possible tags
        # shape: (batch_size,)
        return mnp.log(mnp.sum(mnp.exp(score), axis=1))

    def _viterbi_decode(self, emissions, mask):
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)

        seq_length = mask.shape[0]

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history = ()

        # score is a tensor of size (batch_size, num_tags) where for every batch,
        # value at column j stores the score of the best tag sequence so far that ends
        # with tag j
        # history saves where the best tags candidate transitioned from; this is used
        # when we trace back the best tag sequence

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            # Broadcast viterbi score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.expand_dims(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emission = emissions[i].expand_dims(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the score of the best
            # tag sequence so far that ends with transitioning from tag i to tag j and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emission

            # Find the maximum score over all possible current tag
            # shape: (batch_size, num_tags)
            indices = next_score.argmax(axis=1)
            next_score = next_score.max(axis=1)
            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags)
            score = mnp.where(mask[i].expand_dims(1), next_score, score)
            history += (indices,)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        return score, history

    def post_decode(self, score, history, seq_length):
        # Now, compute the best path for each sample
        batch_size = seq_length.shape[0]
        seq_ends = seq_length - 1
        # shape: (batch_size,)
        best_tags_list = []

        for idx in range(batch_size):
            # Find the tag which maximizes the score at the last timestep; this is our best tag
            # for the last timestep
            best_last_tag = score[idx].argmax(axis=0)
            best_tags = [int(best_last_tag.asnumpy())]
            # We trace back where the best last tag comes from, append that to our best tag
            # sequence, and trace it back again, and so on
            for hist in reversed(history[:seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(int(best_last_tag.asnumpy()))

            # Reverse the order because we start from the last timestep
            best_tags.reverse()
            best_tags_list.append(best_tags)

        return best_tags_list

In [2]:
class BiLSTM_CRF(nn.Cell):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
        self.hidden2tag = nn.Dense(hidden_dim, num_tags)
        self.crf = CRF(num_tags, batch_first=True)

    def construct(self, inputs, seq_length, tags=None):
        embeds = self.embedding(inputs)
        outputs, _ = self.lstm(embeds, seq_length=seq_length)
        feats = self.hidden2tag(outputs)
        
        crf_outs = self.crf(feats, tags, seq_length)
        return crf_outs

In [3]:
embedding_dim = 5
hidden_dim = 4

training_data = [(
    "the wall street journal reported today that apple corporation made money".split(),
    "B I I I O O O B I O O".split()
), (
    "georgia tech is a university in georgia".split(),
    "B I O O O O B".split()
)]

word_to_idx = {}
word_to_idx['<pad>'] = 0
for sentence, tags in training_data:
    for word in sentence:
        if word not in word_to_idx:
            word_to_idx[word] = len(word_to_idx)

tag_to_idx = {"B": 0, "I": 1, "O": 2}

In [4]:
len(word_to_idx)

18

In [5]:
model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))
optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4)

In [6]:
train_one_step = nn.TrainOneStepCell(model, optimizer)

In [7]:
def prepare_sequence(seqs, word_to_idx, tag_to_idx):
    seq_outputs, label_outputs, seq_length = [], [], []
    max_len = max([len(i[0]) for i in seqs])

    for seq, tag in seqs:
        seq_length.append(len(seq))
        idxs = [word_to_idx[w] for w in seq]
        labels = [tag_to_idx[t] for t in tag]
        idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])
        labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])
        seq_outputs.append(idxs)
        label_outputs.append(labels)

    return mindspore.Tensor(seq_outputs, mindspore.int64), \
            mindspore.Tensor(label_outputs, mindspore.int64), \
            mindspore.Tensor(seq_length, mindspore.int64)

In [8]:
data, label, seq_length = prepare_sequence(training_data, word_to_idx, tag_to_idx)
data.shape, label.shape, seq_length.shape

((2, 11), (2, 11), (2,))

In [9]:
from tqdm import tqdm

steps = 500
with tqdm(total=steps) as t:
    for i in range(steps):
        loss = train_one_step(data, seq_length, label)
        t.set_postfix(loss=loss)
        t.update(1)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:22<00:00, 21.76it/s, loss=[0.81230736 0.29350662]]


In [10]:
score, histroy = model(data, seq_length)
score

Tensor(shape=[2, 3], dtype=Float32, value=
[[ 2.58174553e+01,  2.04221268e+01,  2.78829422e+01],
 [ 1.92055836e+01,  1.57861795e+01,  1.70711060e+01]])

In [11]:
predict = model.crf.post_decode(score, histroy, seq_length)

In [12]:
predict

[[0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2], [0, 1, 2, 2, 2, 2, 0]]