<a href="https://colab.research.google.com/github/katrina906/CS6120-Summarization-Project/blob/main/ner_decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementing a Viterbi Decoder and Evaluation for Sequence Labeling

In this assignment, you will build a Viterbi decoder for an LSTM named-entity recognition model. As we mentioned in class, recurrent and bidirectional recurrent neural networks, of which LSTMs are the most common examples, can be used to perform sequence labeling. Although these models encode information from the surrounding words in order to make predictions, there are no "hard" constraints on what tags can appear where.

There hard constraints are particularly important for tasks that label spans of more than one token. The most common example of a span-labeling task is named-entity recognition (NER). As described in Eisenstein, Jurafksy & Martin, and other texts, the goal of NER is to label spans of one or more words as _mentions_ of an _entity_, such as a person, location, organization, etc.

The most common approach to NER is to reduce it to a sequence-labeling task, where each token in the input is labeled either with an `O`, if it is "outside" any named-entity span, or with `B-TYPE`, if it is the first token in an entity of type `TYPE`, or with `I-TYPE`, if it is the second or later token in an entity of type `TYPE`. Distinguishing between the first and later tokens of an entity allow us to identify distinct entity spans even when they are adjacent.

Common values of `TYPE` include `PER` for person, `LOC` for location, `DATE` for date, and so on. In the dataset we load below, there are 17 distinct types.

The span-labeling scheme just described implies that the labels on tokens must obey certain constraints: the tag `I-PER` must follow either `B-PER` or another `I-PER`. It cannot follow `O`, `B-LOC`, or `I-LOC`, i.e., a tag for a different entity type. By themselves, LSTMs or bidirectional LSTMs cannot directly enforce these constraints. This is one reason why conditional random fields (CRFs), which _can_ enforce these constraints, are often layered on top of these recurrent models.

In this assignment, you will implement the simplest possible CRF: a CRF so simple that it does not require any training. Rather, it will assign weight 1 to any sequence of tags that obeys the constraints and weight 0 to any sequence of tags that violates them. The inputs to the CRF, which are analogous to the emission probabilities in an HMM, will come from an LSTM.

But first, in order to test your decoder, you will also implement some functions to evaluate the output of an NER system according to two metrics:
1. You will count the number of _violations_ of the NER label constraints, i.e., how many times `I-TYPE` follows `O` or a tag of a different type or occurs at the beginning of a sentence. This number will be greater than 0 in the raw LSTM output, but should be 0 for your CRF output.
1. You will compute the _span-level_ precision, recall, and F1 of NER output. Although the baseline LSTM was trained to achieve high _token-level_ accuracy, this metric can be misleadingly high, since so many tokens are correctly labeled `O`. In other words, what proportion of spans predicted by the model line up exactly with spans in the gold standard, and what proportion of spans in the gold standard were predicted by the model? Define _span_ as a sequence of tags that starts with a `B-TYPE` followed by zero or more `I-TYPE` tags. Sequences solely of `I-TYPE` tags don't count as spans.For more, see the original task definition: https://www.aclweb.org/anthology/W03-0419/.

We start with loading some code and data and the describe your tasks in more detail.

## Set Up Dependencies and Definitions

In [1]:
%%capture
!pip install --upgrade spacy==2.1.0 allennlp==0.9.0
import spacy

In [2]:
from typing import Iterator, List, Dict
import torch
import torch.optim as optim
import numpy as np
from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.data.iterators import BucketIterator
from allennlp.training.trainer import Trainer
from allennlp.predictors import SentenceTaggerPredictor
from allennlp.data.dataset_readers import conll2003
import pandas as pd

torch.manual_seed(1)

<torch._C.Generator at 0x7f905f8367d0>

In [3]:
class LstmTagger(Model):
  def __init__(self,
               word_embeddings: TextFieldEmbedder,
               encoder: Seq2SeqEncoder,
               vocab: Vocabulary) -> None:
    super().__init__(vocab)
    self.word_embeddings = word_embeddings
    self.encoder = encoder
    self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                      out_features=vocab.get_vocab_size('labels'))
    self.accuracy = CategoricalAccuracy()

  def forward(self,
              tokens: Dict[str, torch.Tensor],
              metadata,
              tags: torch.Tensor = None) -> Dict[str, torch.Tensor]:
    mask = get_text_field_mask(tokens)
    embeddings = self.word_embeddings(tokens)
    encoder_out = self.encoder(embeddings, mask)
    tag_logits = self.hidden2tag(encoder_out)
    output = {"tag_logits": tag_logits}
    if tags is not None:
      self.accuracy(tag_logits, tags, mask)
      output["loss"] = sequence_cross_entropy_with_logits(tag_logits, tags, mask)

    return output

  def get_metrics(self, reset: bool = False) -> Dict[str, float]:
    return {"accuracy": self.accuracy.get_metric(reset)}

## Import Data

In [4]:
reader = conll2003.Conll2003DatasetReader()
train_dataset = reader.read(cached_path('http://www.ccs.neu.edu/home/dasmith/onto.train.ner.sample'))
validation_dataset = reader.read(cached_path('http://www.ccs.neu.edu/home/dasmith/onto.development.ner.sample'))

from itertools import chain
vocab = Vocabulary.from_instances(chain(train_dataset, validation_dataset))

159377B [00:00, 2605116.85B/s]
562it [00:00, 18547.33it/s]
8366B [00:00, 5210802.98B/s]
23it [00:00, 3081.68it/s]
585it [00:00, 57725.21it/s]


In [None]:
#list(train_dataset[100]['tokens'])

## Define and Train Model

In [5]:
EMBEDDING_DIM = 6
HIDDEN_DIM = 6
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, bidirectional=False, batch_first=True))
model = LstmTagger(word_embeddings, lstm, vocab)
if torch.cuda.is_available():
    cuda_device = 0
    model = model.cuda(cuda_device)
else:
    cuda_device = -1
# optimizer = optim.AdamW(model.parameters(), lr=1e-4, eps=1e-8)
optimizer = optim.SGD(model.parameters(), lr=0.1)
iterator = BucketIterator(batch_size=2, sorting_keys=[("tokens", "num_tokens")])
iterator.index_with(vocab)
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  patience=10,
                  num_epochs=100,
                  cuda_device=cuda_device)
trainer.train()

accuracy: 0.8442, loss: 0.9027 ||: 100%|██████████| 281/281 [00:01<00:00, 188.78it/s]
accuracy: 0.7878, loss: 1.2009 ||: 100%|██████████| 12/12 [00:00<00:00, 423.15it/s]
accuracy: 0.8442, loss: 0.7302 ||: 100%|██████████| 281/281 [00:01<00:00, 255.52it/s]
accuracy: 0.7878, loss: 1.1918 ||: 100%|██████████| 12/12 [00:00<00:00, 411.32it/s]
accuracy: 0.8442, loss: 0.7157 ||: 100%|██████████| 281/281 [00:01<00:00, 259.25it/s]
accuracy: 0.7878, loss: 1.1967 ||: 100%|██████████| 12/12 [00:00<00:00, 458.44it/s]
accuracy: 0.8442, loss: 0.7069 ||: 100%|██████████| 281/281 [00:01<00:00, 254.95it/s]
accuracy: 0.7878, loss: 1.1707 ||: 100%|██████████| 12/12 [00:00<00:00, 419.80it/s]
accuracy: 0.8442, loss: 0.6985 ||: 100%|██████████| 281/281 [00:01<00:00, 255.25it/s]
accuracy: 0.7878, loss: 1.1755 ||: 100%|██████████| 12/12 [00:00<00:00, 395.90it/s]
accuracy: 0.8442, loss: 0.6904 ||: 100%|██████████| 281/281 [00:01<00:00, 257.33it/s]
accuracy: 0.7878, loss: 1.1467 ||: 100%|██████████| 12/12 [00:00

{'best_epoch': 99,
 'best_validation_accuracy': 0.8836734693877552,
 'best_validation_loss': 0.3793452437821543,
 'epoch': 99,
 'peak_cpu_memory_MB': 3221.472,
 'peak_gpu_0_memory_MB': 1058,
 'training_accuracy': 0.9279737315962292,
 'training_cpu_memory_MB': 3221.472,
 'training_duration': '0:01:58.824019',
 'training_epochs': 99,
 'training_gpu_0_memory_MB': 1058,
 'training_loss': 0.18746380445160454,
 'training_start_epoch': 0,
 'validation_accuracy': 0.8836734693877552,
 'validation_loss': 0.3793452437821543}

## Evaluation

The simple code below loops over the validation set, applying the model to each example and collecting out the input token, gold-standard output, and model output. You can see from these methods how to access ground truth and model outputs for evaluation.

In [6]:
def tag_sentence(s):
  tag_ids = np.argmax(model.forward_on_instance(s)['tag_logits'], axis=-1)
  fields = zip(s['tokens'], s['tags'], [model.vocab.get_token_from_index(i, 'labels') for i in tag_ids])
  return list(fields)

baseline_output = [tag_sentence(i) for i in validation_dataset]
## Show the first example
baseline_output[0]

[(With, 'O', 'O'),
 (a, 'O', 'O'),
 (wave, 'O', 'O'),
 (of, 'O', 'O'),
 (his, 'O', 'O'),
 (hand, 'O', 'O'),
 (,, 'O', 'O'),
 (Peng, 'B-PERSON', 'B-PERSON'),
 (Dehuai, 'I-PERSON', 'I-PERSON'),
 (said, 'O', 'O'),
 (that, 'O', 'O'),
 (despite, 'O', 'O'),
 (being, 'O', 'O'),
 (over, 'O', 'O'),
 (100, 'B-CARDINAL', 'B-CARDINAL'),
 (regiments, 'O', 'O'),
 (,, 'O', 'O'),
 (let, 'O', 'O'),
 ('s, 'O', 'O'),
 (call, 'O', 'O'),
 (this, 'O', 'O'),
 (campaign, 'O', 'O'),
 (the, 'B-EVENT', 'O'),
 (Hundred, 'I-EVENT', 'I-EVENT'),
 (Regiments, 'I-EVENT', 'I-EVENT'),
 (Offensive, 'I-EVENT', 'I-EVENT'),
 (., 'O', 'O')]

Now, you can implement two evaluation functions: `violations` and `span_stats`.

### Violations

In [7]:
# count the number of NER label violations,
# such as O followed by I-TYPE or B-TYPE followed by
# I-OTHER_TYPE
# Take tagger output as input
def violations(predicted_type_list):

  count = 0

  # violation 1: I-TYPE following O (or I-TYPE starting sentence)
  # mark with 1 whenever we see the violation
  invalid_start = [[1 for i in range(len(sentence)) if (sentence[i][0] == 'I' and sentence[i-1][0] == 'O' and i != 0) or 
                  (i == 0 and sentence[i][0] == 'I')] for sentence in predicted_type_list]
  # sum up number of violations
  invalid_start = [item for sublist in invalid_start for item in sublist]
  count += np.sum(invalid_start)

  # violation 2: I-OTHER_TYPE following B-TYPE or I-TYPE
  # mark with 1 whenever we see the violation
  invalid_type = [[1 for i in range(len(sentence)) if 
                   (sentence[i][0] == 'I' and sentence[i-1][0] != 'O' and sentence[i-1][1] != sentence[i][1] and i != 0)] 
                  for sentence in predicted_type_list]
  # sum up number of violations
  invalid_type = [item for sublist in invalid_type for item in sublist]
  count += np.sum(invalid_type)

  return count

In [8]:
# get list of predicted [B/I/O tag, tag type (if exists)] for each sentence
predicted_type_list = [[tag[2].split('-') for tag in sentence] for sentence in baseline_output]
# get list of actual [B/I/O tag, tag type (if exists)] for each sentence
actual_type_list = [[tag[1].split('-') for tag in sentence] for sentence in baseline_output]

In [9]:
# number of violations
violations(predicted_type_list)

33

### Span Statistics

In [10]:
# create dataframe with one row per token and columns with predicted tag, predicted type, actual tag, actual type, document number
# easier to mark spans in a dataframe than in a list 
def create_dataframe(predicted_type_list, actual_type_list):

  tag_df = pd.DataFrame()

  # create dataframes for each record and concat together 
  for i in range(len(predicted_type_list)):
    # dataframe of predicted tags and types
    predicted_df = pd.DataFrame.from_records(predicted_type_list[i])
    # some records don't have any type because all O tags
    if len(predicted_df.columns) == 2: 
      predicted_df.columns = ['predicted_tag', 'predicted_type']
    else:
      predicted_df.columns = ['predicted_tag']
    predicted_df['num'] = i # mark which document

    # dataframe of actual tags and types
    actual_df = pd.DataFrame.from_records(actual_type_list[i])
    if len(actual_df.columns) == 2:
      actual_df.columns = ['actual_tag', 'actual_type']
    else:
      actual_df.columns = ['actual_tag']

    # merge actual and predicted
    tmp_df = pd.merge(predicted_df, actual_df, left_index = True, right_index = True)

    # concat to overall dataframe for all records
    tag_df = pd.concat([tag_df, tmp_df], sort = False)

  return tag_df.reset_index(drop = True)

In [11]:
# create indicator for each sequence of actual and predicted tags
# only consider valid spans 
def mark_sequences(tag_df, var_type):
  tagvar = var_type + '_tag'
  typevar = var_type + '_type'
  sequencevar = var_type + '_sequence'

  # mark start of sequence
  tag_df[sequencevar] = np.where(tag_df[tagvar] == 'B', tag_df.index, 0)
  # sequence ends if different type var or different record
  tag_df[sequencevar] = np.where((tag_df[typevar] != tag_df[typevar].shift()) &
                                 (tag_df[sequencevar] == 0),
                                 np.nan, tag_df[sequencevar])
  tag_df[sequencevar] = np.where((tag_df.num != tag_df.num.shift()) &
                                 (tag_df[sequencevar] == 0), np.nan, tag_df[sequencevar])
  
  # drag forward sequence start indicator over entire sequence
  tag_df[sequencevar] = tag_df[sequencevar].replace(to_replace = 0, method = 'ffill')

  return tag_df

In [12]:
# return the span-level precision, recall, and F1
# Only count valid spans that start with a B tag,
# followed by zero or more I tags of the same type.
# This is harsher than the token-level metric that the
# LSTM was trained to optimize, but it is the standard way
# of evaluating NER systems.
# Take tagger output as input
def span_stats(predicted_type_list, actual_type_list):

  # mark actual vs predicted sequences
  # create pandas dataframe. easier to generate a sense of prolonged sequences and order
  tag_df = create_dataframe(predicted_type_list, actual_type_list)
  tag_df = mark_sequences(tag_df, 'actual')
  tag_df = mark_sequences(tag_df, 'predicted')

  # generate lists of actual and predicted sequences
  # sequence "signature" = index (to mark position) + tag + type
    # thus when match actual and predicted sequences up, checking correct words, tags, and types 
  tag_df['sequence_signature_actual'] = tag_df.index.astype(str) + tag_df.actual_tag + tag_df.actual_type 
  # for each sequence, get list of signatures. When matching against predicted, all signatures must be present
  actual_sequences = tag_df.groupby('actual_sequence').sequence_signature_actual.unique().to_list()
  actual_sequences = [list(i) for i in actual_sequences]

  tag_df['sequence_signature_predicted'] = tag_df.index.astype(str) + tag_df.predicted_tag + tag_df.predicted_type 
  predicted_sequences = tag_df.groupby('predicted_sequence').sequence_signature_predicted.unique().to_list()
  predicted_sequences = [list(i) for i in predicted_sequences]

  # precision: percent of predicted sequences that are correct
  correct_sequence = 0
  for sequence in predicted_sequences:
    if sequence in actual_sequences:
      correct_sequence += 1 
  precision = correct_sequence / len(predicted_sequences)

  # recall: percent of actual sequences that are predicted
  correct_sequence = 0
  for sequence in actual_sequences:
    if sequence in predicted_sequences:
      correct_sequence += 1 
  recall = correct_sequence / len(actual_sequences)

  # f1 score
  f1 = 2 * (precision * recall) / (precision + recall)

  return {'precision': precision,
          'recall': recall,
          'f1': f1}

In [13]:
span_stats(predicted_type_list, actual_type_list)

{'f1': 0.3055555555555556,
 'precision': 0.3793103448275862,
 'recall': 0.2558139534883721}

## Decoding

Now you can finally implement the simple Viterbi decoder. The `model` object, when applied to an input sentence, first calculates the scores for each possible output tag for each token. See the expression `model.forward_on_instance(s)['tag_logits']` in the code above.

Then, you will construct a transition matrix. You can use the code below to get a list of the tags the model knows about. For a set of K tags, construct a K-by-K matrix with a log(1)=0 when a transition between a given tag pair is valid and a log(0)=-infinity otherwise.

Finally, implement a Viterbi decoder that takes the model object and a dataset object and outputs tagged data, just like the `tag_sentence` function above. It should use the Viterbi algorithm with the (max, plus) semiring. You'll be working with sums of log probabilities instead of products of probabilties.

Run your `violations` function on the output of this decoder to make sure that there are no invalid tag transitions. Also, compare the span-level metrics on `baseline_output` and your new output using your `span_stats` function.

In [90]:
def create_transition_matrix(vocab):
  # initialize matrix of zeros
  K = len(vocab.get_index_to_token_vocabulary('labels'))
  transition_matrix = np.zeros((K,K))

  # get lists of indexes of I tags and B tags
  tag_dict = vocab.get_index_to_token_vocabulary('labels')
  Ipositions = [k for (k,v) in tag_dict.items() if 'I-' in v]
  Bpositions = [k for (k,v) in tag_dict.items() if 'B-' in v]

  # mark invalid transitions
  # Invalid transition 1: from O to I 
  for i in Ipositions:
    transition_matrix[0][i] = -np.Inf

  # Invalid transition 2: from B-TYPE to I-OTHERTYPE 
  for b in Bpositions:
    for i in Ipositions:
      # skip if same type
      if tag_dict[i].split('-')[1] == tag_dict[b].split('-')[1]:
        continue
      else:
        transition_matrix[b][i] = -np.Inf

  # Invalid transition 3: from I-TYPE to I-OTHERTYPE
  for i in Ipositions:
    for i2 in Ipositions:
      # skip if same tag
      if i == i2:
        continue
      else:
        transition_matrix[i][i2] = -np.Inf

  return transition_matrix, Ipositions

In [91]:
def decode(s, Ipositions):

  # initialize matrix to hold probabilites and best path for each tag-word position combination
  # rows = tags, columns = tokens
  K = len(vocab.get_index_to_token_vocabulary('labels'))
  viterbi = np.zeros((K,len(s['tokens'])))
  backpointer = np.zeros((K, len(s['tokens'])))

  # initial probabilities of each tag for first token in sentence
  viterbi[:,0] = model.forward_on_instance(s)['tag_logits'][0]
  # -infinity probability for I tokens (cannot start a sentence)
  for i in Ipositions:
    viterbi[i,0] = -np.inf
  # initial tag on the best path. 
  backpointer[:,0] = np.argmax(viterbi[:,0])

  # loop through tokens in sentence
  for token_index in range(1, len(s['tokens'])):
    # loop through possible tags
    for new_tag_index in range(K):
      bestprob = -np.Inf
      bestpath = -1
      # loop through tags in prior position to find best path so far
      for prior_tag_index in range(K):
        # prior path probability + transition probability + state observation probability
        prob = transition_matrix[prior_tag_index, new_tag_index] + viterbi[prior_tag_index, token_index - 1] + model.forward_on_instance(s)['tag_logits'][token_index, new_tag_index]
        # keep track of best seen 
        if prob > bestprob:
          bestprob = prob
          bestpath = prior_tag_index
      # record best paths and probabilities
      viterbi[new_tag_index, token_index] = bestprob 
      backpointer[new_tag_index, token_index] = bestpath

  # generate best path: follow backpointer backwards
  # starting at path that ends with highest probability at final token position
  best_sequence = []
  back_position = int(np.argmax([score[-1] for score in viterbi]))
  best_sequence.append(back_position)
  for i in range(len(s['tokens'])-1, 0, -1):
    back_position = int(backpointer[back_position][i])
    best_sequence.append(back_position)
  best_sequence = best_sequence[::-1] # reverse sequence 

  # formatted output
  fields = zip(s['tokens'], s['tags'], [model.vocab.get_token_from_index(i, 'labels') for i in best_sequence])

  return list(fields)

In [92]:
transition_matrix, Ipositions = create_transition_matrix(vocab)
output = [decode(i, Ipositions) for i in validation_dataset]

In [94]:
# number of violations
# get list of predicted [B/I/O tag, tag type (if exists)] for each sentence
predicted_type_list = [[tag[2].split('-') for tag in sentence] for sentence in output]
# get list of actual [B/I/O tag, tag type (if exists)] for each sentence
actual_type_list = [[tag[1].split('-') for tag in sentence] for sentence in output]
violations(predicted_type_list)

0.0

The revised algorithm using a transition matrix to enforce tag ordering constraints significantly improved both the precision and recall. By eliminating paths that are impossible, we have a better chance of finding the correct labeling sequence. For example, the baseline model labels "The Central Military Comission" as O, I-EVENT, I-ORG, I-ORG. The improved model cannot generate this path because it has (mutliple) violations. It thus instead finds the next best valid path and ultimately predicts the correct sequence B-ORG, I-ORG, I-ORG, I-ORG.     
    
More specifically, recall improved more than precision. This makes sense because invalid sequences are not considered in the denominator when calculating precision for the original model (while the denominator of recall is the correct sequences so it is unaffected). Note however that the revised algorithm does still improve precision because it improves the accuracy of even previously valid sequences. For example, the baseline model labels "600,000 - plus" as B-CARDINAL, I-ORG, I-PERSON. The B-CARDINAL is a valid but incorrect (incomplete) sequence. Because the following tags are not possible in the improved model, the sequence gets correctly predicted as B-CARDINAL, I-CARDINAL, I-CARDINAL and precision improves. 

In [95]:
span_stats(predicted_type_list, actual_type_list)

{'f1': 0.45977011494252873,
 'precision': 0.45454545454545453,
 'recall': 0.46511627906976744}