In [1]:
import faiss
import pyterrier as pt
import ujson
import numpy as np
import pandas as pd

import itertools
import threading
import queue

from colbert.modeling.inference import ModelInference
from colbert.evaluation.loaders import load_colbert
from pyterrier_colbert import load_checkpoint
# monkeypatch to use our downloading version
import colbert.evaluation.loaders

colbert.evaluation.loaders.load_checkpoint = load_checkpoint
colbert.evaluation.loaders.load_model.__globals__['load_checkpoint'] = load_checkpoint
from colbert.utils.utils import print_message
import pickle
from colbert.indexing.index_manager import IndexManager
from warnings import warn
from colbert.modeling import colbert as CBERT

In [2]:
pt.init()

PyTerrier 0.10.0 has loaded Terrier 5.8 (built by craigm on 2023-11-01 18:05) and terrier-helper 0.0.8

No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.


In [3]:
from pyterrier_colbert.preprocessing import DatasetPreprocessor, TokenRemover, HFTokenizer, NLTKTokenizer
from transformers import AutoTokenizer, AutoModelForMaskedLM
import ir_datasets, ir_measures

In [4]:
def load_colbert(args):
    print_message("#> Loading model checkpoint.")
    colbert = CBERT.ColBERT.from_pretrained('bert-base-uncased',
                                      query_maxlen=args.query_maxlen,
                                      doc_maxlen=args.doc_maxlen,
                                      dim=args.dim,
                                      similarity_metric=args.similarity, mask_punctuation=args.mask_punctuation)
    DEVICE = 'cuda:0' if faiss.get_num_gpus() > 0 else 'cpu'
    colbert = colbert.to(DEVICE)
    checkpoint = load_checkpoint(args.checkpoint, colbert)
    colbert.eval()

    print('\n')

    return colbert, checkpoint

In [5]:
class Object():
    pass

In [6]:
checkpoint="http://www.dcs.gla.ac.uk/~craigm/colbert.dnn.zip"

In [7]:
args = Object()
args.similarity = 'cosine'
args.dim = 128
args.query_maxlen = 32
args.doc_maxlen = 180
args.checkpoint = checkpoint
args.mask_punctuation = False

In [8]:
wordpiece = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
hf_tokenizer = HFTokenizer(tokenizer=wordpiece)
nltk_tokenizer = NLTKTokenizer(tokenizer_type='treebank')

In [9]:
colbert, model_checkpoint = load_colbert(args)

[Mar 11, 14:22:22] #> Loading model checkpoint.


Some weights of ColBERT were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['linear.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Mar 11, 14:22:28] #> Loading checkpoint http://www.dcs.gla.ac.uk/~craigm/colbert.dnn.zip




[Mar 11, 14:22:42] #> checkpoint['epoch'] = 0
[Mar 11, 14:22:42] #> checkpoint['batch'] = 44500




In [10]:
from ir_datasets import create_dataset

In [11]:
dataset = create_dataset(docs_tsv='minimarco/msmarco-passage-trec-dl-2019-docs.tsv', queries_tsv='minimarco/msmarco-passage-trec-dl-2019-queries.tsv', qrels_trec='minimarco/msmarco-passage-trec-dl-2019-qrels.tsv')

In [12]:
irds_dataset = pt.datasets.IRDSDataset(irds_id='irds:minimarco', defer_load=True)
irds_dataset._irds_ref = dataset

In [13]:
en_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt'])
en_punc_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt', 'stopwords/stopwords-punctuations.txt'])

en2_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt'], stopword_max_length=2)
en2_punc_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt', 'stopwords/stopwords-punctuations.txt'], stopword_max_length=2)

en4_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt'], stopword_max_length=4)
en4_punc_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt', 'stopwords/stopwords-punctuations.txt'], stopword_max_length=4)


lim_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-limited.txt'])
lim_punc_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-limited.txt', 'stopwords/stopwords-punctuations.txt'])

en2the_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt'], stopword_max_length=2)
en2the_remover.stopwords = en2the_remover.stopwords | set(['the'])

en2the_punc_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-limited.txt', 'stopwords/stopwords-punctuations.txt'], stopword_max_length=2)
en2the_punc_remover.stopwords = en2the_punc_remover.stopwords | set(['the'])

punc_only = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-punctuations.txt'])

In [14]:
cleaners = [
    ('en', en_remover),
    ('enpunc', en_punc_remover),
    ('en2', en2_remover),
    ('en2punc', en2_punc_remover),
    ('en2the', en2the_remover),
    ('en2thepunc', en2the_punc_remover),
    ('punc', punc_only),
    #('en4', en4_remover) ,
    #('en4punc', en4_punc_remover),
    #('few', lim_remover),
    #('fewpunc', lim_punc_remover)
           ]

In [15]:
datasets_cleaned = {name: DatasetPreprocessor(dataset=irds_dataset, tokenizer=nltk_tokenizer, preprocessor=cleaner) for name, cleaner in cleaners}

In [16]:
import pyterrier_colbert.indexing
import torch
import os
from pyterrier_colbert.ranking import ColBERTFactory

In [17]:
retrievers = {}
pyterrier_colbert_factory = ColBERTFactory((colbert, model_checkpoint), "./indexes/", "index.base.minimarco", faisstype='mmap')
colbert_e2e = pyterrier_colbert_factory.end_to_end()
retrievers['base'] = colbert_e2e

[Mar 11, 14:22:43] #> Loading the FAISS index from ./indexes/index.base.minimarco/ivfpq.256.faiss ..
[Mar 11, 14:22:43] #> Building the emb2pid mapping..
[Mar 11, 14:22:43] len(self.emb2pid) = 716547
Loading reranking index, memtype=mem


Loading index shards to memory: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.61shard/s]


In [18]:
cleaner_names = [t[0] for t in cleaners]

In [19]:
for name in cleaner_names:
    factory = ColBERTFactory((colbert, model_checkpoint), f'./indexes', f'index.clean.{name}.minimarco', faisstype='mmap')
    rete2e = factory.end_to_end()
    retrievers[name] = rete2e

[Mar 11, 14:22:45] #> Loading the FAISS index from ./indexes/index.clean.en.minimarco/ivfpq.256.faiss ..
[Mar 11, 14:22:45] #> Building the emb2pid mapping..
[Mar 11, 14:22:45] len(self.emb2pid) = 482356
Loading reranking index, memtype=mem


Loading index shards to memory: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 15.15shard/s]


[Mar 11, 14:22:45] #> Loading the FAISS index from ./indexes/index.clean.enpunc.minimarco/ivfpq.256.faiss ..
[Mar 11, 14:22:45] #> Building the emb2pid mapping..
[Mar 11, 14:22:46] len(self.emb2pid) = 426165
Loading reranking index, memtype=mem


Loading index shards to memory: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.99shard/s]


[Mar 11, 14:22:46] #> Loading the FAISS index from ./indexes/index.clean.en2.minimarco/ivfpq.256.faiss ..
[Mar 11, 14:22:47] #> Building the emb2pid mapping..
[Mar 11, 14:22:47] len(self.emb2pid) = 629420
Loading reranking index, memtype=mem


Loading index shards to memory: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.56shard/s]


[Mar 11, 14:22:47] #> Loading the FAISS index from ./indexes/index.clean.en2punc.minimarco/ivfpq.256.faiss ..
[Mar 11, 14:22:47] #> Building the emb2pid mapping..
[Mar 11, 14:22:47] len(self.emb2pid) = 575290
Loading reranking index, memtype=mem


Loading index shards to memory: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.55shard/s]


[Mar 11, 14:22:48] #> Loading the FAISS index from ./indexes/index.clean.en2the.minimarco/ivfpq.256.faiss ..
[Mar 11, 14:22:48] #> Building the emb2pid mapping..
[Mar 11, 14:22:49] len(self.emb2pid) = 603554
Loading reranking index, memtype=mem


Loading index shards to memory: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.89shard/s]


[Mar 11, 14:22:49] #> Loading the FAISS index from ./indexes/index.clean.en2thepunc.minimarco/ivfpq.256.faiss ..
[Mar 11, 14:22:49] #> Building the emb2pid mapping..
[Mar 11, 14:22:49] len(self.emb2pid) = 562565
Loading reranking index, memtype=mem


Loading index shards to memory: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  5.83shard/s]


[Mar 11, 14:22:50] #> Loading the FAISS index from ./indexes/index.clean.punc.minimarco/ivfpq.256.faiss ..
[Mar 11, 14:22:50] #> Building the emb2pid mapping..
[Mar 11, 14:22:50] len(self.emb2pid) = 660872
Loading reranking index, memtype=mem


Loading index shards to memory: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.86shard/s]


In [20]:
pt.Experiment(
    list(retrievers.values()),
    irds_dataset.get_topics(),
    irds_dataset.get_qrels(),
    eval_metrics=["ndcg_cut_10", 'mrt'],
    names = list(retrievers.keys()),
    baseline=0,
    highlight='bold',
    round=2
)

Unnamed: 0,name,ndcg_cut_10,mrt,ndcg_cut_10 +,ndcg_cut_10 -,ndcg_cut_10 p-value
0,base,0.73,265.67,,,
1,en,0.7,218.49,17.0,22.0,0.03749
2,enpunc,0.69,195.12,13.0,28.0,0.013167
3,en2,0.73,189.44,24.0,14.0,0.724551
4,en2punc,0.71,260.02,18.0,20.0,0.062662
5,en2the,0.73,266.49,23.0,15.0,0.678887
6,en2thepunc,0.71,265.0,19.0,20.0,0.113418
7,punc,0.72,261.3,16.0,23.0,0.046621
