In [1]:
from datasets import concatenate_datasets, load_from_disk
from src.retriever.retrieval.retrieval import SparseRetrieval
from retriever.embedding.flag_embedding import DenseRetrieval
from transformers import AutoTokenizer
from tqdm import tqdm
from src.retriever.score.ranking import check_original_in_context, calculate_reverse_rank_score, calculate_linear_score
org_dataset = load_from_disk('./data/train_dataset')
print(org_dataset)
full_ds = concatenate_datasets(
        [
            org_dataset["train"].flatten_indices(),
            org_dataset["validation"].flatten_indices(),
        ]
    )  # train dev 를 합친 4192 개 질문에 대해 모두 테스트
print("*" * 40, "query dataset", "*" * 40)
print(full_ds)

  from .autonotebook import tqdm as notebook_tqdm


DatasetDict({
    train: Dataset({
        features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__'],
        num_rows: 3952
    })
    validation: Dataset({
        features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__'],
        num_rows: 240
    })
})
**************************************** query dataset ****************************************
Dataset({
    features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__'],
    num_rows: 4192
})


In [2]:
data_path = "./data/"
context_path = "wikipedia_documents.json"

In [4]:
import json
import os
with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
            wiki = json.load(f)
docs = list(dict.fromkeys([v["text"] for v in wiki.values()]))
print(f"Lengths of unique contexts : {len(docs)}")
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased", use_fast=False,)
tokenized_docs = [tokenizer(doc) for doc in tqdm(docs, desc="Tokenizing...")]

Lengths of unique contexts : 56737


Tokenizing...:   0%|          | 0/56737 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1337 > 512). Running this sequence through the model will result in indexing errors
Tokenizing...: 100%|██████████| 56737/56737 [01:26<00:00, 659.58it/s]


In [None]:
# 위에서 선언한거 가져오기 Retriever
for b in [0.25, 0.5, 0.75, 1.0]:
    for k1 in [1.1, 1.3, 1.5, 1.7, 1.9]:
        retriever = SparseRetrieval(
            tokenize_fn=tokenizer.tokenize,
            data_path="./data/",
            context_path="wikipedia_documents.json",
            mode = "bm25",
            max_feature=200000,
            ngram_range=(1,2),
            tokenized_docs = tokenized_docs,
            b=b,
            k1=k1,
        )
        print(f'b: {b} & k1 : {k1}')
        retriever.get_sparse_embedding()
        df = retriever.retrieve(full_ds, topk=10)
        retriever.get_score(df)



Lengths of unique contexts : 56737
b: 0.25 & k1 : 1.1
Building bm25 embedding...
Start Initializing...


Tokenizing...: 100%|██████████| 56737/56737 [01:14<00:00, 758.94it/s]


Generating n-grams and building vocabulary...


Generating n-grams: 100%|██████████| 56737/56737 [00:22<00:00, 2473.04it/s]


Current mode : bm25
End Initialization


Calculating BM25: 100%|██████████| 57/57 [07:20<00:00,  7.73s/it]


Finish BM25 Embedding
bm25 embedding shape: (56737, 200000)
(4192, 200000) (56737, 200000)
result shape : (4192, 56737)
[query exhaustive search] done in 19.765 s


Sparse retrieval: 100%|██████████| 4192/4192 [00:00<00:00, 10695.75it/s]


correct retrieval 0.8322996183206107
reverse rank retrieval 0.48521718441181044
linear retrieval 0.752088091571962
Lengths of unique contexts : 56737
b: 0.25 & k1 : 1.3
Building bm25 embedding...
Start Initializing...


Tokenizing...: 100%|██████████| 56737/56737 [01:15<00:00, 754.89it/s]


Generating n-grams and building vocabulary...


Generating n-grams: 100%|██████████| 56737/56737 [00:22<00:00, 2484.12it/s]


Current mode : bm25
End Initialization


Calculating BM25: 100%|██████████| 57/57 [07:18<00:00,  7.69s/it]


Finish BM25 Embedding
bm25 embedding shape: (56737, 200000)
(4192, 200000) (56737, 200000)
result shape : (4192, 56737)
[query exhaustive search] done in 19.381 s


Sparse retrieval: 100%|██████████| 4192/4192 [00:00<00:00, 11050.64it/s]


correct retrieval 0.8356393129770993
reverse rank retrieval 0.4890297805197442
linear retrieval 0.7521429663755775
Lengths of unique contexts : 56737
b: 0.25 & k1 : 1.5
Building bm25 embedding...
Start Initializing...


Tokenizing...: 100%|██████████| 56737/56737 [01:14<00:00, 757.50it/s]


Generating n-grams and building vocabulary...


Generating n-grams: 100%|██████████| 56737/56737 [00:22<00:00, 2502.93it/s]


Current mode : bm25
End Initialization


Calculating BM25: 100%|██████████| 57/57 [07:17<00:00,  7.67s/it]


Finish BM25 Embedding
bm25 embedding shape: (56737, 200000)
(4192, 200000) (56737, 200000)
result shape : (4192, 56737)
[query exhaustive search] done in 20.525 s


Sparse retrieval: 100%|██████████| 4192/4192 [00:00<00:00, 10392.54it/s]


correct retrieval 0.8361164122137404
reverse rank retrieval 0.484263658173665
linear retrieval 0.7505740053221104
Lengths of unique contexts : 56737
b: 0.25 & k1 : 1.7
Building bm25 embedding...
Start Initializing...


Tokenizing...: 100%|██████████| 56737/56737 [01:16<00:00, 737.57it/s]


Generating n-grams and building vocabulary...


Generating n-grams: 100%|██████████| 56737/56737 [00:23<00:00, 2399.43it/s]


Current mode : bm25
End Initialization


Calculating BM25: 100%|██████████| 57/57 [07:40<00:00,  8.09s/it]


Finish BM25 Embedding
bm25 embedding shape: (56737, 200000)
(4192, 200000) (56737, 200000)
result shape : (4192, 56737)
[query exhaustive search] done in 21.169 s


Sparse retrieval: 100%|██████████| 4192/4192 [00:00<00:00, 11078.63it/s]


correct retrieval 0.8327767175572519
reverse rank retrieval 0.4825900502852024
linear retrieval 0.7487632477260269
Lengths of unique contexts : 56737
b: 0.25 & k1 : 1.9
Building bm25 embedding...
Start Initializing...


Tokenizing...: 100%|██████████| 56737/56737 [01:19<00:00, 712.10it/s]


Generating n-grams and building vocabulary...


Generating n-grams: 100%|██████████| 56737/56737 [00:28<00:00, 2025.91it/s]


Current mode : bm25
End Initialization


Calculating BM25:  54%|█████▍    | 31/57 [04:30<03:53,  9.00s/it]

In [5]:
retriever = SparseRetrieval(
            tokenize_fn=tokenizer.tokenize,
            data_path="./data/",
            context_path="wikipedia_documents.json",
            mode = "bm25",
            max_feature=200000,
            ngram_range=(1,2),
            tokenized_docs = tokenized_docs,
            b=0.25,
            k1=1.1,
        )
retriever.get_sparse_embedding()
df = retriever.retrieve(full_ds, topk=10)
retriever.get_score(df)

Lengths of unique contexts : 56737
Building bm25 embedding...
Start Initializing...
Pass Tokenizing
Generating n-grams and building vocabulary...


Generating n-grams: 100%|██████████| 56737/56737 [00:00<00:00, 169690.96it/s]


Current mode : bm25
End Initialization


Calculating BM25: 100%|██████████| 57/57 [00:00<00:00, 139.21it/s]


Finish BM25 Embedding
New embeddings calculated and saved.
bm25 embedding shape: (56737, 1)


AssertionError: query_vecs가 제대로 변환되지않음.

In [None]:
df = retriever.retrieve(full_ds, topk=10)
retriever.get_score(df)


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# 데이터 준비
data = {
    'max_feature': [48600, 50000, 80000, 100000, 200000, 400000, 1000000, 1784976, 2000000],
    'Correct': [0.6836832061, 0.7645515267, 0.7924618321, 0.8058206107, 0.8401717557, 0.8590171756, 0.8707061069, 0.8730916031, 0.8289599237],
    'Reverse_Rank': [0.3587869386, 0.4144341905, 0.4454477764, 0.4557101427, 0.4905500473, 0.5160778654, 0.5361060586, 0.5403996207, 0.4767265568],
    'Linear': [0.5986908422, 0.6757890968, 0.7080769905, 0.7213994905, 0.7603113276, 0.7819702961, 0.7966366694, 0.8001752019, 0.7452705811]
}

df = pd.DataFrame(data)

# 'no gram' 표시 추가
df['label'] = df['max_feature'].astype(str)
df.loc[0, 'label'] = '48600 (no gram)'

# 데이터 정렬
df = df.sort_values('max_feature')

# 그래프 스타일 설정
plt.figure(figsize=(14, 8))

# 선 그래프 그리기
plt.plot(df['max_feature'], df['Correct'], marker='o', label='Correct')
plt.plot(df['max_feature'], df['Reverse_Rank'], marker='s', label='Reverse Rank')
plt.plot(df['max_feature'], df['Linear'], marker='^', label='Linear')

# 그래프 꾸미기
plt.xscale('log')  # x축을 로그 스케일로 변경
plt.xlabel('Max Feature')
plt.ylabel('Score')
plt.title('Performance Metrics vs Max Feature')
plt.legend()
plt.grid(True, which="both", ls="-", alpha=0.2)

# x축 눈금 설정
plt.xticks(df['max_feature'], df['label'], rotation=45, ha='right')

plt.tight_layout()
plt.show()
