<a href="https://colab.research.google.com/github/finardi/Ranking/blob/main/4_Cobert_Index.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install -q ujson
!pip install -q transformers

In [None]:
import gc
import os
import ujson
import torch
import random
import pickle
import numpy as np
import pandas as pd
from functools import partial

from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
 
# better pandas viz
pd.set_option('display.max_columns', 100)  
pd.set_option('display.expand_frame_repr', 100)
pd.set_option('max_colwidth', 700)
pd.set_option('display.max_rows', 5000)
  
# save/load pickles
def pickle_file(path, data=None):
    if data is None:
        with open(path, 'rb') as f:
            return pickle.load(f)
    if data is not None:
        with open(path, 'wb') as handle:
            pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
 
# path base
path_base = '/content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal/'

In [None]:
# =============
# ✨ Constants
# =============
bsize = 16 # N
query_maxlen = 48
doc_maxlen = 128
path_model = 'bert-base-multilingual-uncased'

# ==================
# ✨ QueryTokenizer
# ==================
class QueryTokenizer():
    def __init__(self, query_maxlen, path_tokenizer):
        self.tok = BertTokenizerFast.from_pretrained(path_tokenizer)
        self.query_maxlen = query_maxlen

        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
        self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token], [self.sep_token]
        tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        obj = self.tok(batch_text, padding='max_length', truncation=True,
                       return_tensors='pt', max_length=self.query_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        ids[ids == 0] = self.mask_token_id

        if bsize:
            batches = _split_into_batches(ids, mask, bsize)
            return batches

        return ids, mask

# ================
# ✨ DocTokenizer
# ================
class DocTokenizer():
    def __init__(self, doc_maxlen, path_tokenizer):
        self.tok = BertTokenizerFast.from_pretrained(path_tokenizer)
        self.doc_maxlen = doc_maxlen

        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token], [self.sep_token]
        tokens = [prefix + lst + suffix for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        obj = self.tok(batch_text, padding='longest', truncation='longest_first',
                       return_tensors='pt', max_length=self.doc_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        if bsize:
            ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
            batches = _split_into_batches(ids, mask, bsize)
            return batches, reverse_indices

        return ids, mask

# =====================
# ✨ tensorize triples
# =====================
def tensorize_triples(query_tokenizer, doc_tokenizer, queries, positives, negatives, bsize):
    assert len(queries) == len(positives) == len(negatives)
    assert bsize is None or len(queries) % bsize == 0

    N = len(queries)
    assert bsize == N
    Q_ids, Q_mask = query_tokenizer.tensorize(queries)
    D_ids, D_mask = doc_tokenizer.tensorize(positives + negatives)
    D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1)

    # Compute max among {length of i^th positive, length of i^th negative} for i \in N
    maxlens = D_mask.sum(-1).max(0).values

    # Sort by maxlens
    indices = maxlens.sort().indices
    Q_ids, Q_mask = Q_ids[indices], Q_mask[indices]
    D_ids, D_mask = D_ids[:, indices], D_mask[:, indices]

    (positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask

    query_batches = _split_into_batches(Q_ids, Q_mask, bsize)
    positive_batches = _split_into_batches(positive_ids, positive_mask, bsize)
    negative_batches = _split_into_batches(negative_ids, negative_mask, bsize)

    batches = []
    for (q_ids, q_mask), (p_ids, p_mask), (n_ids, n_mask) in zip(query_batches, positive_batches, negative_batches):
        Q = (torch.cat((q_ids, q_ids)), torch.cat((q_mask, q_mask)))
        D = (torch.cat((p_ids, n_ids)), torch.cat((p_mask, n_mask)))
        batches.append((Q, D))

    return batches

# =============
# ✨ Aux funcs
# =============
def _sort_by_length(ids, mask, bsize):
    if ids.size(0) <= bsize:
        return ids, mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices

    return ids[indices], mask[indices], reverse_indices

def _split_into_batches(ids, mask, bsize):
    batches = []
    for offset in range(0, ids.size(0), bsize):
        batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))

    return batches

# ===============
# ✨ LazyBatcher
# ===============
class LazyBatcher():
    def __init__(self, bsize, path, path_tokenizer, query_maxlen, doc_maxlen, mode='train', accumsteps=1):
        self.bsize, self.accumsteps = bsize, accumsteps
        self.query_tokenizer = QueryTokenizer(query_maxlen=query_maxlen, path_tokenizer=path_tokenizer)
        self.doc_tokenizer = DocTokenizer(doc_maxlen=doc_maxlen, path_tokenizer=path_tokenizer)
        self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer)
        self.position = 0
        self.mode = mode

        self.triples = self._load_triples(path_base)
        self.queries = self._load_queries(path_base)
        self.collection = self._load_collection(path_base)
    
    def _load_triples(self, path):
        if self.mode == 'train':
            path = path+'data/df_FAQ_triplet_IDS_TRAIN.parquet.gzip'
        elif self.mode == 'valid':
            path = path+'data/df_FAQ_triplet_IDS_VALID.parquet.gzip'

        df_triplet = pd.read_parquet(path)
        triples = []
        for qid, pos_pid, neg_pid in zip(
            df_triplet.qid.values,
            df_triplet.pos_pid.values,
            df_triplet.neg_pid.values
            ):
            triples.append((qid, pos_pid, neg_pid))

        return triples

    def _load_queries(self, path):
        if self.mode == 'train':
            qid_to_query_train = path+'data/qid_to_query_TRAIN'
            return pickle_file(qid_to_query_train)
        elif self.mode == 'valid':
            qid_to_query_valid = path+'data/qid_to_query_VALID'
            return pickle_file(qid_to_query_valid)

    def _load_collection(self, path):
        if self.mode == 'train':
            pid_to_doc_train = path+'data/pid_to_doc_TRAIN'
            return pickle_file(pid_to_doc_train)
        elif self.mode == 'valid':
            pid_to_doc_valid = path+'data/pid_to_doc_VALID'
            return pickle_file(pid_to_doc_valid)
        

    def __iter__(self):
        return self

    def __len__(self):
        return len(self.triples)

    def __next__(self):
        # offsets determines the starting index position of each bag (sequence) in input.
        offset, endpos = self.position, min(self.position + self.bsize, len(self.triples))
        self.position = endpos

        if offset + self.bsize > len(self.triples):
            raise StopIteration

        queries, positives, negatives = [], [], []

        for position in range(offset, endpos):
            query, pos, neg = self.triples[position]
            query, pos, neg = self.queries[query], self.collection[pos], self.collection[neg]
            queries.append(query)
            positives.append(pos)
            negatives.append(neg)

        return self.collate(queries, positives, negatives)

    def collate(self, queries, positives, negatives):
        assert len(queries) == len(positives) == len(negatives) == self.bsize

        return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps)

# ===========
# ✨ ColBERT
# ===========
class ColBERT(BertPreTrainedModel):
    def __init__(self, config, query_maxlen, doc_maxlen, mask_punctuation, dim=128, similarity_metric='cosine'):

        super(ColBERT, self).__init__(config)

        self.query_maxlen = query_maxlen
        self.doc_maxlen = doc_maxlen
        self.similarity_metric = similarity_metric
        self.dim = dim

        self.mask_punctuation = mask_punctuation
        self.skiplist = {}

        if self.mask_punctuation:
            self.tokenizer = BertTokenizerFast.from_pretrained(path_model)
            self.skiplist = {w: True
                             for symbol in string.punctuation
                             for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]}

        self.bert = BertModel(config)
        self.linear = torch.nn.Linear(config.hidden_size, dim, bias=False)

        self.init_weights()

    def forward(self, Q, D):
        return self.score(self.query(*Q), self.doc(*D))

    def query(self, input_ids, attention_mask):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        Q = self.bert(input_ids, attention_mask=attention_mask)[0]
        Q = self.linear(Q)

        return torch.nn.functional.normalize(Q, p=2, dim=2)

    def doc(self, input_ids, attention_mask, keep_dims=True):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        D = self.bert(input_ids, attention_mask=attention_mask)[0]
        D = self.linear(D)

        mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float()
        D = D * mask

        D = torch.nn.functional.normalize(D, p=2, dim=2)

        if not keep_dims:
            D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
            D = [d[mask[idx]] for idx, d in enumerate(D)]

        return D

    def score(self, Q, D):
        if self.similarity_metric == 'cosine':
            return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)

        assert self.similarity_metric == 'l2'
        return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)

    def mask(self, input_ids):
        mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
        return mask

# ==================
# ✨ ModelInference
# ==================
class ModelInference():
    def __init__(self, colbert, path_model):
        assert colbert.training is False

        self.colbert = colbert
        self.query_tokenizer = QueryTokenizer(colbert.query_maxlen, path_tokenizer=path_model)
        self.doc_tokenizer = DocTokenizer(colbert.doc_maxlen, path_tokenizer=path_model)

    def query(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            Q = self.colbert.query(*args, **kw_args)
            return Q.cpu() if to_cpu else Q

    def doc(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            D = self.colbert.doc(*args, **kw_args)
            return D.cpu() if to_cpu else D

    def queryFromText(self, queries, bsize=None, to_cpu=False):
        if bsize:
            batches = self.query_tokenizer.tensorize(queries, bsize=bsize)
            batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
            return torch.cat(batches)

        input_ids, attention_mask = self.query_tokenizer.tensorize(queries)
        return self.query(input_ids, attention_mask)

    def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False):
        if bsize:
            batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)

            batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
                       for input_ids, attention_mask in batches]

            if keep_dims:
                D = _stack_3D_tensors(batches)
                return D[reverse_indices]

            D = [d for batch in batches for d in batch]
            return [D[idx] for idx in reverse_indices.tolist()]

        input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
        return self.doc(input_ids, attention_mask, keep_dims=keep_dims)

    def score(self, Q, D, mask=None, lengths=None, explain=False):
        if lengths is not None:
            assert mask is None, "don't supply both mask and lengths"

            mask = torch.arange(D.size(1), device=DEVICE) + 1
            mask = mask.unsqueeze(0) <= lengths.to(DEVICE).unsqueeze(-1)

        scores = (D @ Q)
        scores = scores if mask is None else scores * mask.unsqueeze(-1)
        scores = scores.max(1)

        if explain:
            assert False, "TODO"

        return scores.values.sum(-1).cpu()

def _stack_3D_tensors(groups):
    bsize = sum([x.size(0) for x in groups])
    maxlen = max([x.size(1) for x in groups])
    hdim = groups[0].size(2)

    output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)

    offset = 0
    for x in groups:
        endpos = offset + x.size(0)
        output[offset:endpos, :x.size(1)] = x
        offset = endpos

    return output    

# - - - - -
dataloader_train = LazyBatcher(
    bsize=bsize, 
    path=path_base, 
    path_tokenizer=path_model,
    query_maxlen=query_maxlen,
    doc_maxlen=doc_maxlen,
    mode='train'
    )

print('batches:')
for i, batches in enumerate(dataloader_train):
    print(f' {i }.', end ='')

try:
    del colbert
    gc.collect()
    torch.cuda.empty_cache()
except:
    pass

DEVICE = 'cuda'

print()

colbert = ColBERT.from_pretrained(
    path_model,
    query_maxlen=query_maxlen,
    doc_maxlen=doc_maxlen,
    dim=128,
    similarity_metric='cosine',
    mask_punctuation=False).to(DEVICE)        

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=871891.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1715180.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…


batches:
 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37.


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=625.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=672271273.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing ColBERT: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing ColBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ColBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ColBERT were not initialized from the model checkpoint at bert-base-multilingual-uncased and are newly initialized: ['linear.weight'

In [None]:
# ✨ load colbert from checkpoint
colbert.load_state_dict(torch.load(path_base+'data/EPOCH_3_FAQ'))
print('\nmodel loaded!\n')


model loaded!



In [None]:
# =============
# ✨ Indexing
# =============
def save_tensor(tensor:torch.tensor, path_prefix:str):
    torch.save(tensor, path_prefix)

def load_index_part(filename, verbose=True):
    part = torch.load(filename)
    if type(part) == list:  # for backward compatibility
        part = torch.cat(part)
    return part    

def encode_batch(batch_idx:int, batch:list):
    colbert.eval()
    inference = ModelInference(colbert, path_model)
    with torch.no_grad():
        embs = inference.docFromText(batch, bsize=bsize, keep_dims=False)
        assert type(embs) is list
        assert len(embs) == len(batch)
        local_doclens = [d.size(0) for d in embs]
        embs = torch.cat(embs)
    return embs, local_doclens

def save_batch(batch_idx:int, embs:torch.tensor, doclens:list, index_path:str):
    output_path = os.path.join(index_path, f"{batch_idx}.pt")
    output_sample_path = os.path.join(index_path, f"{batch_idx}.sample")
    doclens_path = os.path.join(index_path, f'doclens.{batch_idx}.json')

    # Save the embeddings.
    save_tensor(embs, output_path)
    save_tensor(embs[torch.randint(low=0, high=embs.size(0), size=(embs.size(0) // 20,) )], output_sample_path)

    # Save the doclens.
    with open(doclens_path, 'w') as output_doclens:
        ujson.dump(doclens, output_doclens)

def get_collection_batch(collection:list, bsize=bsize):
    offset = 0
    for _ in range(len(collection)):
        batch = collection[offset:offset+bsize] 
        if len(batch) == 0:
            break #EOF
        
        yield batch
        offset += bsize
    return

def run_index(collection:dict, bsize:int, index_path:str, verbose=True):
    for batch_idx, batch in enumerate(get_collection_batch(list(collection.values()), bsize=bsize)):
        embs, doclens = encode_batch(batch_idx, batch)
        save_batch(batch_idx, embs, doclens, index_path)
        if verbose:
            print(batch_idx, batch)

# - - - - - 
# ✨ Train indexes
print('Indexing train set')
collection_train = pickle_file(path_base+'data/collection_TRAIN')
run_index(
    collection=collection_train, 
    bsize=64, 
    index_path=path_base+'index/train_index',
    )          

# ✨ Valid indexes
print('\nIndexing valid set')
collection_valid = pickle_file(path_base+'data/collection_VALID')
run_index(
    collection=collection_valid, 
    bsize=64, 
    index_path=path_base+'index/valid_index',
    )     

Indexing train set
0 ['sim.  doação modal ou onerosa é aquela que traz consigo um encargo para o donatário. os valores recebidos  em função desse encargo estão sujeitos ao recolhimento mensal (carnê-leão), se recebidos de pessoa física  ou, na fonte, se pagos por pessoa jurídica, e na declaração de ajuste.', 'sim.  contribuinte, independentemente da opção pelo desconto simplificado ou não, deve informar como  rendimento tributável o valor dos aluguéis recebidos, podendo excluir os impostos, as taxas e os emolumentos  incidentes sobre o bem que produzir o rendimento, desde que o ônus desses encargos tenha sido  exclusivamente do declarante.', 'contribuinte pode solicitar confirmação do pagamento na unidade de atendimento da secretaria especial da  receita federal do brasil   de sua jurisdição fiscal.', 'não. os valores recebidos a título de pensão em cumprimento de acordo ou decisão judicial, ou ainda por  escritura pública, inclusive a prestação de alimentos provisionais, estão abrangi