In [5]:
import os
import random
import time
import datetime
import sys
import torch
import torch.nn as nn
import numpy as np
import mlflow
import ujson
from itertools import accumulate
import math
from math import ceil
import traceback
import string

from contextlib import contextmanager

from packaging import version

import copy
import faiss
import threading
import queue

import pandas as pd

from argparse import ArgumentParser

from collections import defaultdict, OrderedDict

from transformers import BertTokenizerFast, BertPreTrainedModel, BertModel, BertConfig, AutoConfig, AutoTokenizer

In [7]:
# NUM = faiss.get_num_gpus()
NUM = 1

In [8]:
# indexing/loaders.py
'''
from colbert.utils.utils import print_message
'''
def get_parts(directory):
    extension = '.pt'

    parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory)
                    if filename.endswith(extension)])

    assert list(range(len(parts))) == parts, parts

    # Integer-sortedness matters.
    parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts]
    samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts]

    return parts, parts_paths, samples_paths


def load_doclens(directory, flatten=True):
    parts, _, _ = get_parts(directory)

    doclens_filenames = [os.path.join(directory, 'doclens.{}.json'.format(filename)) for filename in parts]
    all_doclens = [ujson.load(open(filename)) for filename in doclens_filenames]

    if flatten:
        all_doclens = [x for sub_doclens in all_doclens for x in sub_doclens]

    return all_doclens

In [9]:
# indexing/index_manager.py
'''from colbert.utils.utils import print_message'''
class IndexManager():
    def __init__(self, dim):
        self.dim = dim

    def save(self, tensor, path_prefix):
        torch.save(tensor, path_prefix)


def load_index_part(filename, verbose=True):
    part = torch.load(filename)

    if type(part) == list:  # for backward compatibility
        part = torch.cat(part)

    return part

In [10]:
# indexing/faiss_index.py
'''
from colbert.indexing.faiss_index_gpu import FaissIndexGPU
from colbert.utils.utils import print_message
'''
class FaissIndex():
    def __init__(self, dim, partitions):
        self.dim = dim
        self.partitions = partitions

        self.gpu = FaissIndexGPU()
        self.quantizer, self.index = self._create_index()
        self.offset = 0

    def _create_index(self):
        quantizer = faiss.IndexFlatL2(self.dim)  # faiss.IndexHNSWFlat(dim, 32)
        index = faiss.IndexIVFPQ(quantizer, self.dim, self.partitions, 16, 8)

        return quantizer, index

    def train(self, train_data):
        print_message(f"#> Training now (using {self.gpu.ngpu} GPUs)...")

        if self.gpu.ngpu > 0:
            self.gpu.training_initialize(self.index, self.quantizer)

        s = time.time()
        self.index.train(train_data)
        print(time.time() - s)

        if self.gpu.ngpu > 0:
            self.gpu.training_finalize()

    def add(self, data):
        print_message(f"Add data with shape {data.shape} (offset = {self.offset})..")

        if self.gpu.ngpu > 0 and self.offset == 0:
            self.gpu.adding_initialize(self.index)

        if self.gpu.ngpu > 0:
            self.gpu.add(self.index, data, self.offset)
        else:
            self.index.add(data)

        self.offset += data.shape[0]

    def save(self, output_path):
        print_message(f"Writing index to {output_path} ...")

        self.index.nprobe = 10  # just a default
        faiss.write_index(self.index, output_path)

In [11]:
#indexing/faiss_index_gpu.py

class FaissIndexGPU():
    def __init__(self):
        self.ngpu = NUM

        if self.ngpu == 0:
            return

        self.tempmem = 1 << 33
        self.max_add_per_gpu = 1 << 25
        self.max_add = self.max_add_per_gpu * self.ngpu
        self.add_batch_size = 65536

        self.gpu_resources = self._prepare_gpu_resources()

    def _prepare_gpu_resources(self):
        print_message(f"Preparing resources for {self.ngpu} GPUs.")

        gpu_resources = []

        for _ in range(self.ngpu):
            res = faiss.StandardGpuResources()
            if self.tempmem >= 0:
                res.setTempMemory(self.tempmem)
            gpu_resources.append(res)

        return gpu_resources

    def _make_vres_vdev(self):
        """
        return vectors of device ids and resources useful for gpu_multiple
        """

        assert self.ngpu > 0

        vres = faiss.GpuResourcesVector()
        vdev = faiss.IntVector()

        for i in range(self.ngpu):
            vdev.push_back(i)
            vres.push_back(self.gpu_resources[i])

        return vres, vdev

    def training_initialize(self, index, quantizer):
        """
        The index and quantizer should be owned by caller.
        """

        assert self.ngpu > 0

        s = time.time()
        self.index_ivf = faiss.extract_index_ivf(index)
        self.clustering_index = faiss.index_cpu_to_all_gpus(quantizer)
        self.index_ivf.clustering_index = self.clustering_index
        print(time.time() - s)

    def training_finalize(self):
        assert self.ngpu > 0

        s = time.time()
        self.index_ivf.clustering_index = faiss.index_gpu_to_cpu(self.index_ivf.clustering_index)
        print(time.time() - s)

    def adding_initialize(self, index):
        """
        The index should be owned by caller.
        """

        assert self.ngpu > 0

        self.co = faiss.GpuMultipleClonerOptions()
        self.co.useFloat16 = True
        self.co.useFloat16CoarseQuantizer = False
        self.co.usePrecomputed = False
        self.co.indicesOptions = faiss.INDICES_CPU
        self.co.verbose = True
        self.co.reserveVecs = self.max_add
        self.co.shard = True
        assert self.co.shard_type in (0, 1, 2)

        self.vres, self.vdev = self._make_vres_vdev()
        self.gpu_index = faiss.index_cpu_to_gpu_multiple(self.vres, self.vdev, index, self.co)

    def add(self, index, data, offset):
        assert self.ngpu > 0

        t0 = time.time()
        nb = data.shape[0]

        for i0 in range(0, nb, self.add_batch_size):
            i1 = min(i0 + self.add_batch_size, nb)
            xs = data[i0:i1]

            self.gpu_index.add_with_ids(xs, np.arange(offset+i0, offset+i1))

            if self.max_add > 0 and self.gpu_index.ntotal > self.max_add:
                self._flush_to_cpu(index, nb, offset)

            print('\r%d/%d (%.3f s)  ' % (i0, nb, time.time() - t0), end=' ')
            sys.stdout.flush()

        if self.gpu_index.ntotal > 0:
            self._flush_to_cpu(index, nb, offset)

        assert index.ntotal == offset+nb, (index.ntotal, offset+nb, offset, nb)
        print(f"add(.) time: %.3f s \t\t--\t\t index.ntotal = {index.ntotal}" % (time.time() - t0))

    def _flush_to_cpu(self, index, nb, offset):
        print("Flush indexes to CPU")

        for i in range(self.ngpu):
            index_src_gpu = faiss.downcast_index(self.gpu_index if self.ngpu == 1 else self.gpu_index.at(i))
            index_src = faiss.index_gpu_to_cpu(index_src_gpu)

            index_src.copy_subset_to(index, 0, offset, offset+nb)
            index_src_gpu.reset()
            index_src_gpu.reserveMemory(self.max_add)

        if self.ngpu > 1:
            try:
                self.gpu_index.sync_with_shard_indexes()
            except:
                self.gpu_index.syncWithSubIndexes()

In [12]:
# indexing/faiss.py
'''
from colbert.utils.utils import print_message, grouper
from colbert.indexing.loaders import get_parts
from colbert.indexing.index_manager import load_index_part
from colbert.indexing.faiss_index import FaissIndex
'''

def get_faiss_index_name(args, offset=None, endpos=None):
    partitions_info = '' if args.partitions is None else f'.{args.partitions}'
    range_info = '' if offset is None else f'.{offset}-{endpos}'

    return f'ivfpq{partitions_info}{range_info}.faiss'


def load_sample(samples_paths, sample_fraction=None):
    sample = []

    for filename in samples_paths:
        print_message(f"#> Loading {filename} ...")
        part = load_index_part(filename)
        if sample_fraction:
            part = part[torch.randint(0, high=part.size(0), size=(int(part.size(0) * sample_fraction),))]
        sample.append(part)

    sample = torch.cat(sample).float().numpy()

    print("#> Sample has shape", sample.shape)

    return sample


def prepare_faiss_index(slice_samples_paths, partitions, sample_fraction=None):
    training_sample = load_sample(slice_samples_paths, sample_fraction=sample_fraction)

    dim = training_sample.shape[-1]
    index = FaissIndex(dim, partitions)

    print_message("#> Training with the vectors...")

    index.train(training_sample)

    print_message("Done training!\n")

    return index


SPAN = 3


def index_faiss(args):
    print_message("#> Starting..")

    parts, parts_paths, samples_paths = get_parts(args.index_path)

    if args.sample is not None:
        assert args.sample, args.sample
        print_message(f"#> Training with {round(args.sample * 100.0, 1)}% of *all* embeddings (provided --sample).")
        samples_paths = parts_paths

    num_parts_per_slice = math.ceil(len(parts) / args.slices)

    for slice_idx, part_offset in enumerate(range(0, len(parts), num_parts_per_slice)):
        part_endpos = min(part_offset + num_parts_per_slice, len(parts))

        slice_parts_paths = parts_paths[part_offset:part_endpos]
        slice_samples_paths = samples_paths[part_offset:part_endpos]

        if args.slices == 1:
            faiss_index_name = get_faiss_index_name(args)
        else:
            faiss_index_name = get_faiss_index_name(args, offset=part_offset, endpos=part_endpos)

        output_path = os.path.join(args.index_path, faiss_index_name)
        print_message(f"#> Processing slice #{slice_idx+1} of {args.slices} (range {part_offset}..{part_endpos}).")
        print_message(f"#> Will write to {output_path}.")

        assert not os.path.exists(output_path), output_path

        index = prepare_faiss_index(slice_samples_paths, args.partitions, args.sample)

        loaded_parts = queue.Queue(maxsize=1)

        def _loader_thread(thread_parts_paths):
            for filenames in grouper(thread_parts_paths, SPAN, fillvalue=None):
                sub_collection = [load_index_part(filename) for filename in filenames if filename is not None]
                sub_collection = torch.cat(sub_collection)
                sub_collection = sub_collection.float().numpy()
                loaded_parts.put(sub_collection)

        thread = threading.Thread(target=_loader_thread, args=(slice_parts_paths,))
        thread.start()

        print_message("#> Indexing the vectors...")

        for filenames in grouper(slice_parts_paths, SPAN, fillvalue=None):
            print_message("#> Loading", filenames, "(from queue)...")
            sub_collection = loaded_parts.get()

            print_message("#> Processing a sub_collection with shape", sub_collection.shape)
            index.add(sub_collection)

        print_message("Done indexing!")

        index.save(output_path)

        print_message(f"\n\nDone! All complete (for slice #{slice_idx+1} of {args.slices})!")

        thread.join()

In [13]:
# utils/utils.py
def print_message(*s, condition=True):
    s = ' '.join([str(x) for x in s])
    msg = "[{}] {}".format(datetime.datetime.now().strftime("%b %d, %H:%M:%S"), s)

    if condition:
        print(msg, flush=True)

    return msg
def timestamp():
    format_str = "%Y-%m-%d_%H.%M.%S"
    result = datetime.datetime.now().strftime(format_str)
    return result
def create_directory(path):
    if os.path.exists(path):
        print('\n')
        print_message("#> Note: Output directory", path, 'already exists\n\n')
    else:
        print('\n')
        print_message("#> Creating directory", path, '\n\n')
        os.makedirs(path)
def distributed_init(rank):
    nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE'])
    nranks = max(1, nranks)
    is_distributed = nranks > 1

    if rank == 0:
        print('nranks =', nranks, '\t num_gpus =', 1)#torch.cuda.device_count())

    if is_distributed:
        num_gpus = 1#torch.cuda.device_count()
        torch.cuda.set_device(rank % num_gpus)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    return nranks, is_distributed
def distributed_barrier(rank):
    if rank >= 0:
        torch.distributed_barrier()
def grouper(iterable, n, fillvalue=None):
    import itertools
    """
    Collect data into fixed-length chunks or blocks
        Example: grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
        Source: https://docs.python.org/3/library/itertools.html#itertools-recipes
    """

    args = [iter(iterable)] * n
    return itertools.zip_longest(*args, fillvalue=fillvalue)

In [14]:
#utils/plgging.py
class Logger():
    def __init__(self, rank, run):
        self.rank = rank
        self.is_main = self.rank in [-1, 0]
        self.run = run
        self.logs_path = os.path.join(self.run.path, "logs/")

        if self.is_main:
            self._init_mlflow()
            self.initialized_tensorboard = False
            create_directory(self.logs_path)

    def _init_mlflow(self):
        mlflow.set_tracking_uri('file://' + os.path.join(self.run.experiments_root, "logs/mlruns/"))
        mlflow.set_experiment('/'.join([self.run.experiment, self.run.script]))
        
        mlflow.set_tag('experiment', self.run.experiment)
        mlflow.set_tag('name', self.run.name)
        mlflow.set_tag('path', self.run.path)

    def _init_tensorboard(self):
        root = os.path.join(self.run.experiments_root, "logs/tensorboard/")
        logdir = '__'.join([self.run.experiment, self.run.script, self.run.name])
        logdir = os.path.join(root, logdir)

        self.writer = SummaryWriter(log_dir=logdir)
        self.initialized_tensorboard = True

    def _log_exception(self, etype, value, tb):
        if not self.is_main:
            return

        output_path = os.path.join(self.logs_path, 'exception.txt')
        trace = ''.join(traceback.format_exception(etype, value, tb)) + '\n'
        print_message(trace, '\n\n')

        self.log_new_artifact(output_path, trace)

    def _log_all_artifacts(self):
        if not self.is_main:
            return

        mlflow.log_artifacts(self.logs_path)

    def _log_args(self, args):
        if not self.is_main:
            return

        for key in vars(args):
            value = getattr(args, key)
            if type(value) in [int, float, str, bool]:
                mlflow.log_param(key, value)

        with open(os.path.join(self.logs_path, 'args.json'), 'w') as output_metadata:
            ujson.dump(args.input_arguments.__dict__, output_metadata, indent=4)
            output_metadata.write('\n')

        with open(os.path.join(self.logs_path, 'args.txt'), 'w') as output_metadata:
            output_metadata.write(' '.join(sys.argv) + '\n')

    def log_metric(self, name, value, step, log_to_mlflow=True):
        if not self.is_main:
            return

        if not self.initialized_tensorboard:
            self._init_tensorboard()

        if log_to_mlflow:
            mlflow.log_metric(name, value, step=step)
        self.writer.add_scalar(name, value, step)

    def log_new_artifact(self, path, content):
        with open(path, 'w') as f:
            f.write(content)

        mlflow.log_artifact(path)

    def warn(self, *args):
        msg = print_message('[WARNING]', '\t', *args)

        with open(os.path.join(self.logs_path, 'warnings.txt'), 'a') as output_metadata:
            output_metadata.write(msg + '\n\n\n')

    def info_all(self, *args):
        print_message('[' + str(self.rank) + ']', '\t', *args)

    def info(self, *args):
        if self.is_main:
            print_message(*args)

In [15]:
# utils/runs.py
class _RunManager():
    def __init__(self):
        self.experiments_root = None
        self.experiment = None
        self.path = None
        self.script = self._get_script_name()
        self.name = self._generate_default_run_name()
        self.original_name = self.name
        self.exit_status = 'FINISHED'

        self._logger = None
        self.start_time = time.time()

    def init(self, rank, root, experiment, name):
        assert '/' not in experiment, experiment
        assert '/' not in name, name

        self.experiments_root = os.path.abspath(root)
        self.experiment = experiment
        self.name = name
        self.path = os.path.join(self.experiments_root, self.experiment, self.script, self.name)

        if rank < 1:
            if os.path.exists(self.path):
                print('\n\n')
                print_message("It seems that ", self.path, " already exists.")
                print_message("Do you want to overwrite it? \t yes/no \n")

                # TODO: This should timeout and exit (i.e., fail) given no response for 60 seconds.

                response = input()
                if response.strip() != 'yes':
                    assert not os.path.exists(self.path), self.path
            else:
                create_directory(self.path)

        distributed_barrier(rank)

        self._logger = Logger(rank, self)
        self._log_args = self._logger._log_args
        self.warn = self._logger.warn
        self.info = self._logger.info
        self.info_all = self._logger.info_all
        self.log_metric = self._logger.log_metric
        self.log_new_artifact = self._logger.log_new_artifact

    def _generate_default_run_name(self):
        return timestamp()

    def _get_script_name(self):
        return os.path.basename('main'.__file__) if '__file__' in dir('main') else 'none'
#         return os.path.basename(__main__.__file__) if '__file__' in dir(__main__) else 'none'


    @contextmanager
    def context(self, consider_failed_if_interrupted=True):
        try:
            yield

        except KeyboardInterrupt as ex:
            print('\n\nInterrupted\n\n')
            self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
            self._logger._log_all_artifacts()

            if consider_failed_if_interrupted:
                self.exit_status = 'KILLED'  # mlflow.entities.RunStatus.KILLED

            sys.exit(128 + 2)

        except Exception as ex:
            self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
            self._logger._log_all_artifacts()

            self.exit_status = 'FAILED'  # mlflow.entities.RunStatus.FAILED

            raise ex

        finally:
            total_seconds = str(time.time() - self.start_time) + '\n'
            original_name = str(self.original_name)
            name = str(self.name)

            self.log_new_artifact(os.path.join(self._logger.logs_path, 'elapsed.txt'), total_seconds)
            self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.original.txt'), original_name)
            self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.txt'), name)

            self._logger._log_all_artifacts()

            mlflow.end_run(status=self.exit_status)


Run = _RunManager()

In [17]:
#utils/parser.py
class Arguments():
    def __init__(self, description):
        self.parser = ArgumentParser(description=description)
        self.checks = []

        self.add_argument('--root', dest='root', default='/')
        self.add_argument('--experiment', dest='experiment', default='MSMARCO-psg')
        self.add_argument('--run', dest='run', default=Run.name)

        self.add_argument('--local_rank', dest='rank', default=-1, type=int)

    def add_model_parameters(self):
        # Core Arguments
        self.add_argument('--similarity', dest='similarity', default='cosine', choices=['cosine', 'l2'])
        self.add_argument('--dim', dest='dim', default=128, type=int)
        self.add_argument('--query_maxlen', dest='query_maxlen', default=32, type=int)
        self.add_argument('--doc_maxlen', dest='doc_maxlen', default=180, type=int)

        # Filtering-related Arguments
        self.add_argument('--mask-punctuation', dest='mask_punctuation', default=False, action='store_true')

    def add_index_use_input(self):
        self.add_argument('--index_root', dest='index_root', default='/')
        self.add_argument('--index_name', dest='index_name', default='MSMARCO.L2.32x200k')
        self.add_argument('--partitions', dest='partitions', default=32768, type=int)

    def add_argument(self, *args, **kw_args):
        return self.parser.add_argument(*args, **kw_args)

    def check_arguments(self, args):
        for check in self.checks:
            check(args)

    def parse(self):
        args = self.parser.parse_args()
        self.check_arguments(args)

        args.input_arguments = copy.deepcopy(args)

        args.nranks, args.distributed = distributed_init(args.rank)

        args.nthreads = int(max(os.cpu_count(), faiss.omp_get_max_threads()) * 0.8)
        args.nthreads = max(1, args.nthreads // args.nranks)

        if args.nranks > 1:
            print_message(f"#> Restricting number of threads for FAISS to {args.nthreads} per process",
                          condition=(args.rank == 0))
            faiss.omp_set_num_threads(args.nthreads)

        Run.init(args.rank, args.root, args.experiment, args.run)
        Run._log_args(args)
        Run.info(args.input_arguments.__dict__, '\n')

        return args

In [None]:
## index_falss.py
'''
from colbert.utils.runs import Run
from colbert.utils.parser import Arguments
from colbert.indexing.faiss import index_faiss
from colbert.indexing.loaders import load_doclens
'''
def main():
    random.seed(12345)

    parser = Arguments(description='Faiss indexing for end-to-end retrieval with ColBERT.')
    parser.add_index_use_input()

    parser.add_argument('--sample', dest='sample', default=0.3, type=float)
    parser.add_argument('--slices', dest='slices', default=1, type=int)
    parser.add_argument('-f')

    args = parser.parse()
    assert args.slices >= 1
    assert args.sample is None or (0.0 < args.sample < 1.0), args.sample

    with Run.context():
        args.index_path = os.path.join(args.index_root, args.index_name)
        assert os.path.exists(args.index_path), args.index_path

        num_embeddings = sum(load_doclens(args.index_path))
        print("#> num_embeddings =", num_embeddings)

        if args.partitions is None:
            args.partitions = 1 << math.ceil(math.log2(8 * math.sqrt(num_embeddings)))
            print('\n\n')
            Run.warn("You did not specify --partitions!")
            Run.warn("Default computation chooses", args.partitions,
                     "partitions (for {} embeddings)".format(num_embeddings))
            print('\n\n')

        index_faiss(args)


if __name__ == "__main__":
    main()