# Question Parsing
To correctly parse questions, we can employ many techniques though none are trivial or have good coverage.  
Instead, what we will do is write a bunch of rule and monitor the coverage and overlap on the training data.  
We'll begin with focusing on extracting entities or key terms. We'll follow on with attributes, predicates and prepositions later

In [3]:
# Load all questions 
import json
import spacy
import numpy as np
nlp = spacy.load('en_core_web_lg')
from pytorch_pretrained_bert.tokenization import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [82]:
for word in nlp.Defaults.stop_words:
    lex = nlp.vocab[word]
    lex.is_stop = True

In [2]:
with open("dataset/train-v2.0.json", 'r') as handle:
    jdata = json.load(handle)
    data = jdata['data']
contexts = []
questions = []
unanswerable = []
answerable = []
for i in range(len(data)):
    section = data[i]['paragraphs']
    for sec in section:
        context = sec['context']
        contexts.append(context)
        qas = sec['qas']
        for j in range(len(qas)):
            question = qas[j]['question']
            questions.append(question)
            label = qas[j]['is_impossible']
            if label:
                unanswerable.append((len(contexts)-1, len(questions)-1))
            else:
                answerable.append((len(contexts)-1, len(questions)-1))

# Rules

In [163]:
q_words = ['what', 'when','where', 'who', 'how', 'why', 'which']
def rule_nsubj(q, _):
    q = nlp(q)
    for chunk in q.noun_chunks:
        if chunk.root.dep_ == 'nsubj' or chunk.root.dep_ == 'nsubjpass':
            if chunk[0].text.lower() in q_words: 
                if len(chunk) > 1:
                    if chunk[1].text.lower() == 'many':
                        if len(chunk[2:]):
                            return chunk[2:]
                        else:
                            return None
                    return chunk[1:]
                else:
                    return None
            else:
                return chunk

def rule_ents(q, _):
    q = nlp(q)
    ents = list(q.ents)
    if not ents:
        return None
    if len(ents) == 1:
        return ents[0]
    return_ents = []
    for ent in ents:
        if ent.root.dep_ != 'pobj':
            return_ents.append(ent)
    if len(ents) == 1:
        return ents[0]
    for ent in return_ents:
        if ent.root.dep_ != 'nsubj' or ent.root.dep_ != 'nsubjpass':
            return ent

def overlapping_spans(q, c):
    qt = tokenizer.tokenize(q)
    ct = tokenizer.tokenize(c)
    # build index
    output_set = set()
    context_ngram_set = set()
    for i in range(len(ct)):
        for j in range(10):
            if j == 0:
                continue
            context_ngram_set.add(tuple(ct[i:i+j]))
    skip = 0
    for i in range(len(qt)):
        longest = None
        if skip:
            skip -= 1
            continue
        for j in range(10):
            if j == 0 or len(qt[i:i+j]) < j:
                continue
            span = tuple(qt[i:i+j])
            if span in context_ngram_set:
                longest = span
            if span not in context_ngram_set and longest:
                output_set.add(longest)
                skip = len(longest) - 1 
                break
    return output_set

def rule_token_match(q, c):
    spans = overlapping_spans(q,c)
    kept_spans = []
    for span in spans:
        for token in span:
            if token not in nlp.vocab or not nlp.vocab[token].is_stop:
                kept_spans.append(span)
                continue
    max_span = None
    for span in kept_spans:
        if max_span is None:
            max_span = span
            continue
        if len(span) > len(max_span):
            max_span = span
    return max_span

In [154]:
def parse_all(idx_tuples, rules):
    total = len(idx_tuples)
    coverage = 0
    overlap = np.zeros((len(rules), len(rules)))
    for i, (cid, qid) in enumerate(idx_tuples):
        c = contexts[cid]
        q = questions[qid]
        overlapping_rules = set()
        for rule_i, rule in enumerate(rules):
            ent = rule(q,c)
            if ent:
                overlapping_rules.add(rule_i)
        for r in overlapping_rules:
            for rr in overlapping_rules:
                overlap[r,rr] += 1
        if len(overlapping_rules) > 0:
            coverage += 1
    print("Coverage: ", coverage / float(total))
    print("Overlap matrix: ")
    print(overlap)

def gen_input_and_target(ent):
    if type(ent) == tuple: # already tokenized
        ent_tokenized = ent
    else:
        ent_tokenized = tokenizer.tokenize(ent.text)
    return ent_tokenized

def gen_dataset(idx_tuples, rules):
    data = {}
    for i, (cid, qid) in enumerate(idx_tuples):
        c = contexts[cid]
        q = questions[qid]
        for rule_i, rule in enumerate(rules):
            ent = rule(q,c)
            if ent:
                target = gen_input_and_target(ent)
                if cid not in data:
                    data[cid] = set()
                data[cid].add(tuple(target))
    return data

In [152]:
parse_all(answerable[:100], [rule_nsubj, rule_ents, rule_token_match])

Coverage:  1.0
Overlap matrix: 
[[ 79.  72.  79.]
 [ 72.  92.  92.]
 [ 79.  92. 100.]]


In [167]:
labels = gen_dataset(answerable, [rule_nsubj, rule_ents, rule_token_match])

In [165]:
import pickle

In [168]:
with open("entity_labels_v2.pkl", "wb") as f:
    pickle.dump(labels, f)

In [169]:
len(labels)

18880

# Spot checks

In [16]:
import random

In [123]:
cid, qid = random.sample(answerable[:500], 1)[0]
c = contexts[cid]
q = questions[qid]
print(questions[qid])
print(rule_nsubj(q,c))
print(rule_ents(q,c))
print(rule_token_match(q,c))

Who was the letter addressed to?
None
None
('addressed', 'to')


In [91]:
for token in nlp(q):
    print (token, token.pos_)

What NOUN
was VERB
Angela PROPN
Merkel PROPN
serving VERB
as ADP
in ADP
relation NOUN
to ADP
the DET
letter NOUN
? PUNCT


In [173]:
for i, question in enumerate(questions):
    if "first to invade Manchuria" in question:
        print(i)


107295


In [177]:
contexts[15759]

"A series of international crises strained the League to its limits, the earliest being the invasion of Manchuria by Japan and the Abyssinian crisis of 1935/36 in which Italy invaded Abyssinia, one of the only free African nations at that time. The League tried to enforce economic sanctions upon Italy, but to no avail. The incident highlighted French and British weakness, exemplified by their reluctance to alienate Italy and lose her as their ally. The limited actions taken by the Western powers pushed Mussolini's Italy towards alliance with Hitler's Germany anyway. The Abyssinian war showed Hitler how weak the League was and encouraged the remilitarization of the Rhineland in flagrant disregard of the Treaty of Versailles. This was the first in a series of provocative acts culminating in the invasion of Poland in September 1939 and the beginning of the Second World War."

In [176]:
for a,b in answerable:
    if b == 107295:
        print(a)

15759


# Issues
* nsubj rule:
    * split possessives
    * omit question words (eg. which song...)