In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OMP_WAIT_POLICY"] = "PASSIVE"

import torch
torch.set_num_threads(1)

import colbert.warp.setup

from colbert.warp.config import WARPRunConfig
from colbert.warp.searcher import WARPSearcher
from colbert.warp.data.queries import WARPQueries

DATASETS = ["lifestyle", "writing", "recreation", "technology", "science", "pooled"]

def config_for_dataset(dataset):
    assert dataset in DATASETS
    # Configure WARP to use specified dataset & the unquantized model
    optim = None
    return WARPRunConfig(
        nranks=4,
        dataset="lotte",
        collection=dataset,
        type_="search",
        datasplit="test",
        nbits=4,
        optim=optim,
    )

def prepare_lotte_dataset(dataset):
    config = config_for_dataset(dataset=dataset)
    queries = WARPQueries(config)
    searcher = WARPSearcher(config)
    return queries, searcher

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


#> Running WARP Setup Code.


In [18]:
# Source: https://arxiv.org/pdf/2304.01982 (Page 6, Table 2)
# Success@5 on the LoTTE Search benchmarks.

xtr_paper_stats = {
    "BM25": {
        "Writing": 60.3, "Recreation": 56.5, "Science": 32.7, "Technology": 41.8, "Lifestyle": 63.8, "Pooled": 48.3
    },
    "ColBERT": {
        "Writing": 74.7, "Recreation": 68.5, "Science": 53.6, "Technology": 61.9, "Lifestyle": 80.2, "Pooled": 67.3
    },
    "GTR_base": {
        "Writing": 74.1, "Recreation": 65.7, "Science": 49.8, "Technology": 58.1, "Lifestyle": 82.0, "Pooled": 65.0
    },
    "XTR_base": {
        "Writing": 77.0, "Recreation": 69.4, "Science": 54.9, "Technology": 63.2, "Lifestyle": 82.1, "Pooled": 69.0
    },
    "Splade_v2": {
        "Writing": 77.1, "Recreation": 69.0, "Science": 55.4, "Technology": 62.4, "Lifestyle": 82.3, "Pooled": 68.9
    },
    "ColBERT_v2": {
        "Writing": 80.1, "Recreation": 72.3, "Science": 56.7, "Technology": 66.1, "Lifestyle": 84.7, "Pooled": 71.6
    },
    "GTR_xxl": {
        "Writing": 83.9, "Recreation": 78.0, "Science": 60.0, "Technology": 69.5, "Lifestyle": 87.4, "Pooled": 76.0
    },
    "XTR_xxl": {
        "Writing": 83.3, "Recreation": 79.3, "Science": 60.8, "Technology": 73.7, "Lifestyle": 89.1, "Pooled": 77.3
    }
}

cross_encoder_hard_neg_models = ["Splade_v2", "ColBERT_v2"]
xxl_models = ["GTR_xxl", "XTR_xxl"]

In [3]:
from tqdm import tqdm

metrics = dict()
for dataset in tqdm(DATASETS):
    queries, searcher = prepare_lotte_dataset(dataset=dataset)
    rankings = searcher.search_all(queries, k=5, batched=False, show_progress=False)
    metrics[dataset] = rankings.evaluate(queries.qrels, k=5)

  0%|                                                                                                                                           | 0/6 [00:00<?, ?it/s]

[Sep 01, 01:45:36] #> Loading the queries from /lfs/1/scheerer/datasets/lotte/lotte/lifestyle/test/questions.search.tsv ...
[Sep 01, 01:45:36] #> Got 661 queries. All QIDs are unique.

[Sep 01, 01:45:36] #> Loading collection...
0M 




[Sep 01, 01:45:39] #> Loading buckets...
[Sep 01, 01:45:39] #> Loading codec...
[Sep 01, 01:45:39] Loading segmented_lookup_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Sep 01, 01:45:40] #> Loading repacked residuals...
[Sep 01, 01:45:40] Loading precompute_topk_centroids_cpp extension (set WARP_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Sep 01, 01:45:41] Loading decompress_centroid_embeds_strided_repacked_cpp extension (set WARP_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Sep 01, 01:45:41] Loading compute_candidate_scores_cpp extension (set WARP_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
nprobe 12 t_prime 10000


 17%|█████████████████████▊                                                                                                             | 1/6 [01:25<07:05, 85.16s/it]

[Sep 01, 01:47:01] #> Loading the queries from /lfs/1/scheerer/datasets/lotte/lotte/writing/test/questions.search.tsv ...
[Sep 01, 01:47:01] #> Got 1071 queries. All QIDs are unique.

[Sep 01, 01:47:01] #> Loading collection...
0M 




[Sep 01, 01:47:04] #> Loading buckets...
[Sep 01, 01:47:04] #> Loading codec...
[Sep 01, 01:47:05] #> Loading repacked residuals...
nprobe 12 t_prime 10000


 33%|███████████████████████████████████████████▎                                                                                      | 2/6 [03:37<07:32, 113.20s/it]

[Sep 01, 01:49:14] #> Loading the queries from /lfs/1/scheerer/datasets/lotte/lotte/recreation/test/questions.search.tsv ...
[Sep 01, 01:49:14] #> Got 924 queries. All QIDs are unique.

[Sep 01, 01:49:14] #> Loading collection...
0M 




[Sep 01, 01:49:16] #> Loading buckets...
[Sep 01, 01:49:16] #> Loading codec...
[Sep 01, 01:49:17] #> Loading repacked residuals...
nprobe 12 t_prime 10000


 50%|█████████████████████████████████████████████████████████████████                                                                 | 3/6 [05:36<05:46, 115.51s/it]

[Sep 01, 01:51:12] #> Loading the queries from /lfs/1/scheerer/datasets/lotte/lotte/technology/test/questions.search.tsv ...
[Sep 01, 01:51:12] #> Got 596 queries. All QIDs are unique.

[Sep 01, 01:51:12] #> Loading collection...
0M 




[Sep 01, 01:51:17] #> Loading buckets...
[Sep 01, 01:51:17] #> Loading codec...
[Sep 01, 01:51:21] #> Loading repacked residuals...
nprobe 18 t_prime 50000


 67%|██████████████████████████████████████████████████████████████████████████████████████▋                                           | 4/6 [07:38<03:56, 118.19s/it]

[Sep 01, 01:53:15] #> Loading the queries from /lfs/1/scheerer/datasets/lotte/lotte/science/test/questions.search.tsv ...
[Sep 01, 01:53:15] #> Got 617 queries. All QIDs are unique.

[Sep 01, 01:53:15] #> Loading collection...
0M 1M 




[Sep 01, 01:53:25] #> Loading buckets...
[Sep 01, 01:53:25] #> Loading codec...
[Sep 01, 01:53:41] #> Loading repacked residuals...
nprobe 24 t_prime 50000


 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                     | 5/6 [11:04<02:29, 149.99s/it]

[Sep 01, 01:56:41] #> Loading the queries from /lfs/1/scheerer/datasets/lotte/lotte/pooled/test/questions.search.tsv ...
[Sep 01, 01:56:41] #> Got 3869 queries. All QIDs are unique.

[Sep 01, 01:56:41] #> Loading collection...
0M 1M 2M 




[Sep 01, 01:56:56] #> Loading buckets...
[Sep 01, 01:56:56] #> Loading codec...
[Sep 01, 01:57:19] #> Loading repacked residuals...
nprobe 24 t_prime 100000


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [33:12<00:00, 332.06s/it]


In [5]:
metrics

{'lifestyle': {'provenance': {'query_type': 'search', 'dataset': 'lifestyle'},
  'metrics': {'Success@5': 0.8229954614220878}},
 'writing': {'provenance': {'query_type': 'search', 'dataset': 'writing'},
  'metrics': {'Success@5': 0.7684407096171803}},
 'recreation': {'provenance': {'query_type': 'search',
   'dataset': 'recreation'},
  'metrics': {'Success@5': 0.6904761904761905}},
 'technology': {'provenance': {'query_type': 'search',
   'dataset': 'technology'},
  'metrics': {'Success@5': 0.6409395973154363}},
 'science': {'provenance': {'query_type': 'search', 'dataset': 'science'},
  'metrics': {'Success@5': 0.5607779578606159}},
 'pooled': {'provenance': {'query_type': 'search', 'dataset': 'pooled'},
  'metrics': {'Success@5': 0.6805376066166968}}}

In [7]:
for dataset, results in metrics.items():
    result_success5 = round(results["metrics"]["Success@5"] * 100, 1)
    print(dataset, result_success5)

lifestyle 82.3
writing 76.8
recreation 69.0
technology 64.1
science 56.1
pooled 68.1


In [28]:
from tabulate import tabulate, SEPARATING_LINE

headers = [""] + [dataset.title() for dataset in DATASETS] + ["Avg."]

def append_model_metrics(data, model_names):
    for model_name in model_names:
        results = xtr_paper_stats[model_name]
        scores = [results[dataset.title()] for dataset in DATASETS]
        avg = round(sum(scores) / len(DATASETS), 1)
        data.append([model_name] + scores + [avg])

data = []
append_model_metrics(data, [x for x in xtr_paper_stats.keys() if x not in cross_encoder_hard_neg_models + xxl_models])
warp_scores = [round(metrics[dataset]["metrics"]["Success@5"] * 100, 1) for dataset in DATASETS]
warp_avg = round(sum(warp_scores) / len(DATASETS), 1)
data.append(
    ["XTR_base / WARP"] + warp_scores + [warp_avg]
)
data.append([SEPARATING_LINE])
append_model_metrics(data, cross_encoder_hard_neg_models)
data.append([SEPARATING_LINE])
append_model_metrics(data, xxl_models)

print(tabulate(data, headers, floatfmt=".1f"))

                   Lifestyle    Writing    Recreation    Technology    Science    Pooled    Avg.
---------------  -----------  ---------  ------------  ------------  ---------  --------  ------
BM25                    63.8       60.3          56.5          41.8       32.7      48.3    50.6
ColBERT                 80.2       74.7          68.5          61.9       53.6      67.3    67.7
GTR_base                82.0       74.1          65.7          58.1       49.8      65.0    65.8
XTR_base                82.1       77.0          69.4          63.2       54.9      69.0    69.3
XTR_base / WARP         82.3       76.8          69.0          64.1       56.1      68.1    69.4
---------------  -----------  ---------  ------------  ------------  ---------  --------  ------
Splade_v2               82.3       77.1          69.0          62.4       55.4      68.9    69.2
ColBERT_v2              84.7       80.1          72.3          66.1       56.7      71.6    71.9
---------------  -----------  