In [1]:
import os
import random
import time
import datetime
import sys
import torch
import torch.nn as nn
import numpy as np
import mlflow
import ujson
from itertools import accumulate
from math import ceil
import traceback
import string

from contextlib import contextmanager

from packaging import version

import copy
import faiss

import pandas as pd

from argparse import ArgumentParser

from collections import defaultdict, OrderedDict

from transformers import BertTokenizerFast, BertPreTrainedModel, BertModel, BertConfig, AutoConfig, AutoTokenizer

In [3]:
!python --version
print("Torch version:{}".format(torch.__version__))
print("cuda version: {}".format(torch.version.cuda))
torch.cuda.is_available()

SEED = 12345

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)

In [5]:
# colbert/modeling/colbert.py
'''
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
from colbert.parameters import DEVICE
'''
class ColBERT(BertPreTrainedModel):
    def __init__(self, config, query_maxlen, doc_maxlen, mask_punctuation, dim=128, similarity_metric='cosine'):

        super(ColBERT, self).__init__(config)

        self.query_maxlen = query_maxlen
        self.doc_maxlen = doc_maxlen
        self.similarity_metric = similarity_metric
        self.dim = dim

        self.mask_punctuation = mask_punctuation
        self.skiplist = {}

        if self.mask_punctuation:
            self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
            self.skiplist = {w: True
                             for symbol in string.punctuation
                             for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]}

        self.bert = BertModel(config)
        self.linear = nn.Linear(config.hidden_size, dim, bias=False)

        self.init_weights()

    def forward(self, Q, D):
        return self.score(self.query(*Q), self.doc(*D))

    def query(self, input_ids, attention_mask):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        Q = self.bert(input_ids, attention_mask=attention_mask)[0]
        Q = self.linear(Q)

        return torch.nn.functional.normalize(Q, p=2, dim=2)

    def doc(self, input_ids, attention_mask, keep_dims=True):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        D = self.bert(input_ids, attention_mask=attention_mask)[0]
        D = self.linear(D)

        mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float()
        D = D * mask

        D = torch.nn.functional.normalize(D, p=2, dim=2)

        if not keep_dims:
            D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
            D = [d[mask[idx]] for idx, d in enumerate(D)]

        return D

    def score(self, Q, D):
        if self.similarity_metric == 'cosine':
            return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)

        assert self.similarity_metric == 'l2'
        return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)

    def mask(self, input_ids):
        mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
        return mask

In [6]:
# colbert/modeling/tokenization/utils.py
def _sort_by_length(ids, mask, bsize):
    if ids.size(0) <= bsize:
        return ids, mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices

    return ids[indices], mask[indices], reverse_indices


def _split_into_batches(ids, mask, bsize):
    batches = []
    for offset in range(0, ids.size(0), bsize):
        batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))

    return batches

In [7]:
# colbert/modeling/tokenization/query_tokenization.py
'''
from transformers import BertTokenizerFast
from colbert.modeling.tokenization.utils import _split_into_batches
'''

class QueryTokenizer():
    def __init__(self, query_maxlen):
        self.tok = BertTokenizerFast.from_pretrained('bert-base-uncased')
        self.query_maxlen = query_maxlen

        self.Q_marker_token, self.Q_marker_token_id = '[Q]', self.tok.convert_tokens_to_ids('[unused0]')
        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
        self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id

        assert self.Q_marker_token_id == 1 and self.mask_token_id == 103

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token, self.Q_marker_token], [self.sep_token]
        tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']
        print(ids)

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id, self.Q_marker_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        # add placehold for the [Q] marker
        batch_text = ['. ' + x for x in batch_text]

        obj = self.tok(batch_text, padding='max_length', truncation=True,
                       return_tensors='pt', max_length=self.query_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        # postprocess for the [Q] marker and the [MASK] augmentation
        ids[:, 1] = self.Q_marker_token_id
        ids[ids == 0] = self.mask_token_id

        if bsize:
            batches = _split_into_batches(ids, mask, bsize)
            return batches

        return ids, mask

In [8]:
# colbert/modeling/tokenization/doc_tokenization.py 

'''
from transformers import BertTokenizerFast
from colbert.modeling.tokenization.utils import _split_into_batches, _sort_by_length
'''

class DocTokenizer():
    def __init__(self, doc_maxlen):
        self.tok = BertTokenizerFast.from_pretrained('bert-base-uncased')
        self.doc_maxlen = doc_maxlen

        self.D_marker_token, self.D_marker_token_id = '[D]', self.tok.convert_tokens_to_ids('[unused1]')
        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id

        assert self.D_marker_token_id == 2

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token]
        tokens = [prefix + lst + suffix for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        # add placehold for the [D] marker
        batch_text = ['. ' + x for x in batch_text]

        obj = self.tok(batch_text, padding='longest', truncation='longest_first',
                       return_tensors='pt', max_length=self.doc_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        # postprocess for the [D] marker
        ids[:, 1] = self.D_marker_token_id

        if bsize:
            ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
            batches = _split_into_batches(ids, mask, bsize)
            return batches, reverse_indices

        return ids, mask

In [9]:
# colbert/modeling/inference.py
'''
from colbert.modeling.colbert import ColBERT--
from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer--
from colbert.utils.amp import MixedPrecisionManager--
from colbert.parameters import DEVICE--
'''
class ModelInference():
    def __init__(self, colbert: ColBERT, amp=False):
        assert colbert.training is False

        self.colbert = colbert
        self.query_tokenizer = QueryTokenizer(colbert.query_maxlen)
        self.doc_tokenizer = DocTokenizer(colbert.doc_maxlen)

        self.amp_manager = MixedPrecisionManager(amp)

    def query(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            with self.amp_manager.context():
                Q = self.colbert.query(*args, **kw_args)
                return Q.cpu() if to_cpu else Q

    def doc(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            with self.amp_manager.context():
                D = self.colbert.doc(*args, **kw_args)
                return D.cpu() if to_cpu else D

    def queryFromText(self, queries, bsize=None, to_cpu=False):
        if bsize:
            batches = self.query_tokenizer.tensorize(queries, bsize=bsize)
            batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
            return torch.cat(batches)

        input_ids, attention_mask = self.query_tokenizer.tensorize(queries)
        return self.query(input_ids, attention_mask)

    def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False):
        if bsize:
            batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)

            batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
                       for input_ids, attention_mask in batches]

            if keep_dims:
                D = _stack_3D_tensors(batches)
                return D[reverse_indices]

            D = [d for batch in batches for d in batch]
            return [D[idx] for idx in reverse_indices.tolist()]

        input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
        return self.doc(input_ids, attention_mask, keep_dims=keep_dims)

    def score(self, Q, D, mask=None, lengths=None, explain=False):
        if lengths is not None:
            assert mask is None, "don't supply both mask and lengths"

            mask = torch.arange(D.size(1), device=DEVICE) + 1
            mask = mask.unsqueeze(0) <= lengths.to(DEVICE).unsqueeze(-1)

        scores = (D @ Q)
        scores = scores if mask is None else scores * mask.unsqueeze(-1)
        scores = scores.max(1)

        if explain:
            assert False, "TODO"

        return scores.values.sum(-1).cpu()


def _stack_3D_tensors(groups):
    bsize = sum([x.size(0) for x in groups])
    maxlen = max([x.size(1) for x in groups])
    hdim = groups[0].size(2)

    output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)

    offset = 0
    for x in groups:
        endpos = offset + x.size(0)
        output[offset:endpos, :x.size(1)] = x
        offset = endpos

    return output

In [10]:
# colbert/evaluation/loaders.py
def load_queries(queries_path):
    queries = OrderedDict()

    print_message("#> Loading the queries from", queries_path, "...")
    
    file = pd.read_csv(queries_path, sep="\t", header=None)
    for i in range(len(file)):
        qid = file.loc[i][0]
        query = file.loc[i][1]
        qid = int(qid)
        assert (qid not in queries), ("Query QID", qid, "is repeated!")
        queries[qid] = query
        
    print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")

    return queries
def load_qrels(qrels_path):
    if qrels_path is None:
        return None

    print_message("#> Loading qrels from", qrels_path, "...")

    qrels = OrderedDict()
    
    file = pd.read_csv(qrels_path, sep="\t", header=None)
    for i in range(len(file)):
        qid = int(file.loc[i][0])
        x = int(file.loc[i][1])
        pid = int(file.loc[i][2])
        y = int(file.loc[i][3])
        assert x == 0 and y == 1
        qrels[qid] = qrels.get(qid, [])
        qrels[qid].append(pid)

    assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels)

    avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2)

    print_message("#> Loaded qrels for", len(qrels), "unique queries with",
                  avg_positive, "positives per query on average.\n")

    return qrels
# colbert/evaluation/loaders.py
def load_colbert(args, do_print=True):
    colbert, checkpoint = load_model(args, do_print)

    # TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
    # I.e., not their purely (i.e., training) default values.

    for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
        if 'arguments' in checkpoint and hasattr(args, k):
            if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
                a, b = checkpoint['arguments'][k], getattr(args, k)
                Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")

    if 'arguments' in checkpoint:
        if args.rank < 1:
            print(ujson.dumps(checkpoint['arguments'], indent=4))

    if do_print:
        print('\n')

    return colbert, checkpoint
def load_model(args, do_print=True):
    kolbert = KolBERT(
        config=MCONFIG,
        query_maxlen=args.query_maxlen,
        doc_maxlen=args.doc_maxlen,
        dim=args.dim,
        similarity_metric=args.similarity,
        mask_punctuation=args.mask_punctuation
    )
    kolbert = kolbert.to(DEVICE)

    print_message("#> Loading model checkpoint.", condition=do_print)

    checkpoint = load_checkpoint(args.checkpoint, kolbert, do_print=do_print)

    kolbert.eval()

    return kolbert, checkpoint

In [11]:
#evaluation/ranking_logger.py
from contextlib import contextmanager
'''from colbert.utils.utils import print_message, NullContextManager
from colbert.utils.runs import Run'''


class RankingLogger():
    def __init__(self, directory, qrels=None, log_scores=False):
        self.directory = directory
        self.qrels = qrels
        self.filename, self.also_save_annotations = None, None
        self.log_scores = log_scores

    @contextmanager
    def context(self, filename, also_save_annotations=False):
        assert self.filename is None
        assert self.also_save_annotations is None

        filename = os.path.join(self.directory, filename)
        self.filename, self.also_save_annotations = filename, also_save_annotations

        print_message("#> Logging ranked lists to {}".format(self.filename))

        with open(filename, 'w') as f:
            self.f = f
            with (open(filename + '.annotated', 'w') if also_save_annotations else NullContextManager()) as g:
                self.g = g
                try:
                    yield self
                finally:
                    pass

    def log(self, qid, ranking, is_ranked=True, print_positions=[]):
        print_positions = set(print_positions)

        f_buffer = []
        g_buffer = []

        for rank, (score, pid, passage) in enumerate(ranking):
            is_relevant = self.qrels and int(pid in self.qrels[qid])
            rank = rank+1 if is_ranked else -1

            possibly_score = [score] if self.log_scores else []

            f_buffer.append('\t'.join([str(x) for x in [qid, pid, rank] + possibly_score]) + "\n")
            if self.g:
                g_buffer.append('\t'.join([str(x) for x in [qid, pid, rank, is_relevant]]) + "\n")

            if rank in print_positions:
                prefix = "** " if is_relevant else ""
                prefix += str(rank)
                print("#> ( QID {} ) ".format(qid) + prefix + ") ", pid, ":", score, '    ', passage)

        self.f.write(''.join(f_buffer))
        if self.g:
            self.g.write(''.join(g_buffer))

In [12]:
# indexing/loaders.py
def get_parts(directory):
    extension = '.pt'

    parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory)
                    if filename.endswith(extension)])

    assert list(range(len(parts))) == parts, parts

    # Integer-sortedness matters.
    parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts]
    samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts]

    return parts, parts_paths, samples_paths
def load_doclens(directory, flatten=True):
    parts, _, _ = get_parts(directory)

    doclens_filenames = [os.path.join(directory, 'doclens.{}.json'.format(filename)) for filename in parts]
    all_doclens = [ujson.load(open(filename)) for filename in doclens_filenames]

    if flatten:
        all_doclens = [x for sub_doclens in all_doclens for x in sub_doclens]

    return all_doclens

In [13]:
# indexing/faiss.py
import threading
import queue

'''from colbert.utils.utils import print_message, grouper
from colbert.indexing.loaders import get_parts
from colbert.indexing.index_manager import load_index_part
from colbert.indexing.faiss_index import FaissIndex'''


def get_faiss_index_name(args, offset=None, endpos=None):
    partitions_info = '' if args.partitions is None else f'.{args.partitions}'
    range_info = '' if offset is None else f'.{offset}-{endpos}'

    return f'ivfpq{partitions_info}{range_info}.faiss'

In [14]:
#indexing/index_manager.py
class IndexManager():
    def __init__(self, dim):
        self.dim = dim

    def save(self, tensor, path_prefix):
        torch.save(tensor, path_prefix)


def load_index_part(filename, verbose=True):
    part = torch.load(filename)

    if type(part) == list:  # for backward compatibility
        part = torch.cat(part)

    return part

In [15]:
#ranking/faiss_index.py

'''from multiprocessing import Pool
from colbert.modeling.inference import ModelInference

from colbert.utils.utils import print_message, flatten, batch
from colbert.indexing.loaders import load_doclens'''


class FaissIndex():
    def __init__(self, index_path, faiss_index_path, nprobe, part_range=None):
        print_message("#> Loading the FAISS index from", faiss_index_path, "..")

        faiss_part_range = os.path.basename(faiss_index_path).split('.')[-2].split('-')

        if len(faiss_part_range) == 2:
            faiss_part_range = range(*map(int, faiss_part_range))
            assert part_range[0] in faiss_part_range, (part_range, faiss_part_range)
            assert part_range[-1] in faiss_part_range, (part_range, faiss_part_range)
        else:
            faiss_part_range = None

        self.part_range = part_range
        self.faiss_part_range = faiss_part_range

        self.faiss_index = faiss.read_index(faiss_index_path)
        self.faiss_index.nprobe = nprobe

        print_message("#> Building the emb2pid mapping..")
        all_doclens = load_doclens(index_path, flatten=False)

        pid_offset = 0
        if faiss_part_range is not None:
            print(f"#> Restricting all_doclens to the range {faiss_part_range}.")
            pid_offset = len(flatten(all_doclens[:faiss_part_range.start]))
            all_doclens = all_doclens[faiss_part_range.start:faiss_part_range.stop]

        self.relative_range = None
        if self.part_range is not None:
            start = self.faiss_part_range.start if self.faiss_part_range is not None else 0
            a = len(flatten(all_doclens[:self.part_range.start - start]))
            b = len(flatten(all_doclens[:self.part_range.stop - start]))
            self.relative_range = range(a, b)
            print(f"self.relative_range = {self.relative_range}")

        all_doclens = flatten(all_doclens)

        total_num_embeddings = sum(all_doclens)
        self.emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int)

        offset_doclens = 0
        for pid, dlength in enumerate(all_doclens):
            self.emb2pid[offset_doclens: offset_doclens + dlength] = pid_offset + pid
            offset_doclens += dlength

        print_message("len(self.emb2pid) =", len(self.emb2pid))

        self.parallel_pool = Pool(16)

    def retrieve(self, faiss_depth, Q, verbose=False):
        embedding_ids = self.queries_to_embedding_ids(faiss_depth, Q, verbose=verbose)
        pids = self.embedding_ids_to_pids(embedding_ids, verbose=verbose)

        if self.relative_range is not None:
            pids = [[pid for pid in pids_ if pid in self.relative_range] for pids_ in pids]

        return pids

    def queries_to_embedding_ids(self, faiss_depth, Q, verbose=True):
        # Flatten into a matrix for the faiss search.
        num_queries, embeddings_per_query, dim = Q.size()
        Q_faiss = Q.view(num_queries * embeddings_per_query, dim).cpu().contiguous()

        # Search in large batches with faiss.
        print_message("#> Search in batches with faiss. \t\t",
                      f"Q.size() = {Q.size()}, Q_faiss.size() = {Q_faiss.size()}",
                      condition=verbose)

        embeddings_ids = []
        faiss_bsize = embeddings_per_query * 5000
        for offset in range(0, Q_faiss.size(0), faiss_bsize):
            endpos = min(offset + faiss_bsize, Q_faiss.size(0))

            print_message("#> Searching from {} to {}...".format(offset, endpos), condition=verbose)

            some_Q_faiss = Q_faiss[offset:endpos].float().numpy()
            _, some_embedding_ids = self.faiss_index.search(some_Q_faiss, faiss_depth)
            embeddings_ids.append(torch.from_numpy(some_embedding_ids))

        embedding_ids = torch.cat(embeddings_ids)

        # Reshape to (number of queries, non-unique embedding IDs per query)
        embedding_ids = embedding_ids.view(num_queries, embeddings_per_query * embedding_ids.size(1))

        return embedding_ids

    def embedding_ids_to_pids(self, embedding_ids, verbose=True):
        # Find unique PIDs per query.
        print_message("#> Lookup the PIDs..", condition=verbose)
        all_pids = self.emb2pid[embedding_ids]

        print_message(f"#> Converting to a list [shape = {all_pids.size()}]..", condition=verbose)
        all_pids = all_pids.tolist()

        print_message("#> Removing duplicates (in parallel if large enough)..", condition=verbose)

        if len(all_pids) > 5000:
            all_pids = list(self.parallel_pool.map(uniq, all_pids))
        else:
            all_pids = list(map(uniq, all_pids))

        print_message("#> Done with embedding_ids_to_pids().", condition=verbose)

        return all_pids


def uniq(l):
    return list(set(l))

In [16]:
# idnexing/faiss_index_gpu.py
class FaissIndexGPU():
    def __init__(self):
        self.ngpu = faiss.get_num_gpus()

        if self.ngpu == 0:
            return

        self.tempmem = 1 << 33
        self.max_add_per_gpu = 1 << 25
        self.max_add = self.max_add_per_gpu * self.ngpu
        self.add_batch_size = 65536

        self.gpu_resources = self._prepare_gpu_resources()

    def _prepare_gpu_resources(self):
        print_message(f"Preparing resources for {self.ngpu} GPUs.")

        gpu_resources = []

        for _ in range(self.ngpu):
            res = faiss.StandardGpuResources()
            if self.tempmem >= 0:
                res.setTempMemory(self.tempmem)
            gpu_resources.append(res)

        return gpu_resources

    def _make_vres_vdev(self):
        """
        return vectors of device ids and resources useful for gpu_multiple
        """

        assert self.ngpu > 0

        vres = faiss.GpuResourcesVector()
        vdev = faiss.IntVector()

        for i in range(self.ngpu):
            vdev.push_back(i)
            vres.push_back(self.gpu_resources[i])

        return vres, vdev

    def training_initialize(self, index, quantizer):
        """
        The index and quantizer should be owned by caller.
        """

        assert self.ngpu > 0

        s = time.time()
        self.index_ivf = faiss.extract_index_ivf(index)
        self.clustering_index = faiss.index_cpu_to_all_gpus(quantizer)
        self.index_ivf.clustering_index = self.clustering_index
        print(time.time() - s)

    def training_finalize(self):
        assert self.ngpu > 0

        s = time.time()
        self.index_ivf.clustering_index = faiss.index_gpu_to_cpu(self.index_ivf.clustering_index)
        print(time.time() - s)

    def adding_initialize(self, index):
        """
        The index should be owned by caller.
        """

        assert self.ngpu > 0

        self.co = faiss.GpuMultipleClonerOptions()
        self.co.useFloat16 = True
        self.co.useFloat16CoarseQuantizer = False
        self.co.usePrecomputed = False
        self.co.indicesOptions = faiss.INDICES_CPU
        self.co.verbose = True
        self.co.reserveVecs = self.max_add
        self.co.shard = True
        assert self.co.shard_type in (0, 1, 2)

        self.vres, self.vdev = self._make_vres_vdev()
        self.gpu_index = faiss.index_cpu_to_gpu_multiple(self.vres, self.vdev, index, self.co)

    def add(self, index, data, offset):
        assert self.ngpu > 0

        t0 = time.time()
        nb = data.shape[0]

        for i0 in range(0, nb, self.add_batch_size):
            i1 = min(i0 + self.add_batch_size, nb)
            xs = data[i0:i1]

            self.gpu_index.add_with_ids(xs, np.arange(offset+i0, offset+i1))

            if self.max_add > 0 and self.gpu_index.ntotal > self.max_add:
                self._flush_to_cpu(index, nb, offset)

            print('\r%d/%d (%.3f s)  ' % (i0, nb, time.time() - t0), end=' ')
            sys.stdout.flush()

        if self.gpu_index.ntotal > 0:
            self._flush_to_cpu(index, nb, offset)

        assert index.ntotal == offset+nb, (index.ntotal, offset+nb, offset, nb)
        print(f"add(.) time: %.3f s \t\t--\t\t index.ntotal = {index.ntotal}" % (time.time() - t0))

    def _flush_to_cpu(self, index, nb, offset):
        print("Flush indexes to CPU")

        for i in range(self.ngpu):
            index_src_gpu = faiss.downcast_index(self.gpu_index if self.ngpu == 1 else self.gpu_index.at(i))
            index_src = faiss.index_gpu_to_cpu(index_src_gpu)

            index_src.copy_subset_to(index, 0, offset, offset+nb)
            index_src_gpu.reset()
            index_src_gpu.reserveMemory(self.max_add)

        if self.ngpu > 1:
            try:
                self.gpu_index.sync_with_shard_indexes()
            except:
                self.gpu_index.syncWithSubIndexes()

In [17]:
# ranking/retrieval.py
from multiprocessing import Pool
'''
from colbert.modeling.inference import ModelInference
from colbert.evaluation.ranking_logger import RankingLogger

from colbert.utils.utils import print_message, batch
from colbert.ranking.rankers import Ranker'''


def retrieve(args):
    import itertools
    inference = ModelInference(args.colbert, amp=args.amp)
    ranker = Ranker(args, inference, faiss_depth=args.faiss_depth)

    ranking_logger = RankingLogger(Run.path, qrels=None)
    milliseconds = 0

    with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger:
        queries = args.queries
        qids_in_order = list(queries.keys())

        for qoffset, qbatch in batch(qids_in_order, 100, provide_offset=True):
            qbatch_text = [queries[qid] for qid in qbatch]

            rankings = []

            for query_idx, q in enumerate(qbatch_text):
                torch.cuda.synchronize('cuda:0')
                s = time.time()

                Q = ranker.encode([q])
                pids, scores = ranker.rank(Q)

                torch.cuda.synchronize()
                milliseconds += (time.time() - s) * 1000.0

                if len(pids):
                    print(qoffset+query_idx, q, len(scores), len(pids), scores[0], pids[0],
                          milliseconds / (qoffset+query_idx+1), 'ms')

                rankings.append(zip(pids, scores))

            for query_idx, (qid, ranking) in enumerate(zip(qbatch, rankings)):
                query_idx = qoffset + query_idx

                if query_idx % 100 == 0:
                    print_message(f"#> Logging query #{query_idx} (qid {qid}) now...")

                ranking = [(score, pid, None) for pid, score in itertools.islice(ranking, args.depth)]
                rlogger.log(qid, ranking, is_ranked=True)

    print('\n\n')
    print(ranking_logger.filename)
    print("#> Done.")
    print('\n\n')

In [18]:
from functools import partial

'''from colbert.ranking.index_part import IndexPart
from colbert.ranking.faiss_index import FaissIndex
from colbert.utils.utils import flatten, zipstar'''


class Ranker():
    def __init__(self, args, inference, faiss_depth=1024):
        self.inference = inference
        self.faiss_depth = faiss_depth

        if faiss_depth is not None:
            self.faiss_index = FaissIndex(args.index_path, args.faiss_index_path, args.nprobe, part_range=args.part_range)
            self.retrieve = partial(self.faiss_index.retrieve, self.faiss_depth)

        self.index = IndexPart(args.index_path, dim=inference.colbert.dim, part_range=args.part_range, verbose=True)

    def encode(self, queries):
        assert type(queries) in [list, tuple], type(queries)

        Q = self.inference.queryFromText(queries, bsize=512 if len(queries) > 512 else None)

        return Q

    def rank(self, Q, pids=None):
        pids = self.retrieve(Q, verbose=False)[0] if pids is None else pids

        assert type(pids) in [list, tuple], type(pids)
        assert Q.size(0) == 1, (len(pids), Q.size())
        assert all(type(pid) is int for pid in pids)

        scores = []
        if len(pids) > 0:
            Q = Q.permute(0, 2, 1)
            scores = self.index.rank(Q, pids)

            scores_sorter = torch.tensor(scores).sort(descending=True)
            pids, scores = torch.tensor(pids)[scores_sorter.indices].tolist(), scores_sorter.values.tolist()

        return pids, scores

In [19]:
# ranking/index_ranker.py
BSIZE = 1 << 14


class IndexRanker():
    def __init__(self, tensor, doclens):
        self.tensor = tensor
        self.doclens = doclens

        self.maxsim_dtype = torch.float32
        self.doclens_pfxsum = [0] + list(accumulate(self.doclens))

        self.doclens = torch.tensor(self.doclens)
        self.doclens_pfxsum = torch.tensor(self.doclens_pfxsum)

        self.dim = self.tensor.size(-1)

        self.strides = [torch_percentile(self.doclens, p) for p in [90]]
        self.strides.append(self.doclens.max().item())
        self.strides = sorted(list(set(self.strides)))

        print_message(f"#> Using strides {self.strides}..")

        self.views = self._create_views(self.tensor)
        self.buffers = self._create_buffers(BSIZE, self.tensor.dtype, {'cpu', 'cuda:0'})

    def _create_views(self, tensor):
        views = []

        for stride in self.strides:
            outdim = tensor.size(0) - stride + 1
            view = torch.as_strided(tensor, (outdim, stride, self.dim), (self.dim, self.dim, 1))
            views.append(view)

        return views

    def _create_buffers(self, max_bsize, dtype, devices):
        buffers = {}

        for device in devices:
            buffers[device] = [torch.zeros(max_bsize, stride, self.dim, dtype=dtype,
                                           device=device, pin_memory=(device == 'cpu'))
                               for stride in self.strides]

        return buffers

    def rank(self, Q, pids, views=None, shift=0):
        assert len(pids) > 0
        assert Q.size(0) in [1, len(pids)]

        Q = Q.contiguous().to(DEVICE).to(dtype=self.maxsim_dtype)

        views = self.views if views is None else views
        VIEWS_DEVICE = views[0].device

        D_buffers = self.buffers[str(VIEWS_DEVICE)]

        raw_pids = pids if type(pids) is list else pids.tolist()
        pids = torch.tensor(pids) if type(pids) is list else pids

        doclens, offsets = self.doclens[pids], self.doclens_pfxsum[pids]

        assignments = (doclens.unsqueeze(1) > torch.tensor(self.strides).unsqueeze(0) + 1e-6).sum(-1)

        one_to_n = torch.arange(len(raw_pids))
        output_pids, output_scores, output_permutation = [], [], []

        for group_idx, stride in enumerate(self.strides):
            locator = (assignments == group_idx)

            if locator.sum() < 1e-5:
                continue

            group_pids, group_doclens, group_offsets = pids[locator], doclens[locator], offsets[locator]
            group_Q = Q if Q.size(0) == 1 else Q[locator]

            group_offsets = group_offsets.to(VIEWS_DEVICE) - shift
            group_offsets_uniq, group_offsets_expand = torch.unique_consecutive(group_offsets, return_inverse=True)

            D_size = group_offsets_uniq.size(0)
            D = torch.index_select(views[group_idx], 0, group_offsets_uniq, out=D_buffers[group_idx][:D_size])
            D = D.to(DEVICE)
            D = D[group_offsets_expand.to(DEVICE)].to(dtype=self.maxsim_dtype)

            mask = torch.arange(stride, device=DEVICE) + 1
            mask = mask.unsqueeze(0) <= group_doclens.to(DEVICE).unsqueeze(-1)

            scores = (D @ group_Q) * mask.unsqueeze(-1)
            scores = scores.max(1).values.sum(-1).cpu()

            output_pids.append(group_pids)
            output_scores.append(scores)
            output_permutation.append(one_to_n[locator])

        output_permutation = torch.cat(output_permutation).sort().indices
        output_pids = torch.cat(output_pids)[output_permutation].tolist()
        output_scores = torch.cat(output_scores)[output_permutation].tolist()

        assert len(raw_pids) == len(output_pids)
        assert len(raw_pids) == len(output_scores)
        assert raw_pids == output_pids

        return output_scores

    def batch_rank(self, all_query_embeddings, all_query_indexes, all_pids, sorted_pids):
        assert sorted_pids is True

        ######

        scores = []
        range_start, range_end = 0, 0

        for pid_offset in range(0, len(self.doclens), 50_000):
            pid_endpos = min(pid_offset + 50_000, len(self.doclens))

            range_start = range_start + (all_pids[range_start:] < pid_offset).sum()
            range_end = range_end + (all_pids[range_end:] < pid_endpos).sum()

            pids = all_pids[range_start:range_end]
            query_indexes = all_query_indexes[range_start:range_end]

            print_message(f"###--> Got {len(pids)} query--passage pairs in this sub-range {(pid_offset, pid_endpos)}.")

            if len(pids) == 0:
                continue

            print_message(f"###--> Ranking in batches the pairs #{range_start} through #{range_end} in this sub-range.")

            tensor_offset = self.doclens_pfxsum[pid_offset].item()
            tensor_endpos = self.doclens_pfxsum[pid_endpos].item() + 512

            collection = self.tensor[tensor_offset:tensor_endpos].to(DEVICE)
            views = self._create_views(collection)

            print_message(f"#> Ranking in batches of {BSIZE} query--passage pairs...")

            for batch_idx, offset in enumerate(range(0, len(pids), BSIZE)):
                if batch_idx % 100 == 0:
                    print_message("#> Processing batch #{}..".format(batch_idx))

                endpos = offset + BSIZE
                batch_query_index, batch_pids = query_indexes[offset:endpos], pids[offset:endpos]

                Q = all_query_embeddings[batch_query_index]

                scores.extend(self.rank(Q, batch_pids, views, shift=tensor_offset))

        return scores


def torch_percentile(tensor, p):
    assert p in range(1, 100+1)
    assert tensor.dim() == 1

    return tensor.kthvalue(int(p * tensor.size(0) / 100.0)).values.item()


In [20]:
#ranking/index_part
'''from math import ceil
from itertools import accumulate
from colbert.utils.utils import print_message, dotdict, flatten

from colbert.indexing.loaders import get_parts, load_doclens
from colbert.indexing.index_manager import load_index_part
from colbert.ranking.index_ranker import IndexRanker'''


class IndexPart():
    def __init__(self, directory, dim=128, part_range=None, verbose=True):
        first_part, last_part = (0, None) if part_range is None else (part_range.start, part_range.stop)

        # Load parts metadata
        all_parts, all_parts_paths, _ = get_parts(directory)
        self.parts = all_parts[first_part:last_part]
        self.parts_paths = all_parts_paths[first_part:last_part]

        # Load doclens metadata
        all_doclens = load_doclens(directory, flatten=False)

        self.doc_offset = sum([len(part_doclens) for part_doclens in all_doclens[:first_part]])
        self.doc_endpos = sum([len(part_doclens) for part_doclens in all_doclens[:last_part]])
        self.pids_range = range(self.doc_offset, self.doc_endpos)

        self.parts_doclens = all_doclens[first_part:last_part]
        self.doclens = flatten(self.parts_doclens)
        self.num_embeddings = sum(self.doclens)

        self.tensor = self._load_parts(dim, verbose)
        self.ranker = IndexRanker(self.tensor, self.doclens)

    def _load_parts(self, dim, verbose):
        tensor = torch.zeros(self.num_embeddings + 512, dim, dtype=torch.float16)

        if verbose:
            print_message("tensor.size() = ", tensor.size())

        offset = 0
        for idx, filename in enumerate(self.parts_paths):
            print_message("|> Loading", filename, "...", condition=verbose)

            endpos = offset + sum(self.parts_doclens[idx])
            part = load_index_part(filename, verbose=verbose)

            tensor[offset:endpos] = part
            offset = endpos

        return tensor

    def pid_in_range(self, pid):
        return pid in self.pids_range

    def rank(self, Q, pids):
        """
        Rank a single batch of Q x pids (e.g., 1k--10k pairs).
        """

        assert Q.size(0) in [1, len(pids)], (Q.size(0), len(pids))
        assert all(pid in self.pids_range for pid in pids), self.pids_range

        pids_ = [pid - self.doc_offset for pid in pids]
        scores = self.ranker.rank(Q, pids_)

        return scores

    def batch_rank(self, all_query_embeddings, query_indexes, pids, sorted_pids):
        """
        Rank a large, fairly dense set of query--passage pairs (e.g., 1M+ pairs).
        Higher overhead, much faster for large batches.
        """

        assert ((pids >= self.pids_range.start) & (pids < self.pids_range.stop)).sum() == pids.size(0)

        pids_ = pids - self.doc_offset
        scores = self.ranker.batch_rank(all_query_embeddings, query_indexes, pids_, sorted_pids)

        return scores

In [21]:
# utils/utils.py
def print_message(*s, condition=True):
    s = ' '.join([str(x) for x in s])
    msg = "[{}] {}".format(datetime.datetime.now().strftime("%b %d, %H:%M:%S"), s)

    if condition:
        print(msg, flush=True)

    return msg
def timestamp():
    format_str = "%Y-%m-%d_%H.%M.%S"
    result = datetime.datetime.now().strftime(format_str)
    return result

def create_directory(path):
    if os.path.exists(path):
        print('\n')
        print_message("#> Note: Output directory", path, 'already exists\n\n')
    else:
        print('\n')
        print_message("#> Creating directory", path, '\n\n')
        os.makedirs(path)
        
def distributed_init(rank):
    nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE'])
    nranks = max(1, nranks)
    is_distributed = nranks > 1

    if rank == 0:
        print('nranks =', nranks, '\t num_gpus =', torch.cuda.device_count())

    if is_distributed:
        num_gpus = torch.cuda.device_count()
        torch.cuda.set_device(rank % num_gpus)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    return nranks, is_distributed

def distributed_barrier(rank):
    if rank >= 0:
        torch.distributed.barrier()
        
def load_checkpoint(path, model, optimizer=None, do_print=True):
    if do_print:
        print_message("#> Loading checkpoint", path, "..")

    if path.startswith("http:") or path.startswith("https:"):
        checkpoint = torch.hub.load_state_dict_from_url(path, map_location='cpu')
    else:
        checkpoint = torch.load(path, map_location='cpu')

    state_dict = checkpoint['model_state_dict']
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if k[:7] == 'module.':
            name = k[7:]
        new_state_dict[name] = v

    checkpoint['model_state_dict'] = new_state_dict

    try:
        model.load_state_dict(checkpoint['model_state_dict'])
    except:
        print_message("[WARNING] Loading checkpoint with strict=False")
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if do_print:
        print_message("#> checkpoint['epoch'] =", checkpoint['epoch'])
        print_message("#> checkpoint['batch'] =", checkpoint['batch'])

    return checkpoint
class NullContextManager(object):
    def __init__(self, dummy_resource=None):
        self.dummy_resource = dummy_resource
    def __enter__(self):
        return self.dummy_resource
    def __exit__(self, *args):
        pass
    
def grouper(iterable, n, fillvalue=None):
    """
    Collect data into fixed-length chunks or blocks
        Example: grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
        Source: https://docs.python.org/3/library/itertools.html#itertools-recipes
    """

    args = [iter(iterable)] * n
    return itertools.zip_longest(*args, fillvalue=fillvalue)
def batch(group, bsize, provide_offset=False):
    offset = 0
    while offset < len(group):
        L = group[offset: offset + bsize]
        yield ((offset, L) if provide_offset else L)
        offset += len(L)
    return
class dotdict(dict):
    """
    dot.notation access to dictionary attributes
    Credit: derek73 @ https://stackoverflow.com/questions/2352181
    """
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


def flatten(L):
    return [x for y in L for x in y]
def zipstar(L, lazy=False):
    """
    A much faster A, B, C = zip(*[(a, b, c), (a, b, c), ...])
    May return lists or tuples.
    """

    if len(L) == 0:
        return L

    width = len(L[0])

    if width < 100:
        return [[elem[idx] for elem in L] for idx in range(width)]

    L = zip(*L)

    return L if lazy else list(L)

In [22]:
# colbert/utils/amp.py
'''
from contextlib import contextmanager
from colbert.utils.utils import NullContextManager'''
from packaging import version

v = version.parse
PyTorch_over_1_6  = v(torch.__version__) >= v('1.6')

class MixedPrecisionManager():
    def __init__(self, activated):
        assert (not activated) or PyTorch_over_1_6, "Cannot use AMP for PyTorch version < 1.6"

        self.activated = activated

        if self.activated:
            self.scaler = torch.cuda.amp.GradScaler()

    def context(self):
        return torch.cuda.amp.autocast() if self.activated else NullContextManager()

    def backward(self, loss):
        if self.activated:
            self.scaler.scale(loss).backward()
        else:
            loss.backward()

    def step(self, kolbert, optimizer):
        if self.activated:
            self.scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(kolbert.parameters(), 2.0)

            self.scaler.step(optimizer)
            self.scaler.update()
            optimizer.zero_grad()
        else:
            torch.nn.utils.clip_grad_norm_(kolbert.parameters(), 2.0)
            optimizer.step()
            optimizer.zero_grad()

In [23]:
# utils/logging.py
class Logger():
    def __init__(self, rank, run):
        self.rank = rank
        self.is_main = self.rank in [-1, 0]
        self.run = run
        self.logs_path = os.path.join(self.run.path, "logs/")

        if self.is_main:
            self._init_mlflow()
            self.initialized_tensorboard = False
            create_directory(self.logs_path)

    def _init_mlflow(self):
        mlflow.set_tracking_uri('file://' + os.path.join(self.run.experiments_root, "logs/mlruns/"))
        mlflow.set_experiment('/'.join([self.run.experiment, self.run.script]))
        
        mlflow.set_tag('experiment', self.run.experiment)
        mlflow.set_tag('name', self.run.name)
        mlflow.set_tag('path', self.run.path)

    def _init_tensorboard(self):
        root = os.path.join(self.run.experiments_root, "logs/tensorboard/")
        logdir = '__'.join([self.run.experiment, self.run.script, self.run.name])
        logdir = os.path.join(root, logdir)

        self.writer = SummaryWriter(log_dir=logdir)
        self.initialized_tensorboard = True

    def _log_exception(self, etype, value, tb):
        if not self.is_main:
            return

        output_path = os.path.join(self.logs_path, 'exception.txt')
        trace = ''.join(traceback.format_exception(etype, value, tb)) + '\n'
        print_message(trace, '\n\n')

        self.log_new_artifact(output_path, trace)

    def _log_all_artifacts(self):
        if not self.is_main:
            return

        mlflow.log_artifacts(self.logs_path)

    def _log_args(self, args):
        if not self.is_main:
            return

        for key in vars(args):
            value = getattr(args, key)
            if type(value) in [int, float, str, bool]:
                mlflow.log_param(key, value)

        with open(os.path.join(self.logs_path, 'args.json'), 'w') as output_metadata:
            ujson.dump(args.input_arguments.__dict__, output_metadata, indent=4)
            output_metadata.write('\n')

        with open(os.path.join(self.logs_path, 'args.txt'), 'w') as output_metadata:
            output_metadata.write(' '.join(sys.argv) + '\n')

    def log_metric(self, name, value, step, log_to_mlflow=True):
        if not self.is_main:
            return

        if not self.initialized_tensorboard:
            self._init_tensorboard()

        if log_to_mlflow:
            mlflow.log_metric(name, value, step=step)
        self.writer.add_scalar(name, value, step)

    def log_new_artifact(self, path, content):
        with open(path, 'w') as f:
            f.write(content)

        mlflow.log_artifact(path)

    def warn(self, *args):
        msg = print_message('[WARNING]', '\t', *args)

        with open(os.path.join(self.logs_path, 'warnings.txt'), 'a') as output_metadata:
            output_metadata.write(msg + '\n\n\n')

    def info_all(self, *args):
        print_message('[' + str(self.rank) + ']', '\t', *args)

    def info(self, *args):
        if self.is_main:
            print_message(*args)

In [24]:
# utils/runs.py
'''
import colbert.utils.distributed as distributed

from contextlib import contextmanager
from colbert.utils.logging import Logger
from colbert.utils.utils import timestamp, create_directory, print_message
'''
class _RunManager():
    def __init__(self):
        self.experiments_root = None
        self.experiment = None
        self.path = None
        self.script = self._get_script_name()
        self.name = self._generate_default_run_name()
        self.original_name = self.name
        self.exit_status = 'FINISHED'

        self._logger = None
        self.start_time = time.time()

    def init(self, rank, root, experiment, name):
        assert '/' not in experiment, experiment
        assert '/' not in name, name

        self.experiments_root = os.path.abspath(root)
        self.experiment = experiment
        self.name = name
        self.path = os.path.join(self.experiments_root, self.experiment, self.script, self.name)

        if rank < 1:
            if os.path.exists(self.path):
                print('\n\n')
                print_message("It seems that ", self.path, " already exists.")
                print_message("Do you want to overwrite it? \t yes/no \n")

                # TODO: This should timeout and exit (i.e., fail) given no response for 60 seconds.

                response = input()
                if response.strip() != 'yes':
                    assert not os.path.exists(self.path), self.path
            else:
                create_directory(self.path)

        distributed_barrier(rank)

        self._logger = Logger(rank, self)
        self._log_args = self._logger._log_args
        self.warn = self._logger.warn
        self.info = self._logger.info
        self.info_all = self._logger.info_all
        self.log_metric = self._logger.log_metric
        self.log_new_artifact = self._logger.log_new_artifact

    def _generate_default_run_name(self):
        return timestamp()

    def _get_script_name(self):
        return os.path.basename('main'.__file__) if '__file__' in dir('main') else 'none'
#         return os.path.basename(__main__.__file__) if '__file__' in dir(__main__) else 'none'

    @contextmanager
    def context(self, consider_failed_if_interrupted=True):
        try:
            yield

        except KeyboardInterrupt as ex:
            print('\n\nInterrupted\n\n')
            self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
            self._logger._log_all_artifacts()

            if consider_failed_if_interrupted:
                self.exit_status = 'KILLED'  # mlflow.entities.RunStatus.KILLED

            sys.exit(128 + 2)

        except Exception as ex:
            self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
            self._logger._log_all_artifacts()

            self.exit_status = 'FAILED'  # mlflow.entities.RunStatus.FAILED

            raise ex

        finally:
            total_seconds = str(time.time() - self.start_time) + '\n'
            original_name = str(self.original_name)
            name = str(self.name)

            self.log_new_artifact(os.path.join(self._logger.logs_path, 'elapsed.txt'), total_seconds)
            self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.original.txt'), original_name)
            self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.txt'), name)

            self._logger._log_all_artifacts()

            mlflow.end_run(status=self.exit_status)


Run = _RunManager()

In [25]:
# utils/parser.py
class Arguments():
    def __init__(self, description):
        self.parser = ArgumentParser(description=description)
        self.checks = []

        self.add_argument('--root', dest='root', default='/')
        self.add_argument('--experiment', dest='experiment', default='MSMARCO-psg')
        self.add_argument('--run', dest='run', default=Run.name)

        self.add_argument('--local_rank', dest='rank', default=-1, type=int)

    def add_model_parameters(self):
        # Core Arguments
        self.add_argument('--similarity', dest='similarity', default='cosine', choices=['cosine', 'l2'])
        self.add_argument('--dim', dest='dim', default=128, type=int)
        self.add_argument('--query_maxlen', dest='query_maxlen', default=32, type=int)
        self.add_argument('--doc_maxlen', dest='doc_maxlen', default=180, type=int)

        # Filtering-related Arguments
        self.add_argument('--mask-punctuation', dest='mask_punctuation', default=True, action='store_true')

    def add_index_use_input(self):
        self.add_argument('--index_root', dest='index_root', default='/')
        self.add_argument('--index_name', dest='index_name', default='MSMARCO.L2.32x200k')
        self.add_argument('--partitions', dest='partitions', default=32768, type=int)
        
    def add_model_inference_parameters(self):
        self.add_argument('--checkpoint', dest='checkpoint', default='/', type=str)
        self.add_argument('--bsize', dest='bsize', default=256, type=int)
        self.add_argument('--amp', dest='amp', default=180, action='store_true')

    def add_ranking_input(self):
        self.add_argument('--queries', dest='queries', default='/')
        self.add_argument('--collection', dest='collection', default=None)
        self.add_argument('--qrels', dest='qrels', default=None)

    def add_retrieval_input(self):
        self.add_index_use_input()
        self.add_argument('--nprobe', dest='nprobe', default=32, type=int)
        self.add_argument('--retrieve_only', dest='retrieve_only', default=False, action='store_true')

    def add_argument(self, *args, **kw_args):
        return self.parser.add_argument(*args, **kw_args)

    def check_arguments(self, args):
        for check in self.checks:
            check(args)

    def parse(self):
        args = self.parser.parse_args()
        self.check_arguments(args)

        args.input_arguments = copy.deepcopy(args)

        args.nranks, args.distributed = distributed_init(args.rank)

        args.nthreads = int(max(os.cpu_count(), faiss.omp_get_max_threads()) * 0.8)
        args.nthreads = max(1, args.nthreads // args.nranks)

        if args.nranks > 1:
            print_message(f"#> Restricting number of threads for FAISS to {args.nthreads} per process",
                          condition=(args.rank == 0))
            faiss.omp_set_num_threads(args.nthreads)

        Run.init(args.rank, args.root, args.experiment, args.run)
        Run._log_args(args)
        Run.info(args.input_arguments.__dict__, '\n')

        return args

In [None]:
# retrieve.py
'''from colbert.utils.parser import Arguments
from colbert.utils.runs import Run

from colbert.evaluation.loaders import load_colbert, load_qrels, load_queries
from colbert.indexing.faiss import get_faiss_index_name
from colbert.ranking.retrieval import retrieve
from colbert.ranking.batch_retrieval import batch_retrieve'''
def main():
    random.seed(12345)

    parser = Arguments(description='End-to-end retrieval and ranking with ColBERT.')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()
    parser.add_ranking_input()
    parser.add_retrieval_input()

    parser.add_argument('--faiss_name', dest='faiss_name', default=None, type=str)
    parser.add_argument('--faiss_depth', dest='faiss_depth', default=1024, type=int)
    parser.add_argument('--part-range', dest='part_range', default=None, type=str)
    parser.add_argument('--batch', dest='batch', default=False, action='store_true')
    parser.add_argument('--depth', dest='depth', default=1000, type=int)
    
    parser.add_argument('-f')

    args = parser.parse()

    args.depth = args.depth if args.depth > 0 else None

    if args.part_range:
        part_offset, part_endpos = map(int, args.part_range.split('..'))
        args.part_range = range(part_offset, part_endpos)

    with Run.context():
        args.colbert, args.checkpoint = load_colbert(args)
        args.qrels = load_qrels(args.qrels)
        args.queries = load_queries(args.queries)

        args.index_path = os.path.join(args.index_root, args.index_name)

        if args.faiss_name is not None:
            args.faiss_index_path = os.path.join(args.index_path, args.faiss_name)
        else:
            args.faiss_index_path = os.path.join(args.index_path, get_faiss_index_name(args))

        if args.batch:
            batch_retrieve(args)
        else:
            retrieve(args)


if __name__ == "__main__":
    main()