In [1]:
from pathlib import Path
import pandas as pd

from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval import models
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval.evaluation import EvaluateRetrieval

  from tqdm.autonotebook import tqdm


In [2]:
BASE_DIR = Path(".").resolve()
# 如果以后把 notebook 放到 notebooks/ 之类的子目录，就改成：
# BASE_DIR = Path("..").resolve()
DATASETS = {
    "clapnq": {
        "data_dir": BASE_DIR / "dataset" / "clapnq",
        "faiss_dir": BASE_DIR / "indexes" / "clapnq-bge-faiss",
    },
    "cloud": {
        "data_dir": BASE_DIR / "dataset" / "cloud",
        "faiss_dir": BASE_DIR / "indexes" / "cloud-bge-faiss",
    },
    "fiqa": {
        "data_dir": BASE_DIR / "dataset" / "fiqa",
        "faiss_dir": BASE_DIR / "indexes" / "fiqa-bge-faiss",
    },
    "govt": {
        "data_dir": BASE_DIR / "dataset" / "govt",
        "faiss_dir": BASE_DIR / "indexes" / "govt-bge-faiss",
    },
}


In [6]:
str(DATASETS["govt"]["data_dir"])

'/Users/sandylin/Challenge in CL/dataset/govt'

In [3]:
MODEL_NAME = "BAAI/bge-base-en-v1.5"  
BATCH_SIZE = 64 #adjustable
SPLIT = "train"    # 现在 qrels 里是 train.tsv

# initialize embedding model
embedding_model = models.SentenceBERT(MODEL_NAME)

# DenseRetrievalExactSearch 精确搜索
dres_model = DRES(embedding_model, batch_size=BATCH_SIZE)

# EvaluateRetrieval
retriever = EvaluateRetrieval(dres_model, score_function="cos_sim")

K_VALUES = [1, 3, 5, 10]


In [4]:
def evaluate_dataset(dataset_name: str):
    cfg = DATASETS[dataset_name]
    data_dir = cfg["data_dir"]
    #faiss_dir = cfg["faiss_dir"]

    print(f"Dataset: {dataset_name}")
    print("data_dir :", data_dir)
    #print("faiss_dir:", faiss_dir)

    # 保证目录存在
    #faiss_dir.mkdir(parents=True, exist_ok=True)

    # loading data into beir
    corpus, queries, qrels = GenericDataLoader(
        data_folder=str(data_dir)
    ).load(split=SPLIT) # TODO：这里有些变量名以后可以统一

    print(f"#docs = {len(corpus)}, #queries = {len(queries)}")

    # encode 
    results = retriever.retrieve(
        corpus, 
        queries, 
    )

    # nDCG  Recall 
    print("start evaluation")
    ndcg, _map, recall, precision = retriever.evaluate(qrels, results, K_VALUES)
    print("Done evaluation")

    return {
        "ndcg": ndcg,
        "recall": recall,
    }


In [5]:
all_metrics = {}

for name in ["clapnq", "cloud", "fiqa", "govt"]:
    all_metrics[name] = evaluate_dataset(name)

all_metrics

Dataset: clapnq
data_dir : /Users/sandylin/Challenge in CL/dataset/clapnq
faiss_dir: /Users/sandylin/Challenge in CL/indexes/clapnq-bge-faiss


  0%|          | 0/183408 [00:00<?, ?it/s]

#docs = 183408, #queries = 208


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/782 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
rows = []

for ds_name, metrics in all_metrics.items():
    ndcg = metrics["ndcg"]
    recall = metrics["recall"]

    for k in K_VALUES:
        rows.append({
            "dataset": ds_name,
            "k": k,
            "nDCG": ndcg[k],
            "Recall": recall[k],
        })

df_results = pd.DataFrame(rows)
df_results