In [1]:
import sys
from collections import defaultdict
sys.path.append('..')

import numpy as np
import pandas as pd
import re
from tqdm.auto import tqdm
from retrieval_dense import DenseRetrieval
from ms_retrieval import SparseRetrieval
from utils_qa import tokenize

from datasets import load_from_disk, load_dataset
import torch

# Dense Retrieval 및 데이터셋 불러오기

In [2]:
dataset = load_from_disk('/opt/ml/input/data/data/train_dataset')

dense_retriever = DenseRetrieval(p_path='thingsu/koDPR_context', q_path='thingsu/koDPR_question',
                           bert_path='kykim/bert-kor-base')
sparse_retriever = SparseRetrieval(tokenize)

dense_retriever.get_dense_embedding()
sparse_retriever.get_sparse_embedding()

Lengths of unique contexts : 56737
Lengths of unique contexts : 56737
Embedding pickle load.
Embedding pickle load.


In [3]:
df_dense = dense_retriever.retrieve(dataset['validation'],topk=100)
df_sparse = sparse_retriever.retrieve(dataset['validation'],topk=100)

HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))


[transform] done in 4.765 s
[query exhaustive search] done in 7.120 s


HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…


[query exhaustive search] done in 5.560 s


HBox(children=(FloatProgress(value=0.0, description='Sparse retrieval: ', max=240.0, style=ProgressStyle(descr…




# 두 retrieval 합치기

In [8]:
best_score = 0
best_k = 0
total_grade = 0
correct_number = 0
for k in tqdm(np.arange(0.5, 1.5, 0.01)):
    j = 1
    
    dict_context_list = []
    for idx in range(len(df_sparse)):
        dict_context = defaultdict(float)
        data = df_sparse.loc[idx]
        for context_id, score in zip(data['context_id'], data['scores']):
            dict_context[context_id] = j * score
        dict_context_list.append(dict_context)

    for idx, dict_context in enumerate(dict_context_list):
        data = df_dense.loc[idx]
        for context_id, score in zip(data['context_id'], data['scores']):
            dict_context[context_id] += k * score

    context_score_pair_list = []
    for dict_context in dict_context_list:
        tmp_list = list(dict_context.items())
        tmp_list.sort(key=lambda x : x[1], reverse=True)
        context_score_pair_list.append(tmp_list)

    topk_20_prob = 0
    for idx in range(len(df_sparse)):
        original_cxt = df_sparse.loc[idx]['original_context']
        original_cxt = re.sub(r'\\n','\n', original_cxt) 
        original_cxt = re.sub(r'( )+',' ', original_cxt) 
        tmp_id_list = [cxt_id[0] for cxt_id in context_score_pair_list[idx]]
        for grade, cxt_id in enumerate(tmp_id_list[:20]):
            compare_cxt = sparse_retriever.contexts[cxt_id]
            compare_cxt = re.sub(r'\\n','\n', compare_cxt) 
            compare_cxt = re.sub(r'( )+',' ', compare_cxt)
            if original_cxt == compare_cxt:
                topk_20_prob += 1
                total_grade += grade + 1
                correct_number += 1
                break

    score = topk_20_prob/len(df_sparse)
    if score > best_score:
        best_score = score
        best_k = k
    print(score, k, total_grade/correct_number)
print(best_score, best_k)

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

0.925 0.5 2.1981981981981984
0.925 0.51 2.1936936936936937
0.925 0.52 2.1831831831831834
0.925 0.53 2.1734234234234235
0.925 0.54 2.163963963963964
0.925 0.55 2.156906906906907
0.925 0.56 2.14993564993565
0.925 0.5700000000000001 2.143581081081081
0.925 0.5800000000000001 2.1366366366366365
0.925 0.5900000000000001 2.130630630630631
0.925 0.6000000000000001 2.124897624897625
0.925 0.6100000000000001 2.12012012012012
0.925 0.6200000000000001 2.1146916146916146
0.925 0.6300000000000001 2.109073359073359
0.925 0.6400000000000001 2.103903903903904
0.925 0.6500000000000001 2.099380630630631
0.925 0.6600000000000001 2.0945945945945947
0.925 0.6700000000000002 2.0895895895895897
0.925 0.6800000000000002 2.0848743480322427
0.925 0.6900000000000002 2.0806306306306306
0.925 0.7000000000000002 2.0763620763620763
0.925 0.7100000000000002 2.0720720720720722
0.925 0.7200000000000002 2.0675675675675675
0.925 0.7300000000000002 2.063626126126126
0.925 0.7400000000000002 2.06
0.925 0.7500000000000002 2

## topk 30안에 passage를 가장 잘 가지고 오는 계수 k는 1.1쯔음임을 알 수 있다.