In [72]:
from input.code.retrieval import SparseRetrieval

from transformers import AutoTokenizer
from datasets import load_from_disk

from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np

import os
import time
import json
import pickle
from contextlib import contextmanager

In [2]:
with open('../../input/data/wikipedia_documents.json', 'r', encoding='utf-8') as f:
    wiki = json.load(f)

wiki_contents = list(dict.fromkeys([v['text'] for v in wiki.values()]))

In [5]:
print('wikipedia data length:', len(wiki.keys()))

print('wikipedia set data length:', len(wiki_contents))

wikipedia data length: 60613
wikipedia set data length: 56737


In [11]:
MODEL_NAME = 'bert-base-multilingual-cased'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)

sparse_embedding_path = ''
tfidfv_path = ''

tfidfv = TfidfVectorizer(
    tokenizer=tokenizer.tokenize,
    ngram_range=(1, 2),
    max_features=50000,
)

if os.path.isfile(sparse_embedding_path) and os.path.isfile(tfidfv_path):
    with open(sparse_embedding_path, "rb") as file:
        p_embedding = pickle.load(file)
    with open(tfidfv_path, "rb") as file:
        tfidfv = pickle.load(file)
    print("Embedding pickle load.")
else:
    print("Build passage embedding")
    p_embedding = tfidfv.fit_transform(wiki_contents)
    print(p_embedding.shape)
    with open(sparse_embedding_path, "wb") as file:
        pickle.dump(p_embedding, file)
    with open(tfidfv_path, "wb") as file:
        pickle.dump(tfidfv, file)
    print("Embedding pickle saved.")

In [16]:
wiki_tfidf = tfidfv.fit_transform(wiki_contents)

In [17]:
print('wiki TF-IDF shape:', wiki_tfidf.shape)

wiki TF-IDF shape: (56737, 50000)


In [25]:
org_dataset = load_from_disk('../../input/data/train_dataset')

validation_query_tfidf = tfidfv.transform(org_dataset['validation']['question'])

In [35]:
result = validation_query_tfidf * wiki_tfidf.T

In [38]:
if not isinstance(result, np.ndarray):
    result = result.toarray()

In [39]:
result.shape

(240, 56737)

In [73]:
@contextmanager
def timer(name):
    t0 = time.time()
    yield
    print(f"[{name}] done in {time.time() - t0:.3f} s")

In [74]:
for i in range(1, 21):
    with timer(f'TOP K: {i}'):
        TOPK = i
        doc_scores = []
        doc_indices = []
        for i in range(result.shape[0]):
            sorted_result = np.argsort(result[i, :])[::-1]
            doc_scores.append(result[i, :][sorted_result].tolist()[:TOPK])
            doc_indices.append(sorted_result.tolist()[:TOPK])


        correct = 0
        for idx, doc_indice in enumerate(doc_indices):
            # print('QUERY:', org_dataset['validation']['question'][idx])
            # print('ORG PASSAGE\n', org_dataset['validation']['context'][idx])
            # for jdx, indice in enumerate(doc_indice):
            #     print(f'PREDICT {indice}\tSCORE {doc_scores[idx][jdx]}\n', wiki_contents[indice])
            # print('*'*40)
            
            # print('QUERY:', org_dataset['validation']['question'][idx], end=' ')
            for jdx, indice in enumerate(doc_indice):
                if org_dataset['validation']['context'][idx] == wiki_contents[indice]:
                    correct += 1
                    # print('YES!!')
                # else:
                #     print('NO')


        print(f"Total Validation Score: {correct/len(org_dataset['validation'])*100}%")

Total Validation Score: 30.0%
[TOP K: 1] done in 2.341 s
Total Validation Score: 42.5%
[TOP K: 2] done in 2.777 s
Total Validation Score: 48.75%
[TOP K: 3] done in 3.209 s
Total Validation Score: 52.916666666666664%
[TOP K: 4] done in 3.657 s
Total Validation Score: 56.666666666666664%
[TOP K: 5] done in 4.091 s
Total Validation Score: 59.166666666666664%
[TOP K: 6] done in 4.541 s
Total Validation Score: 60.416666666666664%
[TOP K: 7] done in 4.985 s
Total Validation Score: 63.33333333333333%
[TOP K: 8] done in 5.422 s
Total Validation Score: 64.16666666666667%
[TOP K: 9] done in 5.986 s
Total Validation Score: 65.83333333333333%
[TOP K: 10] done in 6.435 s
Total Validation Score: 66.66666666666666%
[TOP K: 11] done in 6.933 s
Total Validation Score: 67.5%
[TOP K: 12] done in 7.420 s
Total Validation Score: 69.58333333333333%
[TOP K: 13] done in 7.850 s
Total Validation Score: 70.0%
[TOP K: 14] done in 8.271 s
Total Validation Score: 72.08333333333333%
[TOP K: 15] done in 8.763 s
Tota