In [1]:
import re

from konlpy.tag import Komoran


class KonlpyTokenize:
    def __init__(self):
        self.noun_collector = Komoran()

    def __pre_regex(self, context):
        re_compile = re.compile("[^a-zA-Z0-9ㄱ-ㅣ가-힣\s\(\)\[\]?!.,\@\*\{\}\-\_\=\+]")
        context = re.sub("\s", " ", context)
        re_context = re_compile.sub(" ", context)
        return re_context

    def __pre_devide(self, context):
        if len(context) < 3000:
            return [context]
        else:
            return re.split(".\s|.\\n", context)

    def __context_tokenize(self, context):
        tokenized_list = []
        context = context.strip()
        if context == "":
            return tokenized_list
        noum_tokenize = self.noun_collector.pos(context)
        for word, tag in noum_tokenize:
            if tag == "NNG" or tag == "NNP":
                tokenized_list.append(word)
        return tokenized_list

    def tokenize_fn(self, context):
        context = self.__pre_regex(context)
        tokenized_list = []
        context_list = self.__pre_devide(context)
        for context in context_list:
            tokenized_list.extend(self.__context_tokenize(context))
        return tokenized_list


In [2]:
import os
import json
import time

# import faiss
import pickle
import numpy as np
import pandas as pd

from rank_bm25 import BM25Okapi
from tqdm.auto import tqdm
from contextlib import contextmanager
from typing import List, Tuple, NoReturn, Any, Optional, Union
from datasets import Dataset


@contextmanager
def timer(name):
    t0 = time.time()
    yield
    print(f"[{name}] done in {time.time() - t0:.3f} s")


class BM25SparseRetrieval:
    def __init__(
        self,
        tokenize_fn,
        data_path: Optional[str] = "/opt/ml/data/",
        context_path: Optional[str] = "wikipedia_documents.json",
    ) -> NoReturn:
        self.data_path = data_path
        with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
            wiki = json.load(f)

        self.contexts = list(
            dict.fromkeys([v["text"] for v in wiki.values()])
        )  # set 은 매번 순서가 바뀌므로
        print(f"Lengths of unique contexts : {len(self.contexts)}")
        self.ids = list(range(len(self.contexts)))

        self.tokenize_fn = tokenize_fn
        self.bm25 = None
        self.indexer = None  # build_faiss()로 생성합니다.

    def get_sparse_embedding(self, pickle_name="bm25api.bin") -> NoReturn:

        pickle_name = pickle_name
        emd_path = os.path.join(self.data_path, pickle_name)

        if os.path.isfile(emd_path):
            with open(emd_path, "rb") as file:
                self.bm25 = pickle.load(file)
            print("Embedding pickle load.")
        else:
            print("Build passage embedding")
            tokenized_contexts = list(map(self.tokenize_fn, tqdm(self.contexts)))
            self.bm25 = BM25Okapi(tokenized_contexts)
            with open(emd_path, "wb") as file:
                pickle.dump(self.bm25, file)
            print("Embedding pickle saved.")

    def retrieve(
        self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1
    ) -> Union[Tuple[List, List], pd.DataFrame]:

        assert self.bm25 is not None, "get_sparse_embedding() 메소드를 먼저 수행해줘야합니다."

        if isinstance(query_or_dataset, str):
            doc_scores, doc_indices = self.get_relevant_doc(query_or_dataset, k=topk)
            print("[Search query]\n", query_or_dataset, "\n")

            for i in range(topk):
                print(f"Top-{i+1} passage with score {doc_scores[i]:4f}")
                print(self.contexts[doc_indices[i]])

            return (doc_scores, [self.contexts[doc_indices[i]] for i in range(topk)])

        elif isinstance(query_or_dataset, Dataset):

            # Retrieve한 Passage를 pd.DataFrame으로 반환합니다.
            total = []
            doc_scores, doc_indices = [], []
            with timer("query exhaustive search"):
                for question in tqdm(query_or_dataset["question"]):
                    doc_score, doc_indice = self.get_relevant_doc(question, k=topk)
                    doc_scores.append(doc_score)
                    doc_indices.append(doc_indice)
            for idx, example in enumerate(
                tqdm(query_or_dataset, desc="Sparse retrieval: ")
            ):
                tmp = {
                    # Query와 해당 id를 반환합니다.
                    "question": example["question"],
                    "id": example["id"],
                    # Retrieve한 Passage의 id, context를 반환합니다.
                    "context_id": doc_indices[idx],
                    "context": " ".join(doc_indices[idx]),
                }
                if "context" in example.keys() and "answers" in example.keys():
                    # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
                    tmp["original_context"] = example["context"]
                    tmp["answers"] = example["answers"]
                total.append(tmp)

            cqas = pd.DataFrame(total)
            return cqas

    def get_relevant_doc(self, query: str, k: Optional[int] = 1) -> Tuple[List, List]:

        """
        Arguments:
            query (str):
                하나의 Query를 받습니다.
            k (Optional[int]): 1
                상위 몇 개의 Passage를 반환할지 정합니다.
        Note:
            vocab 에 없는 이상한 단어로 query 하는 경우 assertion 발생 (예) 뙣뙇?
        """
        tokenized_query = self.tokenize_fn(query)
        raw_doc_scores = self.bm25.get_scores(tokenized_query)

        doc_scores_index_desc = np.argsort(-raw_doc_scores)
        doc_scores = raw_doc_scores[doc_scores_index_desc]

        doc_list = self.bm25.get_top_n(tokenized_query, self.contexts, k)

        return doc_scores[:k], doc_list


In [4]:
tokenize_fn = KonlpyTokenize().tokenize_fn

In [5]:
retrieval = BM25SparseRetrieval(tokenize_fn)

Lengths of unique contexts : 56737


In [6]:
retrieval.get_sparse_embedding()

Build passage embedding


  0%|          | 0/56737 [00:00<?, ?it/s]

Embedding pickle saved.


In [7]:
from datasets import load_from_disk

datasets = load_from_disk("/opt/ml/data/new_train_dataset")

In [9]:
k = 10
score = 0
wrong_set = []
for data in tqdm(datasets['train']):
  query = data['question']
  scores, retrieved_examples = retrieval.get_relevant_doc(query, k)
  if data['context'] in retrieved_examples:
    score += 1
  else: 
    wrong_set.append(data)

  0%|          | 0/3351 [00:00<?, ?it/s]

In [10]:
score/len(datasets['train']) * 100

84.48224410623695

In [11]:
k = 10
score2 = 0
wrong_set2 = []
for data in tqdm(datasets['validation']):
  query = data['question']
  scores, retrieved_examples = retrieval.get_relevant_doc(query, k)
  if data['context'] in retrieved_examples:
    score2 += 1
  else: 
    wrong_set2.append(data)

  0%|          | 0/841 [00:00<?, ?it/s]

In [12]:
score2 / len(datasets['validation']) * 100

83.82877526753865