In [2]:
from datasets import concatenate_datasets, load_from_disk
from src.retriever.retrieval.sparse_retrieval import SparseRetrieval
from src.retriever.embedding.flag_embedding import DenseRetrieval
from transformers import AutoTokenizer
from tqdm import tqdm
from src.retriever.score.ranking import check_original_in_context, calculate_reverse_rank_score, calculate_linear_score
org_dataset = load_from_disk('./data/train_dataset')
print(org_dataset)
full_ds = concatenate_datasets(
        [
            org_dataset["train"].flatten_indices(),
            org_dataset["validation"].flatten_indices(),
        ]
    )  # train dev 를 합친 4192 개 질문에 대해 모두 테스트
print("*" * 40, "query dataset", "*" * 40)
print(full_ds)

DatasetDict({
    train: Dataset({
        features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__'],
        num_rows: 3952
    })
    validation: Dataset({
        features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__'],
        num_rows: 240
    })
})
**************************************** query dataset ****************************************
Dataset({
    features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__'],
    num_rows: 4192
})


In [8]:
from transformers import AutoTokenizer
from typing import List, Tuple
import matplotlib.pyplot as plt
from konlpy.tag import Okt, Kkma, Komoran, Hannanum, Mecab

def okt_tokenizer(text):
    okt = Okt()
    return okt.morphs(text)

def kkma_tokenizer(text):
    kkma = Kkma()
    return kkma.morphs(text)

def komoran_tokenizer(text):
    komoran = Komoran()
    return komoran.morphs(text)

def hannanum_tokenizer(text):
    hannanum = Hannanum()
    return hannanum.morphs(text)
def mecab_tokenizer(text):
    mecab = Mecab()
    return mecab.morphs(text)

def compare_tokenizers(tokenizers: List[Tuple[str, AutoTokenizer]],  full_ds):
    results = {}
    
    for name, tokenizer in tokenizers:
        print(f"Processing {name}...")
        retriever = SparseRetrieval(
            tokenize_fn=tokenizer,#.tokenize,
            data_path="./data/",
            context_path="wikipedia_documents.json",
            mode = "bm25",
            max_feature=1000000,
            ngram_range=(1,2),
            #tokenized_docs = tokenized_docs,
        )
        # Update tokenizer in retriever
        # Get sparse embedding and retrieve
        retriever.get_sparse_embedding()
        df = retriever.retrieve(full_ds, topk=10)
        
        # Get scores
        scores = retriever.get_score(df)
        results[name] = scores
    
    return results

def plot_results(results):
    plt.figure(figsize=(12, 6))
    for name, scores in results.items():
        plt.plot(scores, label=name)
    plt.xlabel('Query Index')
    plt.ylabel('Score')
    plt.title('Tokenizer Comparison for Retrieval')
    plt.legend()
    plt.show()

# Define tokenizers to compare
tokenizers = [
    #("KLUE RoBERTa Large", AutoTokenizer.from_pretrained("klue/roberta-large", use_fast=False)),
    #("KLUE RoBERTa Base", AutoTokenizer.from_pretrained("klue/roberta-base", use_fast=False)),
    #("KLUE RoBERTa Small", AutoTokenizer.from_pretrained("klue/roberta-small", use_fast=False)),
    #("KoBERT", AutoTokenizer.from_pretrained("skt/kobert-base-v1", use_fast=False).tokenize),
    ("KoELECTRA", AutoTokenizer.from_pretrained("monologg/koelectra-base-v3-discriminator", use_fast=False).tokenize),
    #("KoGPT2", AutoTokenizer.from_pretrained("skt/kogpt2-base-v2", use_fast=False).tokenize),
    ("Mecab", mecab_tokenizer),
    ("Okt", okt_tokenizer),
    ("Kkma", kkma_tokenizer),
    ("Komoran", komoran_tokenizer),
    ("Hannanum", hannanum_tokenizer),
]
results = compare_tokenizers(tokenizers, full_ds)

Processing KoELECTRA...
Lengths of unique contexts : 56737
Building bm25 embedding...
Start Initializing...


Tokenizing...: 100%|██████████| 56737/56737 [01:13<00:00, 776.02it/s]


Generating n-grams and building vocabulary...


Generating n-grams: 100%|██████████| 56737/56737 [00:23<00:00, 2401.67it/s]


Current mode : bm25
End Initialization


Calculating BM25: 100%|██████████| 57/57 [10:46<00:00, 11.34s/it]


Finish BM25 Embedding
bm25 embedding shape: (56737, 1000000)
(4192, 1000000) (56737, 1000000)
result shape : (4192, 56737)
[query exhaustive search] done in 38.100 s


Sparse retrieval: 100%|██████████| 4192/4192 [00:00<00:00, 12807.72it/s]


correct retrieval 0.8752385496183206
reverse rank retrieval 0.5425256927956311
linear retrieval 0.8051138530275673
Processing Mecab...
Lengths of unique contexts : 56737
Building bm25 embedding...
Start Initializing...


Tokenizing...:   0%|          | 0/56737 [00:00<?, ?it/s]


Exception: Install MeCab in order to use it: http://konlpy.org/en/latest/install/

In [None]:
# 위에서 선언한거 가져오기 Retriever
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased", use_fast=False,)
retriever = SparseRetrieval(
    tokenize_fn=tokenizer.tokenize,
    data_path="./data/",
    context_path="wikipedia_documents.json",
    mode = "bm25",
    max_feature=1000000,
    ngram_range=(1,3),
    #tokenized_docs = tokenized_docs,
)
retriever.get_sparse_embedding()
df = retriever.retrieve(full_ds, topk=10)
retriever.get_score(df)



In [None]:
import json
import os
with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
            wiki = json.load(f)
docs = list(dict.fromkeys([v["text"] for v in wiki.values()]))
print(f"Lengths of unique contexts : {len(docs)}")
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased", use_fast=False,)
tokenized_docs = [tokenizer(doc) for doc in tqdm(docs, desc="Tokenizing...")]

In [None]:
retriever = SparseRetrieval(
            tokenize_fn=tokenizer.tokenize,
            data_path="./data/",
            context_path="wikipedia_documents.json",
            mode = "bm25",
            max_feature=200000,
            ngram_range=(1,2),
            tokenized_docs = tokenized_docs,
            b=0.25,
            k1=1.1,
        )
retriever.get_sparse_embedding()
df = retriever.retrieve(full_ds, topk=10)
retriever.get_score(df)