In [3]:
from rank_bm25 import BM25Plus
from transformers import AutoTokenizer
from typing import List


class BM25PlusRetriever:
    def __init__(
        self,
        context,
        model_name_or_path = "Salesforce/codet5p-110m-embedding",
    ) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.context = context

        tokenized_contexts = [
            self.tokenizer(doc)["input_ids"][1:-1] for doc in self.context
        ]
        self.bm25plus = BM25Plus(tokenized_contexts)

    def retrieve(
        self,
        query,
        topk = 5,
    ) -> List:

        tokenized_query = self.tokenizer(query)["input_ids"][1:-1]

        scores = [
            (val, idx)
            for idx, val in enumerate(self.bm25plus.get_scores(tokenized_query))
        ]

        scores.sort(reverse=True)
        scores = scores[:topk]

        doc_indices = [idx for _, idx in scores]

        return [self.context[doc_indices[i]] for i in range(topk)]

In [4]:
from datasets import load_dataset

dataset = load_dataset('KonradSzafer/stackoverflow_python_preprocessed')

# answer list 구성
answer_texts = [data['answer'] for data in dataset['train']]

In [5]:
bm25_retrieval = BM25PlusRetriever(answer_texts)

Token indices sequence length is longer than the specified maximum sequence length for this model (678 > 512). Running this sequence through the model will result in indexing errors


In [6]:
# 중복된 query를 방지하기 위해 unique한 query list 및 query 별 answer list 구성
# example
# unique_questions = [query_1, query_2, query_3, ...]
# all_answers = [[answer_1_for_query_1, answer_2_for_query_2, ...], [answer_1_for_query_2, answer_2_for_query_2, ...], ...]

prev_question = ''
prev_answers = []
unique_questions = []
unique_answers = []
all_answers = []

for data in dataset['train']:
    title = data['title']
    question = data['question']
    answer = data['answer']
    if prev_question != question:
        if prev_answers:
            all_answers.append(prev_answers)
            prev_answers = []
        prev_question = question
        unique_questions.append(title + '\n' + question)
        unique_answers.append(answer)
        prev_answers.append(answer)
    else:
        prev_answers.append(answer)
all_answers.append(prev_answers)

In [8]:
# GT answer 중 하나라도 retrieve가 되면 true로 판별
from tqdm import tqdm

is_retrieved = []
for query, answers in tqdm(zip(unique_questions, all_answers)):
    retrived_documents = bm25_retrieval.retrieve(query)
    retrieved = False
    for answer in answers:
        if not retrieved:
            for doc in retrived_documents:
                if answer == doc:
                    retrieved = True
                    break
    is_retrieved.append(retrieved)

962it [01:12, 13.35it/s]


In [9]:
true_ratio = sum([1 for res in is_retrieved if res ])/len(is_retrieved)
print(f'True Ratio:{true_ratio*100:0.2f}%')

True Ratio:63.10%
