In [1]:
#!/usr/bin/env python
# coding: utf8
"""Derived from https://raw.githubusercontent.com/explosion/spaCy/master/examples/training/train_textcat.py
"""
import plac
import random
from pathlib import Path
from tqdm import tqdm
import json

import spacy
from spacy.util import minibatch, compounding
import re
import mwparserfromhell as mwparser
from smart_open import smart_open
import dill
from itertools import islice, groupby
from operator import itemgetter


In [2]:
def read_ndjson(file):
    """Load lines from new-line delimited JSON file"""
    for line in file:
        yield json.loads(line)

In [3]:
def has_pov(doc):
    return bool(len(doc.filter_templates(matches="POV", recursive=True, flags=re.I)))

In [4]:
def load_wp10(filename):
    with smart_open(filename, "r") as f:
        for doc in tqdm(read_ndjson(f)):
            if doc['wp10'] in ("FA", "GA"):
                text = doc['wikitext']
                parsed = mwparser.parse(text)
                yield (wiki2plaintext(parsed), has_pov(parsed))
                
def load_npov(filename):
    with smart_open(filename, "r") as f:
        for doc in tqdm(read_ndjson(f)):
            text = doc['content']
            parsed = mwparser.parse(text)
            yield (wiki2plaintext(parsed), has_pov(parsed))                

In [5]:
def wiki2plaintext(parsed):
    """Convert Wiki markup to plain text."""
    re_image_wl = re.compile('^(?:File|Image|Media):', flags=re.I | re.U)
    bad_template_names = {
        'reflist', 'notelist', 'notelist-ua', 'notelist-lr', 'notelist-ur', 'notelist-lg'}
    bad_tags = {'ref', 'table'}

    def is_bad_wikilink(obj):
        return bool(re_image_wl.match(str(obj.title)))

    def is_bad_tag(obj):
        return str(obj.tag) in bad_tags

    def is_bad_template(obj):
        return obj.name.lower() in bad_template_names


    texts = []
    # strip out references, tables, and file/image links
    # then concatenate the stripped text of each section
    sections = parsed.get_sections(flat=True, include_lead=True,
                                     include_headings=False)
    for i, section in enumerate(sections):
        for obj in section.ifilter_wikilinks(matches=is_bad_wikilink,
                                             recursive=True):
            try:
                section.remove(obj)
            except Exception:
                continue
        for obj in section.ifilter_templates(matches=is_bad_template,
                                             recursive=True):
            try:
                section.remove(obj)
            except Exception:
                continue
        for obj in section.ifilter_tags(matches=is_bad_tag, recursive=True):
            try:
                section.remove(obj)
            except Exception:
                continue
        texts.append(section.strip_code().strip())

    return '\n\n'.join(texts)


In [20]:
filename_wp10 = "../data/enwiki.labeling_revisions.w_text.nettrom_30k.ndjson.gz"
filename_npov = "../data/npov.ndjson.gz"
corpus = list(itertools.chain(load_wp10(filename_wp10), load_npov(filename_npov)))

with open("../data/npov-experiment-corpus.pkl", "wb") as f:
    dill.dump(corpus, f)

32424it [16:34, 32.60it/s]  
6467it [05:19, 20.23it/s]


In [36]:
with open("../data/npov-experiment-corpus.pkl", "rb") as f:
    corpus = dill.load(f)

In [31]:
split = 0.8
train_data = []
dev_data = []
for pov, g in itertools.groupby(corpus, itemgetter(1)):
    # POV = TRUE
    groupvals = list(g)
    limit = int(len(groupvals) * split)
    for i, x in enumerate(groupvals):
        if i <= limit:
            train_data.append((x[0], {'cats': {'POV': bool(pov)}}))
        else:
            dev_data.append((x[0], {'POV': bool(pov)}))
random.shuffle(train_data)
random.shuffle(dev_data)   
dev_texts, dev_cats = list(zip(*dev_data))

In [32]:
# output directory
output_dir = "../data"
# model name
model = "en_core_web_lg"
nlp = spacy.load(model)
print("Loaded model '%s'" % model)

Loaded model 'en_core_web_lg'


In [33]:
# add the text classifier to the pipeline if it doesn't exist
# nlp.create_pipe works for built-ins that are registered with spaCy
if 'textcat' not in nlp.pipe_names:
    textcat = nlp.create_pipe('textcat')
    nlp.add_pipe(textcat, last=True)
# otherwise, get it, so we can add labels to it
else:
    textcat = nlp.get_pipe('textcat')

# add label to text classifier
textcat.add_label('POV')

1

In [34]:
n_train = len(train_data)
n_dev = len(dev_data)
n_texts = n_train + n_dev
print(f"Using {n_texts} examples ({n_train} training, {n_dev} evaluation)")


Using 16958 examples (13754 training, 3204 evaluation)


In [35]:
# number of training iterations
n_iter = 2
drop = 0.2
size = compounding(4., 32., 1.001)

# get names of other pipes to disable them during training
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'textcat']
with nlp.disable_pipes(*other_pipes):  # only train textcat
    optimizer = nlp.begin_training()
    print("Training the model...")
    print('{:^5}\t{:^5}\t{:^5}\t{:^5}'.format('LOSS', 'P', 'R', 'F'))
    for i in range(n_iter):
        losses = {}
        # batch up the examples using spaCy's minibatch
        batches = minibatch(train_data, size=size)
        for batch in batches:
            texts, annotations = zip(*batch)
            nlp.update(texts, annotations, sgd=optimizer, drop=drop, losses=losses)
        with textcat.model.use_params(optimizer.averages):
            # evaluate on the dev data split off in load_data()
            scores = evaluate(nlp.tokenizer, textcat, zip(*dev_data))
        print('{0:.3f}\t{1:.3f}\t{2:.3f}\t{3:.3f}'  # print a simple table
              .format(losses['textcat'], scores['textcat_p'],
                      scores['textcat_r'], scores['textcat_f']))

Training the model...
LOSS 	  P  	  R  	  F  


KeyboardInterrupt: 

In [None]:
# test the trained model
test_text = "This movie sucked"
doc = nlp(test_text)
print(test_text, doc.cats)

if output_dir is not None:
    output_dir = Path(output_dir)
    if not output_dir.exists():
        output_dir.mkdir()
    nlp.to_disk(output_dir)
    print("Saved model to", output_dir)

    # test the saved model
    print("Loading from", output_dir)
    nlp2 = spacy.load(output_dir)
    doc2 = nlp2(test_text)
    print(test_text, doc2.cats)

In [None]:
def evaluate(tokenizer, textcat, texts, cats):
    docs = (tokenizer(text) for text in texts)
    tp = 1e-8  # True positives
    fp = 1e-8  # False positives
    fn = 1e-8  # False negatives
    tn = 1e-8  # True negatives
    for i, doc in enumerate(textcat.pipe(docs)):
        gold = cats[i]
        for label, score in doc.cats.items():
            if label not in gold:
                continue
            if score >= 0.5 and gold[label] >= 0.5:
                tp += 1.
            elif score >= 0.5 and gold[label] < 0.5:
                fp += 1.
            elif score < 0.5 and gold[label] < 0.5:
                tn += 1
            elif score < 0.5 and gold[label] >= 0.5:
                fn += 1
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f_score = 2 * (precision * recall) / (precision + recall)
    return {'textcat_p': precision, 'textcat_r': recall, 'textcat_f': f_score}