In [33]:
import os
import random
import time
import datetime
import sys
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import mlflow
import ujson
from itertools import accumulate
from math import ceil
import traceback
import string

from contextlib import contextmanager

from packaging import version

import copy
import faiss


from argparse import ArgumentParser

from collections import defaultdict, OrderedDict

from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)

## utils

In [6]:
# colbert/utils/utils.py

def print_message(*s, condition=True):
    s = ' '.join([str(x) for x in s])
    msg = "[{}] {}".format(datetime.datetime.now().strftime("%b %d, %H:%M:%S"), s)

    if condition:
        print(msg, flush=True)

    return msg

def timestamp():
    format_str = "%Y-%m-%d_%H.%M.%S"
    result = datetime.datetime.now().strftime(format_str)
    return result

def create_directory(path):
    if os.path.exists(path):
        print('\n')
        print_message("#> Note: Output directory", path, 'already exists\n\n')
    else:
        print('\n')
        print_message("#> Creating directory", path, '\n\n')
        os.makedirs(path)
class NullContextManager(object):
    def __init__(self, dummy_resource=None):
        self.dummy_resource = dummy_resource
    def __enter__(self):
        return self.dummy_resource
    def __exit__(self, *args):
        pass
    
def load_checkpoint(path, model, optimizer=None, do_print=True):
    if do_print:
        print_message("#> Loading checkpoint", path, "..")

    if path.startswith("http:") or path.startswith("https:"):
        checkpoint = torch.hub.load_state_dict_from_url(path, map_location='cpu')
    else:
        checkpoint = torch.load(path, map_location='cpu')

    state_dict = checkpoint['model_state_dict']
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if k[:7] == 'module.':
            name = k[7:]
        new_state_dict[name] = v

    checkpoint['model_state_dict'] = new_state_dict

    try:
        model.load_state_dict(checkpoint['model_state_dict'])
    except:
        print_message("[WARNING] Loading checkpoint with strict=False")
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if do_print:
        print_message("#> checkpoint['epoch'] =", checkpoint['epoch'])
        print_message("#> checkpoint['batch'] =", checkpoint['batch'])

    return checkpoint

In [7]:
# colbert/utils/logging.py
'''
from torch.utils.tensorboard import SummaryWriter
from colbert.utils.utils import print_message, create_directory--
'''
from torch.utils.tensorboard import SummaryWriter

class Logger():
    def __init__(self, rank, run):
        self.rank = rank
        self.is_main = self.rank in [-1, 0]
        self.run = run
        self.logs_path = os.path.join(self.run.path, "logs/")

        if self.is_main:
            self._init_mlflow()
            self.initialized_tensorboard = False
            create_directory(self.logs_path)

    def _init_mlflow(self):
        mlflow.set_tracking_uri('file://' + os.path.join(self.run.experiments_root, "logs/mlruns/"))
        mlflow.set_experiment('/'.join([self.run.experiment, self.run.script]))
        
        mlflow.set_tag('experiment', self.run.experiment)
        mlflow.set_tag('name', self.run.name)
        mlflow.set_tag('path', self.run.path)

    def _init_tensorboard(self):
        root = os.path.join(self.run.experiments_root, "logs/tensorboard/")
        logdir = '__'.join([self.run.experiment, self.run.script, self.run.name])
        logdir = os.path.join(root, logdir)

        self.writer = SummaryWriter(log_dir=logdir)
        self.initialized_tensorboard = True

    def _log_exception(self, etype, value, tb):
        if not self.is_main:
            return

        output_path = os.path.join(self.logs_path, 'exception.txt')
        trace = ''.join(traceback.format_exception(etype, value, tb)) + '\n'
        print_message(trace, '\n\n')

        self.log_new_artifact(output_path, trace)

    def _log_all_artifacts(self):
        if not self.is_main:
            return

        mlflow.log_artifacts(self.logs_path)

    def _log_args(self, args):
        if not self.is_main:
            return

        for key in vars(args):
            value = getattr(args, key)
            if type(value) in [int, float, str, bool]:
                mlflow.log_param(key, value)

        with open(os.path.join(self.logs_path, 'args.json'), 'w') as output_metadata:
            ujson.dump(args.input_arguments.__dict__, output_metadata, indent=4)
            output_metadata.write('\n')

        with open(os.path.join(self.logs_path, 'args.txt'), 'w') as output_metadata:
            output_metadata.write(' '.join(sys.argv) + '\n')

    def log_metric(self, name, value, step, log_to_mlflow=True):
        if not self.is_main:
            return

        if not self.initialized_tensorboard:
            self._init_tensorboard()

        if log_to_mlflow:
            mlflow.log_metric(name, value, step=step)
        self.writer.add_scalar(name, value, step)

    def log_new_artifact(self, path, content):
        with open(path, 'w') as f:
            f.write(content)

        mlflow.log_artifact(path)

    def warn(self, *args):
        msg = print_message('[WARNING]', '\t', *args)

        with open(os.path.join(self.logs_path, 'warnings.txt'), 'a') as output_metadata:
            output_metadata.write(msg + '\n\n\n')

    def info_all(self, *args):
        print_message('[' + str(self.rank) + ']', '\t', *args)

    def info(self, *args):
        if self.is_main:
            print_message(*args)

In [8]:
# colbert/utils/distributed.py
def distributed_init(rank):
    nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE'])
    nranks = max(1, nranks)
    is_distributed = nranks > 1

    if rank == 0:
        print('nranks =', nranks, '\t num_gpus =', torch.cuda.device_count())

    if is_distributed:
        num_gpus = torch.cuda.device_count()
        torch.cuda.set_device(rank % num_gpus)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    return nranks, is_distributed
def distributed_barrier(rank):
    if rank >= 0:
        torch.distributed.barrier()

In [9]:
# colbert/utils/runs.py
'''
import colbert.utils.distributed as distributed--
'''
'''
from colbert.utils.logging import Logger--
from colbert.utils.utils import timestamp, create_directory, print_message--
'''

class _RunManager():
    def __init__(self):
        self.experiments_root = None
        self.experiment = None
        self.path = None
        self.script = self._get_script_name()
        self.name = self._generate_default_run_name()
        self.original_name = self.name
        self.exit_status = 'FINISHED'

        self._logger = None
        self.start_time = time.time()

    def init(self, rank, root, experiment, name):
        assert '/' not in experiment, experiment
        assert '/' not in name, name

        self.experiments_root = os.path.abspath(root)
        self.experiment = experiment
        self.name = name
        self.path = os.path.join(self.experiments_root, self.experiment, self.script, self.name)

        if rank < 1:
            if os.path.exists(self.path):
                print('\n\n')
                print_message("It seems that ", self.path, " already exists.")
                print_message("Do you want to overwrite it? \t yes/no \n")

                # TODO: This should timeout and exit (i.e., fail) given no response for 60 seconds.

                response = input()
                if response.strip() != 'yes':
                    assert not os.path.exists(self.path), self.path
            else:
                create_directory(self.path)


        if rank >= 0:
            torch.distributed_barrier()

        self._logger = Logger(rank, self)
        self._log_args = self._logger._log_args
        self.warn = self._logger.warn
        self.info = self._logger.info
        self.info_all = self._logger.info_all
        self.log_metric = self._logger.log_metric
        self.log_new_artifact = self._logger.log_new_artifact

    def _generate_default_run_name(self):
        return timestamp()

    def _get_script_name(self):
        #return os.path.basename(__main__.__file__) if '__file__' in dir(__main__) else 'none'
        return os.path.basename('main'.__file__) if '__file__' in dir('main') else 'none'

    @contextmanager
    def context(self, consider_failed_if_interrupted=True):
        try:
            yield

        except KeyboardInterrupt as ex:
            print('\n\nInterrupted\n\n')
            self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
            self._logger._log_all_artifacts()

            if consider_failed_if_interrupted:
                self.exit_status = 'KILLED'  # mlflow.entities.RunStatus.KILLED

            sys.exit(128 + 2)

        except Exception as ex:
            self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
            self._logger._log_all_artifacts()

            self.exit_status = 'FAILED'  # mlflow.entities.RunStatus.FAILED

            raise ex

        finally:
            total_seconds = str(time.time() - self.start_time) + '\n'
            original_name = str(self.original_name)
            name = str(self.name)

            self.log_new_artifact(os.path.join(self._logger.logs_path, 'elapsed.txt'), total_seconds)
            self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.original.txt'), original_name)
            self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.txt'), name)

            self._logger._log_all_artifacts()

            mlflow.end_run(status=self.exit_status)


Run = _RunManager()

In [10]:
# colbert/utils/amp.py
'''
from contextlib import contextmanager
from colbert.utils.utils import NullContextManager--
from packaging import version
'''
v = version.parse
PyTorch_over_1_6  = v(torch.__version__) >= v('1.6')

class MixedPrecisionManager():
    def __init__(self, activated):
        assert (not activated) or PyTorch_over_1_6, "Cannot use AMP for PyTorch version < 1.6"

        self.activated = activated

        if self.activated:
            self.scaler = torch.cuda.amp.GradScaler()

    def context(self):
        return torch.cuda.amp.autocast() if self.activated else NullContextManager()

    def backward(self, loss):
        if self.activated:
            self.scaler.scale(loss).backward()
        else:
            loss.backward()

    def step(self, colbert, optimizer):
        if self.activated:
            self.scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)

            self.scaler.step(optimizer)
            self.scaler.update()
            optimizer.zero_grad()
        else:
            torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
            optimizer.step()
            optimizer.zero_grad()

## modeling

### tokenization

In [11]:
# colbert/modeling/tokenization/utils.py 
def _sort_by_length(ids, mask, bsize):
    if ids.size(0) <= bsize:
        return ids, mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices

    return ids[indices], mask[indices], reverse_indices


def _split_into_batches(ids, mask, bsize):
    batches = []
    for offset in range(0, ids.size(0), bsize):
        batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))

    return batches

In [12]:
# colbert/modeling/tokenization/utils.py

def tensorize_triples(query_tokenizer, doc_tokenizer, queries, positives, negatives, bsize):
    assert len(queries) == len(positives) == len(negatives)
    assert bsize is None or len(queries) % bsize == 0

    N = len(queries)
    Q_ids, Q_mask = query_tokenizer.tensorize(queries)
    D_ids, D_mask = doc_tokenizer.tensorize(positives + negatives)
    D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1)

    # Compute max among {length of i^th positive, length of i^th negative} for i \in N
    maxlens = D_mask.sum(-1).max(0).values

    # Sort by maxlens
    indices = maxlens.sort().indices
    Q_ids, Q_mask = Q_ids[indices], Q_mask[indices]
    D_ids, D_mask = D_ids[:, indices], D_mask[:, indices]

    (positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask

    query_batches = _split_into_batches(Q_ids, Q_mask, bsize)
    positive_batches = _split_into_batches(positive_ids, positive_mask, bsize)
    negative_batches = _split_into_batches(negative_ids, negative_mask, bsize)

    batches = []
    for (q_ids, q_mask), (p_ids, p_mask), (n_ids, n_mask) in zip(query_batches, positive_batches, negative_batches):
        Q = (torch.cat((q_ids, q_ids)), torch.cat((q_mask, q_mask)))
        D = (torch.cat((p_ids, n_ids)), torch.cat((p_mask, n_mask)))
        batches.append((Q, D))

    return batches


def _sort_by_length(ids, mask, bsize):
    if ids.size(0) <= bsize:
        return ids, mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices

    return ids[indices], mask[indices], reverse_indices


def _split_into_batches(ids, mask, bsize):
    batches = []
    for offset in range(0, ids.size(0), bsize):
        batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))

    return batches

In [13]:
# colbert/modeling/tokenization/query_tokenization.py
'''
from transformers import BertTokenizerFast
from colbert.modeling.tokenization.utils import _split_into_batches
'''

class QueryTokenizer():
    def __init__(self, query_maxlen):
        self.tok = BertTokenizerFast.from_pretrained('bert-base-uncased')
        self.query_maxlen = query_maxlen

        self.Q_marker_token, self.Q_marker_token_id = '[Q]', self.tok.convert_tokens_to_ids('[unused0]')
        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
        self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id

        assert self.Q_marker_token_id == 1 and self.mask_token_id == 103

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token, self.Q_marker_token], [self.sep_token]
        tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id, self.Q_marker_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        # add placehold for the [Q] marker
        batch_text = ['. ' + x for x in batch_text]

        obj = self.tok(batch_text, padding='max_length', truncation=True,
                       return_tensors='pt', max_length=self.query_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        # postprocess for the [Q] marker and the [MASK] augmentation
        ids[:, 1] = self.Q_marker_token_id
        ids[ids == 0] = self.mask_token_id

        if bsize:
            batches = _split_into_batches(ids, mask, bsize)
            return batches

        return ids, mask

In [14]:
# colbert/modeling/tokenization/doc_tokenization.py 

'''
from transformers import BertTokenizerFast
from colbert.modeling.tokenization.utils import _split_into_batches, _sort_by_length
'''

class DocTokenizer():
    def __init__(self, doc_maxlen):
        self.tok = BertTokenizerFast.from_pretrained('bert-base-uncased')
        self.doc_maxlen = doc_maxlen

        self.D_marker_token, self.D_marker_token_id = '[D]', self.tok.convert_tokens_to_ids('[unused1]')
        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id

        assert self.D_marker_token_id == 2

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token]
        tokens = [prefix + lst + suffix for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        # add placehold for the [D] marker
        batch_text = ['. ' + x for x in batch_text]

        obj = self.tok(batch_text, padding='longest', truncation='longest_first',
                       return_tensors='pt', max_length=self.doc_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        # postprocess for the [D] marker
        ids[:, 1] = self.D_marker_token_id

        if bsize:
            ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
            batches = _split_into_batches(ids, mask, bsize)
            return batches, reverse_indices

        return ids, mask

### colbert

In [15]:
# colbert/modeling/colbert.py
'''
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
from colbert.parameters import DEVICE--
'''
class ColBERT(BertPreTrainedModel):
    def __init__(self, config, query_maxlen, doc_maxlen, mask_punctuation, dim=128, similarity_metric='cosine'):

        super(ColBERT, self).__init__(config)

        self.query_maxlen = query_maxlen
        self.doc_maxlen = doc_maxlen
        self.similarity_metric = similarity_metric
        self.dim = dim

        self.mask_punctuation = mask_punctuation
        self.skiplist = {}

        if self.mask_punctuation:
            self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
            self.skiplist = {w: True
                             for symbol in string.punctuation
                             for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]}

        self.bert = BertModel(config)
        self.linear = nn.Linear(config.hidden_size, dim, bias=False)

        self.init_weights()

    def forward(self, Q, D):
        return self.score(self.query(*Q), self.doc(*D))

    def query(self, input_ids, attention_mask):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        Q = self.bert(input_ids, attention_mask=attention_mask)[0]
        Q = self.linear(Q)

        return torch.nn.functional.normalize(Q, p=2, dim=2)

    def doc(self, input_ids, attention_mask, keep_dims=True):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        D = self.bert(input_ids, attention_mask=attention_mask)[0]
        D = self.linear(D)

        mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float()
        D = D * mask

        D = torch.nn.functional.normalize(D, p=2, dim=2)

        if not keep_dims:
            D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
            D = [d[mask[idx]] for idx, d in enumerate(D)]

        return D

    def score(self, Q, D):
        if self.similarity_metric == 'cosine':
            return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)

        assert self.similarity_metric == 'l2'
        return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)

    def mask(self, input_ids):
        mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
        return mask

### parameters

In [16]:
# colbert/parameters.py 
DEVICE = torch.device("cuda:0")

SAVED_CHECKPOINTS = [32*1000, 100*1000, 150*1000, 200*1000, 300*1000, 400*1000]
SAVED_CHECKPOINTS += [10*1000, 20*1000, 30*1000, 40*1000, 50*1000, 60*1000, 70*1000, 80*1000, 90*1000]
SAVED_CHECKPOINTS += [25*1000, 50*1000, 75*1000]

SAVED_CHECKPOINTS = set(SAVED_CHECKPOINTS)

### inference

In [17]:
# colbert/modeling/inference.py
'''
from colbert.modeling.colbert import ColBERT--
from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer--
from colbert.utils.amp import MixedPrecisionManager--
from colbert.parameters import DEVICE--
'''
class ModelInference():
    def __init__(self, colbert: ColBERT, amp=False):
        assert colbert.training is False

        self.colbert = colbert
        self.query_tokenizer = QueryTokenizer(colbert.query_maxlen)
        self.doc_tokenizer = DocTokenizer(colbert.doc_maxlen)

        self.amp_manager = MixedPrecisionManager(amp)

    def query(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            with self.amp_manager.context():
                Q = self.colbert.query(*args, **kw_args)
                return Q.cpu() if to_cpu else Q

    def doc(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            with self.amp_manager.context():
                D = self.colbert.doc(*args, **kw_args)
                return D.cpu() if to_cpu else D

    def queryFromText(self, queries, bsize=None, to_cpu=False):
        if bsize:
            batches = self.query_tokenizer.tensorize(queries, bsize=bsize)
            batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
            return torch.cat(batches)

        input_ids, attention_mask = self.query_tokenizer.tensorize(queries)
        return self.query(input_ids, attention_mask)

    def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False):
        if bsize:
            batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)

            batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
                       for input_ids, attention_mask in batches]

            if keep_dims:
                D = _stack_3D_tensors(batches)
                return D[reverse_indices]

            D = [d for batch in batches for d in batch]
            return [D[idx] for idx in reverse_indices.tolist()]

        input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
        return self.doc(input_ids, attention_mask, keep_dims=keep_dims)

    def score(self, Q, D, mask=None, lengths=None, explain=False):
        if lengths is not None:
            assert mask is None, "don't supply both mask and lengths"

            mask = torch.arange(D.size(1), device=DEVICE) + 1
            mask = mask.unsqueeze(0) <= lengths.to(DEVICE).unsqueeze(-1)

        scores = (D @ Q)
        scores = scores if mask is None else scores * mask.unsqueeze(-1)
        scores = scores.max(1)

        if explain:
            assert False, "TODO"

        return scores.values.sum(-1).cpu()


def _stack_3D_tensors(groups):
    bsize = sum([x.size(0) for x in groups])
    maxlen = max([x.size(1) for x in groups])
    hdim = groups[0].size(2)

    output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)

    offset = 0
    for x in groups:
        endpos = offset + x.size(0)
        output[offset:endpos, :x.size(1)] = x
        offset = endpos

    return output

## evaluation

In [18]:
def load_model(args, do_print=True):
    colbert = ColBERT.from_pretrained('bert-base-uncased',
                                      query_maxlen=args.query_maxlen,
                                      doc_maxlen=args.doc_maxlen,
                                      dim=args.dim,
                                      similarity_metric=args.similarity,
                                      mask_punctuation=args.mask_punctuation)
    colbert = colbert.to(DEVICE)

    print_message("#> Loading model checkpoint.", condition=do_print)

    checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)

    colbert.eval()

    return colbert, checkpoint

In [70]:
# colbert/evaluation/loaders.py
def load_queries(queries_path):
    queries = OrderedDict()

    print_message("#> Loading the queries from", queries_path, "...")

    with open(queries_path) as f:
        for line in f:
            qid, query, *_ = line.strip().split('\t')
            qid = int(qid)

            assert (qid not in queries), ("Query QID", qid, "is repeated!")
            queries[qid] = query

    print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")

    return queries


def load_qrels(qrels_path):
    if qrels_path is None:
        return None

    print_message("#> Loading qrels from", qrels_path, "...")

    qrels = OrderedDict()
    with open(qrels_path, mode='r', encoding="utf-8") as f:
        for line in f:
            qid, x, pid, y = map(int, line.strip().split('\t'))
            assert x == 0 and y == 1
            qrels[qid] = qrels.get(qid, [])
            qrels[qid].append(pid)

    assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels)

    avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2)

    print_message("#> Loaded qrels for", len(qrels), "unique queries with",
                  avg_positive, "positives per query on average.\n")

    return qrels


def load_topK(topK_path):
    queries = OrderedDict()
    topK_docs = OrderedDict()
    topK_pids = OrderedDict()

    print_message("#> Loading the top-k per query from", topK_path, "...")

    with open(topK_path) as f:
        for line_idx, line in enumerate(f):
            if line_idx and line_idx % (10*1000*1000) == 0:
                print(line_idx, end=' ', flush=True)

            qid, pid, query, passage = line.split('\t')
            qid, pid = int(qid), int(pid)

            assert (qid not in queries) or (queries[qid] == query)
            queries[qid] = query
            topK_docs[qid] = topK_docs.get(qid, [])
            topK_docs[qid].append(passage)
            topK_pids[qid] = topK_pids.get(qid, [])
            topK_pids[qid].append(pid)

        print()

    assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)

    Ks = [len(topK_pids[qid]) for qid in topK_pids]

    print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
    print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n")

    return queries, topK_docs, topK_pids


def load_topK_pids(topK_path, qrels):
    topK_pids = defaultdict(list)
    topK_positives = defaultdict(list)

    print_message("#> Loading the top-k PIDs per query from", topK_path, "...")

    with open(topK_path) as f:
        for line_idx, line in enumerate(f):
            if line_idx and line_idx % (10*1000*1000) == 0:
                print(line_idx, end=' ', flush=True)

            qid, pid, *rest = line.strip().split('\t')
            qid, pid = int(qid), int(pid)

            topK_pids[qid].append(pid)



    assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
    assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives)

    # Make them sets for fast lookups later
    topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives}

    Ks = [len(topK_pids[qid]) for qid in topK_pids]

    print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
    print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n")

    if len(topK_positives) == 0:
        topK_positives = None
    else:
        assert len(topK_pids) >= len(topK_positives)

        for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())):
            topK_positives[qid] = []

        assert len(topK_pids) == len(topK_positives)

        avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2)

        print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with",
                      avg_positive, "positives per query on average.\n")

    assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!"

    if topK_positives is None:
        topK_positives = qrels

    return topK_pids, topK_positives


def load_collection(collection_path):
    print_message("#> Loading collection...")

    collection = []

    with open(collection_path) as f:
        for line_idx, line in enumerate(f):
            if line_idx % (1000*1000) == 0:
                print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)

            pid, passage, *rest = line.strip().split('\t')
            assert pid == 'id' or int(pid) == line_idx

            if len(rest) >= 1:
                title = rest[0]
                passage = title + ' | ' + passage

            collection.append(passage)

    print()

    return collection


def load_colbert(args, do_print=True):
    colbert, checkpoint = load_model(args, do_print)

    # TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
    # I.e., not their purely (i.e., training) default values.

    for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
        if 'arguments' in checkpoint and hasattr(args, k):
            if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
                a, b = checkpoint['arguments'][k], getattr(args, k)
                Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")

    if 'arguments' in checkpoint:
        if args.rank < 1:
            print(ujson.dumps(checkpoint['arguments'], indent=4))

    if do_print:
        print('\n')

    return colbert, checkpoint

In [71]:
# colbert/evaluation/metrics.py 
'''
from colbert.utils.runs import Run--
'''
class Metrics:
    def __init__(self, mrr_depths: set, recall_depths: set, success_depths: set, total_queries=None):
        self.results = {}
        self.mrr_sums = {depth: 0.0 for depth in mrr_depths}
        self.recall_sums = {depth: 0.0 for depth in recall_depths}
        self.success_sums = {depth: 0.0 for depth in success_depths}
        self.total_queries = total_queries

        self.max_query_idx = -1
        self.num_queries_added = 0

    def add(self, query_idx, query_key, ranking, gold_positives):
        self.num_queries_added += 1

        assert query_key not in self.results
        assert len(self.results) <= query_idx
        assert len(set(gold_positives)) == len(gold_positives)
        assert len(set([pid for _, pid, _ in ranking])) == len(ranking)

        self.results[query_key] = ranking

        positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives]

        if len(positives) == 0:
            return

        for depth in self.mrr_sums:
            first_positive = positives[0]
            self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0

        for depth in self.success_sums:
            first_positive = positives[0]
            self.success_sums[depth] += 1.0 if first_positive < depth else 0.0

        for depth in self.recall_sums:
            num_positives_up_to_depth = len([pos for pos in positives if pos < depth])
            self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives)

    def print_metrics(self, query_idx):
        for depth in sorted(self.mrr_sums):
            print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0))

        for depth in sorted(self.success_sums):
            print("Success@" + str(depth), "=", self.success_sums[depth] / (query_idx+1.0))

        for depth in sorted(self.recall_sums):
            print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0))

    def log(self, query_idx):
        assert query_idx >= self.max_query_idx
        self.max_query_idx = query_idx

        Run.log_metric("ranking/max_query_idx", query_idx, query_idx)
        Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx)

        for depth in sorted(self.mrr_sums):
            score = self.mrr_sums[depth] / (query_idx+1.0)
            Run.log_metric("ranking/MRR." + str(depth), score, query_idx)

        for depth in sorted(self.success_sums):
            score = self.success_sums[depth] / (query_idx+1.0)
            Run.log_metric("ranking/Success." + str(depth), score, query_idx)

        for depth in sorted(self.recall_sums):
            score = self.recall_sums[depth] / (query_idx+1.0)
            Run.log_metric("ranking/Recall." + str(depth), score, query_idx)

    def output_final_metrics(self, path, query_idx, num_queries):
        assert query_idx + 1 == num_queries
        assert num_queries == self.total_queries

        if self.max_query_idx < query_idx:
            self.log(query_idx)

        self.print_metrics(query_idx)

        output = defaultdict(dict)

        for depth in sorted(self.mrr_sums):
            score = self.mrr_sums[depth] / (query_idx+1.0)
            output['mrr'][depth] = score

        for depth in sorted(self.success_sums):
            score = self.success_sums[depth] / (query_idx+1.0)
            output['success'][depth] = score

        for depth in sorted(self.recall_sums):
            score = self.recall_sums[depth] / (query_idx+1.0)
            output['recall'][depth] = score

        with open(path, 'w') as f:
            ujson.dump(output, f, indent=4)
            f.write('\n')


def evaluate_recall(qrels, queries, topK_pids):
    if qrels is None:
        return

    assert set(qrels.keys()) == set(queries.keys())
    recall_at_k = [len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) / max(1.0, len(qrels[qid]))
                   for qid in qrels]
    recall_at_k = sum(recall_at_k) / len(qrels)
    recall_at_k = round(recall_at_k, 3)
    print("Recall @ maximum depth =", recall_at_k)


# TODO: If implicit qrels are used (for re-ranking), warn if a recall metric is requested + add an asterisk to output.

In [72]:
# colbert/evaluation/slow.py
def slow_rerank(args, query, pids, passages):
    colbert = args.colbert
    inference = args.inference

    Q = inference.queryFromText([query])

    D_ = inference.docFromText(passages, bsize=args.bsize)
    scores = colbert.score(Q, D_).cpu()

    scores = scores.sort(descending=True)
    ranked = scores.indices.tolist()

    ranked_scores = scores.values.tolist()
    ranked_pids = [pids[position] for position in ranked]
    ranked_passages = [passages[position] for position in ranked]

    assert len(ranked_pids) == len(set(ranked_pids))

    return list(zip(ranked_scores, ranked_pids, ranked_passages))

In [73]:
# colbert/evaluation/ranking_logger.py
'''
from colbert.utils.utils import print_message, NullContextManager
from colbert.utils.runs import Run
'''
class RankingLogger():
    def __init__(self, directory, qrels=None, log_scores=False):
        self.directory = directory
        self.qrels = qrels
        self.filename, self.also_save_annotations = None, None
        self.log_scores = log_scores

    @contextmanager
    def context(self, filename, also_save_annotations=False):
        assert self.filename is None
        assert self.also_save_annotations is None

        filename = os.path.join(self.directory, filename)
        self.filename, self.also_save_annotations = filename, also_save_annotations

        print_message("#> Logging ranked lists to {}".format(self.filename))

        with open(filename, 'w') as f:
            self.f = f
            with (open(filename + '.annotated', 'w') if also_save_annotations else NullContextManager()) as g:
                self.g = g
                try:
                    yield self
                finally:
                    pass

    def log(self, qid, ranking, is_ranked=True, print_positions=[]):
        print_positions = set(print_positions)

        f_buffer = []
        g_buffer = []

        for rank, (score, pid, passage) in enumerate(ranking):
            is_relevant = self.qrels and int(pid in self.qrels[qid])
            rank = rank+1 if is_ranked else -1

            possibly_score = [score] if self.log_scores else []

            f_buffer.append('\t'.join([str(x) for x in [qid, pid, rank] + possibly_score]) + "\n")
            if self.g:
                g_buffer.append('\t'.join([str(x) for x in [qid, pid, rank, is_relevant]]) + "\n")

            if rank in print_positions:
                prefix = "** " if is_relevant else ""
                prefix += str(rank)
                print("#> ( QID {} ) ".format(qid) + prefix + ") ", pid, ":", score, '    ', passage)

        self.f.write(''.join(f_buffer))
        if self.g:
            self.g.write(''.join(g_buffer))

In [74]:
# colbert/evaluation/ranking.py
'''
from colbert.utils.runs import Run--
from colbert.utils.utils import print_message--

from colbert.evaluation.metrics import Metrics--
from colbert.evaluation.ranking_logger import RankingLogger--
from colbert.modeling.inference import ModelInference--

from colbert.evaluation.slow import slow_rerank--
'''

def evaluate(args):
    args.inference = ModelInference(args.colbert, amp=args.amp)
    qrels, queries, topK_pids = args.qrels, args.queries, args.topK_pids

    depth = args.depth
    collection = args.collection
    if collection is None:
        topK_docs = args.topK_docs

    def qid2passages(qid):
        if collection is not None:
            return [collection[pid] for pid in topK_pids[qid][:depth]]
        else:
            return topK_docs[qid][:depth]

    metrics = Metrics(mrr_depths={10, 100}, recall_depths={50, 200, 1000},
                      success_depths={5, 10, 20, 50, 100, 1000},
                      total_queries=len(queries))

    ranking_logger = RankingLogger(Run.path, qrels=qrels)

    args.milliseconds = []

    with ranking_logger.context('ranking.tsv', also_save_annotations=(qrels is not None)) as rlogger:
        with torch.no_grad():
            keys = sorted(list(queries.keys()))
            random.shuffle(keys)

            for query_idx, qid in enumerate(keys):
                query = queries[qid]

                print_message(query_idx, qid, query, '\n')

                if qrels and args.shortcircuit and len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) == 0:
                    continue

                ranking = slow_rerank(args, query, topK_pids[qid], qid2passages(qid))

                rlogger.log(qid, ranking, [0, 1])

                if qrels:
                    metrics.add(query_idx, qid, ranking, qrels[qid])

                    for i, (score, pid, passage) in enumerate(ranking):
                        if pid in qrels[qid]:
                            print("\n#> Found", pid, "at position", i+1, "with score", score)
                            print(passage)
                            break

                    metrics.print_metrics(query_idx)
                    metrics.log(query_idx)

                print_message("#> checkpoint['batch'] =", args.checkpoint['batch'], '\n')
                print("rlogger.filename =", rlogger.filename)

                if len(args.milliseconds) > 1:
                    print('Slow-Ranking Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:]))

                print("\n\n")

        print("\n\n")
        # print('Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:]))
        print("\n\n")

    print('\n\n')
    if qrels:
        assert query_idx + 1 == len(keys) == len(set(keys))
        metrics.output_final_metrics(os.path.join(Run.path, 'ranking.metrics'), query_idx, len(queries))
    print('\n\n')

In [75]:
# colbert/utils/parser.py 
'''
import colbert.utils.distributed as distributed--
from colbert.utils.runs import Run--
from colbert.utils.utils import print_message, timestamp, create_directory--
'''

class Arguments():
    def __init__(self, description):
        self.parser = ArgumentParser(description=description)
        self.checks = []

        self.add_argument('--root', dest='root', default='/', type=str)
        self.add_argument('--experiment', dest='experiment', default='MSMARCO-psg', type=str)
        self.add_argument('--run', dest='run', default='msmarco.psg.l2', type=str)

        self.add_argument('--local_rank', dest='rank', default=-1, type=int)

    def add_model_parameters(self):
        # Core Arguments
        self.add_argument('--similarity', dest='similarity', default='cosine', choices=['cosine', 'l2'])
        self.add_argument('--dim', dest='dim', default=128, type=int)
        self.add_argument('--query_maxlen', dest='query_maxlen', default=32, type=int)
        self.add_argument('--doc_maxlen', dest='doc_maxlen', default=180, type=int)

        # Filtering-related Arguments'../movie/data/toptop.tsv'
        self.add_argument('--mask-punctuation', dest='mask_punctuation', default=True, action='store_true')

    def add_model_inference_parameters(self):
        self.add_argument('--checkpoint', dest='checkpoint', default="/", type=str)
        self.add_argument('--bsize', dest='bsize', default=128, type=int)
        self.add_argument('--amp', dest='amp', default=True, action='store_true')
        
    def add_ranking_input(self):
        self.add_argument('--queries', dest='queries', default='/queries.dev.small.tsv', type=str)
        self.add_argument('--collection', dest='collection', default='collection.tsv', type=str)
        self.add_argument('--qrels', dest='qrels', default='qrels.dev.small.tsv', type=str)
# '/home/dilab/movie/data/'
    def add_reranking_input(self):
        self.add_ranking_input()
        self.add_argument('--topk', dest='topK', default='/top1000.dev', type=str)
        self.add_argument('--shortcircuit', dest='shortcircuit', default=False, action='store_true')

    def add_argument(self, *args, **kw_args):
        return self.parser.add_argument(*args, **kw_args)

    def check_arguments(self, args):
        for check in self.checks:
            check(args)

    def parse(self):
        args = self.parser.parse_args()
        self.check_arguments(args)

        args.input_arguments = copy.deepcopy(args)

        args.nranks, args.distributed = distributed_init(args.rank)

        args.nthreads = int(max(os.cpu_count(), faiss.omp_get_max_threads()) * 0.8)
        args.nthreads = max(1, args.nthreads // args.nranks)

        if args.nranks > 1:
            print_message(f"#> Restricting number of threads for FAISS to {args.nthreads} per process",
                          condition=(args.rank == 0))
            faiss.omp_set_num_threads(args.nthreads)

        Run.init(args.rank, args.root, args.experiment, args.run)
        Run._log_args(args)
        Run.info(args.input_arguments.__dict__, '\n')

        return args

In [None]:
# colbert/test.py
'''
from colbert.utils.parser import Arguments
from colbert.utils.runs import Run

from colbert.evaluation.loaders import load_colbert, load_topK, load_qrels
from colbert.evaluation.loaders import load_queries, load_topK_pids, load_collection
from colbert.evaluation.ranking import evaluate
from colbert.evaluation.metrics import evaluate_recall
'''

def main():
    random.seed(12345)

    parser = Arguments(description='Exhaustive (slow, not index-based) evaluation of re-ranking with ColBERT.')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()
    parser.add_reranking_input()
    parser.add_argument('-f')

    parser.add_argument('--depth', dest='depth', required=False, default=None, type=int)

    args = parser.parse()

    with Run.context():
        args.colbert, args.checkpoint = load_colbert(args)
        args.qrels = load_qrels(args.qrels)

        if args.collection or args.queries:
            assert args.collection and args.queries

            args.queries = load_queries(args.queries)
            args.collection = load_collection(args.collection)
            args.topK_pids, args.qrels = load_topK_pids(args.topK, args.qrels)

        else:
            args.queries, args.topK_docs, args.topK_pids = load_topK(args.topK)

        assert (not args.shortcircuit) or args.qrels, \
            "Short-circuiting (i.e., applying minimal computation to queries with no positives in the re-ranked set) " \
            "can only be applied if qrels is provided."

        evaluate_recall(args.qrels, args.queries, args.topK_pids)
        evaluate(args)


if __name__ == "__main__":
    main()