In [1]:
from mteb.evaluation.evaluators import RetrievalEvaluator
from typing import Any, Dict
from air_benchmark import AIRBench, Retriever
from FlagEmbedding import FlagModel
import os

## 开启代理以连接huggingface下载数据集
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'

class FlagEmbeddingModel:
    def __init__(self, model_path: str, **kwargs):
        self.model_name = os.path.basename(model_path)
        self.model = FlagModel(
            model_name_or_path=model_path, 
            query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章：",
            use_fp16=True
        ) 
    
    def __str__(self) -> str:
        return self.model_name
    
    def encode_corpus(self, corpus, **kwargs):
        input_texts = corpus
        if isinstance(corpus[0], dict):
            input_texts = [
                "{} {}".format(doc.get("title", ""), doc.get("text", "")).strip()
                for doc in corpus
            ]
        return self.encode(input_texts, **kwargs)
    
    def encode_queries(self, queries, **kwargs):
        input_texts = queries
        if isinstance(queries[0], dict):
            input_texts = [doc.get("text", "").strip() for doc in queries]
        return self.encode(input_texts, **kwargs)

    def encode(self, sentences, **kwargs):
        embeddings = self.model.encode(sentences)
        return embeddings


class EmbeddingModelRetriever(Retriever):
    def __init__(self, embedding_model, search_top_k: int = 1000, **kwargs):
        self.embedding_model = embedding_model
        super().__init__(search_top_k)
        self.retriever = RetrievalEvaluator(
            retriever=self.embedding_model,
            k_values=[self.search_top_k],
            **kwargs,
        )
    
    def __str__(self):
        return str(self.embedding_model)
    
    def __call__(
        self,
        corpus: Dict[str, Dict[str, Any]],
        queries: Dict[str, str],
        **kwargs,
    ):
        search_results = self.retriever(corpus=corpus, queries=queries)
        return search_results


embedding_model = FlagEmbeddingModel('../resources/open_models/bge-large-zh-v1.5')

evaluation = AIRBench(
    benchmark_version="AIR-Bench_24.05",
    task_types=["qa"],    # choose a single task for demo purpose
    domains=["finance"],           # choose a single domain for demo purpose
    languages=["zh"],           # choose a single language for demo purpose
    splits=["dev"],            # choose a single split for demo purpose
    cache_dir="../resources/data/raw"
)

retriever = EmbeddingModelRetriever(
    embedding_model, 
    search_top_k=1000,
    corpus_chunk_size=10000,  # change to 10_000_000 when encoding the large corpus to avoid multiple tqdm bars
)

evaluation.run(
    retriever,
    output_dir='./results/search',
    overwrite=True,
)
    
# compute metrics for dev set
evaluation.evaluate_dev(
    benchmark_version="AIR-Bench_24.05",
    search_results_save_dir='./results/search',
    output_method="markdown",
    output_path='./results/eval_dev_results.md',
    metrics=["ndcg_at_10", "recall_at_10"],
    cache_dir="../resources/data/raw"
)



  from .autonotebook import tqdm as notebook_tqdm


----------using 8*GPUs----------


    There is an imbalance between your GPUs. You may want to exclude GPU 4 which
    has less than 75% of the memory or cores of GPU 0. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.


Downloading readme: 100%|██████████| 978/978 [00:00<00:00, 3.58kB/s]
Inference Embeddings: 100%|██████████| 1/1 [00:05<00:00,  5.07s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:16<00:00,  3.30s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:16<00:00,  3.24s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:15<00:00,  3.17s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:15<00:00,  3.19s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:16<00:00,  3.20s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:15<00:00,  3.20s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:15<00:00,  3.16s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:16<00:00,  3.23s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:16<00:00,  3.23s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:16<00:00,  3.20s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:16<00:00,  3.23s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:16<00:00,  3.24s/it]
Inference Embeddings: 100%|██████████| 5/5 [00:

Downloading readme: 100%|██████████| 463/463 [00:00<00:00, 1.77kB/s]


Results saved to ./results/eval_dev_results.md
