<a href="https://colab.research.google.com/github/finardi/Ranking/blob/main/6_Cobert_Retrieval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install -q ujson
!apt install -q libomp-dev
!pip install -q transformers
!python -m pip -q install --upgrade faiss faiss-gpu

In [None]:
import gc
import os
import sys
import math
import time
import ujson
import torch
import faiss
import queue
import random
import pickle
import threading
import traceback
import itertools
import numpy as np
import pandas as pd
from math import ceil
from itertools import islice
from functools import partial
from multiprocessing import Pool
from itertools import accumulate
from collections import defaultdict, OrderedDict
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast

# better pandas viz
pd.set_option('display.max_columns', 100)  
pd.set_option('display.expand_frame_repr', 100)
pd.set_option('max_colwidth', 700)
pd.set_option('display.max_rows', 5000)
  
# save/load pickles
def pickle_file(path, data=None):
    if data is None:
        with open(path, 'rb') as f:
            return pickle.load(f)
    if data is not None:
        with open(path, 'wb') as handle:
            pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
 
# path base
path_base = '/content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal/'

In [None]:
# =============
# ✨ Constants
# =============
bsize = 16 # N
query_maxlen = 48
doc_maxlen = 128
path_model = 'bert-base-multilingual-uncased'

# ==================
# ✨ QueryTokenizer
# ==================
class QueryTokenizer():
    def __init__(self, query_maxlen, path_tokenizer):
        self.tok = BertTokenizerFast.from_pretrained(path_tokenizer)
        self.query_maxlen = query_maxlen

        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
        self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token], [self.sep_token]
        tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        obj = self.tok(batch_text, padding='max_length', truncation=True,
                       return_tensors='pt', max_length=self.query_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        ids[ids == 0] = self.mask_token_id

        if bsize:
            batches = _split_into_batches(ids, mask, bsize)
            return batches

        return ids, mask

# ================
# ✨ DocTokenizer
# ================
class DocTokenizer():
    def __init__(self, doc_maxlen, path_tokenizer):
        self.tok = BertTokenizerFast.from_pretrained(path_tokenizer)
        self.doc_maxlen = doc_maxlen

        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token], [self.sep_token]
        tokens = [prefix + lst + suffix for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        obj = self.tok(batch_text, padding='longest', truncation='longest_first',
                       return_tensors='pt', max_length=self.doc_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        if bsize:
            ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
            batches = _split_into_batches(ids, mask, bsize)
            return batches, reverse_indices

        return ids, mask

# =====================
# ✨ tensorize triples
# =====================
def tensorize_triples(query_tokenizer, doc_tokenizer, queries, positives, negatives, bsize):
    assert len(queries) == len(positives) == len(negatives)
    assert bsize is None or len(queries) % bsize == 0

    N = len(queries)
    assert bsize == N
    Q_ids, Q_mask = query_tokenizer.tensorize(queries)
    D_ids, D_mask = doc_tokenizer.tensorize(positives + negatives)
    D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1)

    # Compute max among {length of i^th positive, length of i^th negative} for i \in N
    maxlens = D_mask.sum(-1).max(0).values

    # Sort by maxlens
    indices = maxlens.sort().indices
    Q_ids, Q_mask = Q_ids[indices], Q_mask[indices]
    D_ids, D_mask = D_ids[:, indices], D_mask[:, indices]

    (positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask

    query_batches = _split_into_batches(Q_ids, Q_mask, bsize)
    positive_batches = _split_into_batches(positive_ids, positive_mask, bsize)
    negative_batches = _split_into_batches(negative_ids, negative_mask, bsize)

    batches = []
    for (q_ids, q_mask), (p_ids, p_mask), (n_ids, n_mask) in zip(query_batches, positive_batches, negative_batches):
        Q = (torch.cat((q_ids, q_ids)), torch.cat((q_mask, q_mask)))
        D = (torch.cat((p_ids, n_ids)), torch.cat((p_mask, n_mask)))
        batches.append((Q, D))

    return batches

# =============
# ✨ Aux funcs
# =============
def _sort_by_length(ids, mask, bsize):
    if ids.size(0) <= bsize:
        return ids, mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices

    return ids[indices], mask[indices], reverse_indices

def _split_into_batches(ids, mask, bsize):
    batches = []
    for offset in range(0, ids.size(0), bsize):
        batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))

    return batches

# ===============
# ✨ LazyBatcher
# ===============
class LazyBatcher():
    def __init__(self, bsize, path, path_tokenizer, query_maxlen, doc_maxlen, mode='train', accumsteps=1):
        self.bsize, self.accumsteps = bsize, accumsteps
        self.query_tokenizer = QueryTokenizer(query_maxlen=query_maxlen, path_tokenizer=path_tokenizer)
        self.doc_tokenizer = DocTokenizer(doc_maxlen=doc_maxlen, path_tokenizer=path_tokenizer)
        self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer)
        self.position = 0
        self.mode = mode

        self.triples = self._load_triples(path_base)
        self.queries = self._load_queries(path_base)
        self.collection = self._load_collection(path_base)
    
    def _load_triples(self, path):
        if self.mode == 'train':
            path = path+'data/df_FAQ_triplet_IDS_TRAIN.parquet.gzip'
        elif self.mode == 'valid':
            path = path+'data/df_FAQ_triplet_IDS_VALID.parquet.gzip'

        df_triplet = pd.read_parquet(path)
        triples = []
        for qid, pos_pid, neg_pid in zip(
            df_triplet.qid.values,
            df_triplet.pos_pid.values,
            df_triplet.neg_pid.values
            ):
            triples.append((qid, pos_pid, neg_pid))

        return triples

    def _load_queries(self, path):
        if self.mode == 'train':
            qid_to_query_train = path+'data/qid_to_query_TRAIN'
            return pickle_file(qid_to_query_train)
        elif self.mode == 'valid':
            qid_to_query_valid = path+'data/qid_to_query_VALID'
            return pickle_file(qid_to_query_valid)

    def _load_collection(self, path):
        if self.mode == 'train':
            pid_to_doc_train = path+'data/pid_to_doc_TRAIN'
            return pickle_file(pid_to_doc_train)
        elif self.mode == 'valid':
            pid_to_doc_valid = path+'data/pid_to_doc_VALID'
            return pickle_file(pid_to_doc_valid)
        

    def __iter__(self):
        return self

    def __len__(self):
        return len(self.triples)

    def __next__(self):
        # offsets determines the starting index position of each bag (sequence) in input.
        offset, endpos = self.position, min(self.position + self.bsize, len(self.triples))
        self.position = endpos

        if offset + self.bsize > len(self.triples):
            raise StopIteration

        queries, positives, negatives = [], [], []

        for position in range(offset, endpos):
            query, pos, neg = self.triples[position]
            query, pos, neg = self.queries[query], self.collection[pos], self.collection[neg]
            queries.append(query)
            positives.append(pos)
            negatives.append(neg)

        return self.collate(queries, positives, negatives)

    def collate(self, queries, positives, negatives):
        assert len(queries) == len(positives) == len(negatives) == self.bsize

        return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps)

# ===========
# ✨ ColBERT
# ===========
class ColBERT(BertPreTrainedModel):
    def __init__(self, config, query_maxlen, doc_maxlen, mask_punctuation, dim=128, similarity_metric='cosine'):

        super(ColBERT, self).__init__(config)

        self.query_maxlen = query_maxlen
        self.doc_maxlen = doc_maxlen
        self.similarity_metric = similarity_metric
        self.dim = dim

        self.mask_punctuation = mask_punctuation
        self.skiplist = {}

        if self.mask_punctuation:
            self.tokenizer = BertTokenizerFast.from_pretrained(path_model)
            self.skiplist = {w: True
                             for symbol in string.punctuation
                             for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]}

        self.bert = BertModel(config)
        self.linear = torch.nn.Linear(config.hidden_size, dim, bias=False)

        self.init_weights()

    def forward(self, Q, D):
        return self.score(self.query(*Q), self.doc(*D))

    def query(self, input_ids, attention_mask):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        Q = self.bert(input_ids, attention_mask=attention_mask)[0]
        Q = self.linear(Q)

        return torch.nn.functional.normalize(Q, p=2, dim=2)

    def doc(self, input_ids, attention_mask, keep_dims=True):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        D = self.bert(input_ids, attention_mask=attention_mask)[0]
        D = self.linear(D)

        mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float()
        D = D * mask

        D = torch.nn.functional.normalize(D, p=2, dim=2)

        if not keep_dims:
            D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
            D = [d[mask[idx]] for idx, d in enumerate(D)]

        return D

    def score(self, Q, D):
        if self.similarity_metric == 'cosine':
            return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)

        assert self.similarity_metric == 'l2'
        return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)

    def mask(self, input_ids):
        mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
        return mask

# ==================
# ✨ ModelInference
# ==================
class ModelInference():
    def __init__(self, colbert, path_model):
        assert colbert.training is False

        self.colbert = colbert
        self.query_tokenizer = QueryTokenizer(colbert.query_maxlen, path_tokenizer=path_model)
        self.doc_tokenizer = DocTokenizer(colbert.doc_maxlen, path_tokenizer=path_model)

    def query(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            Q = self.colbert.query(*args, **kw_args)
            return Q.cpu() if to_cpu else Q

    def doc(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            D = self.colbert.doc(*args, **kw_args)
            return D.cpu() if to_cpu else D

    def queryFromText(self, queries, bsize=None, to_cpu=False):
        if bsize:
            batches = self.query_tokenizer.tensorize(queries, bsize=bsize)
            batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
            return torch.cat(batches)

        input_ids, attention_mask = self.query_tokenizer.tensorize(queries)
        return self.query(input_ids, attention_mask)

    def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False):
        if bsize:
            batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)

            batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
                       for input_ids, attention_mask in batches]

            if keep_dims:
                D = _stack_3D_tensors(batches)
                return D[reverse_indices]

            D = [d for batch in batches for d in batch]
            return [D[idx] for idx in reverse_indices.tolist()]

        input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
        return self.doc(input_ids, attention_mask, keep_dims=keep_dims)

    def score(self, Q, D, mask=None, lengths=None, explain=False):
        if lengths is not None:
            assert mask is None, "don't supply both mask and lengths"

            mask = torch.arange(D.size(1), device=DEVICE) + 1
            mask = mask.unsqueeze(0) <= lengths.to(DEVICE).unsqueeze(-1)

        scores = (D @ Q)
        scores = scores if mask is None else scores * mask.unsqueeze(-1)
        scores = scores.max(1)

        if explain:
            assert False, "TODO"

        return scores.values.sum(-1).cpu()

def _stack_3D_tensors(groups):
    bsize = sum([x.size(0) for x in groups])
    maxlen = max([x.size(1) for x in groups])
    hdim = groups[0].size(2)

    output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)

    offset = 0
    for x in groups:
        endpos = offset + x.size(0)
        output[offset:endpos, :x.size(1)] = x
        offset = endpos

    return output    

# - - - - -
dataloader_train = LazyBatcher(
    bsize=bsize, 
    path=path_base, 
    path_tokenizer=path_model,
    query_maxlen=query_maxlen,
    doc_maxlen=doc_maxlen,
    mode='train'
    )

print('batches:')
for i, batches in enumerate(dataloader_train):
    print(f' {i }.', end ='')

try:
    del colbert
    gc.collect()
    torch.cuda.empty_cache()
except:
    pass

DEVICE = 'cuda'

print()

colbert = ColBERT.from_pretrained(
    path_model,
    query_maxlen=query_maxlen,
    doc_maxlen=doc_maxlen,
    dim=128,
    similarity_metric='cosine',
    mask_punctuation=False).to(DEVICE)        

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=871891.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1715180.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…


batch:
 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37.


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=625.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=672271273.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing ColBERT: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing ColBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ColBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ColBERT were not initialized from the model checkpoint at bert-base-multilingual-uncased and are newly initialized: ['linear.weight'

In [None]:
# ✨ load colbert from checkpoint
colbert.load_state_dict(torch.load(path_base+'data/EPOCH_3_FAQ'))
print('\nmodel loaded!\n')


model loaded!



# Ranking


In [None]:
# =========================
# ✨ Aux funcs for ranking
# =========================
def batch(group, bsize, provide_offset=False):
    offset = 0
    while offset < len(group):
        L = group[offset: offset + bsize]
        yield ((offset, L) if provide_offset else L)
        offset += len(L)
    return

def torch_percentile(tensor, p):
    assert p in range(1, 100+1)
    assert tensor.dim() == 1

    return tensor.kthvalue(int(p * tensor.size(0) / 100.0)).values.item()

def load_index_part(filename, verbose=True):
    part = torch.load(filename)

    if type(part) == list:  # for backward compatibility
        part = torch.cat(part)

    return part    

def load_doclens(directory, flatten=True):
    parts, _, _ = get_parts(directory)

    doclens_filenames = [os.path.join(directory, 'doclens.{}.json'.format(filename)) for filename in parts]
    all_doclens = [ujson.load(open(filename)) for filename in doclens_filenames]

    if flatten:
        all_doclens = [x for sub_doclens in all_doclens for x in sub_doclens]

    return all_doclens

def get_parts(directory):
    extension = '.pt'

    parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory)
                    if filename.endswith(extension)])

    assert list(range(len(parts))) == parts, parts

    # Integer-sortedness matters.
    parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts]
    samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts]

    return parts, parts_paths, samples_paths

def flatten(L):
    return [x for y in L for x in y]

def uniq(l):
    return list(set(l))

# ===========
# ✨ Metrics
# ===========
class Metrics:
    def __init__(self, mrr_depths:set, recall_depths:set, success_depths:set, total_queries=None):
        self.results = {}
        self.mrr_sums = {depth:0.0 for depth in mrr_depths}
        self.recall_sums = {depth:0.0 for depth in recall_depths}
        self.success_sums = {depth:0.0 for depth in success_depths}
        self.total_queries = total_queries

    def get_result(self, query_idx, query_key, ranking, gold_positives):
        assert query_key not in self.results
        assert len(self.results) <= query_idx
        assert len(set(gold_positives)) == len(gold_positives)
        assert len(set([pid for _, pid, _ in ranking])) == len(ranking)

        self.results[query_key] = ranking

        positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives]

        if len(positives) == 0:
            return

        for depth in self.mrr_sums:
            first_positive = positives[0]
            self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0

        for depth in self.success_sums:
            first_positive = positives[0]
            self.success_sums[depth] += 1.0 if first_positive < depth else 0.0

        for depth in self.recall_sums:
            num_positives_up_to_depth = len([pos for pos in positives if pos < depth])
            self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives)

    def print_metrics(self, query_idx):
        print('- '*10)
        for depth in sorted(self.mrr_sums):
            mrr_value =  self.mrr_sums[depth] / (query_idx+1.0)
            print(f"MRR@{str(depth):<2} = {mrr_value:.3}")
        
        print('- '*10)
        for depth in sorted(self.recall_sums):
            recall_value = self.recall_sums[depth] / (query_idx+1.0)
            print(f"Recall@{str(depth):<2} = {recall_value:.3}")
        print('- '*10)
        for depth in sorted(self.success_sums):
            success_value = self.success_sums[depth] / (query_idx+1.0)
            print(f"Success@{str(depth):<2} = {success_value:.3}")
        print('- '*10)

def qid2passages(qid, topK_pids, collection, depth):
        if collection is not None:
            return [collection[pid] for pid in topK_pids[qid][:depth]]
        else:
            return topK_docs[qid][:depth]

In [None]:
# ===============
# ✨ IndexRanker
# ===============
class IndexRanker():
    def __init__(self, tensor, doclens):
        self.tensor = tensor
        self.doclens = doclens

        self.maxsim_dtype = torch.float32
        self.doclens_pfxsum = [0] + list(accumulate(self.doclens))

        self.doclens = torch.tensor(self.doclens)
        self.doclens_pfxsum = torch.tensor(self.doclens_pfxsum)

        self.dim = self.tensor.size(-1)

        self.strides = [torch_percentile(self.doclens, p) for p in [25, 50, 75]]
        self.strides.append(self.doclens.max().item())
        self.strides = sorted(list(set(self.strides)))

        self.views = self._create_views(self.tensor)
        self.buffers = self._create_buffers(bsize, self.tensor.dtype, {'cpu', 'cuda:0'})

    def _create_views(self, tensor):
        views = []

        for stride in self.strides:
            outdim = tensor.size(0) - stride + 1
            view = torch.as_strided(tensor, (outdim, stride, self.dim), (self.dim, self.dim, 1))
            views.append(view)

        return views

    def _create_buffers(self, max_bsize, dtype, devices):
        buffers = {}

        for device in devices:
            buffers[device] = [torch.zeros(max_bsize, stride, self.dim, dtype=dtype,
                                           device=device, pin_memory=(device == 'cpu'))
                               for stride in self.strides]

        return buffers

    def rank(self, Q, pids, views=None, shift=0):
        assert len(pids) > 0
        assert Q.size(0) in [1, len(pids)]

        Q = Q.contiguous().to(DEVICE).to(dtype=self.maxsim_dtype)

        views = self.views if views is None else views
        VIEWS_DEVICE = views[0].device

        D_buffers = self.buffers[str(VIEWS_DEVICE)]

        raw_pids = pids if type(pids) is list else pids.tolist()
        pids = torch.tensor(pids) if type(pids) is list else pids

        doclens, offsets = self.doclens[pids], self.doclens_pfxsum[pids]

        assignments = (doclens.unsqueeze(1) > torch.tensor(self.strides).unsqueeze(0) + 1e-6).sum(-1)

        one_to_n = torch.arange(len(raw_pids))
        output_pids, output_scores, output_permutation = [], [], []

        for group_idx, stride in enumerate(self.strides):
            locator = (assignments == group_idx)

            if locator.sum() < 1e-5:
                continue

            group_pids, group_doclens, group_offsets = pids[locator], doclens[locator], offsets[locator]
            group_Q = Q if Q.size(0) == 1 else Q[locator]

            group_offsets = group_offsets.to(VIEWS_DEVICE) - shift
            group_offsets_uniq, group_offsets_expand = torch.unique_consecutive(group_offsets, return_inverse=True)

            D_size = group_offsets_uniq.size(0)
            D = torch.index_select(views[group_idx], 0, group_offsets_uniq, out=D_buffers[group_idx][:D_size])
            D = D.to(DEVICE)
            D = D[group_offsets_expand.to(DEVICE)].to(dtype=self.maxsim_dtype)

            mask = torch.arange(stride, device=DEVICE) + 1
            mask = mask.unsqueeze(0) <= group_doclens.to(DEVICE).unsqueeze(-1)

            scores = (D @ group_Q) * mask.unsqueeze(-1)
            scores = scores.max(1).values.sum(-1).cpu()

            output_pids.append(group_pids)
            output_scores.append(scores)
            output_permutation.append(one_to_n[locator])

        output_permutation = torch.cat(output_permutation).sort().indices
        output_pids = torch.cat(output_pids)[output_permutation].tolist()
        output_scores = torch.cat(output_scores)[output_permutation].tolist()

        assert len(raw_pids) == len(output_pids)
        assert len(raw_pids) == len(output_scores)
        assert raw_pids == output_pids

        return output_scores

# =============
# ✨ IndexPart
# =============
class IndexPart():
    def __init__(self, directory, dim=128, part_range=None, verbose=True):
        first_part, last_part = (0, None) if part_range is None else (part_range.start, part_range.stop)

        # Load parts metadata
        all_parts, all_parts_paths, _ = get_parts(directory)
        self.parts = all_parts[first_part:last_part]
        self.parts_paths = all_parts_paths[first_part:last_part]

        # Load doclens metadata
        all_doclens = load_doclens(directory, flatten=False)

        self.doc_offset = sum([len(part_doclens) for part_doclens in all_doclens[:first_part]])
        self.doc_endpos = sum([len(part_doclens) for part_doclens in all_doclens[:last_part]])
        self.pids_range = range(self.doc_offset, self.doc_endpos)

        self.parts_doclens = all_doclens[first_part:last_part]
        self.doclens = flatten(self.parts_doclens)
        self.num_embeddings = sum(self.doclens)

        self.tensor = self._load_parts(dim, verbose)
        self.ranker = IndexRanker(self.tensor, self.doclens)

    def _load_parts(self, dim, verbose):
        tensor = torch.zeros(self.num_embeddings + 512, dim, dtype=torch.float16)

        offset = 0
        for idx, filename in enumerate(self.parts_paths):
            print("> Loading", filename, "\n")

            endpos = offset + sum(self.parts_doclens[idx])
            part = load_index_part(filename, verbose=verbose)

            tensor[offset:endpos] = part
            offset = endpos

        return tensor

    def pid_in_range(self, pid):
        return pid in self.pids_range

    def rank(self, Q, pids):

        assert Q.size(0) in [1, len(pids)], (Q.size(0), len(pids))
        assert all(pid in self.pids_range for pid in pids), self.pids_range

        pids_ = [pid - self.doc_offset for pid in pids]
        scores = self.ranker.rank(Q, pids_)

        return scores

# ======================
# ✨ FaissIndex ranking
# ======================
class FaissIndex():
    def __init__(self, index_path, faiss_index_path, nprobe=8, part_range=None):
        print("> Loading the FAISS index from", faiss_index_path, "..")

        faiss_part_range = os.path.basename(faiss_index_path).split('.')[-2].split('-')

        if len(faiss_part_range) == 2:
            faiss_part_range = range(*map(int, faiss_part_range))
            assert part_range[0] in faiss_part_range, (part_range, faiss_part_range)
            assert part_range[-1] in faiss_part_range, (part_range, faiss_part_range)
        else:
            faiss_part_range = None

        self.part_range = part_range
        self.faiss_part_range = faiss_part_range

        self.faiss_index = faiss.read_index(faiss_index_path)
        self.faiss_index.nprobe = nprobe

        print("> Building the emb2pid mapping")
        all_doclens = load_doclens(index_path, flatten=False)

        pid_offset = 0

        all_doclens = flatten(all_doclens)

        total_num_embeddings = sum(all_doclens)
        self.emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int)

        offset_doclens = 0
        for pid, dlength in enumerate(all_doclens):
            self.emb2pid[offset_doclens: offset_doclens + dlength] = pid_offset + pid
            offset_doclens += dlength

        print(f"len(self.emb2pid): {len(self.emb2pid)}" )

        self.parallel_pool = Pool(16)

    def retrieve(self, faiss_depth, Q, verbose=False):
        embedding_ids = self.queries_to_embedding_ids(faiss_depth, Q, verbose=verbose)
        pids = self.embedding_ids_to_pids(embedding_ids, verbose=verbose)

        return pids

    def queries_to_embedding_ids(self, faiss_depth, Q, verbose=True):
        # Flatten into a matrix for the faiss search.
        num_queries, embeddings_per_query, dim = Q.size()
        Q_faiss = Q.view(num_queries * embeddings_per_query, dim).cpu().contiguous()

        embeddings_ids = []
        faiss_bsize = embeddings_per_query * 5000
        for offset in range(0, Q_faiss.size(0), faiss_bsize):
            endpos = min(offset + faiss_bsize, Q_faiss.size(0))

            some_Q_faiss = Q_faiss[offset:endpos].float().numpy()
            _, some_embedding_ids = self.faiss_index.search(some_Q_faiss, faiss_depth)
            embeddings_ids.append(torch.from_numpy(some_embedding_ids))

        embedding_ids = torch.cat(embeddings_ids)

        # Reshape to (number of queries, non-unique embedding IDs per query)
        embedding_ids = embedding_ids.view(num_queries, embeddings_per_query * embedding_ids.size(1))

        return embedding_ids

    def embedding_ids_to_pids(self, embedding_ids, verbose=True):
        # Find unique PIDs per query.
        all_pids = self.emb2pid[embedding_ids]

        all_pids = all_pids.tolist()

        if len(all_pids) > 5000:
            all_pids = list(self.parallel_pool.map(uniq, all_pids))
        else:
            all_pids = list(map(uniq, all_pids))

        return all_pids

# ==========
# ✨ Ranker
# ==========
class Ranker():
    def __init__(self, index_path, faiss_index_path, inference, faiss_depth=1024):
        self.inference = inference
        self.faiss_depth = faiss_depth

        if faiss_depth is not None:
            self.faiss_index = FaissIndex(index_path, faiss_index_path, nprobe=8, part_range=None)
            self.retrieve = partial(self.faiss_index.retrieve, self.faiss_depth)

        self.index = IndexPart(index_path, dim=inference.colbert.dim, part_range=None, verbose=True)

    def encode(self, queries):
        assert type(queries) in [list, tuple], type(queries)

        Q = self.inference.queryFromText(queries, bsize=512 if len(queries) > 512 else None)

        return Q

    def rank(self, Q, pids=None):
        pids = self.retrieve(Q, verbose=False)[0] if pids is None else pids

        assert type(pids) in [list, tuple], type(pids)
        assert Q.size(0) == 1, (len(pids), Q.size())
        assert all(type(pid) is int for pid in pids)

        scores = []
        if len(pids) > 0:
            Q = Q.permute(0, 2, 1)
            scores = self.index.rank(Q, pids)

            scores_sorter = torch.tensor(scores).sort(descending=True)
            pids, scores = torch.tensor(pids)[scores_sorter.indices].tolist(), scores_sorter.values.tolist()

        return pids, scores

# ==================
# ✨ Faiss retrieve
# ==================
def retrieve(index_path, faiss_index_path, colbert, metrics, queries, qrels, topK_pids, collection, depth):
    inference = ModelInference(colbert, path_model)
    ranker = Ranker(index_path, faiss_index_path, inference, faiss_depth=depth)

    qids_in_order = list(queries.keys())
    
    for qbatch in batch(qids_in_order, len(queries), provide_offset=False):
        qbatch_text = [queries[qid] for qid in qbatch]

        rankings = []
        for query_idx, q_text in enumerate(qbatch_text):
            torch.cuda.synchronize('cuda:0')
            Q = ranker.encode([q_text])
            pids, scores = ranker.rank(Q)
            torch.cuda.synchronize()
            rankings.append(zip(pids, scores))

        print('Evaluate faiss retrieval')
        for query_idx, (qid, ranking) in enumerate(zip(qbatch, rankings)):
            rank_qid = [(score, pid, qid2passages(qid, topK_pids, collection, depth)) 
                            for pid, score in itertools.islice(ranking, depth)]
            if qrels:
                metrics.get_result(query_idx, qid, rank_qid, qrels[qid])
                if query_idx%25 == 0:
                    print(f'\n[{query_idx}]. Query: {queries[qid]}')
                    for i, (score, pid, passage) in enumerate(rank_qid, 1):
                        if pid in qrels[qid]:
                            print(f"Found at position: {i} with score {score:.3}")
                            print(passage)
                            break
                    metrics.print_metrics(query_idx)

In [None]:
# ===========================
# ✨ Run retrieval train set
# ===========================
topK_pids_train  = pickle_file(path_base+'data/topK_pids_TRAIN')
topK_docs_train  = pickle_file(path_base+'data/topK_docs_TRAIN')
queries_train    = pickle_file(path_base+'data/queries_TRAIN')
qrels_train      = pickle_file(path_base+'data/qrels_TRAIN')
collection_train = pickle_file(path_base+'data/collection_TRAIN')

print('TRAIN OBJECTS')
assert len(queries_train) == len(topK_docs_train) == len(topK_pids_train)
print(f'\tlen(queries_train):    {len(queries_train)}')
print(f'\tlen(topK_docs_train):  {len(topK_docs_train)}')
print(f'\tlen(topK_pids_train):  {len(topK_pids_train)}')
print(f'\tlen(collection_train): {len(collection_train)}')

metrics = Metrics(
        mrr_depths=    {1, 3, 5, 10, 20}, 
        recall_depths= {1, 3, 5, 10, 20},
        success_depths={1, 3, 5, 10, 20},
        )

retrieve(
    index_path=path_base+'index/train_index',
    faiss_index_path=path_base+'index/train_index/ivfpq.50.faiss',
    colbert=colbert,
    metrics=metrics, 
    queries=queries_train,
    qrels=qrels_train,
    topK_pids=topK_pids_train,
    collection=collection_train, 
    depth=50,  
    )

In [None]:
# ===========================
# ✨ Run retrieval valid set
# ===========================
topK_pids_valid  = pickle_file(path_base+'data/topK_pids_VALID')
topK_docs_valid  = pickle_file(path_base+'data/topK_docs_VALID')
queries_valid    = pickle_file(path_base+'data/queries_VALID')
qrels_valid      = pickle_file(path_base+'data/qrels_VALID')
collection_valid = pickle_file(path_base+'data/collection_VALID')

print('\nVALID OBJECTS')
assert len(queries_valid) == len(topK_docs_valid) == len(topK_pids_valid)
print(f'\tlen(queries_valid):    {len(queries_valid)}')
print(f'\tlen(topK_docs_valid):  {len(topK_docs_valid)}')
print(f'\tlen(topK_pids_valid):  {len(topK_pids_valid)}')
print(f'\tlen(collection_valid): {len(collection_valid)}')

metrics = Metrics(
        mrr_depths=    {1, 3, 5, 10, 20}, 
        recall_depths= {1, 3, 5, 10, 20},
        success_depths={1, 3, 5, 10, 20},
        )

retrieve(
    index_path=path_base+'index/valid_index',
    faiss_index_path=path_base+'index/valid_index/ivfpq.50.faiss',
    colbert=colbert,
    metrics=metrics, 
    queries=queries_valid,
    qrels=qrels_valid,
    topK_pids=topK_pids_valid,
    collection=collection_valid, 
    depth=50,  
    )