In [1]:
import os
# Enforces CPU-only execution of torch
os.environ["CUDA_VISIBLE_DEVICES"] = ""

# Configure environment to ensure single-threaded execution.
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"]= "1"
os.environ["OMP_NUM_THREADS"] = "1"

import torch
torch.set_num_threads(1)

In [34]:
from xtr.datasets import BEIR, BEIRDataset, LoTTE, LoTTEDataset
from xtr.config import XTRConfig, XTRModel, XTRScaNNIndexConfig, XTRBruteForceIndexConfig, XTRFAISSIndexConfig
from xtr.utils import xtr_tracker, canonical_index_name
from xtr.modeling.xtr import XTR

import json

In [18]:
def xtr_eval_latency(dataset, index_config, document_top_k, token_top_k):
    index_name = canonical_index_name(dataset=dataset, index_config=index_config)
    config = XTRConfig(index_name=index_name, model=XTRModel.BASE_EN, index_config=index_config, override=False)
    xtr = XTR(config=config, collection=dataset.collection, device=torch.device("cpu"))
    tracker = xtr_tracker(name=index_name)
    rankings = xtr.retrieve_docs(dataset.queries, document_top_k=document_top_k, token_top_k=token_top_k, tracker=tracker)
    return tracker, dataset.eval(rankings)

In [29]:
def xtr_run_configuration(dataset, index_config, document_top_k, token_top_k):
    tracker, metrics = xtr_eval_latency(dataset, index_config, document_top_k, token_top_k)
    configuration = {"dataset": dataset.name, "index": index_config.name,
                     "document_top_k": document_top_k, "token_top_k": token_top_k}
    return {
        "config": configuration,
        "metrics": metrics,
        "tracker": tracker.as_dict()
    }

In [32]:
def xtr_run_configurations(datasets, index_configs, document_top_k, token_top_k_values, override=False):
    if os.path.exists("results.json") and not override:
        raise AssertionError
    results = []
    for dataset in datasets:
        for index_config in index_configs:
            for token_top_k in token_top_k_values:
                results.append(xtr_run_configuration(dataset, index_config, document_top_k=document_top_k, token_top_k=token_top_k))
                with open("results.json", "w") as file:
                    json.dump(results, file)

In [36]:
DATASETS = [LoTTEDataset(dataset=LoTTE.LIFESTYLE, datasplit="test")]
INDEX_CONFIGS = [XTRScaNNIndexConfig(), XTRFAISSIndexConfig()]
TOKEN_TOP_K_VALUES = [1_000, 40_000]
xtr_run_configurations(datasets=DATASETS, index_configs=INDEX_CONFIGS,
                       document_top_k=100, token_top_k_values=TOKEN_TOP_K_VALUES)

#> Loading collection from /lfs/1/scheerer/datasets/lotte/lotte/lifestyle/test/collection.tsv ...
0M 
#> Loading the queries from /lfs/1/scheerer/datasets/lotte/lotte/lifestyle/test/questions.search.tsv ...
#> Got 661 queries. All QIDs are unique.
Loading existing index from /future/u/scheerer/home/data/xtr-eval/indexes/LoTTE.LIFESTYLE.search.split=test.XTRIndexType.SCANN.


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 661/661 [02:21<00:00,  4.67it/s]


Loading existing index from /future/u/scheerer/home/data/xtr-eval/indexes/LoTTE.LIFESTYLE.search.split=test.XTRIndexType.SCANN.


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 661/661 [21:24<00:00,  1.94s/it]


Loading existing index from /future/u/scheerer/home/data/xtr-eval/indexes/LoTTE.LIFESTYLE.search.split=test.XTRIndexType.FAISS.


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 661/661 [12:01<00:00,  1.09s/it]


Loading existing index from /future/u/scheerer/home/data/xtr-eval/indexes/LoTTE.LIFESTYLE.search.split=test.XTRIndexType.FAISS.


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 661/661 [27:03<00:00,  2.46s/it]
