In [None]:
############################################
## Entity Tagging using Flair for NYT corpus
############################################

# Setup instructions
# 1. Download the flair ner model "en-ner-conll03-v0.4.pt" to the same directory.
# 2. Install flair: pip install flair



In [20]:
import re
import json
from collections import defaultdict
from flair.models import SequenceTagger
from flair.tokenization import SegtokSentenceSplitter

# Initialize Tagger and Sentence splitter
tagger = SequenceTagger.load('en-ner-conll03-v0.4.pt')
splitter = SegtokSentenceSplitter()

def get_tagged_field(field_name, stories):
    type_ctr = defaultdict(int)
    overall_story_len = 0
    selected_stories = []
    for idx, story_data in enumerate(stories):
        
        if idx % 100 == 0:
            print(idx, " document processed.")
        
        text = story_data[field_name].replace("\n\n", " ")

        # use splitter to split text into list of sentences
        sentences = splitter.split(text)

        # predict tags for sentences
        tagger.predict(sentences)

        # iterate through sentences and print predicted labels
        entities = []
        tagged_sentences = []
        sentence_text = []
        story_len = 0
        story_offset = 0
        for sentence in sentences:
            story_len += len(sentence.tokens)
            sentence_text.append(sentence.to_original_text())
            for entity in sentence.get_spans('ner'):
                entity_label = entity.labels[0].value
                if entity_label == "ORG" or entity_label == "PER":
                    entities.append({"surface": entity.text,
                                     "type": entity_label,
                                     "startCharOffset": entity.start_pos + story_offset,
                                     "endCharOffset": entity.start_pos + story_offset + len(entity.text)})
            story_offset += len(sentence.to_original_text()) + 1

        story_data[field_name] = " ".join(sentence_text).strip()
        story_data[field_name + "_sentences"] = sentence_text
        story_data[field_name + "_entities"] = entities

        if story_len <= 1500:
            selected_stories.append(story_data)
            overall_story_len += story_len
            # Validate entities and capture statistics
            for ent_idx, entity in enumerate(entities):
                actual_text = story_data[field_name][entity["startCharOffset"]:entity["endCharOffset"]]
                if actual_text != entity["surface"]:
                    del entities[ent_idx]
                    continue
                type_ctr[entity["type"]] += 1
    
    print("Average Story Length: ", overall_story_len/(len(selected_stories)*1.0))
    print("Entity Distribution: ", type_ctr)
    return selected_stories



In [19]:
# Read file and predict entities for NYT corpus

DATA_FNAME = "<Input jsonl file>"
OUT_FNAME = "<Output jsonl file with entity information>"

stories = []
for line in open(DATA_FNAME):
    story_data = json.loads(line)
    stories.append(story_data)
print("Number of stories: ", len(stories))

stories = get_tagged_field("article_text", stories)
print("Done article text")
stories = get_tagged_field("abstract_text", stories)
print("Done abstract text")

fout = open(OUT_FNAME, "wb")
for story_data in stories:
    fout.write(str.encode(json.dumps(story_data), "utf-8"))
    fout.write(str.encode("\n", "utf-8"))
fout.close()
    