In [1]:
import os
import json
import time
import faiss
import pickle
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from contextlib import contextmanager
from typing import List, Tuple, NoReturn, Any, Optional, Union
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import AutoTokenizer
from datasets import (
    Dataset,
    load_from_disk,
    concatenate_datasets,
    Features,
    Value,
    DatasetDict,
)
from retrieval import *
from tqdm import tqdm

In [2]:
dataset = "../data/train_dataset"
org_dataset = load_from_disk(dataset)
tokenizer = AutoTokenizer.from_pretrained(
        "klue/bert-base",
        use_fast=False,
        )
retriever = SparseRetrieval(
        tokenize_fn=tokenizer.tokenize,
        data_path="../data/",
        context_path="wikipedia_documents.json",
        is_bm25=True
    )
full_ds = concatenate_datasets(
        [
            org_dataset["train"].flatten_indices(),
            org_dataset["validation"].flatten_indices(),
        ]
    )

Loading cached processed dataset at ../data/train_dataset/train/cache-fbc57aa6e699fb0c.arrow
Loading cached processed dataset at ../data/train_dataset/validation/cache-d2fba0c42123b1d6.arrow


Lengths of unique wiki contexts : 56737


In [3]:
def topk_experiment(topK_list):
    result_dict = {}
    retriever.get_sparse_embedding()
    for topK in tqdm(topK_list):
        result_retriever = retriever.retrieve(org_dataset["train"], topk=topK)
        correct = 0
        for index in tqdm(range(len(result_retriever))):
            if  result_retriever['original_context'][index][:200] in result_retriever['context'][index]:
                correct += 1
        result_dict[topK] = correct/len(result_retriever)
    return result_dict

In [4]:
topK_list = [1,10,20,50]
result = topk_experiment(topK_list)
result

  0%|          | 0/4 [00:00<?, ?it/s]

Embedding bm25 pickle load.
[query exhaustive search] done in 597.814 s


HBox(children=(FloatProgress(value=0.0, description='Sparse retrieval: ', max=4192.0, style=ProgressStyle(desc…

100%|██████████| 4192/4192 [00:00<00:00, 31955.85it/s]
 25%|██▌       | 1/4 [09:59<29:58, 599.53s/it]


[query exhaustive search] done in 449.813 s


HBox(children=(FloatProgress(value=0.0, description='Sparse retrieval: ', max=4192.0, style=ProgressStyle(desc…

100%|██████████| 4192/4192 [00:00<00:00, 69154.74it/s]
 50%|█████     | 2/4 [17:30<18:29, 554.83s/it]


[query exhaustive search] done in 476.355 s


HBox(children=(FloatProgress(value=0.0, description='Sparse retrieval: ', max=4192.0, style=ProgressStyle(desc…






100%|██████████| 4192/4192 [00:00<00:00, 21048.12it/s]
 75%|███████▌  | 3/4 [25:29<08:52, 532.11s/it]

[query exhaustive search] done in 498.070 s


HBox(children=(FloatProgress(value=0.0, description='Sparse retrieval: ', max=4192.0, style=ProgressStyle(desc…

100%|██████████| 4192/4192 [00:00<00:00, 60744.80it/s]
100%|██████████| 4/4 [33:48<00:00, 507.10s/it]







{1: 0.5923187022900763,
 10: 0.8735687022900763,
 20: 0.9069656488549618,
 50: 0.9386927480916031}

In [5]:
retriever.get_sparse_embedding()

Embedding bm25 pickle load.


In [6]:
df = retriever.retrieve(full_ds,topk = 20)

[query exhaustive search] done in 460.285 s


HBox(children=(FloatProgress(value=0.0, description='Sparse retrieval: ', max=4192.0, style=ProgressStyle(desc…




In [7]:
df["correct"] = df["original_context"] == df["context"]
print("correct retrieval result by exhaustive search",
        df["correct"].sum() / len(df),)

correct retrieval result by exhaustive search 0.0


In [8]:
df["context"][1][:620] ==  df['original_context'][1][:620]

True

In [9]:
df["context"][1][620:630] ==  df['original_context'][1][620:630]

False

In [10]:
df['context'][1][500:630]

"해에서 다양한 기능을 인사조직관리의 목적, 경영의 목적을 위해서 다양한 분야를 통합하여 '유기적 기업 조직' 이해로 전환되었다. 이 통합적 접근방식은 과정, 시스템, 상황을 중심으로 하는 인사조직관리 방식을 형성했다. 인류 역사에서 "

In [11]:
df['original_context'][1][500:630]

"해에서 다양한 기능을 인사조직관리의 목적, 경영의 목적을 위해서 다양한 분야를 통합하여 '유기적 기업 조직' 이해로 전환되었다. 이 통합적 접근방식은 과정, 시스템, 상황을 중심으로 하는 인사조직관리 방식을 형성했다."

In [12]:
correct = 0
for index in range(len(df)):
    if df["context"][index][:400] == df['original_context'][index][:400]:
        correct += 1
print(correct/len(df))

0.5517652671755725


In [13]:
df['context'][1][:200] in df['original_context'][1]

True

In [14]:
len(df["context"][1]), len(df['original_context'][1])

(9724, 621)

In [15]:
topK_list[1, 10, 20, 50]

TypeError: list indices must be integers or slices, not tuple

In [None]:
result = topk_experiment(topK_list)

In [None]:
result