In [1]:
!pip install seqeval -q
!pip install flair -q

In [2]:
import pandas as pd
import io
from flair.data import Sentence
from flair.models import SequenceTagger
from transformers import LukeTokenizer, LukeForEntitySpanClassification
import timeit
import ast

import unicodedata

import numpy as np
import seqeval.metrics
import spacy
import torch
from tqdm import tqdm, trange


class FlairModel:
    def __init__(self):
        # load tagger
        self.tagger = SequenceTagger.load("flair/ner-english-large")

    def get_entity_list(self, input_string):
        sentence = Sentence(input_string)
        # print(sentence)
        # predict NER tags
        self.tagger.predict(sentence)
        sentence_length = len(sentence)
        values = ["O"] * len(input_string.split(" "))
        total_string = ""
        tagged_string = sentence.to_tagged_string()
        true_index = 0
        # print(tagged_string)
        count_entities = 0
        punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''

        # print(tagged_string.split(">"))
        # print(sentence.to_dict(tag_type='ner'))
        tagged_dict = sentence.to_dict(tag_type='ner')
        named_entities = tagged_dict["entities"]
        total_entities = []
        total_text = []

        for i in named_entities:
          text = named_entities[0]["text"]
          space_count = text.count(" ")

          entities = []
          current_entity = str(i["labels"][0])[:5]
          current_text = str(i["text"])

          if "ORG" in current_entity:
            current_entity = "ORG"
          if "MISC" in current_entity:
            current_entity = "MISC"
          if "PER" in current_entity:
            current_entity = "PER"
          if "LOC" in current_entity:
            current_entity = "LOC"
          total_text.append(current_text)
          entities.append("B-"+current_entity)
          if space_count >=1: 
            for i in range(space_count):
              entities.append("I-"+current_entity)
          total_entities.append(entities)

        copy_string = input_string

        for i, te in enumerate(total_text):
            copy_string = copy_string.replace(te, (str(total_entities[i]).replace(" ", "")), 1)

        entity_list = []

        for i in copy_string.split(" "):
            prefix = (i[0:4])
            if prefix == "['B-":
                entry = [n.strip() for n in ast.literal_eval(i)]
                entity_list.extend(entry)
            else:
                entity_list.append("O")
                
        return entity_list

In [3]:
class LukeModel:
    def __init__(self):
        self.tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")
        self.model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")

    def get_entity_list(self, input_text):
        input_text = input_text.strip()

        split_text = input_text.split(" ")

        word_start_positions = [0]
        word_end_positions = [len(split_text[0])]
        words = [[word_start_positions[0], word_end_positions[0]]]

        for word in split_text[1:]:
            start_index = word_end_positions[-1] + 1
            word_start_positions.append(start_index)
            end_index = len(word) + word_start_positions[-1]
            word_end_positions.append(end_index)
            words.append([start_index, end_index])

        entity_spans = []
        for index, start_pos in enumerate(word_start_positions):
            for end_pos in word_end_positions[index:]:
                entity_spans.append((start_pos, end_pos))

        inputs = self.tokenizer(input_text, entity_spans=entity_spans, return_tensors="pt")
        outputs = self.model(**inputs)
        logits = outputs.logits

        predicted_class_indices = logits.argmax(-1).squeeze().tolist()
        if type(predicted_class_indices) == int:
            predicted_class_indices = [predicted_class_indices]

        text_entities = []
        total_entities = []

        for span, predicted_class_idx in zip(entity_spans, predicted_class_indices):
            if predicted_class_idx != 0:
                current_text = input_text[span[0]:span[1]]
                current_entity = str(self.model.config.id2label[predicted_class_idx])
                current_entities = ["B-" + current_entity]
                num_spaces = current_text.count(" ")
                if num_spaces >= 1:
                    current_entities.extend(["I-" + current_entity] * num_spaces)
                total_entities.append(current_entities)
                text_entities.append(current_text)

        copy_string = input_text
        for i, te in enumerate(text_entities):
            copy_string = copy_string.replace(te, (str(total_entities[i]).replace(" ", "")), 1)
        entity_list = []

        for i in copy_string.split(" "):
            prefix = (i[0:4])
            if prefix == "['B-":
                entry = [n.strip() for n in ast.literal_eval(i)]
                entity_list.extend(entry)
            else:
                entity_list.append("O")

        return entity_list


In [4]:
luke_model = LukeModel()
flair_model = FlairModel()

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/14.6M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/33.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.98k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.66k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/877 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.09G [00:00<?, ?B/s]

Some weights of the model checkpoint at studio-ousia/luke-large-finetuned-conll-2003 were not used when initializing LukeForEntitySpanClassification: ['luke.embeddings.position_ids']
- This IS expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

2021-10-20 10:51:35,026 loading file /root/.flair/models/ner-english-large/07301f59bb8cb113803be316267f06ddf9243cdbba92a4c8067ef92442d2c574.554244d3476d97501a766a98078421817b14654496b86f2f7bd139dc502a4f29


Downloading:   0%|          | 0.00/4.83M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/8.68M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/513 [00:00<?, ?B/s]

In [5]:
luke_model.get_entity_list("hello I'm David Peletz")

['O', 'O', 'B-PER', 'I-PER']

In [6]:
flair_model.get_entity_list("hello I'm David Peletz")

['O', 'O', 'B-PER', 'I-PER']

In [3]:
# Download the testb set of the CoNLL-2003 dataset
!wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testb

--2021-10-20 11:01:22--  https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testb
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 748096 (731K) [text/plain]
Saving to: ‘eng.testb.1’


2021-10-20 11:01:22 (8.68 MB/s) - ‘eng.testb.1’ saved [748096/748096]



In [6]:
model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")
model.eval()
model.to("cuda")

# Load the tokenizer
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")

Some weights of the model checkpoint at studio-ousia/luke-large-finetuned-conll-2003 were not used when initializing LukeForEntitySpanClassification: ['luke.embeddings.position_ids']
- This IS expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
def load_documents(dataset_file):
    documents = []
    words = []
    labels = []
    sentence_boundaries = []
    with open(dataset_file) as f:
        for line in f:
            line = line.rstrip()
            if line.startswith("-DOCSTART"):
                if words:
                    documents.append(dict(
                        words=words,
                        labels=labels,
                        sentence_boundaries=sentence_boundaries
                    ))
                    words = []
                    labels = []
                    sentence_boundaries = []
                continue

            if not line:
                if not sentence_boundaries or len(words) != sentence_boundaries[-1]:
                    sentence_boundaries.append(len(words))
            else:
                items = line.split(" ")
                words.append(items[0])
                labels.append(items[-1])

    if words:
        documents.append(dict(
            words=words,
            labels=labels,
            sentence_boundaries=sentence_boundaries
        ))
        
    return documents


def load_examples(documents):
    examples = []
    max_token_length = 510
    max_mention_length = 30

    for document in tqdm(documents):
        words = document["words"]
        subword_lengths = [len(tokenizer.tokenize(w)) for w in words]
        total_subword_length = sum(subword_lengths)
        sentence_boundaries = document["sentence_boundaries"]

        for i in range(len(sentence_boundaries) - 1):
            sentence_start, sentence_end = sentence_boundaries[i:i+2]
            if total_subword_length <= max_token_length:
                # if the total sequence length of the document is shorter than the
                # maximum token length, we simply use all words to build the sequence
                context_start = 0
                context_end = len(words)
            else:
                # if the total sequence length is longer than the maximum length, we add
                # the surrounding words of the target sentence　to the sequence until it
                # reaches the maximum length
                context_start = sentence_start
                context_end = sentence_end
                cur_length = sum(subword_lengths[context_start:context_end])
                while True:
                    if context_start > 0:
                        if cur_length + subword_lengths[context_start - 1] <= max_token_length:
                            cur_length += subword_lengths[context_start - 1]
                            context_start -= 1
                        else:
                            break
                    if context_end < len(words):
                        if cur_length + subword_lengths[context_end] <= max_token_length:
                            cur_length += subword_lengths[context_end]
                            context_end += 1
                        else:
                            break

            text = ""
            for word in words[context_start:sentence_start]:
                if word[0] == "'" or (len(word) == 1 and is_punctuation(word)):
                    text = text.rstrip()
                text += word
                text += " "

            sentence_words = words[sentence_start:sentence_end]
            sentence_subword_lengths = subword_lengths[sentence_start:sentence_end]

            word_start_char_positions = []
            word_end_char_positions = []
            for word in sentence_words:
                if word[0] == "'" or (len(word) == 1 and is_punctuation(word)):
                    text = text.rstrip()
                word_start_char_positions.append(len(text))
                text += word
                word_end_char_positions.append(len(text))
                text += " "

            for word in words[sentence_end:context_end]:
                if word[0] == "'" or (len(word) == 1 and is_punctuation(word)):
                    text = text.rstrip()
                text += word
                text += " "
            text = text.rstrip()

            entity_spans = []
            original_word_spans = []
            for word_start in range(len(sentence_words)):
                for word_end in range(word_start, len(sentence_words)):
                    if sum(sentence_subword_lengths[word_start:word_end]) <= max_mention_length:
                        entity_spans.append(
                            (word_start_char_positions[word_start], word_end_char_positions[word_end])
                        )
                        original_word_spans.append(
                            (word_start, word_end + 1)
                        )

            examples.append(dict(
                text=text,
                words=sentence_words,
                entity_spans=entity_spans,
                original_word_spans=original_word_spans,
            ))

    return examples


def is_punctuation(char):
    cp = ord(char)
    if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
        return True
    cat = unicodedata.category(char)
    if cat.startswith("P"):
        return True
    return False

In [8]:
test_documents = load_documents("eng.testb")
test_examples = load_examples(test_documents)

100%|██████████| 231/231 [00:03<00:00, 63.02it/s]


In [10]:
batch_size = 2
all_logits = []

for batch_start_idx in trange(0, len(test_examples), batch_size):
    batch_examples = test_examples[batch_start_idx:batch_start_idx + batch_size]
    texts = [example["text"] for example in batch_examples]
    entity_spans = [example["entity_spans"] for example in batch_examples]

    inputs = tokenizer(texts, entity_spans=entity_spans, return_tensors="pt", padding=True)
    inputs = inputs.to("cuda")
    with torch.no_grad():
        outputs = model(**inputs)
    all_logits.extend(outputs.logits.tolist())

100%|██████████| 1727/1727 [05:45<00:00,  4.99it/s]


In [11]:
final_labels = [label for document in test_documents for label in document["labels"]]
final_predictions = []
for example_index, example in enumerate(test_examples):
    logits = all_logits[example_index]
    max_logits = np.max(logits, axis=1)
    max_indices = np.argmax(logits, axis=1)
    original_spans = example["original_word_spans"]
    predictions = []
    for logit, index, span in zip(max_logits, max_indices, original_spans):
        if index != 0:  # the span is not NIL
            predictions.append((logit, span, model.config.id2label[index]))

    # construct an IOB2 label sequence
    predicted_sequence = ["O"] * len(example["words"])
    for _, span, label in sorted(predictions, key=lambda o: o[0], reverse=True):
        if all([o == "O" for o in predicted_sequence[span[0] : span[1]]]):
            predicted_sequence[span[0]] = "B-" + label
            if span[1] - span[0] > 1:
                predicted_sequence[span[0] + 1 : span[1]] = ["I-" + label] * (span[1] - span[0] - 1)

    final_predictions += predicted_sequence

In [12]:
print(seqeval.metrics.classification_report([final_labels], [final_predictions], digits=4)) 

              precision    recall  f1-score   support

         LOC     0.9558    0.9478    0.9518      1666
        MISC     0.8553    0.8688    0.8620       701
         ORG     0.9287    0.9496    0.9391      1647
         PER     0.9683    0.9719    0.9701      1602

   micro avg     0.9386    0.9453    0.9420      5616
   macro avg     0.9270    0.9345    0.9307      5616
weighted avg     0.9389    0.9453    0.9421      5616

