In [1]:
import sys
sys.path.append("..")

In [2]:
from eval.dataset import RAGDataset

ds = RAGDataset.from_file("../data/airbench_qa_healthcare_zh_synthesis.json")

In [5]:
from eval.sentence_transformer_model import SentenceTransformerEncoder
from eval.embedding_model import EmbeddingModelPreparer, EmbeddingModelRetriever


embedding_model = SentenceTransformerEncoder(
    model_name_or_path="BAAI/bge-small-zh-v1.5",
    normalize_embeddings=True,
    query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章：",
    max_query_length=512,
    max_passage_length=512,
)

preparer = EmbeddingModelPreparer(
        embedding_model,
        cache_dir="../.cache",
    )

retriever = EmbeddingModelRetriever(
        embedding_model, 
        search_top_k=100, # 搜索出 100 个候选负例
    )

In [6]:
print(">>> Preparing corpus embeddings and faiss index...")
vectorstore = preparer(ds.corpus)

>>> Preparing corpus embeddings and faiss index...


Chunks: 100%|██████████| 4/4 [00:00<00:00,  8.75it/s]
`embedding_function` is expected to be an Embeddings object, support for passing in a function will soon be removed.


In [7]:
print(">>> Retrieving...")
to_retrieve_queries = ds.get_queries_split("train")
retriever_result = retriever(vectorstore=vectorstore,
            queries=to_retrieve_queries)

>>> Retrieving...


Chunks: 100%|██████████| 10/10 [00:01<00:00,  9.46it/s]
Retrieving: 100%|██████████| 9810/9810 [02:41<00:00, 60.60it/s]


In [10]:
print(">>> Hard negative mining...")
# 从 [50:100] 中随机选择 15个文档作为负例，从50开始是避免之前的 embedding 模型识别不太好
import random
random.seed(42)
negative_docs_count = 0 
for query_id, docs in retriever_result.items():
    relevant_doc_ids = ds.relevant_docs[query_id]
    if len(relevant_doc_ids) == 0:
        continue
    doc_ids = [doc_id for doc_id, _ in docs]
    negative_docs = random.sample(doc_ids[50:], 15)
    negative_docs_count += len(negative_docs)
    ds.negative_docs[query_id] = negative_docs
print(f">>> Total negative docs: {negative_docs_count}")

>>> Hard negative mining...
>>> Total negative docs: 147150


In [11]:
ds.save("../data/airbench_qa_healthcare_zh_synthesis_hard_negative.json")