<a href="https://colab.research.google.com/github/finardi/Ranking/blob/main/5_Cobert_Faiss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install -q ujson
!apt install -q libomp-dev
!pip install -q transformers
!python -m pip -q install --upgrade faiss faiss-gpu

In [None]:
import gc
import os
import sys
import math
import time
import ujson
import torch
import faiss
import random
import pickle
import itertools
import numpy as np
import pandas as pd
from functools import partial
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
 
# better pandas viz
pd.set_option('display.max_columns', 100)  
pd.set_option('display.expand_frame_repr', 100)
pd.set_option('max_colwidth', 700)
pd.set_option('display.max_rows', 5000)
  
# save/load pickles
def pickle_file(path, data=None):
    if data is None:
        with open(path, 'rb') as f:
            return pickle.load(f)
    if data is not None:
        with open(path, 'wb') as handle:
            pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
 
# path base
path_base = '/content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal/'

In [None]:
# =============
# ✨ Constants
# =============
bsize = 16 # N
query_maxlen = 48
doc_maxlen = 128
path_model = 'bert-base-multilingual-uncased'

# ==================
# ✨ QueryTokenizer
# ==================
class QueryTokenizer():
    def __init__(self, query_maxlen, path_tokenizer):
        self.tok = BertTokenizerFast.from_pretrained(path_tokenizer)
        self.query_maxlen = query_maxlen

        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
        self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token], [self.sep_token]
        tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        obj = self.tok(batch_text, padding='max_length', truncation=True,
                       return_tensors='pt', max_length=self.query_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        ids[ids == 0] = self.mask_token_id

        if bsize:
            batches = _split_into_batches(ids, mask, bsize)
            return batches

        return ids, mask

# ================
# ✨ DocTokenizer
# ================
class DocTokenizer():
    def __init__(self, doc_maxlen, path_tokenizer):
        self.tok = BertTokenizerFast.from_pretrained(path_tokenizer)
        self.doc_maxlen = doc_maxlen

        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token], [self.sep_token]
        tokens = [prefix + lst + suffix for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        obj = self.tok(batch_text, padding='longest', truncation='longest_first',
                       return_tensors='pt', max_length=self.doc_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        if bsize:
            ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
            batches = _split_into_batches(ids, mask, bsize)
            return batches, reverse_indices

        return ids, mask

# =====================
# ✨ tensorize triples
# =====================
def tensorize_triples(query_tokenizer, doc_tokenizer, queries, positives, negatives, bsize):
    assert len(queries) == len(positives) == len(negatives)
    assert bsize is None or len(queries) % bsize == 0

    N = len(queries)
    assert bsize == N
    Q_ids, Q_mask = query_tokenizer.tensorize(queries)
    D_ids, D_mask = doc_tokenizer.tensorize(positives + negatives)
    D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1)

    # Compute max among {length of i^th positive, length of i^th negative} for i \in N
    maxlens = D_mask.sum(-1).max(0).values

    # Sort by maxlens
    indices = maxlens.sort().indices
    Q_ids, Q_mask = Q_ids[indices], Q_mask[indices]
    D_ids, D_mask = D_ids[:, indices], D_mask[:, indices]

    (positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask

    query_batches = _split_into_batches(Q_ids, Q_mask, bsize)
    positive_batches = _split_into_batches(positive_ids, positive_mask, bsize)
    negative_batches = _split_into_batches(negative_ids, negative_mask, bsize)

    batches = []
    for (q_ids, q_mask), (p_ids, p_mask), (n_ids, n_mask) in zip(query_batches, positive_batches, negative_batches):
        Q = (torch.cat((q_ids, q_ids)), torch.cat((q_mask, q_mask)))
        D = (torch.cat((p_ids, n_ids)), torch.cat((p_mask, n_mask)))
        batches.append((Q, D))

    return batches

# =============
# ✨ Aux funcs
# =============
def _sort_by_length(ids, mask, bsize):
    if ids.size(0) <= bsize:
        return ids, mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices

    return ids[indices], mask[indices], reverse_indices

def _split_into_batches(ids, mask, bsize):
    batches = []
    for offset in range(0, ids.size(0), bsize):
        batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))

    return batches

# ===============
# ✨ LazyBatcher
# ===============
class LazyBatcher():
    def __init__(self, bsize, path, path_tokenizer, query_maxlen, doc_maxlen, mode='train', accumsteps=1):
        self.bsize, self.accumsteps = bsize, accumsteps
        self.query_tokenizer = QueryTokenizer(query_maxlen=query_maxlen, path_tokenizer=path_tokenizer)
        self.doc_tokenizer = DocTokenizer(doc_maxlen=doc_maxlen, path_tokenizer=path_tokenizer)
        self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer)
        self.position = 0
        self.mode = mode

        self.triples = self._load_triples(path_base)
        self.queries = self._load_queries(path_base)
        self.collection = self._load_collection(path_base)
    
    def _load_triples(self, path):
        if self.mode == 'train':
            path = path+'data/df_FAQ_triplet_IDS_TRAIN.parquet.gzip'
        elif self.mode == 'valid':
            path = path+'data/df_FAQ_triplet_IDS_VALID.parquet.gzip'

        df_triplet = pd.read_parquet(path)
        triples = []
        for qid, pos_pid, neg_pid in zip(
            df_triplet.qid.values,
            df_triplet.pos_pid.values,
            df_triplet.neg_pid.values
            ):
            triples.append((qid, pos_pid, neg_pid))

        return triples

    def _load_queries(self, path):
        if self.mode == 'train':
            qid_to_query_train = path+'data/qid_to_query_TRAIN'
            return pickle_file(qid_to_query_train)
        elif self.mode == 'valid':
            qid_to_query_valid = path+'data/qid_to_query_VALID'
            return pickle_file(qid_to_query_valid)

    def _load_collection(self, path):
        if self.mode == 'train':
            pid_to_doc_train = path+'data/pid_to_doc_TRAIN'
            return pickle_file(pid_to_doc_train)
        elif self.mode == 'valid':
            pid_to_doc_valid = path+'data/pid_to_doc_VALID'
            return pickle_file(pid_to_doc_valid)
        

    def __iter__(self):
        return self

    def __len__(self):
        return len(self.triples)

    def __next__(self):
        # offsets determines the starting index position of each bag (sequence) in input.
        offset, endpos = self.position, min(self.position + self.bsize, len(self.triples))
        self.position = endpos

        if offset + self.bsize > len(self.triples):
            raise StopIteration

        queries, positives, negatives = [], [], []

        for position in range(offset, endpos):
            query, pos, neg = self.triples[position]
            query, pos, neg = self.queries[query], self.collection[pos], self.collection[neg]
            queries.append(query)
            positives.append(pos)
            negatives.append(neg)

        return self.collate(queries, positives, negatives)

    def collate(self, queries, positives, negatives):
        assert len(queries) == len(positives) == len(negatives) == self.bsize

        return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps)

# ===========
# ✨ ColBERT
# ===========
class ColBERT(BertPreTrainedModel):
    def __init__(self, config, query_maxlen, doc_maxlen, mask_punctuation, dim=128, similarity_metric='cosine'):

        super(ColBERT, self).__init__(config)

        self.query_maxlen = query_maxlen
        self.doc_maxlen = doc_maxlen
        self.similarity_metric = similarity_metric
        self.dim = dim

        self.mask_punctuation = mask_punctuation
        self.skiplist = {}

        if self.mask_punctuation:
            self.tokenizer = BertTokenizerFast.from_pretrained(path_model)
            self.skiplist = {w: True
                             for symbol in string.punctuation
                             for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]}

        self.bert = BertModel(config)
        self.linear = torch.nn.Linear(config.hidden_size, dim, bias=False)

        self.init_weights()

    def forward(self, Q, D):
        return self.score(self.query(*Q), self.doc(*D))

    def query(self, input_ids, attention_mask):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        Q = self.bert(input_ids, attention_mask=attention_mask)[0]
        Q = self.linear(Q)

        return torch.nn.functional.normalize(Q, p=2, dim=2)

    def doc(self, input_ids, attention_mask, keep_dims=True):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        D = self.bert(input_ids, attention_mask=attention_mask)[0]
        D = self.linear(D)

        mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float()
        D = D * mask

        D = torch.nn.functional.normalize(D, p=2, dim=2)

        if not keep_dims:
            D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
            D = [d[mask[idx]] for idx, d in enumerate(D)]

        return D

    def score(self, Q, D):
        if self.similarity_metric == 'cosine':
            return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)

        assert self.similarity_metric == 'l2'
        return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)

    def mask(self, input_ids):
        mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
        return mask

# ==================
# ✨ ModelInference
# ==================
class ModelInference():
    def __init__(self, colbert, path_model):
        assert colbert.training is False

        self.colbert = colbert
        self.query_tokenizer = QueryTokenizer(colbert.query_maxlen, path_tokenizer=path_model)
        self.doc_tokenizer = DocTokenizer(colbert.doc_maxlen, path_tokenizer=path_model)

    def query(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            Q = self.colbert.query(*args, **kw_args)
            return Q.cpu() if to_cpu else Q

    def doc(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            D = self.colbert.doc(*args, **kw_args)
            return D.cpu() if to_cpu else D

    def queryFromText(self, queries, bsize=None, to_cpu=False):
        if bsize:
            batches = self.query_tokenizer.tensorize(queries, bsize=bsize)
            batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
            return torch.cat(batches)

        input_ids, attention_mask = self.query_tokenizer.tensorize(queries)
        return self.query(input_ids, attention_mask)

    def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False):
        if bsize:
            batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)

            batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
                       for input_ids, attention_mask in batches]

            if keep_dims:
                D = _stack_3D_tensors(batches)
                return D[reverse_indices]

            D = [d for batch in batches for d in batch]
            return [D[idx] for idx in reverse_indices.tolist()]

        input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
        return self.doc(input_ids, attention_mask, keep_dims=keep_dims)

    def score(self, Q, D, mask=None, lengths=None, explain=False):
        if lengths is not None:
            assert mask is None, "don't supply both mask and lengths"

            mask = torch.arange(D.size(1), device=DEVICE) + 1
            mask = mask.unsqueeze(0) <= lengths.to(DEVICE).unsqueeze(-1)

        scores = (D @ Q)
        scores = scores if mask is None else scores * mask.unsqueeze(-1)
        scores = scores.max(1)

        if explain:
            assert False, "TODO"

        return scores.values.sum(-1).cpu()

def _stack_3D_tensors(groups):
    bsize = sum([x.size(0) for x in groups])
    maxlen = max([x.size(1) for x in groups])
    hdim = groups[0].size(2)

    output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)

    offset = 0
    for x in groups:
        endpos = offset + x.size(0)
        output[offset:endpos, :x.size(1)] = x
        offset = endpos

    return output    

# - - - - -
dataloader_train = LazyBatcher(
    bsize=bsize, 
    path=path_base, 
    path_tokenizer=path_model,
    query_maxlen=query_maxlen,
    doc_maxlen=doc_maxlen,
    mode='train'
    )

print('batches:')
for i, batches in enumerate(dataloader_train):
    print(f' {i }.', end ='')

try:
    del colbert
    gc.collect()
    torch.cuda.empty_cache()
except:
    pass

DEVICE = 'cuda'

print()

colbert = ColBERT.from_pretrained(
    path_model,
    query_maxlen=query_maxlen,
    doc_maxlen=doc_maxlen,
    dim=128,
    similarity_metric='cosine',
    mask_punctuation=False).to(DEVICE)        

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=871891.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1715180.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…


batches:
 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37.


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=625.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=672271273.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing ColBERT: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing ColBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ColBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ColBERT were not initialized from the model checkpoint at bert-base-multilingual-uncased and are newly initialized: ['linear.weight'

In [None]:
# ✨ load colbert from checkpoint
colbert.load_state_dict(torch.load(path_base+'data/EPOCH_3_FAQ'))
print('\nmodel loaded!\n')


model loaded!



# Faiss

In [None]:
# =================
# ✨ FaissIndexGPU
# =================
class FaissIndexGPU():
    def __init__(self):
        self.ngpu = faiss.get_num_gpus()

        if self.ngpu == 0:
            return

        self.tempmem = 1 << 33
        self.max_add_per_gpu = 1 << 25
        self.max_add = self.max_add_per_gpu * self.ngpu
        self.add_batch_size = 65536

        self.gpu_resources = self._prepare_gpu_resources()

    def _prepare_gpu_resources(self):
        print(f"\nPreparing resources for {self.ngpu} GPUs.")

        gpu_resources = []

        for _ in range(self.ngpu):
            res = faiss.StandardGpuResources()
            if self.tempmem >= 0:
                res.setTempMemory(self.tempmem)
            gpu_resources.append(res)

        return gpu_resources

    def _make_vres_vdev(self):
        assert self.ngpu > 0

        vres = faiss.GpuResourcesVector()
        vdev = faiss.Int32Vector()

        for i in range(self.ngpu):
            vdev.push_back(i)
            vres.push_back(self.gpu_resources[i])

        return vres, vdev

    def training_initialize(self, index, quantizer):
        assert self.ngpu > 0

        s = time.time()
        self.index_ivf = faiss.extract_index_ivf(index)
        self.clustering_index = faiss.index_cpu_to_all_gpus(quantizer)
        self.index_ivf.clustering_index = self.clustering_index

    def training_finalize(self):
        assert self.ngpu > 0

        s = time.time()
        self.index_ivf.clustering_index = faiss.index_gpu_to_cpu(self.index_ivf.clustering_index)

    def adding_initialize(self, index):
        assert self.ngpu > 0

        self.co = faiss.GpuMultipleClonerOptions()
        self.co.useFloat16 = True
        self.co.useFloat16CoarseQuantizer = False
        self.co.usePrecomputed = False
        self.co.indicesOptions = faiss.INDICES_CPU
        self.co.verbose = True
        self.co.reserveVecs = self.max_add
        self.co.shard = True
        assert self.co.shard_type in (0, 1, 2)

        self.vres, self.vdev = self._make_vres_vdev()
        self.gpu_index = faiss.index_cpu_to_gpu_multiple(self.vres, self.vdev, index, self.co)

    def add(self, index, data, offset):
        assert self.ngpu > 0

        t0 = time.time()
        nb = data.shape[0]

        for i0 in range(0, nb, self.add_batch_size):
            i1 = min(i0 + self.add_batch_size, nb)
            xs = data[i0:i1]

            self.gpu_index.add_with_ids(xs, np.arange(offset+i0, offset+i1))

            if self.max_add > 0 and self.gpu_index.ntotal > self.max_add:
                self._flush_to_cpu(index, nb, offset)
            sys.stdout.flush()

        if self.gpu_index.ntotal > 0:
            self._flush_to_cpu(index, nb, offset)

        assert index.ntotal == offset+nb, (index.ntotal, offset+nb, offset, nb)

    def _flush_to_cpu(self, index, nb, offset):
        for i in range(self.ngpu):
            index_src_gpu = faiss.downcast_index(self.gpu_index if self.ngpu == 1 else self.gpu_index.at(i))
            index_src = faiss.index_gpu_to_cpu(index_src_gpu)

            index_src.copy_subset_to(index, 0, offset, offset+nb)
            index_src_gpu.reset()
            index_src_gpu.reserveMemory(self.max_add)

        if self.ngpu > 1:
            try:
                self.gpu_index.sync_with_shard_indexes()
            except:
                self.gpu_index.syncWithSubIndexes()

# ==============
# ✨ FaissIndex
# ==============
class FaissIndex():
    def __init__(self, dim, partitions):
        self.dim = dim
        self.partitions = partitions

        self.gpu = FaissIndexGPU()
        self.quantizer, self.index = self._create_index()
        self.offset = 0

    def _create_index(self):
        quantizer = faiss.IndexFlatL2(self.dim)  # faiss.IndexHNSWFlat(dim, 32)
        index = faiss.IndexIVFPQ(quantizer, self.dim, self.partitions, 16, 8)

        return quantizer, index

    def train(self, train_data):
        if self.gpu.ngpu > 0:
            self.gpu.training_initialize(self.index, self.quantizer)

        s = time.time()
        self.index.train(train_data)

        if self.gpu.ngpu > 0:
            self.gpu.training_finalize()

    def add(self, data):
        print(f"\tAdd data with shape {data.shape} (offset = {self.offset})\n")

        if self.gpu.ngpu > 0 and self.offset == 0:
            self.gpu.adding_initialize(self.index)

        if self.gpu.ngpu > 0:
            self.gpu.add(self.index, data, self.offset)
        else:
            self.index.add(data)

        self.offset += data.shape[0]

    def save(self, output_path):
        print(f"\nWriting index to {output_path}")

        self.index.nprobe = 10  # just a default
        faiss.write_index(self.index, output_path)                

# ===========
# ✨ Aux fcs
# ===========
def load_sample(samples_paths, sample_fraction=None):
    sample = []

    for filename in samples_paths:
        print(f"> Loading {filename}")
        part = torch.load(filename)

        if type(part) == list: 
            part = torch.cat(part)
        if sample_fraction:
            part = part[torch.randint(0, high=part.size(0), size=(int(part.size(0) * sample_fraction),))]

        sample.append(part)

    sample = torch.cat(sample).float().numpy()
    return sample

def get_parts(directory):
    extension = '.pt'

    parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory)
                    if filename.endswith(extension)])

    assert list(range(len(parts))) == parts, parts

    # Integer-sortedness matters.
    parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts]
    samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts]

    return parts, parts_paths, samples_paths

def prepare_faiss_index(slice_samples_paths, partitions, sample_fraction=None):
    training_sample = load_sample(slice_samples_paths, sample_fraction=sample_fraction)

    dim = training_sample.shape[-1]
    index = FaissIndex(dim, partitions)
    print("> Training with the vectors...")
    index.train(training_sample)
    print("  Done training!\n")

    return index

# ===============
# ✨ index_faiss
# ===============
def index_faiss(index_path, partitions, sample):

    parts, parts_paths, samples_paths = get_parts(index_path)
    
    if sample is not None:
        print(f"Training with {round(sample * 100.0, 1)}% of all embeddings provided.\n")
        samples_paths = parts_paths

    output_path = os.path.join(index_path, f'ivfpq.{str(partitions)}.faiss')

    index = prepare_faiss_index(samples_paths, partitions, sample)

    for filename in parts_paths:
        part = torch.load(filename)
        
        if type(part) == list: 
            part = torch.cat(part)
        
        part = part.float().numpy()
        index.add(part)

    index.save(output_path)

    print(f"\nDone! All complete!")    

In [None]:
# - - - - - 
# ✨ Build faiss indexes for train set
index_faiss(
    index_path=path_base+'index/train_index',
    partitions=100, 
    sample=0.30,
    )

Training with 30.0% of all embeddings provided.

> Loading /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/train_index/0.pt
> Loading /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/train_index/1.pt
> Loading /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/train_index/2.pt
> Loading /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/train_index/3.pt
> Loading /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/train_index/4.pt
> Loading /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/train_index/5.pt
> Loading /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/train_index/6.pt
> Loading /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/train_index/7.pt
> Loading /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/train_index/8.pt
> Loading /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/train_index/9.pt



In [None]:
# ✨ Build faiss indexes for valid set
index_faiss(
    index_path=path_base+'index/valid_index',
    partitions=50, 
    sample=0.30,
    )

Training with 30.0% of all embeddings provided.

> Loading /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/valid_index/0.pt
> Loading /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/valid_index/1.pt

Preparing resources for 1 GPUs.
> Training with the vectors...
  Done training!

	Add data with shape (6166, 128) (offset = 0)

	Add data with shape (384, 128) (offset = 6166)


Writing index to /content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal//index/valid_index/ivfpq.50.faiss

Done! All complete!
