In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OMP_WAIT_POLICY"] = "PASSIVE"

import torch
torch.set_num_threads(1)

import colbert.warp.setup

from colbert.warp.config import WARPRunConfig
from colbert.warp.searcher import WARPSearcher
from colbert.warp.data.queries import WARPQueries

DATASETS = ["nfcorpus", "scifact", "scidocs", "fiqa", "webis-touche2020", "quora"]

def config_for_dataset(dataset):
    assert dataset in DATASETS
    # Configure WARP to use specified dataset & the unquantized model
    optim = None
    return WARPRunConfig(
        nranks=4,
        dataset="beir",
        collection=dataset,
        datasplit="test",
        nbits=4,
        optim=optim,
    )

def prepare_beir_dataset(dataset):
    config = config_for_dataset(dataset=dataset)
    queries = WARPQueries(config)
    searcher = WARPSearcher(config)
    return queries, searcher

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


#> Running WARP Setup Code.


In [2]:
# Source: https://arxiv.org/pdf/2304.01982 (Page 17, Table D.1)
# Recall@100 on a subset of the BEIR benchmarks.
xtr_paper_stats = {
    "One Retriever per Domain": {
        "GenQ": {
            "NF": 28.0, "SF": 89.3, "SD": 33.2, "FQ": 61.8, "TO": 45.1, "QU": 98.9
        },
        "PTR_retriever": {
            "NF": 30.6, "SF": 91.8, "SD": 41.6, "FQ": 76.5, "TO": 47.5, "QU": 99.6
        }
    },
    "One Retriever for All": {
        "BM25": {
            "NF": 25.0, "SF": 90.8, "SD": 35.6, "FQ": 53.9, "TO": 53.8, "QU": 97.3
        },
        "ColBERT": {
            "NF": 25.4, "SF": 87.8, "SD": 34.4, "FQ": 60.3, "TO": 43.9, "QU": 98.9
        },
        "GTR_base": {
            "NF": 27.5, "SF": 87.2, "SD": 34.0, "FQ": 67.0, "TO": 44.3, "QU": 99.6
        },
        "T5-ColBERT_base": {
            "NF": 27.6, "SF": 91.3, "SD": 34.2, "FQ": 63.0, "TO": 49.9, "QU": 97.9
        },
        "XTR_base": {
            "NF": 28.0, "SF": 90.5, "SD": 34.8, "FQ": 63.5, "TO": 50.8, "QU": 98.9
        },
        "GTR_xxl": {
            "NF": 30.0, "SF": 90.0, "SD": 36.6, "FQ": 78.0, "TO": 46.6, "QU": 99.7
        },
        "T5-ColBERT_xxl": {
            "NF": 29.0, "SF": 94.6, "SD": 38.5, "FQ": 72.5, "TO": 50.1, "QU": 99.1
        },
        "XTR_xxl": {
            "NF": 30.7, "SF": 95.0, "SD": 39.4, "FQ": 73.0, "TO": 52.7, "QU": 99.3
        }
    }
}
xxl_models = ["GTR_xxl", "T5-ColBERT_xxl", "XTR_xxl"]

dataset_xtr_map = {"nfcorpus": "NF", "scifact": "SF", "scidocs": "SD", "fiqa": "FQ", "webis-touche2020": "TO", "quora": "QU"}
dataset_to_name = {"nfcorpus": "NFCORPUS", "scifact": "SciFact", "scidocs": "SCIDOCS", "fiqa": "FiQA-2018",
                   "webis-touche2020": "Touche-2020", "quora": "Quora"}

In [3]:
from tqdm import tqdm

metrics = dict()
for dataset in tqdm(DATASETS):
    queries, searcher = prepare_beir_dataset(dataset=dataset)
    rankings = searcher.search_all(queries, k=100, batched=False, show_progress=False)
    metrics[dataset] = rankings.evaluate(queries.qrels, k=100)

  0%|                                                                                                                                           | 0/6 [00:00<?, ?it/s]

[Sep 01, 01:35:29] #> Loading the queries from /lfs/1/scheerer/datasets/beir/datasets/nfcorpus/questions.test.tsv ...
[Sep 01, 01:35:29] #> Got 323 queries. All QIDs are unique.

[Sep 01, 01:35:29] #> Loading collection...
0M 




[Sep 01, 01:35:31] #> Loading buckets...
[Sep 01, 01:35:31] #> Loading codec...
[Sep 01, 01:35:31] Loading segmented_lookup_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Sep 01, 01:35:31] #> Loading repacked residuals...
[Sep 01, 01:35:31] Loading precompute_topk_centroids_cpp extension (set WARP_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Sep 01, 01:35:31] Loading decompress_centroid_embeds_strided_repacked_cpp extension (set WARP_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Sep 01, 01:35:32] Loading compute_candidate_scores_cpp extension (set WARP_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
nprobe 12 t_prime 3000
#> Loading collection_map found in /lfs/1/scheerer/datasets/beir/datasets/nfcorpus/collection.tsv


  0%|          | 0/3633 [00:00<?, ?it/s]

 17%|█████████████████████▊                                                                                                             | 1/6 [00:38<03:10, 38.08s/it]

[Sep 01, 01:36:07] #> Loading the queries from /lfs/1/scheerer/datasets/beir/datasets/scifact/questions.test.tsv ...
[Sep 01, 01:36:07] #> Got 300 queries. All QIDs are unique.

[Sep 01, 01:36:07] #> Loading collection...
0M 




[Sep 01, 01:36:08] #> Loading buckets...
[Sep 01, 01:36:08] #> Loading codec...
[Sep 01, 01:36:08] #> Loading repacked residuals...
nprobe 12 t_prime 5000
#> Loading collection_map found in /lfs/1/scheerer/datasets/beir/datasets/scifact/collection.tsv


  0%|          | 0/5183 [00:00<?, ?it/s]

 33%|███████████████████████████████████████████▋                                                                                       | 2/6 [01:14<02:29, 37.26s/it]

[Sep 01, 01:36:43] #> Loading the queries from /lfs/1/scheerer/datasets/beir/datasets/scidocs/questions.test.tsv ...
[Sep 01, 01:36:44] #> Got 1000 queries. All QIDs are unique.

[Sep 01, 01:36:44] #> Loading collection...
0M 
[Sep 01, 01:36:45] #> Loading buckets...
[Sep 01, 01:36:45] #> Loading codec...
[Sep 01, 01:36:45] #> Loading repacked residuals...
nprobe 12 t_prime 8000
#> Loading collection_map found in /lfs/1/scheerer/datasets/beir/datasets/scidocs/collection.tsv


  0%|          | 0/25657 [00:00<?, ?it/s]

 50%|█████████████████████████████████████████████████████████████████▌                                                                 | 3/6 [03:13<03:42, 74.24s/it]

[Sep 01, 01:38:42] #> Loading the queries from /lfs/1/scheerer/datasets/beir/datasets/fiqa/questions.test.tsv ...
[Sep 01, 01:38:42] #> Got 648 queries. All QIDs are unique.

[Sep 01, 01:38:42] #> Loading collection...
0M 
[Sep 01, 01:38:44] #> Loading buckets...
[Sep 01, 01:38:44] #> Loading codec...
[Sep 01, 01:38:45] #> Loading repacked residuals...
nprobe 12 t_prime 10000
#> Loading collection_map found in /lfs/1/scheerer/datasets/beir/datasets/fiqa/collection.tsv


  0%|          | 0/57638 [00:00<?, ?it/s]

 67%|███████████████████████████████████████████████████████████████████████████████████████▎                                           | 4/6 [04:32<02:32, 76.16s/it]

[Sep 01, 01:40:01] #> Loading the queries from /lfs/1/scheerer/datasets/beir/datasets/webis-touche2020/questions.test.tsv ...
[Sep 01, 01:40:01] #> Got 49 queries. All QIDs are unique.

[Sep 01, 01:40:01] #> Loading collection...
0M 
[Sep 01, 01:40:05] #> Loading buckets...
[Sep 01, 01:40:05] #> Loading codec...
[Sep 01, 01:40:08] #> Loading repacked residuals...
nprobe 18 t_prime 50000
#> Loading collection_map found in /lfs/1/scheerer/datasets/beir/datasets/webis-touche2020/collection.tsv


  0%|          | 0/382545 [00:00<?, ?it/s]

 83%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                     | 5/6 [04:56<00:57, 57.31s/it]

[Sep 01, 01:40:25] #> Loading the queries from /lfs/1/scheerer/datasets/beir/datasets/quora/questions.test.tsv ...
[Sep 01, 01:40:25] #> Got 10000 queries. All QIDs are unique.

[Sep 01, 01:40:25] #> Loading collection...
0M 
[Sep 01, 01:40:27] #> Loading buckets...
[Sep 01, 01:40:27] #> Loading codec...
[Sep 01, 01:40:27] #> Loading repacked residuals...
nprobe 12 t_prime 10000
#> Loading collection_map found in /lfs/1/scheerer/datasets/beir/datasets/quora/collection.tsv


  0%|          | 0/522931 [00:00<?, ?it/s]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [25:31<00:00, 255.31s/it]


In [20]:
from tabulate import tabulate, SEPARATING_LINE

headers = [""] + [dataset_to_name[dataset] for dataset in DATASETS] + ["Avg."]

def append_model_metrics(data, model_names):
    for model_name in model_names:
        results = xtr_paper_stats["One Retriever for All"][model_name]
        scores = [results[dataset_xtr_map[dataset]] for dataset in DATASETS]
        avg = round(sum(scores) / len(DATASETS), 1)
        data.append([model_name] + scores + [avg])

data = []
append_model_metrics(data, [x for x in xtr_paper_stats["One Retriever for All"].keys() if x not in xxl_models])
warp_scores = [round(metrics[dataset]["recall@100"] * 100, 1) for dataset in DATASETS]
warp_avg = round(sum(warp_scores) / len(DATASETS), 1)
data.append(
    ["XTR_base / WARP"] + warp_scores + [warp_avg]
)
data.append([SEPARATING_LINE])
append_model_metrics(data, xxl_models)

print(tabulate(data, headers, floatfmt=".1f"))

                   NFCORPUS    SciFact    SCIDOCS    FiQA-2018    Touche-2020    Quora    Avg.
---------------  ----------  ---------  ---------  -----------  -------------  -------  ------
BM25                   25.0       90.8       35.6         53.9           53.8     97.3    59.4
ColBERT                25.4       87.8       34.4         60.3           43.9     98.9    58.4
GTR_base               27.5       87.2       34.0         67.0           44.3     99.6    59.9
T5-ColBERT_base        27.6       91.3       34.2         63.0           49.9     97.9    60.6
XTR_base               28.0       90.5       34.8         63.5           50.8     98.9    61.1
XTR_base / WARP        28.3       92.4       36.4         60.8           51.1     98.7    61.3
---------------  ----------  ---------  ---------  -----------  -------------  -------  ------
GTR_xxl                30.0       90.0       36.6         78.0           46.6     99.7    63.5
T5-ColBERT_xxl         29.0       94.6       38.5 