In [9]:
from bm25 import BM25

from transformers import AutoTokenizer
from datasets import load_from_disk

import numpy as np

import time
import json
from contextlib import contextmanager

In [10]:
with open('../../data/wikipedia_documents.json', 'r', encoding='utf-8') as f:
    wiki = json.load(f)

wiki_contents = list(dict.fromkeys([v['text'] for v in wiki.values()]))

In [11]:
print('wikipedia data length:', len(wiki.keys()))

print('wikipedia set data length:', len(wiki_contents))

wikipedia data length: 60613
wikipedia set data length: 56737


In [12]:
org_dataset = load_from_disk('../../data/train_dataset')

org_dataset

DatasetDict({
    train: Dataset({
        features: ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title'],
        num_rows: 3952
    })
    validation: Dataset({
        features: ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title'],
        num_rows: 240
    })
})

In [13]:
MODEL_NAME = 'bert-base-multilingual-cased'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)

bm25 = BM25(
    b=0.5,
    k1=2.0,
    tokenizer=tokenizer,
    sparse_embedding_path='/opt/ml/input/data/sparse_embedding.bin',
    tfidfv_path='/opt/ml/input/data/tfidv.bin')

passage embedding data found.
tfidfvectorizer data found.


In [14]:
bm25.fit()

In [15]:
@contextmanager
def timer(name):
    t0 = time.time()
    yield
    print(f"[{name}] done in {time.time() - t0:.3f} s")

In [16]:
for i in range(1, 21):
    with timer(f'TOP K: {i}'):
        TOPK = i
        doc_scores = []
        doc_indices = []
        for i in range(len(org_dataset['validation'])):
            result = bm25.transform(org_dataset['validation']['question'][i])
            sorted_result = np.argsort(result)[::-1]
            doc_scores.append(result[sorted_result].tolist()[:TOPK])
            doc_indices.append(sorted_result.tolist()[:TOPK])

        correct = 0
        for idx, doc_indice in enumerate(doc_indices):
            for jdx, indice in enumerate(doc_indice):
                if org_dataset['validation']['context'][idx] == wiki_contents[indice]:
                    correct += 1
        
        print(f"Total Validation Score: {correct/len(org_dataset['validation'])*100}%")

Total Validation Score: 26.666666666666668%
[TOP K: 1] done in 172.276 s
Total Validation Score: 38.75%
[TOP K: 2] done in 172.159 s
Total Validation Score: 44.166666666666664%
[TOP K: 3] done in 166.244 s
Total Validation Score: 47.083333333333336%
[TOP K: 4] done in 163.462 s
Total Validation Score: 51.24999999999999%
[TOP K: 5] done in 164.845 s
Total Validation Score: 53.75%
[TOP K: 6] done in 164.505 s
Total Validation Score: 56.666666666666664%
[TOP K: 7] done in 165.733 s
Total Validation Score: 59.166666666666664%
[TOP K: 8] done in 166.908 s
Total Validation Score: 61.25000000000001%
[TOP K: 9] done in 173.472 s
Total Validation Score: 63.33333333333333%
[TOP K: 10] done in 173.595 s
Total Validation Score: 63.74999999999999%
[TOP K: 11] done in 169.838 s
Total Validation Score: 64.58333333333334%
[TOP K: 12] done in 171.741 s
Total Validation Score: 65.41666666666667%
[TOP K: 13] done in 170.330 s
Total Validation Score: 65.83333333333333%
[TOP K: 14] done in 174.678 s
Total 