In [1]:
from transformers import AutoTokenizer
from datasets import load_from_disk

from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
import pandas as pd

import os
import time
import json
import pickle
from contextlib import contextmanager

In [2]:
with open('../../data/wikipedia_documents.json', 'r', encoding='utf-8') as f:
    wiki = json.load(f)

wiki_contents = list(dict.fromkeys([v['text'] for v in wiki.values()]))

In [3]:
print('wikipedia data length:', len(wiki.keys()))

print('wikipedia set data length:', len(wiki_contents))

wikipedia data length: 60613
wikipedia set data length: 56737


In [4]:
MODEL_NAME = 'bert-base-multilingual-cased'
MAX_FEATURES = 50000

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)

sparse_embedding_path = f'../../data/{MODEL_NAME}_spare_embedding_{MAX_FEATURES}.bin'
tfidfv_path = f'../../data/{MODEL_NAME}_tfidfv_{MAX_FEATURES}.bin'

if os.path.isfile(sparse_embedding_path) and os.path.isfile(tfidfv_path):
    with open(sparse_embedding_path, "rb") as file:
        wiki_tfidf = pickle.load(file)
    with open(tfidfv_path, "rb") as file:
        tfidfv = pickle.load(file)
    print("Embedding pickle load.")
else:
    print("Build passage embedding")
    tfidfv = TfidfVectorizer(
        tokenizer=tokenizer.tokenize,
        ngram_range=(1, 2),
        max_features=MAX_FEATURES,
    )
    wiki_tfidf = tfidfv.fit_transform(wiki_contents)
    print(wiki_tfidf.shape)
    with open(sparse_embedding_path, "wb") as file:
        pickle.dump(wiki_tfidf, file)
    with open(tfidfv_path, "wb") as file:
        pickle.dump(tfidfv, file)
    print("Embedding pickle saved.")

Build passage embedding
(56737, 50000)
Embedding pickle saved.


In [5]:
print('wiki TF-IDF shape:', wiki_tfidf.shape)

wiki TF-IDF shape: (56737, 50000)


In [6]:
org_dataset = load_from_disk('../../data/train_dataset')

validation_query_tfidf = tfidfv.transform(org_dataset['validation']['question'])

In [7]:
result = validation_query_tfidf * wiki_tfidf.T

In [8]:
if not isinstance(result, np.ndarray):
    result = result.toarray()

In [9]:
result.shape

(240, 56737)

In [10]:
@contextmanager
def timer(name):
    t0 = time.time()
    yield
    print(f"[{name}] done in {time.time() - t0:.3f} s")

In [11]:
for i in range(1, 21):
    with timer(f'TOP K: {i}'):
        TOPK = i
        doc_scores = []
        doc_indices = []
        for j in range(result.shape[0]):
            sorted_result = np.argsort(result[j, :])[::-1]
            doc_scores.append(result[j, :][sorted_result].tolist()[:TOPK])
            doc_indices.append(sorted_result.tolist()[:TOPK])

        correct = 0
        for idx, doc_indice in enumerate(doc_indices):
            for jdx, indice in enumerate(doc_indice):
                if org_dataset['validation']['context'][idx] == wiki_contents[indice]:
                    correct += 1

        print(f"Total Validation Score: {correct/len(org_dataset['validation'])*100}%")

Total Validation Score: 30.0%
[TOP K: 1] done in 2.318 s
Total Validation Score: 42.5%
[TOP K: 2] done in 2.784 s
Total Validation Score: 48.75%
[TOP K: 3] done in 3.248 s
Total Validation Score: 52.916666666666664%
[TOP K: 4] done in 3.716 s
Total Validation Score: 56.666666666666664%
[TOP K: 5] done in 4.180 s
Total Validation Score: 59.166666666666664%
[TOP K: 6] done in 4.641 s
Total Validation Score: 60.416666666666664%
[TOP K: 7] done in 5.101 s
Total Validation Score: 63.33333333333333%
[TOP K: 8] done in 5.654 s
Total Validation Score: 64.16666666666667%
[TOP K: 9] done in 6.080 s
Total Validation Score: 65.83333333333333%
[TOP K: 10] done in 6.487 s
Total Validation Score: 66.66666666666666%
[TOP K: 11] done in 6.943 s
Total Validation Score: 67.5%
[TOP K: 12] done in 7.394 s
Total Validation Score: 69.58333333333333%
[TOP K: 13] done in 7.797 s
Total Validation Score: 70.0%
[TOP K: 14] done in 7.977 s
Total Validation Score: 72.08333333333333%
[TOP K: 15] done in 8.411 s
Tota

In [15]:
# train validation to csv

# df = pd.DataFrame(result)
# df.to_csv(f'Train-Validation_{MODEL_NAME}_TF-IDF_{MAX_FEATURES}.csv', index=False)
# df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,56727,56728,56729,56730,56731,56732,56733,56734,56735,56736
0,0.032986,0.026047,0.024240,0.007701,0.005911,0.011127,0.011298,0.005020,0.011858,0.002647,...,0.013147,0.008110,0.012381,0.009782,0.002888,0.004226,0.002619,0.012030,0.003203,0.026518
1,0.007339,0.019860,0.029288,0.011747,0.011813,0.012816,0.003902,0.004154,0.010746,0.005794,...,0.007553,0.004341,0.008579,0.008422,0.007678,0.006208,0.006326,0.015826,0.010910,0.019147
2,0.013272,0.015507,0.017290,0.022481,0.010147,0.018830,0.003052,0.020512,0.021803,0.026654,...,0.023372,0.005552,0.014942,0.006160,0.003321,0.001366,0.001454,0.005534,0.018505,0.041677
3,0.013961,0.013244,0.042514,0.015374,0.013515,0.011022,0.002062,0.009005,0.007059,0.006781,...,0.004216,0.011746,0.021438,0.013290,0.004997,0.001616,0.007361,0.005297,0.015753,0.008309
4,0.036077,0.037388,0.021183,0.039676,0.008077,0.018149,0.017656,0.008696,0.003918,0.004473,...,0.057901,0.011419,0.018165,0.013474,0.011167,0.005459,0.002065,0.013998,0.024257,0.011912
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
235,0.016216,0.025316,0.033025,0.011765,0.010704,0.009513,0.003256,0.003411,0.004990,0.007918,...,0.017736,0.006832,0.014500,0.014552,0.011009,0.002850,0.005861,0.009513,0.007332,0.015070
236,0.023798,0.033863,0.030881,0.017773,0.008964,0.011841,0.004095,0.007107,0.009754,0.003413,...,0.007840,0.009345,0.008885,0.008146,0.014225,0.014084,0.011765,0.009482,0.008016,0.004957
237,0.022359,0.019816,0.034036,0.031322,0.016362,0.015882,0.017011,0.005788,0.009300,0.003186,...,0.005056,0.006297,0.012105,0.017146,0.012690,0.029030,0.019632,0.007731,0.005988,0.009918
238,0.006850,0.008959,0.029253,0.001626,0.021925,0.007476,0.003084,0.011369,0.006983,0.012280,...,0.004607,0.018113,0.011806,0.007674,0.014057,0.000996,0.004079,0.072659,0.004825,0.000469


In [31]:
# test validation to csv

org_dataset_test = load_from_disk('../../data/test_dataset')
print(org_dataset_test)

test_query_tfidf = tfidfv.transform(org_dataset_test['validation']['question'])
print('Test Validation Shape:', test_query_tfidf.shape)

result_test = test_query_tfidf * wiki_tfidf.T
if not isinstance(result_test, np.ndarray):
    result_test = result_test.toarray()
print('Scores Shape & Type:', result_test.shape, type(result_test))

df = pd.DataFrame(result_test)
df.to_csv(f'Test-Validation_{MODEL_NAME}_TF-IDF_{MAX_FEATURES}.csv', index=False)
df

DatasetDict({
    validation: Dataset({
        features: ['id', 'question'],
        num_rows: 600
    })
})
Test Validation Shape: (600, 50000)
Scores Shape & Type: (600, 56737) <class 'numpy.ndarray'>


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,56727,56728,56729,56730,56731,56732,56733,56734,56735,56736
0,0.011996,0.026468,0.026873,0.005925,0.011956,0.034110,0.010423,0.030534,0.022683,0.005202,...,0.010826,0.003785,0.020876,0.007068,0.007030,0.003643,0.005593,0.013326,0.010441,0.005472
1,0.004959,0.017205,0.022597,0.007943,0.020559,0.011620,0.003559,0.004096,0.002458,0.002370,...,0.004763,0.005878,0.007505,0.003673,0.006072,0.002396,0.003894,0.004790,0.019373,0.008029
2,0.010809,0.020167,0.031249,0.013451,0.009174,0.016224,0.016455,0.006764,0.013155,0.006065,...,0.008505,0.007260,0.013507,0.015534,0.015449,0.016290,0.017183,0.021186,0.003932,0.017133
3,0.020243,0.018868,0.035771,0.021838,0.014150,0.009002,0.026163,0.005854,0.062501,0.007289,...,0.004878,0.014159,0.010680,0.009499,0.005630,0.008404,0.009535,0.020166,0.001526,0.001966
4,0.037384,0.120618,0.033301,0.007963,0.006321,0.010255,0.007743,0.001137,0.004559,0.004193,...,0.031373,0.043635,0.015161,0.015638,0.014245,0.002867,0.010846,0.005597,0.005500,0.008822
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
595,0.011734,0.035735,0.021735,0.013171,0.013557,0.010180,0.007524,0.011164,0.003967,0.003018,...,0.014869,0.008567,0.020801,0.009840,0.008847,0.008205,0.007935,0.023174,0.001552,0.020961
596,0.015475,0.024606,0.027032,0.015458,0.024717,0.025151,0.011160,0.020531,0.007747,0.012974,...,0.021999,0.016471,0.019127,0.023443,0.014118,0.007225,0.005546,0.005272,0.003830,0.008053
597,0.008702,0.015367,0.036231,0.021202,0.011367,0.023500,0.011737,0.003351,0.006662,0.005015,...,0.005372,0.017011,0.009997,0.012456,0.012527,0.005858,0.006135,0.010007,0.007417,0.019707
598,0.011179,0.018199,0.017569,0.007024,0.004335,0.015745,0.002602,0.002533,0.006936,0.014067,...,0.025002,0.002737,0.015138,0.016401,0.007991,0.002788,0.001255,0.004607,0.005191,0.028304
