In [11]:
from typing import Mapping, Sequence, Dict, Optional, List, Iterable

import json, pickle
from copy import deepcopy

import spacy
from spacy.tokens import Doc, Span, Token
from utils import (
    ingest_json_document,
    evaluate_and_print,
)


from feature_extractors import (
    BiasFeature,
    TokenFeature,
    DigitFeature,
    UppercaseFeature,
    TitlecaseFeature,
    InitialTitlecaseFeature,
    PunctuationFeature,
    WordShapeFeature,
    LikelyAdjectiveFeature,
    AfterVerbFeature,
    WordVectorFeature,
    BrownClusterFeature,
    WindowedTokenFeatureExtractor,
)

from entity_recognizer import (
    BILOUEncoder,
    CRFsuiteEntityRecognizer,
)

In [2]:
nlp = spacy.load("en_core_web_sm", disable=["ner"])
with open('data/corpus_train.jsonl', 'r', encoding='utf8', errors='ignore') as train_file:
    train_docs = [ingest_json_document(json.loads(line), nlp) for line in train_file]

with open('data/corpus_dev.jsonl', 'r', encoding='utf8', errors='ignore') as dev_file:
    dev_docs = [ingest_json_document(json.loads(line), nlp) for line in dev_file]

In [5]:
dev_gold = deepcopy(dev_docs)
for doc in dev_docs:
    doc.ents = []

In [59]:
word_vector_file_path = "models/wiki-news-300d-1M-subword.magnitude"
brown_cluster_file_path = "models/rcv1.64M-c10240-p1.paths"

best_features = [
        BiasFeature(),
        TokenFeature(),
        UppercaseFeature(),
        TitlecaseFeature(),
        # InitialTitlecaseFeature(),
        DigitFeature(),
        PunctuationFeature(),
        WordShapeFeature(),
        # LikelyAdjectiveFeature(),
        # AfterVerbFeature(),
        # WordVectorFeature(word_vector_file_path, 1.0),
        # BrownClusterFeature(
        #     brown_cluster_file_path,
        #     use_full_paths=False,
        #     use_prefixes=True,
        #     prefixes=[4, 6, 10, 20],
        # ),
    ]

crf_model = CRFsuiteEntityRecognizer(
    WindowedTokenFeatureExtractor(best_features, 2), BILOUEncoder()
)

In [60]:
%%time
crf_model.train(train_docs, "ap", {"max_iterations": 100}, "tmp.model")

Wall time: 3.01 s


In [62]:
dev_predicted = [crf_model(doc) for doc in dev_docs]
evaluate_and_print(dev_gold, dev_predicted)

Type	Prec	Rec	F1
ALL	62.24	55.35	58.59
AVATAR	57.81	46.84	51.75
GAME	88.89	88.89	88.89
ORG	80	60	68.57
PLAYER	46.71	46.99	46.85
SPONS	0	0	0
TOURN	52.83	71.79	60.87
