In [37]:
import os
import json
import glob
import pickle
import redis
import numpy as np
from tqdm import tqdm
from collections import defaultdict

In [38]:
with open("/nas/luka-group/keming/dataset/CodRED/baseline/data/rawdata/dev_evi.json") as f:
    dev_evi = json.load(f)

In [39]:
redisd = redis.Redis(host='localhost', port=6379, decode_responses=True)

# Loading DPR utils

In [5]:
# Loading auxillary data
with open("../data/title2id.json") as f:
    title2id = json.load(f)
with open("../data/id2title.json") as f:
    id2title = json.load(f)

In [5]:
question_embedding = np.load("/nas/luka-group/DPR/outputs/question_tensor_dev_finetune.npy")

In [7]:
question_embedding.shape, len(dev_evi)

((3497, 768), 3497)

In [16]:
finetune_path = "/nas/luka-group/DPR/outputs/codred_finetune"
def load_passages(path):
    passage_data_dir = path
    passage_pattern = "codred_passages*"
    passage_files = glob.glob(os.path.join(passage_data_dir, passage_pattern))

    passage_arrays = defaultdict(list)
    for file_path in passage_files:
        with open(file_path, "rb") as f:
            data = pickle.load(f)
            pbar = tqdm(total=len(data))
            for idx, embedding in data:
                doc_id, passage_id = idx.split(":")[1].split("_")
                title = id2title[doc_id]
                passage_arrays[title].append((int(passage_id), embedding))
                pbar.update(1)
    for key, value in tqdm(passage_arrays.items()):
        sorted_array = sorted(value, key=lambda x: x[0])
        passage_arrays[key] = np.array([each[1] for each in sorted_array])
    return passage_arrays
passage_arrays_finetune = load_passages(finetune_path)

 99%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████  | 1025825/1038692 [00:01<00:00, 909897.73it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1038692/1038692 [00:07<00:00, 130092.92it/s][A

  7%|██████████▉                                                                                                                                                | 73712/1038692 [00:00<00:01, 737117.19it/s][A
 15%|██████████████████████▊                                                                                                                                   | 153458/1038692 [00:00<00:01, 772607.92it/s][A
 22%|██████████████████████████████████▏                                                                                                                       | 230719/10

 60%|█████████████████████████████████████████████████████████████████████████████████████████████▋                                                              | 154827/257881 [00:04<00:02, 48051.43it/s][A
 62%|█████████████████████████████████████████████████████████████████████████████████████████████████▏                                                          | 160710/257881 [00:04<00:01, 51162.55it/s][A
 64%|████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                       | 165961/257881 [00:05<00:01, 51546.99it/s][A
 66%|███████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                    | 171156/257881 [00:05<00:02, 32990.65it/s][A
 68%|██████████████████████████████████████████████████████████████████████████████████████████████████████████                                                  | 17532

In [31]:
passage_arrays = passage_arrays_finetune

# Augmented DPR

In [34]:
finetune_path = "/nas/luka-group/DPR/outputs/codred_finetune_ext"
def load_passages(path):
    passage_data_dir = path
    passage_pattern = "codred_passages*"
    passage_files = glob.glob(os.path.join(passage_data_dir, passage_pattern))

    passage_arrays = defaultdict(list)
    for file_path in passage_files:
        with open(file_path, "rb") as f:
            data = pickle.load(f)
            pbar = tqdm(total=len(data))
            for idx, embedding in data:
                doc_id, passage_id = idx.split(":")[1].split("_")
                title = id2title[doc_id]
                passage_arrays[title].append((int(passage_id), embedding))
                pbar.update(1)
    for key, value in tqdm(passage_arrays.items()):
        sorted_array = sorted(value, key=lambda x: x[0])
        passage_arrays[key] = np.array([each[1] for each in sorted_array])
    return passage_arrays
passage_arrays = load_passages(finetune_path)

 95%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉       | 1238540/1298363 [00:01<00:00, 763939.19it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1298363/1298363 [00:10<00:00, 129680.82it/s][A

  6%|██████████                                                                                                                                                 | 84349/1298365 [00:00<00:01, 843486.78it/s][A
 14%|████████████████████▊                                                                                                                                     | 175430/1298365 [00:00<00:01, 883084.15it/s][A
 21%|███████████████████████████████▌                                                                                                                          | 266339/12

257881

# Constraint utils

In [17]:
def build_reverse_idx(doch, doct):
    doch_title = doch['title']
    doct_title = doct['title']
    
    entity_reverse_idx = {}
    doc_reverse_idx = {}
    for entity in doch['entities']:
        if 'Q' in entity:
            entity_idx = 'Q' + str(entity['Q'])
            for span in entity['spans']:
                passage_idx = doch_title + "_" + str(span[2])
                if entity_idx not in entity_reverse_idx:
                    entity_reverse_idx[entity_idx] = set()
                entity_reverse_idx[entity_idx].add(passage_idx)
                if passage_idx not in doc_reverse_idx:
                    doc_reverse_idx[passage_idx] = set()
                doc_reverse_idx[passage_idx].add(entity_idx)
            
    for entity in doct['entities']:
        if 'Q' in entity:
            entity_idx = 'Q' + str(entity['Q'])
            for span in entity['spans']:
                passage_idx = doct_title + "_" + str(span[2])
                if entity_idx not in entity_reverse_idx:
                    entity_reverse_idx[entity_idx] = set()
                entity_reverse_idx[entity_idx].add(passage_idx)
                if passage_idx not in doc_reverse_idx:
                    doc_reverse_idx[passage_idx] = set()
                doc_reverse_idx[passage_idx].add(entity_idx)
    return entity_reverse_idx, doc_reverse_idx

In [18]:
def get_neighbor(passage_id, path_entities, entity_reverse_idx, doc_reverse_idx):
    if "_" not in passage_id:
        next_passages = entity_reverse_idx[passage_id]
        return set(zip(next_passages, [passage_id] * len(next_passages)))
    else:
        shared_entities = doc_reverse_idx[passage_id].difference(path_entities)
        output = []
        for entity in shared_entities:
            for next_passage in entity_reverse_idx[entity]:
                if next_passage != passage_id:
                    output.append((next_passage, entity))
        return set(output)

# Inference

In [20]:
sample = dev_evi[0]
sample

{'h': 'Fridley',
 't': 'Mississippi River',
 'r': 'P206',
 'doc_h': 'Medtronic',
 'doc_t': 'Interstate 94',
 'evis_h': [[3, 3]],
 'evis_t': [[3, 0]],
 'id': 259079,
 'key': 'Q985235#Q1497'}

In [52]:
def get_dpr_result(passage_list, query_array, passage_arrays):
    cand_passages = [each.split("_") for each in passage_list]
    cand_arrays = np.array([passage_arrays[title][int(idx)]for title, idx in cand_passages])
    scores = np.matmul(query_array, cand_arrays.T)
    result = sorted(zip(passage_list, scores), key=lambda x: -x[1])
    return result

In [88]:
def get_extend_span_from_pid(doc, pid, max_length=256):
    passage_mapping = doc['passage_mapping']

    start_idx = passage_mapping[pid-1] if pid != 0 else 0
    end_idx = passage_mapping[pid]

    length = end_idx - start_idx
    res_span = (max_length - length) // 2

    start_idx = max(0, start_idx - res_span)
    end_idx = min(passage_mapping[-1], end_idx + res_span)
    return start_idx, end_idx

In [68]:
def get_cover_sentences(doc, start_id, end_id):
    cover_sentences = []
    sentence_mapping = doc['sentence_mapping']
    for pid, passage in enumerate(sentence_mapping):
        for sid, sentence_end in enumerate(passage):
            sentence_begin = sentence_mapping[pid][sid-1] if sid != 0 else\
                                sentence_mapping[pid-1][-1] if pid != 0 else 0
            if start_id <= sentence_begin and end_id >= sentence_end:
                cover_sentences.append((pid, sid))
    return cover_sentences

In [106]:
def find_best_head_tail_passages(sample):
    doch_title = sample['doc_h']
    doct_title = sample['doc_t']
    doch = json.loads(redisd.get('codred-doc-'+doch_title))
    doct = json.loads(redisd.get('codred-doc-'+doct_title))
    entity_reverse_idx, doc_reverse_idx = build_reverse_idx(doch, doct)
    
    h, t = sample['key'].split("#")
    
    h_set = list(entity_reverse_idx[h])
    t_set = list(entity_reverse_idx[t])
    
    best_h_pid = int(get_dpr_result(h_set, query_array, passage_arrays)[0][0].split("_")[1])
    best_t_pid = int(get_dpr_result(t_set, query_array, passage_arrays)[0][0].split("_")[1])
    
    #best_h_pid = int(h_set[-1].split("_")[1])
    #best_t_pid = int(t_set[-1].split("_")[1])
    
    start_idx_h, end_idx_h = get_extend_span_from_pid(doch, best_h_pid)
    start_idx_t, end_idx_t = get_extend_span_from_pid(doct, best_t_pid)
    
    h_cover_sentences = get_cover_sentences(doch, start_idx_h, end_idx_h)
    t_cover_sentences = get_cover_sentences(doct, start_idx_t, end_idx_t)

    return h_cover_sentences, t_cover_sentences

In [110]:
path_recall = []
sentence_recall = []

evis_num = []
for sample in dev_evi:
    head_cover_sentences, tail_cover_sentences = find_best_head_tail_passages(sample)

    predict = [
        f"{sample['doc_h']}_{pid}_{sid}" for pid, sid in head_cover_sentences
    ] + [
        f"{sample['doc_t']}_{pid}_{sid}" for pid, sid in tail_cover_sentences
    ]
    predict = set(predict)
    
    evis = [
        f"{sample['doc_h']}_{pid}_{sid}" for pid, sid in sample['evis_h']
    ] + [
        f"{sample['doc_t']}_{pid}_{sid}" for pid, sid in sample['evis_t']
    ]
    evis = set(evis)
    
    evis_num.append(len(evis))

    if evis.issubset(predict):
        path_recall.append(1)
    else:
        path_recall.append(0)

    sentence_recall.append((len(evis.intersection(predict)), len(evis)))

In [111]:
path_recall_lq_3hop = [path_recall[i] for i in range(len(path_recall)) if evis_num[i] <= 3]
path_recall_ge_3hop = [path_recall[i] for i in range(len(path_recall)) if evis_num[i] > 3]

sentence_recall_lq_3hop = [sentence_recall[i] for i in range(len(sentence_recall)) if evis_num[i] <= 3]
sentence_recall_ge_3hop = [sentence_recall[i] for i in range(len(sentence_recall)) if evis_num[i] > 3]

print(sum(path_recall_lq_3hop)/len(path_recall_lq_3hop), sum(path_recall_ge_3hop)/len(path_recall_ge_3hop))
print(
    sum([each[0] for each in sentence_recall_lq_3hop])/sum([each[1] for each in sentence_recall_lq_3hop]),
    sum([each[0] for each in sentence_recall_ge_3hop])/sum([each[1] for each in sentence_recall_ge_3hop])
)
print(
    sum(path_recall)/len(path_recall),
    sum([each[0] for each in sentence_recall])/sum([each[1] for each in sentence_recall])
)

0.15634920634920635 0.06258381761287439
0.42983490566037735 0.36439895336309064
0.09636831569917072 0.3779445868424265
