In [6]:
import spacy
from spacy.lang.en.stop_words import STOP_WORDS
from spacy.tokens import Token
stop_words_getter = lambda token: token.is_stop or token.lower_ in STOP_WORDS \
                                                or token.lemma_ in STOP_WORDS
Token.set_extension('is_stop', getter=stop_words_getter, force=True)
nlp = spacy.load("en_core_web_lg", disable=["parser","ner"])
def get_nonstop_words(x):
    y = [
        token.lemma_ for token in nlp(x) if
        not token.is_stop
        and not token.is_currency
        and not token.is_digit
        and not token.is_punct
        and not token.is_space
        and not token.like_num
        and not token.pos_ == "PROPN"
    ]
    return set(y)

In [79]:
import pickle
all_tries = pickle.load(open("/dccstor/myu/experiments/knowledge_trie/eli5_openie_merge/id2kg.pickle", "rb"))

In [96]:
all_tries["4lnk7x"].keys("crash")

['crash odds antianxiety medications',
 'crash odds antianxiety meds',
 'crash odds antidepressants',
 'crash odds antihistamines',
 'crash odds 1.12',
 'crash odds 1.68',
 'crash odds lower',
 'crash odds meds',
 'crash odds opiates',
 'crash odds penicillin']

In [54]:
# overlap between OpenIE full kg and the answer
# construct dev dataset for passage-answer overlap
# top n passages; best answer

import json
import pickle

def compute_overlaps(all_info, all_answers, data_lines):
    '''
    all_info: a list of nonstop word sets
    all_answers: a list of nonstop word sets
    data_lines: a list of json
    '''
    assert len(all_answers) == len(all_info) == len(data_lines)
    overlaps = []
    for idx, data in enumerate(data_lines):
        answers_data = [x for x in data["output"] if "answer" in x]
        
        assert len(answers_data) == len(all_answers[idx])
        words_info = all_info[idx]
        max_recall = 0.0
        best_answer = None
        best_overlap = [0,0,0,0,0]
        best_answer_text = ""
        for i, words_ans in enumerate(all_answers[idx]):
            word_overlap = words_info & words_ans
            if len(words_ans) == 0:
                continue
            recall = len(word_overlap)/len(words_ans)
            if recall >= max_recall:
                max_recall = recall
                best_answer = answers_data[i]
                best_answer_text = best_answer["answer"]
                # print(best_answer_text)
                best_overlap = [len(word_overlap), len(word_overlap)/len(words_ans), len(word_overlap)/len(words_info) if len(words_info)>0 else 0, len(words_ans), len(words_info)]
        assert best_answer is not None
        overlaps.append(best_overlap)
        # print(best_answer_text)
    return overlaps

def compute_overlap_gain(all_info1, all_info2, all_answers, data_lines):
    '''
    all_info1: a list of nonstop word sets
    all_info2: a list of nonstop word sets
    all_answers: a list of nonstop word set of answers
    data_lines: a list of json
    '''
    assert len(all_answers) == len(all_info1) == len(all_info2)
    overlap_gains = []
    for idx, data in enumerate(data_lines):
        answers_data = [x for x in data["output"] if "answer" in x]
        
        assert len(answers_data) == len(all_answers[idx])
        words_info1 = all_info1[idx]
        words_info2 = all_info2[idx]
        
        words_ans = all_answers[idx][0]
        word_overlap1 = words_info1 & words_ans
        word_overlap2 = words_info2 & words_ans
        gain = word_overlap2 - word_overlap1
        overlap_gains.append(len(gain))
        # print(best_answer_text)
    return overlap_gains 
 
def print_average_overlap(overlaps):
    for i in range(len(overlaps[0])):
        print(sum([x[i] for x in overlaps])/len(overlaps))

def load_hops(kg_file):
    all_kgs = []
    with open (kg_file, 'r') as f:
        for line in f:
            data = json.loads(line.strip())
            kg_vocab = data["kg_vocab"]
            nonstop = get_nonstop_words(" ".join(kg_vocab))
            all_kgs.append(nonstop)
    return all_kgs



In [27]:
# load oraclekg answers
oraclekg_dev = "/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-oraclekg.json"
def load_single_answer(fn):
    all_answers = []
    with open(fn) as f:
        for line in f:
            answer = json.loads(line.strip())["output"][0]["answer"]
            nonstop = get_nonstop_words(answer)
            all_answers.append([nonstop])
    return all_answers

In [None]:
all_answers = load_single_answer(oraclekg_dev)


In [74]:
all_kgs = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_kgs_dev.pkl", "rb"))
all_kgs_hop1 = load_hops("/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-dpr-kg-hop1.json")
all_kgs_hop2 = load_hops("/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-dpr-kg-hop2.json")
all_kgs_hop3 = load_hops("/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-dpr-kg-hop3.json")
# all_passages = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_psg_dev.pkl", "rb"))
all_passages_top3 = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_psg_dev_3.pkl", "rb"))
# all_passages_top10 = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_psg_dev_10.pkl", "rb"))

data_file = f"/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-oraclekg.json"

with open (data_file, 'r') as f:
    data_lines = [json.loads(line.strip()) for line in f]
psg_overlaps = compute_overlaps(all_passages_top3, all_answers, data_lines)
print_average_overlap(psg_overlaps)

kg_overlaps = compute_overlaps(all_kgs, all_answers, data_lines)
print_average_overlap(kg_overlaps)

9.280690112806901
0.13723187751949392
0.11400686376087445
81.06237558062375
83.83012607830126
16.196416721964166
0.22776100959069615
0.08555405115796175
81.06237558062375
194.6038487060385


In [72]:
x = 7

print(len(all_kgs[x]&all_answers[x][0]))
print(all_kgs[x]&all_answers[x][0])
print(len(all_passages_top3[x]&all_answers[x][0]))
print(all_passages_top3[x]&all_answers[x][0])

15
{'date', 'different', 'app', 'people', 'research', 'tinder', 'right', 'device', 'casual', 'term', 'good', 'find', 'service', 'left', 'company'}
8
{'date', 'different', 'app', 'tinder', 'right', 'party', 'casual', 'service'}


In [77]:
overlap_gain = compute_overlap_gain(all_kgs_hop3, all_passages_top3, all_answers, data_lines)
print(sum(overlap_gain)/len(overlap_gain))

overlap_gain = compute_overlap_gain(all_passages_top3, all_kgs_hop3, all_answers, data_lines)
print(sum(overlap_gain)/len(overlap_gain))

5.224950232249502
3.87193098871931


In [30]:
# construct dev dataset for passage-answer overlap
# top n passages; best answer
import json
import pickle
all_answers = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_ans.pkl", 'rb'))
all_passages = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_psg_dev.pkl", "rb"))

data_file = f"/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-dpr.json"

with open (data_file, 'r') as f:
    data_lines = [json.loads(line.strip()) for line in f]

assert len(all_answers) == len(all_passages) == len(data_lines)
overlaps = []
fw = open(f"/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-oraclekg.json", 'w')
for idx, data in enumerate(data_lines):
    answers_data = [x for x in data["output"] if "answer" in x]
    
    assert len(answers_data) == len(all_answers[idx])
    kg_data = []
    words_psg = all_passages[idx]
    max_recall = 0.0
    max_count = 0
    best_answer = None
    best_overlap = [0,0,0,0,0]
    for i, words_ans in enumerate(all_answers[idx]):
        # words_psg = set().union(*words_psg)
        word_overlap = words_psg & words_ans
        if len(words_ans) == 0:
            # print(answers_data[i]["answer"])
            continue
        recall = len(word_overlap)/len(words_ans)
        # if recall >= max_recall:
        if len(word_overlap) >= max_count:
            # max_recall = recall
            max_count = len(word_overlap)
            best_answer = answers_data[i]
            kg_data = list(word_overlap)
            best_overlap = [len(word_overlap), len(word_overlap)/len(words_ans), len(word_overlap)/len(words_psg), len(words_ans), len(words_psg)]
    assert best_answer is not None
    overlaps.append(best_overlap)
    data["output"] = [best_answer]
    data["kg_vocab"] = [kg_data]
    fw.write(json.dumps(data)+'\n')   
print(len(overlaps))
print_average_overlap(overlaps)

1507
27.68546781685468
0.37868054415979324
0.06956693586894633
81.54147312541473
407.1453218314532


In [21]:
for x in overlaps:
    print(x)

[33, 0.4583333333333333, 0.062264150943396226, 72, 530]
[28, 0.509090909090909, 0.05737704918032787, 55, 488]
[11, 0.4230769230769231, 0.02466367713004484, 26, 446]
[9, 0.45, 0.017857142857142856, 20, 504]
[20, 0.40816326530612246, 0.055401662049861494, 49, 361]
[32, 0.43243243243243246, 0.0784313725490196, 74, 408]
[16, 0.37209302325581395, 0.04030226700251889, 43, 397]
[31, 0.40789473684210525, 0.06858407079646017, 76, 452]
[19, 0.27941176470588236, 0.0420353982300885, 68, 452]
[41, 0.2611464968152866, 0.08855291576673865, 157, 463]
[20, 0.36363636363636365, 0.05698005698005698, 55, 351]
[63, 0.23507462686567165, 0.12727272727272726, 268, 495]
[17, 0.29310344827586204, 0.05573770491803279, 58, 305]
[26, 0.4, 0.07449856733524356, 65, 349]
[28, 0.3146067415730337, 0.05415860735009671, 89, 517]
[30, 0.5769230769230769, 0.07159904534606205, 52, 419]
[15, 0.4166666666666667, 0.02830188679245283, 36, 530]
[35, 0.5223880597014925, 0.08215962441314555, 67, 426]
[13, 0.3170731707317073, 0.027