In [7]:
import logging
import os
import sys
from utils.preprocess import prepare_datasets_with_setting
from typing import List, Callable, NoReturn, NewType, Any
import dataclasses
from datasets import load_metric, load_from_disk, Dataset, DatasetDict
from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer
import torch
from transformers import (
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
)

from tokenizers import Tokenizer
from tokenizers.models import WordPiece

from utils.trainer_qa import QuestionAnsweringTrainer

from arguments import (
    ModelArguments,
    DataTrainingArguments,
)

In [10]:
datasets = load_from_disk("../data/train_dataset")

In [23]:
datasets

DatasetDict({
    train: Dataset({
        features: ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title'],
        num_rows: 3952
    })
    validation: Dataset({
        features: ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title'],
        num_rows: 240
    })
})

In [16]:
from retrieval import *

In [17]:
with open("../data/wikipedia_documents.json", "r", encoding="utf-8") as f:
            wiki = json.load(f)

contexts = list(
            dict.fromkeys([v["text"] for v in wiki.values()])
        ) 

In [19]:
model_checkpoint = "klue/bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [20]:
tokenize_fn = tokenizer.tokenize

In [22]:

tokenized_corpus = []
p_embs = []
for p in tqdm(contexts):
    tokenized_corpus.append(tokenize_fn(p, padding="max_length", truncation=True, return_tensors='pt'))
bm25 = MyBm25(tokenized_corpus)


  0%|          | 0/56737 [00:00<?, ?it/s]

In [31]:
k1=1.5
b=0.75
epsilon=0.25

In [33]:
data_path = "../"

In [35]:
bm25_name = f"bm25.bin"
bm25_path = os.path.join(data_path, bm25_name)
if os.path.isfile(bm25_path):
    with open(bm25_path, "rb") as file:
        bm25 = pickle.load(file)
    print("Embedding bm25 pickle load.")
else:
    print("Building bm25... It may take 1 minute and 30 seconds...")
    # bm25 must tokenizer first 
    # because it runs pool inside and this cuases unexpected result.
    tokenized_corpus = []
    for c in contexts:
        tokenized_corpus.append(tokenize_fn(c))
    bm25 = MyBm25(tokenized_corpus, k1 = k1, b = b, epsilon=epsilon)
    with open(bm25_path, "wb") as file:
        pickle.dump(bm25, file)
    print("bm25 pickle saved.")

Embedding bm25 pickle load.


In [36]:
def get_top_n(bm25, query, documents, n=10):
    assert bm25.corpus_size == len(documents), "The documents given don't match the index corpus!"

    scores = bm25.get_scores(query)

    top_n_idx = np.argsort(scores)[::-1][:n]
    doc_score = scores[top_n_idx]
        
    return doc_score, top_n_idx

In [72]:
doc_scores = []
for q in tqdm(datasets['validation']['question']):
    tok_q = tokenize_fn(q) 
    scores = bm25.get_scores(tok_q)
    doc_scores.append(scores)
print("done!")
#return doc_scores, doc_indices


  0%|          | 0/240 [00:00<?, ?it/s]

done!


In [73]:
doc_scores = np.array(doc_scores)
doc_scores = torch.tensor(doc_scores)
ranks = torch.argsort(doc_scores, dim=1, descending=True).squeeze()
k = 20
context_list = []

for index in range(len(ranks)):
    k_list = []
    for i in range(k):
        k_list.append(contexts[ranks[index][i]])
    context_list.append(k_list)
    
correct= 0
for index in range(len(context_list)):
    if datasets['validation']['context'][index] in context_list[index]:
        correct+=1 
print(correct/len(context_list))

0.9208333333333333


In [74]:
validdata = Dataset.from_dict({'answers':datasets['validation']['answers'], 
                    'context':context_list,
                    'question':datasets['validation']['question']})


In [75]:
datasets

DatasetDict({
    train: Dataset({
        features: ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title'],
        num_rows: 3952
    })
    validation: Dataset({
        features: ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title'],
        num_rows: 240
    })
})

In [81]:
topkdata = DatasetDict({"train":traindata,"validation":validdata})

In [82]:
topkdata.save_to_disk("../data/topk_dataset/")