In [None]:
!pip install mteb==1.1.1 datasets beir sentence_transformers

In [None]:
import datasets
from mteb import MTEB
from sentence_transformers import SentenceTransformer
import os


In [None]:
# SET YOUR HUGGING FACE TOKEN
os.environ['HF_TOKEN'] = ""

# SET YOUR MAX SEQUENCE LENGTH
MAX_SEQUENCE_LENGTH = 512

# SET YOUR MODEL
#model_name = "jinaai/jina-embeddings-v2-base-en"
model_name = "thenlper/gte-base"

In [None]:
# from mteb

import logging
from time import time
from typing import Dict, List

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, WordEmbeddings
import os

from mteb.abstasks.AbsTask import AbsTask

logger = logging.getLogger(__name__)

DRES_METHODS = ["encode_queries", "encode_corpus"]

class AbsTaskRetrieval(AbsTask):
    """
    Abstract class for re-ranking experiments.
    Child-classes must implement the following properties:
    self.corpus = Dict[id, Dict[str, str]] #id => dict with document datas like title and text
    self.queries = Dict[id, str] #id => query
    self.relevant_docs = List[id, id, score]
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @staticmethod
    def is_dres_compatible(model):
        for method in DRES_METHODS:
            op = getattr(model, method, None)
            if not (callable(op)):
                return False
        return True

    def evaluate(
        self,
        model,
        split="test",
        batch_size=8,
        corpus_chunk_size=None,
        score_function="cos_sim",
        **kwargs
    ):
        try:
            from beir.retrieval.evaluation import EvaluateRetrieval
        except ImportError:
            raise Exception("Retrieval tasks require beir package. Please install it with `pip install mteb[beir]`")

        if not self.data_loaded:
            self.load_data()

        corpus, queries, relevant_docs = self.corpus[split], self.queries[split], self.relevant_docs[split]
        model = model if self.is_dres_compatible(model) else DRESModel(model)

        if os.getenv("RANK", None) is None:
            # Non-distributed
            from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
            model = DRES(
                model,
                batch_size=batch_size,
                corpus_chunk_size=corpus_chunk_size if corpus_chunk_size is not None else 50000,
                **kwargs,
            )

        else:
            # Distributed (multi-GPU)
            from beir.retrieval.search.dense import (
                DenseRetrievalParallelExactSearch as DRPES,
            )
            model = DRPES(
                model,
                batch_size=batch_size,
                corpus_chunk_size=corpus_chunk_size,
                **kwargs,
            )



        retriever = EvaluateRetrieval(model, score_function=score_function)  # or "cos_sim" or "dot"
        start_time = time()
        results = retriever.retrieve(corpus, queries)
        end_time = time()
        logger.info("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))

        ndcg, _map, recall, precision = retriever.evaluate(relevant_docs, results, retriever.k_values, ignore_identical_ids=kwargs.get("ignore_identical_ids", True))
        mrr = retriever.evaluate_custom(relevant_docs, results, retriever.k_values, "mrr")

        scores = {
            **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
            **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
            **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
            **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
            **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
        }

        return scores


class DRESModel:
    """
    Dense Retrieval Exact Search (DRES) in BeIR requires an encode_queries & encode_corpus method.
    This class converts a MTEB model (with just an .encode method) into BeIR DRES format.
    """

    def __init__(self, model, sep=" ", **kwargs):
        self.model = model
        self.sep = sep
        self.use_sbert_model = isinstance(model, SentenceTransformer)

    def encode_queries(self, queries: List[str], batch_size: int, **kwargs):
        if self.use_sbert_model:
            if isinstance(self.model._first_module(), Transformer):
                logger.info(f"Queries will be truncated to {self.model.get_max_seq_length()} tokens.")
            elif isinstance(self.model._first_module(), WordEmbeddings):
                logger.warning(
                    "Queries will not be truncated. This could lead to memory issues. In that case please lower the batch_size."
                )
        return self.model.encode(queries, batch_size=batch_size, **kwargs)

    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
        if type(corpus) is dict:
            sentences = [
                (corpus["title"][i] + self.sep + corpus["text"][i]).strip()
                if "title" in corpus
                else corpus["text"][i].strip()
                for i in range(len(corpus["text"]))
            ]
        else:
            sentences = [
                (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
                for doc in corpus
            ]
        return self.model.encode(sentences, batch_size=batch_size, **kwargs)


In [None]:
class NarrativeQARetrieval(AbsTaskRetrieval):

    _EVAL_SPLIT = 'test'

    @property
    def description(self):
        return {
            'name': 'NarrativeQARetrieval',
            'hf_hub_name': 'narrativeqa',
            'reference': 'https://metatext.io/datasets/narrativeqa',
            "description": (
                "NarrativeQA is a dataset for the task of question answering on long narratives. It consists of "
                "realistic QA instances collected from literature (fiction and non-fiction) and movie scripts. "
            ),
            "type": "Retrieval",
            "category": "s2p",
            "eval_splits": ["test"],
            "eval_langs": ["en"],
            "main_score": "ndcg_at_10",
        }

    def load_data(self, **kwargs):
        if self.data_loaded:
            return

        data = datasets.load_dataset(self.description['hf_hub_name'], split=self._EVAL_SPLIT, trust_remote_code=True)
        self.queries = {self._EVAL_SPLIT: {str(i): row['question']['text'] for i, row in enumerate(data)}}
        self.corpus = {self._EVAL_SPLIT: {str(row['document']['id']): {'text': row['document']['text']} for row in data}}
        self.relevant_docs = {self._EVAL_SPLIT: {str(i): {row['document']['id']: 1} for i, row in enumerate(data)}}

        # Print the count of queries, corpus and relevant_docs
        print(f"{self._EVAL_SPLIT} queries: {len(self.queries[self._EVAL_SPLIT])}")
        print(f"{self._EVAL_SPLIT} corpus: {len(self.corpus[self._EVAL_SPLIT])}")
        print(f"{self._EVAL_SPLIT} relevant docs: {len(self.relevant_docs[self._EVAL_SPLIT])}")

        self.data_loaded = True

In [None]:
model = SentenceTransformer(model_name, trust_remote_code=True)
model.max_seq_length = MAX_SEQUENCE_LENGTH
evaluation = MTEB(tasks=["NarrativeQARetrieval"])
results = evaluation.run(model, output_folder=f"results/retrieval/narrativeqa/sequence_length/{MAX_SEQUENCE_LENGTH}/{model_name}")