In [1]:
import os
import json
import pandas as pd
from tqdm import tqdm

from datasets import (
    Dataset,
    DatasetDict,
    Features,
    Value,
    Sequence,
    load_from_disk,
    load_metric
)
from transformers import AutoTokenizer
from retrieval import Retriever

In [2]:
datasets = load_from_disk("../data/train_dataset")

documents = datasets['train']['context']
queries = datasets['train']['question']

In [3]:
data_path = "../data/"
context_path = "wikipedia_documents.json"

with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
            wiki = json.load(f)

total_contexts = list(dict.fromkeys([v["text"] for v in wiki.values()]))

In [4]:
hybrid_topk = [20]
rerank_topk = [500]
retrievers = ['2s_rerank', 'hybridsearch']
result = []

In [5]:
bmtokenizer = AutoTokenizer.from_pretrained("HANTAEK/klue-roberta-large-korquad-v1-qa-finetuned")

data_dict = {'question': queries[:int(len(queries)/5)]}
query_dataset = Dataset.from_dict(data_dict)
top_k = []
total = []
for r_name in tqdm(retrievers, desc="Retrievers Progress"):
    retriever = Retriever(
        tokenize_fn=bmtokenizer.tokenize,
        data_path=data_path,
        context_path=context_path,
        name=r_name
    )
    acc = []
    if r_name == '2s_rerank':
        top_k = rerank_topk
    else:
        top_k = hybrid_topk
    for stopk in tqdm(top_k, desc=f"Top-k Progress for {r_name}", leave=False):
        cnt = 0
        contexts = retriever.retrieve(query_dataset, topk=stopk)
        
        # tqdm을 사용하여 queries 루프를 감싸서 진행 상태를 표시
        for idx in tqdm(range(int(len(queries)/5)), desc=f"Query Progress for {r_name}, Top-k={stopk}", leave=False):
            if documents[idx] == contexts[idx][0]:
                cnt += 1

        print(cnt / int(len(queries)/5))
        acc.append(cnt / int(len(queries)/5))
    total.append(acc)

Retrievers Progress:   0%|          | 0/2 [00:00<?, ?it/s]

Lengths of unique contexts : 56737




Sparse retrieval:   0%|          | 0/790 [00:00<?, ?it/s]

Embedding pickle load.


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))

bar_width = 0.3
x = range(len(top_k))

plt.bar([p - bar_width/2 for p in x], result[0], width=bar_width, label='Reranker', color='blue')

plt.bar([p + bar_width/2 for p in x], result[1], width=bar_width, label='Hybrid Search', color='green')

plt.xticks(x, [f"Top-{k}" for k in top_k])
plt.xlabel('Top-k')
plt.ylabel('Accuracy')
plt.title('Reranker vs Hybrid Search Performance')
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()