In [86]:
import random

import spacy
from spacy.util import minibatch, compounding
from spacy.kb import KnowledgeBase

from src.utils import LanguageService

In [73]:
nlp = spacy.load("en_core_web_md")
vocab = nlp.vocab
kb = KnowledgeBase(vocab=vocab, entity_vector_length=300)

In [74]:
URL = "https://www.cnbc.com/2020/09/22/palantir-says-it-expects-42percent-revenue-growth-this-year-to-1point06-billion.html"
ls = LanguageService()
doc = ls.download_article(URL)
sentences = [x.text.replace("\n", "") for x in doc.sents]

In [75]:
ENTITIES = [
    {"name": "Palantir", "label": "ORG", "KB_QID": "Q2047336"},
    {"name": "Alex Karp", "label": "PERSON", "KB_QID": "Q19560940"},
    {"name": "Elysee Palace", "label": "FAC", "KB_QID": "Q188190"},
    {"name": "Paris", "label": "GPE", "KB_QID": "Q90"}
]


def find_all_mention_spans(text, substring):
    import re
    return [(m.start(), m.end()) for m in re.finditer(substring, text)]


def generate_training_data(entities, sentences):
    TRAIN_DATA = []

    for sentence in sentences:
        entity_mentions_in_sentence = []
        for entity in ENTITIES:
            entity_mentions_in_sentence += [
                (x[0], x[1], entity['label']) for x in find_all_mention_spans(sentence, entity['name'])
            ]
        if len(entity_mentions_in_sentence) > 0:
            TRAIN_DATA += [(
                sentence,
                {"entities": entity_mentions_in_sentence}
            )]

    return TRAIN_DATA

In [76]:
TRAIN_DATA = generate_training_data(ENTITIES, sentences)

In [79]:
def get_word_vector_from_doc(word, doc):
    spans = find_all_mention_spans(doc.text, word)
    if len(spans) == 0:
        raise RuntimeError(f"{word} not found in document")
    start, end = spans[0]
    return doc.char_span(start, end).vector

In [80]:
for entity in ENTITIES:
    name = entity['name']
    vector = get_word_vector_from_doc(name, doc)
    kb.add_entity(
        entity=entity['KB_QID'],
        freq=doc.text.count(name),
        entity_vector=vector
    )

  kb.add_entity(


In [87]:
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
n_iter = 10

with nlp.disable_pipes(*other_pipes):
    for itn in range(n_iter):
        random.shuffle(TRAIN_DATA)
        losses = {}
        batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
        for batch in batches:
            texts, annotations = zip(*batch)
            nlp.update(texts, annotations, drop=0.35, losses=losses)
        print("Losses", losses)

# test the trained model
for text, _ in TRAIN_DATA:
    doc = nlp(text)
    print("Entities", [(ent.text, ent.label_) for ent in doc.ents])
    print("Tokens", [(t.text, t.ent_type_, t.ent_iob) for t in doc])
    print()

Losses {'ner': 35.909622068924364}
Losses {'ner': 26.541565783525584}
Losses {'ner': 28.069616295517562}
Losses {'ner': 18.780868357628048}
Losses {'ner': 5.333915633486223}
Losses {'ner': 9.809165925092229}
Losses {'ner': 3.1300867514058615}
Losses {'ner': 4.783773336845864}
Losses {'ner': 2.987163951148775}
Losses {'ner': 2.008933871972911}
Entities [('Palantir', 'ORG')]
Tokens [('Palantir', 'ORG', 3), ("'s", '', 2), ('update', '', 2), ('shows', '', 2), ('that', '', 2), ('revenue', '', 2), ('growth', '', 2), ('is', '', 2), ('accelerating', '', 2), ('from', '', 2), ('2019', '', 2), (',', '', 2), ('when', '', 2), ('the', '', 2), ('company', '', 2), ('reported', '', 2), ('a', '', 2), ('25', '', 2), ('%', '', 2), ('increase', '', 2), ('to', '', 2), ('$', '', 2), ('742.6', '', 2), ('million', '', 2), ('.', '', 2)]

Entities [('Palantir', 'ORG')]
Tokens [(':', '', 2), ('Palantir', 'ORG', 3), ('prepares', '', 2), ('to', '', 2), ('go', '', 2), ('public', '', 2)]

Entities [('Palantir', 'ORG'

In [89]:
d2 = nlp(" ".join(sentences))

In [91]:
d2.ents

(Palantir,
 Alex Karp,
 the Elysee Palace,
 Paris,
 Palantir,
 Palantir,
 Palantir,
 Palantir,
 Palantir,
 Palantir,
 9.17,
 Palantir)

## Eventually

1. Run article thru model, noting extracted entities
    - for those that don't resolve to a KnowledgeBase entity, ask user to find it on Wikidata

2. Add new entities to KnowledgeBase

3. Retrain model

4. Re-run article thru model


In [52]:
# 1. Run article thru model, noting extracted entities
doc = nlp(article)

#     - for those that don't resolve to a KnowledgeBase entity, ask user to find it on Wikidata
unrecognized_entities = []
for entity in doc.ents:
    if entity.kb_id_ == '':
        qid = input(f"Wikidata QID for {entity.text} (press Enter to skip): ")
        if qid != '':
            determined_label = input(f"NER annotation tag for {entity.text}: "
            # TODO: determine from Wikidata SPARQL query on the QID doc
            unrecognized_entities += [{"name": entity.text, "label": determined_label, "KB_QID": qid}]


# 2. Add new entities to KnowledgeBase
for entity in unrecognized_entities:
    name = entity['name']
    vector = get_word_vector_from_doc(name, doc)
    kb.add_entity(
        entity=entity['KB_QID'],
        freq=doc.text.count(name),
        entity_vector=vector
    )


# 3. Retrain model
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
N_ITER = 10

with nlp.disable_pipes(*other_pipes):
    for itn in range(n_iter):
        random.shuffle(TRAIN_DATA)
        losses = {}
        batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
        for batch in batches:
            texts, annotations = zip(*batch)
            nlp.update(texts, annotations, drop=0.35, losses=losses)
        print("Losses", losses)


# 4. Re-run article thru model

spacy.tokens.span.Span

In [54]:
doc.ents[0].kb_id_

''