In [None]:
import re
import sys
from multiprocessing import Pool
import spacy
from spacy.matcher import Matcher
from tqdm import tqdm
import nltk
import json
import string
import multiprocessing

In [None]:
with open("../input/csqa-with-subgraph/csqa/grounded/train.grounded.jsonl","r") as f:
    res = [json.loads(line) for line in f]

In [None]:
import pandas as pd
res[0]

In [None]:
BLANK_STR = "___"
def rDataFramewh_word_with_blank(question_str: str):
    # if "What is the name of the government building that houses the U.S. Congress?" in question_str:
    #     print()
    question_str = question_str.replace("What's", "What is")
    question_str = question_str.replace("whats", "what")
    question_str = question_str.replace("U.S.", "US")
    wh_word_offset_matches = []
    wh_words = ["which", "what", "where", "when", "how", "who", "why"]
    for wh in wh_words:
        # Some Turk-authored SciQ questions end with wh-word
        # E.g. The passing of traits from parents to offspring is done through what?

        if wh == "who" and "people who" in question_str:
            continue

        m = re.search(wh + r"\?[^\.]*[\. ]*$", question_str.lower())
        if m:
            wh_word_offset_matches = [(wh, m.start())]
            break
        else:
            # Otherwise, find the wh-word in the last sentence
            m = re.search(wh + r"[ ,][^\.]*[\. ]*$", question_str.lower())
            if m:
                wh_word_offset_matches.append((wh, m.start()))
            # else:
            #     wh_word_offset_matches.append((wh, question_str.index(wh)))

    # If a wh-word is found
    if len(wh_word_offset_matches):
        # Pick the first wh-word as the word to be replaced with BLANK
        # E.g. Which is most likely needed when describing the change in position of an object?
        wh_word_offset_matches.sort(key=lambda x: x[1])
        wh_word_found = wh_word_offset_matches[0][0]
        wh_word_start_offset = wh_word_offset_matches[0][1]
        # Replace the last question mark with period.
        question_str = re.sub(r"\?$", ".", question_str.strip())
        # Introduce the blank in place of the wh-word
        fitb_question = (question_str[:wh_word_start_offset] + BLANK_STR +
                         question_str[wh_word_start_offset + len(wh_word_found):])
        # Drop "of the following" as it doesn't make sense in the absence of a multiple-choice
        # question. E.g. "Which of the following force ..." -> "___ force ..."
        final = fitb_question.replace(BLANK_STR + " of the following", BLANK_STR)
        final = final.replace(BLANK_STR + " of these", BLANK_STR)
        return final

    elif " them called?" in question_str:
        return question_str.replace(" them called?", " " + BLANK_STR + ".")
    elif " meaning he was not?" in question_str:
        return question_str.replace(" meaning he was not?", " he was not " + BLANK_STR + ".")
    elif " one of these?" in question_str:
        return question_str.replace(" one of these?", " " + BLANK_STR + ".")
    elif re.match(r".*[^\.\?] *$", question_str):
        # If no wh-word is found and the question ends without a period/question, introduce a
        # blank at the end. e.g. The gravitational force exerted by an object depends on its
        return question_str + " " + BLANK_STR
    else:
        # If all else fails, assume "this ?" indicates the blank. Used in Turk-authored questions
        # e.g. Virtually every task performed by living organisms requires this?
        return re.sub(r" this[ \?]", " ___ ", question_str)

# Get a Fill-In-The-Blank (FITB) statement from the question text. E.g. "George wants to warm his
# hands quickly by rubbing them. Which skin surface will produce the most heat?" ->
# "George wants to warm his hands quickly by rubbing them. ___ skin surface will produce the most
# heat?
def get_fitb_from_question(question_text: str) -> str:
    '''
    This function first find the wh word, then change the multiple choice to
    fill_question. 
    '''
    fitb = replace_wh_word_with_blank(question_text)
    if not re.match(".*_+.*", fitb):
        # print("Can't create hypothesis from: '{}'. Appending {} !".format(question_text, BLANK_STR))
        # Strip space, period and question mark at the end of the question and add a blank
        fitb = re.sub(r"[\.\? ]*$", "", question_text.strip()) + " " + BLANK_STR
    return fitb

def create_hypothesis(mdf_q:str, choice:str, ans_pos:bool) -> str:
    '''
    This function create the mapping string and it's span
    '''
    if "."+ BLANK_STR in mdf_q or mdf_q.startswith(BLANK_STR):
        # means bof(begin of a sentence), makes it upper
        choice = choice[0].upper() + choice[1:]
    else:
        choice = choice.lower()
    # remove . if question not end with ___
    if not mdf_q.endswith(BLANK_STR):
        choice = choice.rstrip(".")
    if not ans_pos:
        try:
            hypothesis = re.sub("__+", choice, mdf_q)
        except:
            print(choice, mdf_q)
        return hypothesis
    choice = choice.strip()
    m = re.search("__+", mdf_q)
    start = m.start()
    # substitute __+ to choice in mdf_q
    length = (len(choice) - 1) if mdf_q.endswith(BLANK_STR) and choice[-1] in ['.', '?', '!'] else len(choice)
    hypothesis = re.sub("__+", choice, mdf_q)
    # return the question text with answer fill in those blank, and return the span index
    return hypothesis, (start, start + length)

In [None]:
file_Path = "../input/csqa-with-subgraph/csqa/"
def create_output_dict(input_json: dict, statement: str, label: bool, ans_pos: bool, pos=None) -> dict:
    if "statements" not in input_json:
        input_json["statements"] = []
    if not ans_pos:
        input_json["statements"].append({"label": label, "statement": statement})
    else:
        input_json["statements"].append({"label": label, "statement": statement, "ans_pos": pos})
    return input_json
def convert_qajson(qa_file:str, op_file: str,ans_pos:bool=False):
    print("Start converting")
    nrow = sum(1 for _ in open(qa_file,'r'))
    with open(qa_file,"r") as f, open(op_file, "w") as op:
        for line in tqdm(f, total = nrow):
            json_l = json.loads(line)
            question_text = json_l['question']['stem']
            choices = json_l["question"]["choices"]
            for choice in choices:
                choice_text = choice["text"]
                pos = None
                if not ans_pos:
                    statement = create_hypothesis(get_fitb_from_question(question_text), choice_text, ans_pos)
                else:
                    statement, pos = create_hypothesis(get_fitb_from_question(question_text), choice_text, ans_pos)
                create_output_dict(json_l,statement,choice["label"] == json_l.get("answerKey", "A"), ans_pos, pos)
            op.write(json.dumps(json_l))
            op.write("\n")
    print("Done")

In [None]:
convert_qajson("../input/csqa-with-subgraph/csqa/train_rand_split.jsonl","test.json")

In [None]:
with open("./test.json") as f:
    ls = [line for line in f]

In [None]:
ls[-1]

## ground result and extract subgraph

In [None]:
blacklist = set(["-PRON-", "actually", "likely", "possibly", "want",
                 "make", "my", "someone", "sometimes_people", "sometimes", "would", "want_to",
                 "one", "something", "sometimes", "everybody", "somebody", "could", "could_be"
                 ])
nltk.download('stopwords', quiet=True)
nltk_stopwords = nltk.corpus.stopwords.words('english')
CPNET_VOCAB = None
PATTERN_PATH = None
nlp = None
matcher = None

In [None]:
def load_matcher(nlp, pattern_path):
    with open(pattern_path, "r", encoding="utf8") as fin:
        all_patterns = json.load(fin)
    matcher = Matcher(nlp.vocab)
    for concept, pattern in all_patterns.items():
        matcher.add(concept,[pattern])
    return matcher
def load_cpnet_vocab(cpnet_vocab_path):
    '''
    This function load the vocab path
    '''
    with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
        cpnet_vocab = [l.strip() for l in fin]
    cpnet_vocab = [c.replace("_", " ") for c in cpnet_vocab]
    return cpnet_vocab
def lemmatize(nlp, concept):

    doc = nlp(concept.replace("_", " "))
    lcs = set()
    lcs.add("_".join([token.lemma_ for token in doc]))  # all lemma
    return lcs
def match_mentioned_concepts(sents, answers, num_processes):
    '''
    Use multi process for a function
    '''
    res = []
    with Pool(num_processes) as p:
        res = list(tqdm(p.imap(ground_qa_pair, zip(sents, answers)), total=len(sents)))
    return res
def ground_qa_pair(qa_pair):
    '''
    This function create a pipeline and 
    detect all the concepts in question and answer
    Also, if not concepts are extracted, will set some default concept, which is called hard concepts here.
    '''
    global nlp, matcher
    # ---------------------------- initialize
    if nlp is None or matcher is None:
        #print("start loading info")
        nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat'])
        nlp.add_pipe('sentencizer')
        matcher = load_matcher(nlp, PATTERN_PATH)
        print("done loading matcher")
    #print("Start extract concepts")
    s, a = qa_pair
    # --------------------------- extract concepts -----------------------
    all_concepts = ground_mentioned_concepts(nlp, matcher, s, a)
    answer_concepts = ground_mentioned_concepts(nlp, matcher, a)
    question_concepts = all_concepts - answer_concepts
    # -------------------------- special cases ---------------------------
    if len(question_concepts) == 0:
        question_concepts = hard_ground(nlp, s, CPNET_VOCAB)  # not very possible

    if len(answer_concepts) == 0:
        answer_concepts = hard_ground(nlp, a, CPNET_VOCAB)  # some case
        
    question_concepts = sorted(list(question_concepts))
    answer_concepts = sorted(list(answer_concepts))
    return {"sent": s, "ans": a, "qc": question_concepts, "ac": answer_concepts}
# extract concept, which is the basic part of this project:
def ground_mentioned_concepts(nlp, matcher, s, ans=None):
    '''
    nlp is the pipeline operation
    matcher is the match dic
    '''
    s = s.lower()
    doc = nlp(s)
    matches = matcher(doc)
    mentioned_concepts = set()
    span_to_concepts = {}
    
    if ans is not None:
        # means we have teh answer
        ans_matcher = Matcher(nlp.vocab)
        ans_words = nlp(ans)
        # print(ans_words)
        tmp = [[{'TEXT': token.text.lower()}] for token in ans_words]
        ans_matcher.add(ans, tmp)

        ans_match = ans_matcher(doc)
        ans_mentions = set()
        for _, ans_start, ans_end in ans_match:
            ans_mentions.add((ans_start, ans_end))
    for match_id, start, end in matches:
        if ans is not None:
            if (start, end) in ans_mentions:
                continue
        # get the match part
        span = doc[start:end].text
        # get source concept
        original_concept = nlp.vocab.strings[match_id]
        original_concept_set = set()
        original_concept_set.add(original_concept)
        
        # lemmatize 
        if len(original_concept.split("_")) == 1:
            # tag = doc[start].tag_
            # if tag in ['VBN', 'VBG']:
            original_concept_set.update(lemmatize(nlp, nlp.vocab.strings[match_id]))

        if span not in span_to_concepts:
            span_to_concepts[span] = set()
        span_to_concepts[span].update(original_concept_set)
    for span, concepts in span_to_concepts.items():
        concepts_sorted = list(concepts)
        concepts_sorted.sort(key=len)
        shortest = concepts_sorted[0:3]

        for c in shortest:
            if c in blacklist:
                continue
            # a set with one string like: set("like_apples")
            lcs = lemmatize(nlp, c)
            intersect = lcs.intersection(shortest)
            if len(intersect) > 0:
                mentioned_concepts.add(list(intersect)[0])
            else:
                mentioned_concepts.add(c)

        # if a mention exactly matches with a concept
        exact_match = set([concept for concept in concepts_sorted if concept.replace("_", " ").lower() == span.lower()])
        assert len(exact_match) < 2
        mentioned_concepts.update(exact_match)

    return mentioned_concepts
def hard_ground(nlp, sent, cpnet_vocab):
    sent = sent.lower()
    doc = nlp(sent)
    res = set()
    for t in doc:
        if t.lemma_ in cpnet_vocab:
            res.add(t.lemma_)
    sent = " ".join([t.text for t in doc])
    if sent in cpnet_vocab:
        res.add(sent)
    try:
        assert len(res) > 0
    except Exception:
        print(f"for {sent}, concept not found in hard grounding.")
    return res
def prune(data, cpnet_vocab_path):
    # reload cpnet_vocab
    with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
        cpnet_vocab = [l.strip() for l in fin]

    prune_data = []
    for item in tqdm(data):
        qc = item["qc"]
        prune_qc = []
        for c in qc:
            if c[-2:] == "er" and c[:-2] in qc:
                continue
            if c[-1:] == "e" and c[:-1] in qc:
                continue
            have_stop = False
            # remove all concepts having stopwords, including hard-grounded ones
            for t in c.split("_"):
                if t in nltk_stopwords:
                    have_stop = True
            if not have_stop and c in cpnet_vocab:
                prune_qc.append(c)

        ac = item["ac"]
        prune_ac = []
        for c in ac:
            if c[-2:] == "er" and c[:-2] in ac:
                continue
            if c[-1:] == "e" and c[:-1] in ac:
                continue
            all_stop = True
            for t in c.split("_"):
                if t not in nltk_stopwords:
                    all_stop = False
            if not all_stop and c in cpnet_vocab:
                prune_ac.append(c)

        try:
            assert len(prune_ac) > 0 and len(prune_qc) > 0
        except Exception as e:
            pass
        item["qc"] = prune_qc
        item["ac"] = prune_ac

        prune_data.append(item)
    return prune_data
def ground(statement_path, cpnet_vocab_path, pattern_path, output_path, num_processes=1):
    global PATTERN_PATH, CPNET_VOCAB
    if PATTERN_PATH is None:
        PATTERN_PATH = pattern_path
        CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path)
    sents, answers = [], []
    with open(statement_path, "r") as fin:
        lines = [line for line in fin]
    for line in lines:
        # for each question json file
        if line == "":
            continue
        j = json.loads(line)
        #print(j)
        # {'answerKey': 'B',
        #   'id': 'b8c0a4703079cf661d7261a60a1bcbff',
        #   'question': {'question_concept': 'magazines',
        #                 'choices': [{'label': 'A', 'text': 'doctor'}, {'label': 'B', 'text': 'bookstore'}, {'label': 'C', 'text': 'market'}, {'label': 'D', 'text': 'train station'}, {'label': 'E', 'text': 'mortuary'}],
        #                 'stem': 'Where would you find magazines along side many other printed works?'},
        #   'statements': [{'label': False, 'statement': 'Doctor would you find magazines along side many other printed works.'}, {'label': True, 'statement': 'Bookstore would you find magazines along side many other printed works.'}, {'label': False, 'statement': 'Market would you find magazines along side many other printed works.'}, {'label': False, 'statement': 'Train station would you find magazines along side many other printed works.'}, {'label': False, 'statement': 'Mortuary would you find magazines along side many other printed works.'}]}
        for statement in j["statements"]:
            sents.append(statement["statement"])
            
        for answer in j["question"]["choices"]:
            ans = answer["text"]
            answers.append(ans)
    res = match_mentioned_concepts(sents, answers, num_processes)
    res = prune(res, cpnet_vocab_path)

        # check_path(output_path)
    with open(output_path, 'w') as fout:
        for dic in res:
            fout.write(json.dumps(dic) + '\n')

    print(f'grounded concepts saved to {output_path}')
    print()


In [None]:
ground("../input/csqa-with-subgraph/csqa/statement/train.statement.jsonl", "../input/cp-net-en/CPnet_en/concept.txt", "../input/cp-net-en/CPnet_en/matcher_patterns.json", "a.json",4)