In [1]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict

In [2]:
model = torch.load('base_model_3.pt', map_location='cuda:0').eval()

In [3]:
class InferDataset(Dataset):
    def __init__(self, candidates, texts, questions):
        self.candidates = candidates
        self.texts = texts
        self.questions = questions

    def __len__(self):
        return len(self.candidates)

    def __getitem__(self, item):
        question, context, label = self.candidates[item]
        return {'question': self.questions[question], 'context': self.texts[context], 'labels': torch.FloatTensor([label]),
                'question_id': question, 'context_id': context}

In [4]:
def data_load(args):
    with open(args.data_path + '/paragraph_context.json', 'r', encoding='utf-8') as f:
        texts = json.load(f)

    with open(args.data_path + '/question_context.json', 'r', encoding='utf-8') as f:
        questions = json.load(f)

    with open(args.data_path + '/test_labels.json', 'r', encoding='utf-8') as f:
        labels_dummy = json.load(f)
        labels = defaultdict(str)

        for i in labels_dummy.keys():
            for qu in labels_dummy[i]:
                labels[qu] = i

    with open(args.data_path + '/bm25_okt_result.json', 'r', encoding='utf-8') as f:
        bm25_result = json.load(f)

    candidates = []
    for i in bm25_result.keys():
        check = False
        for j in bm25_result[i]:
            if labels[i] == j:
                label = 1
            else:
                label = 0
            candidates.append([i, j, label])

    infer_dataset = InferDataset(candidates, texts, questions)
    infer_loader = DataLoader(infer_dataset, batch_size=args.batch_size, shuffle=False)
    return infer_loader

In [5]:
import easydict
from tqdm import tqdm
args = easydict.EasyDict({
    'data_path': 'dataset',
    'batch_size': 32
})

In [6]:
infer_loader = data_load(args)

In [7]:
reranking_answer = defaultdict(list)
for feature in tqdm(infer_loader):
    question_id = feature['question_id']
    context_id = feature['context_id']
    value = list(model(feature).cpu().detach().numpy())
    for q, c, v in zip(question_id, context_id, value):
        reranking_answer[q].append([c, v])

100%|████████████████████████████████████████████████████████████████████████| 218865/218865 [6:51:12<00:00,  8.87it/s]


In [8]:
context_corpus = json.load(open('dataset/paragraph_context.json', 'r', encoding='utf8')) 
question_corpus = json.load(open('dataset/question_context.json', 'r', encoding='utf8'))
train_labels_json = json.load(open('dataset/train_labels.json', 'r', encoding='utf8'))
test_labels_json = json.load(open('dataset/test_labels.json', 'r', encoding='utf8'))

In [9]:
test_label_clean = {}
for test_label in test_labels_json.keys():
    for test in test_labels_json[test_label]:
        test_label_clean[test] = test_label

In [10]:
mrr = 0

for clean in test_label_clean.keys():
    try:
        mrr += 1 / [key for key, value in sorted(reranking_answer[clean], key=lambda x:x[1])].index(test_label_clean[clean]) + 1
    except:
        pass
    
print(mrr / len(test_label_clean))

0.9873783021507243
