In [1]:
import os.path

LINKS_DATA = "../data/links-v3.tsv"
OUTPUT_FOLDER = "../data/v4"
PROCESS_COUNT = 16

In [2]:
import typing


def make_range(s: str) -> typing.Tuple[int, int]:
    r = [int(i) for i in s.split("-")[:2]]
    return r[0], r[1]


def read_lines(filename):
    with open(filename, "r", encoding="utf-8") as f:
        for line in f.readlines():
            try:
                line = line.strip().split("\t")
                para_id = line[0]
                text = line[1]
                ranges = [make_range(s) for s in line[2:]]
                yield dict(para_id=para_id, text=text.lower(), ranges=ranges)
            except Exception as e:
                print(e)
                print(line)
                raise


def line_count(filename):
    with open(filename, "r", encoding="utf-8") as f:
        return len(f.readlines())


In [3]:
data_count = line_count(LINKS_DATA)
data_count


414289

In [4]:
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from chategw.util import split_every
import tqdm
import os.path
import spacy
from spacy.tokens import DocBin, Doc

spacy.require_cpu()
nlp = spacy.load("en_core_web_lg")
if not os.path.exists(OUTPUT_FOLDER):
    os.mkdir(OUTPUT_FOLDER, mode=0o755)

In [5]:
allowed_labels = {'PERSON', 'DATE', 'LOC', 'GPE'}


def read_docs(count: int, lines: typing.Iterable[str]) -> typing.Iterable[Doc]:
    for chunk in split_every(count, tqdm.tqdm(lines, total=data_count, desc="Loading paragraphs", unit="paragraph")):
        docs = nlp.pipe([r['text'] for r in chunk], n_process=PROCESS_COUNT, batch_size=512)
        for doc, row in zip(docs, chunk):
            raw_entities = [ent for ent in doc.ents if ent.label_ in allowed_labels]
            entities = []
            for link in row['ranges']:
                entities.append(doc.char_span(link[0], link[1], label="REFERENCE"))
            has_ranges = bool(entities)
            has_entities = False
            for e in raw_entities:
                can_add = True
                for added_entity in entities:
                    if (added_entity.start - 1 <= e.start <= added_entity.end + 1) or \
                            (added_entity.start - 1 <= e.end <= added_entity.end + 1):
                        can_add = False
                        break
                if can_add:
                    entities.append(e)
                    has_entities = True
            try:
                if has_ranges and has_entities:
                    doc.ents = entities
                    yield doc
            except ValueError:
                continue

In [6]:
all_docs = []
for row in read_docs(16_000, read_lines(LINKS_DATA)):
    all_docs.append(row)

Loading paragraphs: 100%|██████████| 414289/414289 [06:26<00:00, 1071.89paragraph/s]


In [7]:
len(all_docs)

226110

In [8]:
from sklearn.model_selection import train_test_split
train_set, validation_set = train_test_split(all_docs, test_size=0.25)
validation_set, test_set = train_test_split(validation_set, test_size=0.3)

In [9]:
def save_docs(docs : typing.List[Doc], filename: str):
    db = DocBin()
    for r in docs:
        db.add(r)
    db.to_disk(filename)

In [10]:
save_docs(train_set, os.path.join(OUTPUT_FOLDER, "train.spacy"))
save_docs(validation_set, os.path.join(OUTPUT_FOLDER, "dev.spacy"))
save_docs(test_set, os.path.join(OUTPUT_FOLDER, "test.spacy"))