In [None]:
import random
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from datasets import load_dataset, Dataset

# Load a model
model = SentenceTransformer('../training/models/german-nq-granite-embedding-107m-multilingual-exclude-pooling-prompts/checkpoint-4560')
#model = SentenceTransformer("ibm-granite/granite-embedding-107m-multilingual")



In [63]:
revosax = load_dataset("csv", data_files="../data/training/training-data.csv", split="train")
revosax = revosax.rename_column("result", "query").rename_column("chunk", "answer")
revosax = revosax.select_columns(['query', 'answer'])
revosax = revosax.train_test_split(test_size=0.2, seed=12)

train_dataset: Dataset = revosax["train"]
eval_dataset: Dataset = revosax["test"]

In [57]:
eval_dataset

Dataset({
    features: ['query', 'answer'],
    num_rows: 12158
})

In [53]:
#test = eval_dataset.map(lambda x: x["query"])

queries = {str(i): q for i, q in enumerate(eval_dataset["query"])}
corpus  = {str(i): a for i, a in enumerate(eval_dataset["answer"])}

relevant_docs = {qid: {qid} for qid in queries.keys()}


In [14]:
# For this dataset, we want to concatenate the title and texts for the corpus
corpus = corpus.map(lambda x: {'text': x['title'] + " " + x['text']}, remove_columns=['title'])

# Shrink the corpus size heavily to only the relevant documents + 30,000 random documents
required_corpus_ids = set(map(str, relevant_docs_data["corpus-id"]))
required_corpus_ids |= set(random.sample(corpus["_id"], k=30_000))
corpus = corpus.filter(lambda x: x["_id"] in required_corpus_ids)

# Convert the datasets to dictionaries
corpus = dict(zip(corpus["_id"], corpus["text"]))  # Our corpus (cid => document)
queries = dict(zip(queries["_id"], queries["text"]))  # Our queries (qid => question)
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for qid, corpus_ids in zip(relevant_docs_data["query-id"], relevant_docs_data["corpus-id"]):
    qid = str(qid)
    corpus_ids = str(corpus_ids)
    if qid not in relevant_docs:
        relevant_docs[qid] = set()
    relevant_docs[qid].add(corpus_ids)


Filter: 100%|██████████| 382545/382545 [00:00<00:00, 684215.81 examples/s]


In [61]:

# Given queries, a corpus and a mapping with relevant documents, the InformationRetrievalEvaluator computes different IR metrics.
ir_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name="revosax-test-eval",
)
results = ir_evaluator(model)

print(ir_evaluator.primary_metric)
# => "BeIR-touche2020-test_cosine_map@100"
print(results[ir_evaluator.primary_metric])
# => 0.29335196224364596

revosax-test-eval_cosine_ndcg@10
0.7573468749744551


In [62]:
results

{'revosax-test-eval_cosine_accuracy@1': 0.5786313538410923,
 'revosax-test-eval_cosine_accuracy@3': 0.8112354005593025,
 'revosax-test-eval_cosine_accuracy@5': 0.8670011515051818,
 'revosax-test-eval_cosine_accuracy@10': 0.9207928935680211,
 'revosax-test-eval_cosine_precision@1': 0.5786313538410923,
 'revosax-test-eval_cosine_precision@3': 0.2704118001864342,
 'revosax-test-eval_cosine_precision@5': 0.17340023030103638,
 'revosax-test-eval_cosine_precision@10': 0.09207928935680212,
 'revosax-test-eval_cosine_recall@1': 0.5786313538410923,
 'revosax-test-eval_cosine_recall@3': 0.8112354005593025,
 'revosax-test-eval_cosine_recall@5': 0.8670011515051818,
 'revosax-test-eval_cosine_recall@10': 0.9207928935680211,
 'revosax-test-eval_cosine_ndcg@10': 0.7573468749744551,
 'revosax-test-eval_cosine_mrr@10': 0.7041161088002624,
 'revosax-test-eval_cosine_map@100': 0.7074822036268906}

In [None]:
{'revosax-test-eval_cosine_accuracy@1': 0.673877282447771,
 'revosax-test-eval_cosine_accuracy@3': 0.8992432965948347,
 'revosax-test-eval_cosine_accuracy@5': 0.9430827438723475,
 'revosax-test-eval_cosine_accuracy@10': 0.9742556341503537,
 'revosax-test-eval_cosine_precision@1': 0.673877282447771,
 'revosax-test-eval_cosine_precision@3': 0.29974776553161153,
 'revosax-test-eval_cosine_precision@5': 0.1886165487744695,
 'revosax-test-eval_cosine_precision@10': 0.09742556341503539,
 'revosax-test-eval_cosine_recall@1': 0.673877282447771,
 'revosax-test-eval_cosine_recall@3': 0.8992432965948347,
 'revosax-test-eval_cosine_recall@5': 0.9430827438723475,
 'revosax-test-eval_cosine_recall@10': 0.9742556341503537,
 'revosax-test-eval_cosine_ndcg@10': 0.8373392143352189,
 'revosax-test-eval_cosine_mrr@10': 0.7919388435859073,
 'revosax-test-eval_cosine_map@100': 0.7932480849366578}