In [None]:
import json
import spacy
from spacy.lang.en.stop_words import STOP_WORDS
from spacy.tokens import Token
stop_words_getter = lambda token: token.is_stop or token.lower_ in STOP_WORDS \
                                                or token.lemma_ in STOP_WORDS
Token.set_extension('is_stop', getter=stop_words_getter, force=True)
nlp = spacy.load("en_core_web_lg", disable=["parser","ner"])
def get_nonstop_words(x):
    y = [
        token.lemma_ for token in nlp(x) if
        not token.is_stop
        and not token.is_currency
        and not token.is_digit
        and not token.is_punct
        and not token.is_space
        and not token.like_num
        and not token.pos_ == "PROPN"
    ]
    return set(y)


Top 20 passages & Best Answer on dev

In [None]:
data_file = "/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-dpr.json"
data_lines = []
with open (data_file, 'r') as f:
    for line in f:
        data_lines.append(json.loads(line.strip()))

text_passages = []
text_answers = []
for data in data_lines:
    passages = [x["text"] for x in data["passages"]]
    answers =[x["answer"] for x in data["output"] if "answer" in x]
    text_passages.append(passages)
    text_answers.append(answers)

import pickle
all_passages = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_psg.pkl", 'rb'))
all_answers = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_ans.pkl", 'rb'))

max_overlaps = []
for idx in range(len(all_passages)):
    max_overlap = [0, 0, ""]
    psg = set().union(*all_passages[idx])
    # for i, psg in enumerate(all_passages[idx]):
    for j, ans in enumerate(all_answers[idx]):
        overlap = len(psg & ans)
        if overlap > max_overlap[0]:
            max_overlap = [overlap, overlap/len(ans), j]
    print(f"#overlap: {max_overlap[0]}; #overlap ratio: {max_overlap[1]:.2f}\n- Answer: {text_answers[idx][max_overlap[2]]}")
    print('------')
    max_overlaps.append(max_overlap)

print("average #overlap", sum([x[0] for x in max_overlaps])/len(max_overlaps))
print("average (#overlap/#gold_answer) of best passage & best answer overlap:", sum([x[1] for x in max_overlaps])/len(max_overlaps))
from collections import Counter
cnt = Counter([x[0] for x in max_overlaps])
print("#overlap freq")
for overlap, freq in sorted(cnt.items()):
    print(overlap, freq)


In [None]:
# construct train dataset for passage-answer overlap
# passage: top n passages
# answer: all answers that score >= 30
import json
import pickle

all_answers = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_ans_train.pkl", 'rb'))
all_passages = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_psg_train.pkl", "rb"))

data_file = f"/dccstor/myu/data/kilt_eli5_dpr/eli5-train-kilt-dpr.json"

with open (data_file, 'r') as f:
    data_lines = [json.loads(line.strip()) for line in f]

assert len(all_answers) == len(all_passages) == len(data_lines)

overlaps = []
fw = open(f"/dccstor/myu/data/kilt_eli5_dpr/eli5-train-kilt-oraclekg.json", 'w')
for idx, data in enumerate(data_lines):
    answers_data = [x for x in data["output"] if "answer" in x and x["meta"]["score"] >= 3]
    kg_data = []
    assert len(answers_data) == len(all_answers[idx])
    words_psg = all_passages[idx]
    
    for i, words_ans in enumerate(all_answers[idx]):
        word_overlap = words_psg & words_ans
        kg_data.append(list(word_overlap))  
        overlaps.append([len(word_overlap), len(word_overlap)/len(words_ans) if len(words_ans)>0 else 0, len(word_overlap)/len(words_psg)])
    data["output"] = answers_data
    # assert len(data["output"]) > 0
    data["kg_vocab"] = kg_data
    fw.write(json.dumps(data)+'\n')    

In [None]:
print(len(overlaps))
print(sum([x[0] for x in overlaps])/len(overlaps), 
    sum([x[1] for x in overlaps])/len(overlaps),
    sum([x[2] for x in overlaps])/len(overlaps))
from collections import Counter
cnt = Counter([x[0] for x in overlaps])
for k,v in sorted(cnt.items()):
    print(k,v)

In [None]:
import json
with open("/dccstor/myu/data/kilt_eli5/eli5-train-kilt.json") as f:
    data = [json.loads(line) for line in f]


In [None]:
# analysis on train dataset for passage-answer overlap
# 20 passages; best answer
import json
import pickle

all_answers = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_ans_train.pkl", 'rb'))
all_passages = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_psg_train.pkl", "rb"))

data_file = f"/dccstor/myu/data/kilt_eli5_dpr/eli5-train-kilt-dpr.json"

with open (data_file, 'r') as f:
    data_lines = [json.loads(line.strip()) for line in f]

assert len(all_answers) == len(all_passages) == len(data_lines)

overlaps = []
# fw = open(f"/dccstor/myu/data/kilt_eli5_dpr/eli5-train-kilt-oraclekg-single.json", 'w')
for idx, data in enumerate(data_lines):
    answers_data = [x for x in data["output"] if "answer" in x and x["meta"]["score"] >= 3]
    
    assert len(answers_data) == len(all_answers[idx])
    kg_data = []
    words_psg = all_passages[idx][:n]
    max_recall = 0.0
    best_answer = None
    best_overlap = [0,0,0]
    for i, words_ans in enumerate(all_answers[idx]):
        words_psg = set().union(*words_psg)
        word_overlap = words_psg & words_ans
        if len(words_ans) == 0:
            print(answers_data[i]["answer"])
        recall = len(word_overlap)/len(words_ans) if len(words_ans) > 0 else 0
        if recall >= max_recall:
            max_recall = recall
            best_answer = answers_data[i]
            kg_data = list(word_overlap)
            best_overlap = [len(word_overlap), recall, len(word_overlap)/len(words_psg)]

    # assert best_answer is not None
    overlaps.append(best_overlap)
    # data["output"] = [best_answer]
    # data["kg_vocab"] = [kg_data]
    # fw.write(json.dumps(data)+'\n')    

In [None]:
print(len(overlaps))
print(sum([x[0] for x in overlaps])/len(overlaps), 
    sum([x[1] for x in overlaps])/len(overlaps)*100,
    sum([x[2] for x in overlaps])/len(overlaps)*100)
from collections import Counter
cnt = Counter([x[0] for x in overlaps])
for k,v in sorted(cnt.items()):
    print(k,v)

In [None]:
# construct dev dataset for passage-answer overlap
# top n passages; best answer
import json
import pickle
all_answers = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_ans.pkl", 'rb'))
all_passages = pickle.load(open("/dccstor/myu/experiments/eli5_analysis/all_psg_dev.pkl", "rb"))

data_file = f"/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-dpr.json"

with open (data_file, 'r') as f:
    data_lines = [json.loads(line.strip()) for line in f]

assert len(all_answers) == len(all_passages) == len(data_lines)
overlaps = []
fw = open(f"/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-oraclekg.json", 'w')
for idx, data in enumerate(data_lines):
    answers_data = [x for x in data["output"] if "answer" in x]
    
    assert len(answers_data) == len(all_answers[idx])
    kg_data = []
    words_psg = all_passages[idx]
    max_recall = 0.0
    best_answer = None
    best_overlap = [0,0,0]
    for i, words_ans in enumerate(all_answers[idx]):
        # words_psg = set().union(*words_psg)
        word_overlap = words_psg & words_ans
        if len(words_ans) == 0:
            # print(answers_data[i]["answer"])
            continue
        recall = len(word_overlap)/len(words_ans)
        if recall >= max_recall:
            max_recall = recall
            best_answer = answers_data[i]
            kg_data = list(word_overlap)
            best_overlap = [len(word_overlap), len(word_overlap)/len(words_ans), len(word_overlap)/len(words_psg)]
    assert best_answer is not None
    overlaps.append(best_overlap)
    data["output"] = [best_answer]
    data["kg_vocab"] = [kg_data]
    fw.write(json.dumps(data)+'\n')   
print(len(overlaps))
print(sum([x[0] for x in overlaps])/len(overlaps), 
    sum([x[1] for x in overlaps])/len(overlaps)*100,
    sum([x[2] for x in overlaps])/len(overlaps)*100) 

In [None]:
max_overlaps = overlap_kg_ans("/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-dpr-kg-hop3.json", "/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-dpr.json")
print("average #overlap", sum([x[0] for x in max_overlaps])/len(max_overlaps))
print("average (#overlap/#gold_answer) of 3-hop kg & best answer overlap:", sum([x[1] for x in max_overlaps])/len(max_overlaps))
print("average (#overlap/#kg) of 3-hop kg & best answer overlap:", sum([x[2] for x in max_overlaps])/len(max_overlaps))

from collections import Counter
cnt = Counter([x[0] for x in max_overlaps])
print("#overlap freq")
for overlap, freq in sorted(cnt.items()):
    print(overlap, freq)

In [None]:
max_overlaps = overlap_kg_ans("/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-dpr-kg-hop3.json", "/dccstor/myu/data/kilt_eli5_dpr/eli5-dev-kilt-dpr.json")
print("average #overlap", sum([x[0] for x in max_overlaps])/len(max_overlaps))
print("average (#overlap/#gold_answer) of 3-hop kg & best answer overlap:", sum([x[1] for x in max_overlaps])/len(max_overlaps))

from collections import Counter
cnt = Counter([x[0] for x in max_overlaps])
print("#overlap freq")
for overlap, freq in sorted(cnt.items()):
    print(overlap, freq)

In [None]:
import json
import pickle
with open("/dccstor/myu/experiments/eli5_fid_greedy_ctx3_0729/eval_predictions.json", 'r') as f:
    preds = [get_nonstop_words(x["prediction_text"]) for x in json.load(f)]
    # pickle.dump(preds, open("/dccstor/myu/experiments/eli5_fid_greedy_ctx3_0729/eval_pred_nonstop.pkl", 'wb'))
with open("/dccstor/myu/experiments/eli5_fid_kghop2_greedy_ctx3_0802/eval_predictions.json", 'r') as f:
    preds = [get_nonstop_words(x["prediction_text"]) for x in json.load(f)]
    # pickle.dump(preds, open("/dccstor/myu/experiments/eli5_fid_kghop2_greedy_ctx3_0802/eval_pred_nonstop.pkl", 'wb'))


In [None]:
def overlap_oracle(all_answers, all_preds):
    overlaps = []
    for idx in range(len(all_answers)):
        overlap = all_answers[idx] & all_preds[idx]
        overlaps.append([len(overlap), len(overlap)/ len(all_answers[idx]), len(overlap)/len(all_preds[idx])])
    return overlaps
def print_overlap(overlaps):
    print(
    sum(x[0] for x in overlaps)/len(overlaps), 
    sum(x[1] for x in overlaps)/len(overlaps),
    sum(x[2] for x in overlaps)/len(overlaps), 
    )


In [None]:
with open("/dccstor/myu/experiments/knowledge_trie/eli5_openie_merge/id2kg.json") as f:
    data = json.loads(f.readline())
    print(data)

In [None]:
len(data)