In [None]:
!pip install mteb==1.1.1 datasets beir sentence_transformers

In [None]:
import datasets
from mteb import MTEB
from sentence_transformers import SentenceTransformer
import os

In [None]:
# SET YOUR HUGGINGFACE TOKEN
os.environ['HF_TOKEN'] = ""

# SET YOUR CHUNKING SIZE (how many words per chunk)
CHUNKING_SIZE = 1000

# SET YOUR MODEL
#model_name = "jinaai/jina-embeddings-v2-base-en"
model_name = "thenlper/gte-base"

DenseRetrievalExactSearch class from BeIR Benchmark with modified score calculation to use the top chunk for each document.

In [None]:
from beir.retrieval.search import BaseSearch
from beir.retrieval.search.dense.util import cos_sim, dot_score
import logging
import torch
from typing import Dict
import heapq

logger = logging.getLogger(__name__)

# DenseRetrievalExactSearch is parent class for any dense model that can be used for retrieval
# Abstract class is BaseSearch
class DenseRetrievalExactSearch(BaseSearch):

    def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = 50000, **kwargs):
        #model is class that provides encode_corpus() and encode_queries()
        self.model = model
        self.batch_size = batch_size
        self.score_functions = {'cos_sim': cos_sim, 'dot': dot_score}
        self.score_function_desc = {'cos_sim': "Cosine Similarity", 'dot': "Dot Product"}
        self.corpus_chunk_size = corpus_chunk_size
        self.show_progress_bar = kwargs.get("show_progress_bar", True)
        self.convert_to_tensor = kwargs.get("convert_to_tensor", True)
        self.results = {}

    def search(self,
               corpus: Dict[str, Dict[str, str]],
               queries: Dict[str, str],
               top_k: int,
               score_function: str,
               return_sorted: bool = False,
               **kwargs) -> Dict[str, Dict[str, float]]:
        # Create embeddings for all queries using model.encode_queries()
        # Runs semantic search against the corpus embeddings
        # Returns a ranked list with the corpus ids
        if score_function not in self.score_functions:
            raise ValueError("score function: {} must be either (cos_sim) for cosine similarity or (dot) for dot product".format(score_function))

        logger.info("Encoding Queries...")
        query_ids = list(queries.keys())
        self.results = {qid: {} for qid in query_ids}
        queries = [queries[qid] for qid in queries]
        query_embeddings = self.model.encode_queries(
            queries, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_tensor=self.convert_to_tensor)

        logger.info("Sorting Corpus by document length (Longest first)...")

        corpus_ids = sorted(corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True)
        corpus = [corpus[cid] for cid in corpus_ids]

        logger.info("Encoding Corpus in batches... Warning: This might take a while!")
        logger.info("Scoring Function: {} ({})".format(self.score_function_desc[score_function], score_function))

        itr = range(0, len(corpus), self.corpus_chunk_size)

        result_heaps = {qid: [] for qid in query_ids}  # Keep only the top-k docs for each query
        for batch_num, corpus_start_idx in enumerate(itr):
            logger.info("Encoding Batch {}/{}...".format(batch_num+1, len(itr)))
            corpus_end_idx = min(corpus_start_idx + self.corpus_chunk_size, len(corpus))

            # Encode chunk of corpus
            sub_corpus_embeddings = self.model.encode_corpus(
                corpus[corpus_start_idx:corpus_end_idx],
                batch_size=self.batch_size,
                show_progress_bar=self.show_progress_bar,
                convert_to_tensor = self.convert_to_tensor
                )

            # Compute similarites using either cosine-similarity or dot product
            cos_scores = self.score_functions[score_function](query_embeddings, sub_corpus_embeddings)
            cos_scores[torch.isnan(cos_scores)] = -1

            # Get top-k values
            cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(cos_scores, min(top_k+1, len(cos_scores[1])), dim=1, largest=True, sorted=return_sorted)
            cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
            cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()

            for query_itr in range(len(query_embeddings)):
                query_id = query_ids[query_itr]
                for sub_corpus_id, score in zip(cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr]):
                    corpus_id = corpus_ids[corpus_start_idx+sub_corpus_id]
                    if corpus_id != query_id:
                        if len(result_heaps[query_id]) < top_k:
                            # Push item on the heap
                            heapq.heappush(result_heaps[query_id], (score, corpus_id))
                        else:
                            # If item is larger than the smallest in the heap, push it on the heap then pop the smallest element
                            heapq.heappushpop(result_heaps[query_id], (score, corpus_id))

        for qid in result_heaps:
            document_scores = {}
            while result_heaps[qid]:
                score, chunk_id = heapq.heappop(result_heaps[qid])
                document_id = "_".join(chunk_id.split("_")[:-2])  # Extract base document ID from chunk ID
                if document_id not in document_scores or score > document_scores[document_id]:
                    document_scores[document_id] = score
            self.results[qid] = document_scores

        if return_sorted:
            for qid in self.results:
                self.results[qid] = dict(sorted(self.results[qid].items(), key=lambda item: item[1], reverse=True))

        return self.results


DRES = DenseRetrievalExactSearch

AbsTaskRetrieval class from MTEB with modified batch size

In [None]:
# from mteb

import logging
from time import time
from typing import Dict, List

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, WordEmbeddings
import os

from mteb.abstasks.AbsTask import AbsTask

logger = logging.getLogger(__name__)

DRES_METHODS = ["encode_queries", "encode_corpus"]

class AbsTaskRetrieval(AbsTask):
    """
    Abstract class for re-ranking experiments.
    Child-classes must implement the following properties:
    self.corpus = Dict[id, Dict[str, str]] #id => dict with document datas like title and text
    self.queries = Dict[id, str] #id => query
    self.relevant_docs = List[id, id, score]
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @staticmethod
    def is_dres_compatible(model):
        for method in DRES_METHODS:
            op = getattr(model, method, None)
            if not (callable(op)):
                return False
        return True

    def evaluate(
        self,
        model,
        split="test",
        batch_size=1,
        corpus_chunk_size=None,
        score_function="cos_sim",
        **kwargs
    ):
        try:
            from beir.retrieval.evaluation import EvaluateRetrieval
        except ImportError:
            raise Exception("Retrieval tasks require beir package. Please install it with `pip install mteb[beir]`")

        if not self.data_loaded:
            self.load_data()

        corpus, queries, relevant_docs = self.corpus[split], self.queries[split], self.relevant_docs[split]
        model = model if self.is_dres_compatible(model) else DRESModel(model)

        model = DRES(
            model,
            batch_size=batch_size,
            corpus_chunk_size=corpus_chunk_size if corpus_chunk_size is not None else 50000,
            **kwargs,
        )


        retriever = EvaluateRetrieval(model, score_function=score_function)  # or "cos_sim" or "dot"
        start_time = time()
        results = retriever.retrieve(corpus, queries)
        end_time = time()
        logger.info("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))

        ndcg, _map, recall, precision = retriever.evaluate(relevant_docs, results, retriever.k_values, ignore_identical_ids=kwargs.get("ignore_identical_ids", True))
        mrr = retriever.evaluate_custom(relevant_docs, results, retriever.k_values, "mrr")

        scores = {
            **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
            **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
            **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
            **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
            **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
        }

        return scores


class DRESModel:
    """
    Dense Retrieval Exact Search (DRES) in BeIR requires an encode_queries & encode_corpus method.
    This class converts a MTEB model (with just an .encode method) into BeIR DRES format.
    """

    def __init__(self, model, sep=" ", **kwargs):
        self.model = model
        self.sep = sep
        self.use_sbert_model = isinstance(model, SentenceTransformer)

    def encode_queries(self, queries: List[str], batch_size: int, **kwargs):
        if self.use_sbert_model:
            if isinstance(self.model._first_module(), Transformer):
                logger.info(f"Queries will be truncated to {self.model.get_max_seq_length()} tokens.")
            elif isinstance(self.model._first_module(), WordEmbeddings):
                logger.warning(
                    "Queries will not be truncated. This could lead to memory issues. In that case please lower the batch_size."
                )
        return self.model.encode(queries, batch_size=batch_size, **kwargs)

    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
        if type(corpus) is dict:
            sentences = [
                (corpus["title"][i] + self.sep + corpus["text"][i]).strip()
                if "title" in corpus
                else corpus["text"][i].strip()
                for i in range(len(corpus["text"]))
            ]
        else:
            sentences = [
                (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
                for doc in corpus
            ]
        return self.model.encode(sentences, batch_size=batch_size, **kwargs)


In [None]:
from collections import defaultdict
class NarrativeQARetrieval(AbsTaskRetrieval):

    _EVAL_SPLIT = 'test'

    @property
    def description(self):
        return {
            'name': 'NarrativeQARetrieval',
            'hf_hub_name': 'narrativeqa',
            'reference': 'https://metatext.io/datasets/narrativeqa',
            "description": (
                "NarrativeQA is a dataset for the task of question answering on long narratives. It consists of "
                "realistic QA instances collected from literature (fiction and non-fiction) and movie scripts. "
            ),
            "type": "Retrieval",
            "category": "s2p",
            "eval_splits": ["test"],
            "eval_langs": ["en"],
            "main_score": "ndcg_at_10",
        }

    def load_data(self, **kwargs):
        if self.data_loaded:
            return

        data = datasets.load_dataset(self.description['hf_hub_name'], split=self._EVAL_SPLIT, trust_remote_code=True)


        self.queries = {self._EVAL_SPLIT: {str(i): row['question']['text'] for i, row in enumerate(data)}}
        self.corpus = {self._EVAL_SPLIT: {str(row['document']['id']): {'text': row['document']['text']} for row in data}}
        self.relevant_docs = {self._EVAL_SPLIT: {str(i): {row['document']['id']: 1} for i, row in enumerate(data)}}

        # Change chunk size length here (words)
        chunk_length = CHUNKING_SIZE

        # New structures for chunked corpus and relevant docs
        chunked_corpus = {self._EVAL_SPLIT: {}}
        chunked_relevant_docs = {self._EVAL_SPLIT: defaultdict(dict)}

        # Function to chunk the documents
        def chunk_document(doc_text, chunk_length):
            words = doc_text.split()
            for i in range(0, len(words), chunk_length):
                yield ' '.join(words[i:i+chunk_length])

        # Chunk the documents
        for doc_id, doc_data in self.corpus[self._EVAL_SPLIT].items():
            doc_text = doc_data['text']
            for i, chunk in enumerate(chunk_document(doc_text, chunk_length)):
                chunk_id = f'{doc_id}_chunk_{i}'
                chunked_corpus[self._EVAL_SPLIT][chunk_id] = {'text': chunk}


        self.corpus = chunked_corpus


        # Print the count of queries, corpus and relevant_docs
        print(f"{self._EVAL_SPLIT} queries: {len(self.queries[self._EVAL_SPLIT])}")
        print(f"{self._EVAL_SPLIT} corpus: {len(self.corpus[self._EVAL_SPLIT])}")
        print(f"{self._EVAL_SPLIT} relevant docs: {len(self.relevant_docs[self._EVAL_SPLIT])}")

        self.data_loaded = True

In [None]:
model = SentenceTransformer(model_name, trust_remote_code=True)
evaluation = MTEB(tasks=["NarrativeQARetrieval"])
results = evaluation.run(model, output_folder=f"results/retrieval/narrativeqa/chunking/{CHUNKING_SIZE}/{model_name}")