#### We want to determine which questions will be answered with a single PERSON. In this notebook, we will:
0. Load in positives and negatives and format into a training and validation set
1. Train and save a CNN text classifer using the spaCy and thinc

##### Step 0. Load and format the training data

In [1]:
import pickle
from sklearn.model_selection import train_test_split

with open("../data/positives.pkl", "rb") as f:
    positives = pickle.load(f)
with open("../data/negatives.pkl", "rb") as f:
    negatives = pickle.load(f)

pos_texts = [positive['question'] for positive in positives]
neg_texts = [negative['question'] for negative in negatives]
pos_cats = [{"POSITIVE": True, "NEGATIVE": False} for _ in positives]
neg_cats = [{"POSITIVE": False, "NEGATIVE": True} for _ in negatives]
train_texts, dev_texts, train_cats, dev_cats = train_test_split(pos_texts + neg_texts, pos_cats + neg_cats, test_size=0.2, random_state=1)

##### Step 1. Load and format the training data

In [4]:
# Step 2b. Train the model
import os
import random
from pathlib import Path

import spacy
from spacy.util import minibatch, compounding


def evaluate(tokenizer, textcat, texts, cats):
    docs = (tokenizer(text) for text in texts)
    tp = 0.0  # True positives
    fp = 1e-8  # False positives
    fn = 1e-8  # False negatives
    tn = 0.0  # True negatives
    for i, doc in enumerate(textcat.pipe(docs)):
        gold = cats[i]
        for label, score in doc.cats.items():
            if label not in gold:
                continue
            if label == "NEGATIVE":
                continue
            if score >= 0.5 and gold[label] >= 0.5:
                tp += 1.0
            elif score >= 0.5 and gold[label] < 0.5:
                fp += 1.0
            elif score < 0.5 and gold[label] < 0.5:
                tn += 1
            elif score < 0.5 and gold[label] >= 0.5:
                fn += 1
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    if (precision + recall) == 0:
        f_score = 0.0
    else:
        f_score = 2 * (precision * recall) / (precision + recall)
    return {"textcat_p": precision, "textcat_r": recall, "textcat_f": f_score}


nlp = spacy.blank("en")

if "textcat" not in nlp.pipe_names:
    textcat = nlp.create_pipe(
        "textcat", config={"exclusive_classes": True, "architecture": "simple_cnn"}
    )
    nlp.add_pipe(textcat, last=True)
else:
    textcat = nlp.get_pipe("textcat")

# add label to text classifier
textcat.add_label("POSITIVE")
textcat.add_label("NEGATIVE")

train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))
# get names of other pipes to disable them during training
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"]
with nlp.disable_pipes(*other_pipes):  # only train textcat
    optimizer = nlp.begin_training()
    print("Training the model...")
    print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
    batch_sizes = compounding(4.0, 32.0, 1.001)
    for i in range(10):
        losses = {}
        random.shuffle(train_data)
        batches = minibatch(train_data, size=batch_sizes)
        for batch in batches:
            texts, annotations = zip(*batch)
            nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)
        with textcat.model.use_params(optimizer.averages):
            scores = evaluate(nlp.tokenizer, textcat, dev_texts, dev_cats)
        print(
            "{0:.3f}\t{1:.3f}\t{2:.3f}\t{3:.3f}".format(  # print a simple table
                losses["textcat"],
                scores["textcat_p"],
                scores["textcat_r"],
                scores["textcat_f"],
            )
        )
        folder_name = f"../models/epoch{i}"
        os.mkdir(folder_name)
        with nlp.use_params(optimizer.averages):
            nlp.to_disk(folder_name)

Training the model...
LOSS 	  P  	  R  	  F  
4.270	0.706	0.805	0.752
0.167	0.717	0.806	0.759
0.151	0.728	0.800	0.762
0.133	0.731	0.791	0.760
0.119	0.731	0.787	0.758
0.102	0.729	0.776	0.752
0.086	0.729	0.769	0.749
0.072	0.734	0.771	0.752
0.061	0.733	0.755	0.744
0.054	0.729	0.743	0.736


##### Step 2. Randomly sample results from textcat on test data

In [14]:
import pickle
with open('../data/raw_test.pkl', 'rb') as f:
     raw_test = pickle.load(f)
        
raw_test_questions = map(lambda item: item['question'], raw_test)

nlp_infer = spacy.load("../models/epoch2")
for index, (doc, context) in enumerate(nlp_infer.pipe(zip(raw_test_questions, raw_test), as_tuples=True)):
    print(f"Question: {context['question']}")
    print(f"Answer: {context['answer']}")
    print(f"Score: {doc.cats}" + "\n")
    if index > 9:
        break

Question: 'This type of yoga is Sanskrit for "discipline of force" & it's better than none'
Answer: hatha yoga
Score: {'POSITIVE': 9.62425401667133e-05, 'NEGATIVE': 0.9999037981033325}

Question: '4 treaties to mitigate the horrors of war were signed in this city in August 1949'
Answer: Geneva
Score: {'POSITIVE': 7.687673496548086e-05, 'NEGATIVE': 0.9999231100082397}

Question: 'On Dec. 13, 1937 Japan took over the city of Nanking in this Asian country after heavy fighting'
Answer: China
Score: {'POSITIVE': 0.00034473929554224014, 'NEGATIVE': 0.9996552467346191}

Question: 'It's the island where Fay Wray first encountered King Kong; to think of its name, use your "head"'
Answer: Skull Island
Score: {'POSITIVE': 0.0018007828621193767, 'NEGATIVE': 0.9981992840766907}

Question: 'The Metropolitan Museum of Art paid a record $143,352 for the oldest hand-painted complete deck of these'
Answer: Playing cards
Score: {'POSITIVE': 0.001831972156651318, 'NEGATIVE': 0.9981679916381836}

Question:

##### Step 3: Investigate the Positive hits from textcat on test data

In [None]:
import pickle
with open('../data/raw_test.pkl', 'rb') as f:
     raw_test = pickle.load(f)
        
raw_test_questions = map(lambda item: item['question'], raw_test)

nlp_infer = spacy.load("../models/epoch2")
hits = []
for index, (doc, context) in enumerate(nlp_infer.pipe(zip(raw_test_questions, raw_test), as_tuples=True)):
    if doc.cats['POSITIVE'] > 0.5:
        hits.append((doc.cats, context))