<a href="https://colab.research.google.com/github/finardi/Ranking/blob/main/2_Cobert_Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

Tue Jun  8 01:53:58 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
%%capture
!pip install -q transformers

In [None]:
import gc
import os
import time
import torch
import pickle
import numpy as np
import pandas as pd
from functools import partial
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast

# better pandas viz
pd.set_option('display.max_columns', 100)  
pd.set_option('display.expand_frame_repr', 100)
pd.set_option('max_colwidth', 700)
pd.set_option('display.max_rows', 5000)
  
# save/load pickles
def pickle_file(path, data=None):
    if data is None:
        with open(path, 'rb') as f:
            return pickle.load(f)
    if data is not None:
        with open(path, 'wb') as handle:
            pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
 
# path base
path_base =  '/content/drive/MyDrive/ColBERT/ColBERT - FAQ Receita Federal/'

In [None]:
# =============
# ✨ Constants
# =============
bsize = 16 # N
query_maxlen = 48
doc_maxlen = 128
path_model = 'bert-base-multilingual-uncased'

# ==================
# ✨ QueryTokenizer
# ==================
class QueryTokenizer():
    def __init__(self, query_maxlen, path_tokenizer):
        self.tok = BertTokenizerFast.from_pretrained(path_tokenizer)
        self.query_maxlen = query_maxlen

        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
        self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token], [self.sep_token]
        tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        obj = self.tok(batch_text, padding='max_length', truncation=True,
                       return_tensors='pt', max_length=self.query_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        ids[ids == 0] = self.mask_token_id

        if bsize:
            batches = _split_into_batches(ids, mask, bsize)
            return batches

        return ids, mask

# ================
# ✨ DocTokenizer
# ================
class DocTokenizer():
    def __init__(self, doc_maxlen, path_tokenizer):
        self.tok = BertTokenizerFast.from_pretrained(path_tokenizer)
        self.doc_maxlen = doc_maxlen

        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token], [self.sep_token]
        tokens = [prefix + lst + suffix for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        obj = self.tok(batch_text, padding='longest', truncation='longest_first',
                       return_tensors='pt', max_length=self.doc_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        if bsize:
            ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
            batches = _split_into_batches(ids, mask, bsize)
            return batches, reverse_indices

        return ids, mask

# =====================
# ✨ tensorize triples
# =====================
def tensorize_triples(query_tokenizer, doc_tokenizer, queries, positives, negatives, bsize):
    assert len(queries) == len(positives) == len(negatives)
    assert bsize is None or len(queries) % bsize == 0

    N = len(queries)
    assert bsize == N
    Q_ids, Q_mask = query_tokenizer.tensorize(queries)
    D_ids, D_mask = doc_tokenizer.tensorize(positives + negatives)
    D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1)

    # Compute max among {length of i^th positive, length of i^th negative} for i \in N
    maxlens = D_mask.sum(-1).max(0).values

    # Sort by maxlens
    indices = maxlens.sort().indices
    Q_ids, Q_mask = Q_ids[indices], Q_mask[indices]
    D_ids, D_mask = D_ids[:, indices], D_mask[:, indices]

    (positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask

    query_batches = _split_into_batches(Q_ids, Q_mask, bsize)
    positive_batches = _split_into_batches(positive_ids, positive_mask, bsize)
    negative_batches = _split_into_batches(negative_ids, negative_mask, bsize)

    batches = []
    for (q_ids, q_mask), (p_ids, p_mask), (n_ids, n_mask) in zip(query_batches, positive_batches, negative_batches):
        Q = (torch.cat((q_ids, q_ids)), torch.cat((q_mask, q_mask)))
        D = (torch.cat((p_ids, n_ids)), torch.cat((p_mask, n_mask)))
        batches.append((Q, D))

    return batches

# =============
# ✨ Aux funcs
# =============
def _sort_by_length(ids, mask, bsize):
    if ids.size(0) <= bsize:
        return ids, mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices

    return ids[indices], mask[indices], reverse_indices

def _split_into_batches(ids, mask, bsize):
    batches = []
    for offset in range(0, ids.size(0), bsize):
        batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))

    return batches

# ===============
# ✨ LazyBatcher
# ===============
class LazyBatcher():
    def __init__(self, bsize, path, path_tokenizer, query_maxlen, doc_maxlen, mode='train', accumsteps=1):
        self.bsize, self.accumsteps = bsize, accumsteps
        self.query_tokenizer = QueryTokenizer(query_maxlen=query_maxlen, path_tokenizer=path_tokenizer)
        self.doc_tokenizer = DocTokenizer(doc_maxlen=doc_maxlen, path_tokenizer=path_tokenizer)
        self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer)
        self.position = 0
        self.mode = mode

        self.triples = self._load_triples(path_base)
        self.queries = self._load_queries(path_base)
        self.collection = self._load_collection(path_base)
    
    def _load_triples(self, path):
        if self.mode == 'train':
            path = path+'data/df_FAQ_triplet_IDS_TRAIN.parquet.gzip'
        elif self.mode == 'valid':
            path = path+'data/df_FAQ_triplet_IDS_VALID.parquet.gzip'

        df_triplet = pd.read_parquet(path)
        triples = []
        for qid, pos_pid, neg_pid in zip(
            df_triplet.qid.values,
            df_triplet.pos_pid.values,
            df_triplet.neg_pid.values
            ):
            triples.append((qid, pos_pid, neg_pid))

        return triples

    def _load_queries(self, path):
        if self.mode == 'train':
            qid_to_query_train = path+'data/qid_to_query_TRAIN'
            return pickle_file(qid_to_query_train)
        elif self.mode == 'valid':
            qid_to_query_valid = path+'data/qid_to_query_VALID'
            return pickle_file(qid_to_query_valid)

    def _load_collection(self, path):
        if self.mode == 'train':
            pid_to_doc_train = path+'data/pid_to_doc_TRAIN'
            return pickle_file(pid_to_doc_train)
        elif self.mode == 'valid':
            pid_to_doc_valid = path+'data/pid_to_doc_VALID'
            return pickle_file(pid_to_doc_valid)
        

    def __iter__(self):
        return self

    def __len__(self):
        return len(self.triples)

    def __next__(self):
        # offsets determines the starting index position of each bag (sequence) in input.
        offset, endpos = self.position, min(self.position + self.bsize, len(self.triples))
        self.position = endpos

        if offset + self.bsize > len(self.triples):
            raise StopIteration

        queries, positives, negatives = [], [], []

        for position in range(offset, endpos):
            query, pos, neg = self.triples[position]
            query, pos, neg = self.queries[query], self.collection[pos], self.collection[neg]
            queries.append(query)
            positives.append(pos)
            negatives.append(neg)

        return self.collate(queries, positives, negatives)

    def collate(self, queries, positives, negatives):
        assert len(queries) == len(positives) == len(negatives) == self.bsize

        return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps)

# - - - - -
dataloader_train = LazyBatcher(
    bsize=bsize, 
    path=path_base, 
    path_tokenizer=path_model,
    query_maxlen=query_maxlen,
    doc_maxlen=doc_maxlen,
    mode='train'
    )
print('batches:')
for i, batches in enumerate(dataloader_train):
    print(f' {i }.', end ='')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=871891.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1715180.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…


batch:
 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37.

# Train
### Model

In [None]:
# =======================
# ✨ step 0: get a batch
# =======================
dataloader_train = LazyBatcher(
    bsize=bsize, 
    path=path_base, 
    path_tokenizer=path_model,
    query_maxlen=query_maxlen,
    doc_maxlen=doc_maxlen,
    mode='train'
    )

dl_0 = next(iter(dataloader_train))

for i, (queries, passages) in enumerate(dl_0):
    if i==0: break
q_input_ids, q_attention_mask = queries
print('q_input_ids', q_input_ids.shape)
print('q_attention_mask', q_attention_mask.shape, '\n')

d_input_ids, d_attention_mask = passages
print('d_input_ids',d_input_ids.shape)
print('d_attention_mask', d_attention_mask.shape)   

q_input_ids torch.Size([32, 48])
q_attention_mask torch.Size([32, 48]) 

d_input_ids torch.Size([32, 128])
d_attention_mask torch.Size([32, 128])


In [None]:
try:
    del model
    gc.collect()
    torch.cuda.empty_cache()
except:
    pass

# =============================
# ✨ step 0.1: init BERT Model
# =============================
model = BertModel.from_pretrained(path_model, return_dict=True)

# =================================
# ✨ step 0.2: buila a desen layer
# =================================
linear = torch.nn.Linear(model.config.hidden_size, 128, bias=False)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=625.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=672271273.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Prepare Q 

In [None]:
# ============================================================
# ✨ step 0: input the input_ids and attention_mask into BERT
# ============================================================
outs = model(input_ids=q_input_ids, attention_mask=q_attention_mask)
Q = outs['last_hidden_state']
Q.shape

torch.Size([32, 48, 768])

In [None]:
# ========================================
# ✨ step 1: perform Q in the dense layer
# ========================================
Q = linear(Q)
Q.shape

torch.Size([32, 48, 128])

In [None]:
# ===========================
# ✨ step 2: normalize in L2
# ===========================
Q = torch.nn.functional.normalize(Q, p=2, dim=2)
Q.shape

torch.Size([32, 48, 128])

# Prepare D

In [None]:
# ========================================================
# ✨ step 0: input input_ids and attention_mask into BERT
# ========================================================
outs = model(input_ids=d_input_ids, attention_mask=d_attention_mask)
D = outs['last_hidden_state']
D.shape

torch.Size([32, 128, 768])

In [None]:
# ========================================
# ✨ step 1: perform D in the dense layer
# ========================================
D = linear(D)
D.shape

torch.Size([32, 128, 128])

In [None]:
# ==============================
# ✨ step 2: filter D with mask
# ==============================
mask = torch.tensor([[(x != 0) for x in d] for d in d_input_ids.cpu().tolist()])
print(mask.shape)
D = D * mask.unsqueeze(2).float()
print(D.shape)

# ===========================
# ✨ step 3: normalize in L2
# ===========================
D = torch.nn.functional.normalize(D, p=2, dim=2)
D.shape

torch.Size([32, 128])
torch.Size([32, 128, 128])


torch.Size([32, 128, 128])

# Get the score between Q and D

In [None]:
# ==========
# ✨ step 0: 
# ==========
print(f'Q shape: {Q.size()} -- D shape: {D.size()}\n')

scores = torch.einsum('abc, adc -> abd', Q, D)
print(f'Score shape:                       {scores.size()}')

scores = scores.max(2)
print(f'Scores.max(2).values shape:        {scores.values.size()}')

scores = scores.values.sum(1)
print(f'Scores.max(2).values.sum(1) shape: {scores.size()}')

print(f'Scores: {[float(str(s.item())[:7]) for s in scores]:}')

Q shape: torch.Size([32, 48, 128]) -- D shape: torch.Size([32, 128, 128])

Score shape:                       torch.Size([32, 48, 128])
Scores.max(2).values shape:        torch.Size([32, 48])
Scores.max(2).values.sum(1) shape: torch.Size([32])
Scores: [32.9297, 30.2123, 33.2069, 42.3434, 26.3653, 34.8422, 33.7496, 30.2646, 32.624, 30.8361, 27.914, 34.2535, 28.7218, 31.8257, 32.0632, 32.5547, 29.2759, 23.2356, 27.2142, 23.6811, 25.1725, 29.007, 28.0527, 25.8241, 27.9958, 28.2469, 26.3455, 28.763, 27.407, 26.5836, 28.0287, 26.7049]


# Otimization

In [None]:
# ==============================================================
# ✨ step 0: build pseudo-labels for the classe 0 (torch zeros)
# ==============================================================
# labels has the size of the batch bsize
labels = torch.zeros(bsize, dtype=torch.long)
print('labels shape', labels.shape, '\n')

labels shape torch.Size([16]) 



In [None]:
# ==============================
# ✨ step 1: reshap the  scores
# ==============================
scores = scores.view(2, -1).permute(1, 0)
print(scores.shape)
scores

torch.Size([16, 2])


tensor([[32.9297, 29.2759],
        [30.2123, 23.2356],
        [33.2070, 27.2143],
        [42.3435, 23.6811],
        [26.3653, 25.1725],
        [34.8423, 29.0071],
        [33.7497, 28.0528],
        [30.2647, 25.8242],
        [32.6241, 27.9959],
        [30.8361, 28.2470],
        [27.9140, 26.3455],
        [34.2535, 28.7630],
        [28.7219, 27.4071],
        [31.8258, 26.5837],
        [32.0633, 28.0287],
        [32.5547, 26.7050]], grad_fn=<PermuteBackward>)

In [None]:
# ====================
# ✨ step 2: get loss
# ====================
# init the CE Loss
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(scores, labels[:scores.size(0)])
loss.item()

0.05318648740649223

# In class format

In [None]:
# ===========
# ✨ ColBERT
# ===========
class ColBERT(BertPreTrainedModel):
    def __init__(self, config, query_maxlen, doc_maxlen, mask_punctuation, dim=128, similarity_metric='cosine'):

        super(ColBERT, self).__init__(config)

        self.query_maxlen = query_maxlen
        self.doc_maxlen = doc_maxlen
        self.similarity_metric = similarity_metric
        self.dim = dim

        self.mask_punctuation = mask_punctuation
        self.skiplist = {}

        if self.mask_punctuation:
            self.tokenizer = BertTokenizerFast.from_pretrained(path_model)
            self.skiplist = {w: True
                             for symbol in string.punctuation
                             for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]}

        self.bert = BertModel(config)
        self.linear = torch.nn.Linear(config.hidden_size, dim, bias=False)

        self.init_weights()

    def forward(self, Q, D):
        return self.score(self.query(*Q), self.doc(*D))

    def query(self, input_ids, attention_mask):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        Q = self.bert(input_ids, attention_mask=attention_mask)[0]
        Q = self.linear(Q)

        return torch.nn.functional.normalize(Q, p=2, dim=2)

    def doc(self, input_ids, attention_mask, keep_dims=True):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        D = self.bert(input_ids, attention_mask=attention_mask)[0]
        D = self.linear(D)

        mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float()
        D = D * mask

        D = torch.nn.functional.normalize(D, p=2, dim=2)

        if not keep_dims:
            D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
            D = [d[mask[idx]] for idx, d in enumerate(D)]

        return D

    def score(self, Q, D):
        if self.similarity_metric == 'cosine':
            return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)

        assert self.similarity_metric == 'l2'
        return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)

    def mask(self, input_ids):
        mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
        return mask        

In [None]:
try:
    del colbert
    gc.collect()
    torch.cuda.empty_cache()
except:
    pass

DEVICE = 'cuda'

colbert = ColBERT.from_pretrained(
    path_model,
    query_maxlen=query_maxlen,
    doc_maxlen=doc_maxlen,
    dim=128,
    similarity_metric='cosine',
    mask_punctuation=False).to(DEVICE)

criterion = torch.nn.CrossEntropyLoss()
labels = torch.zeros(bsize, dtype=torch.long)
colbert.train()

optimizer = torch.optim.AdamW(filter(
    lambda p: p.requires_grad, colbert.parameters()), 
    lr=3e-06, eps=1e-8,
    )

N_EPOCHS = 3
training_stats = []
print(f'\n>>> Training for {N_EPOCHS} epochs <<<\n')

# ==============
# ✨ train loop
# ==============
start_time = time.time()
for i in range(1, N_EPOCHS+1):
    dataloader_train = LazyBatcher(
        bsize=bsize, 
        path=path_base, 
        path_tokenizer=path_model,
        query_maxlen=query_maxlen,
        doc_maxlen=doc_maxlen,
        mode='train'
    )
    # -------- this is the loop ---------
    train_loss, epoch_loss = 0.0, []
    for batch_number, batches in enumerate(dataloader_train, 1):
        print('.', end='')

        optimizer.zero_grad()
    
        queries, passages = batches[0]
        scores = colbert(queries, passages).view(2, -1).permute(1, 0)
        loss = criterion(scores, labels[:scores.size(0)].to(DEVICE))

        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    epoch_loss.append(train_loss)
    
    # -------- save stats per epoch ---------
    training_stats.append(
        {
            'epoch': i,
            'training loss': float(np.array(epoch_loss)/batch_number),
        }
    )
    print(f'\nEpoch [{i}/{N_EPOCHS}]:')
    print(f'\ttrain loss: {float(np.array(epoch_loss)/batch_number):.4f}')

elapsed = float(time.time() - start_time)

print('.'*batch_number)
print(f'\nTime spent on training: {elapsed/60: .2f} min')

Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing ColBERT: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing ColBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ColBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ColBERT were not initialized from the model checkpoint at bert-base-multilingual-uncased and are newly initialized: ['linear.weight'


>>> Training for 3 epochs <<<

......................................
Epoch [1/3]:
	train loss: 0.2069
......................................
Epoch [2/3]:
	train loss: 0.1254
......................................
Epoch [3/3]:
	train loss: 0.0813
......................................

Time spent on training:  1.19 min


In [None]:
df_stats = pd.DataFrame(data=training_stats)
df_stats = df_stats.set_index('epoch')
pd.set_option('precision', 3)
df_stats

Unnamed: 0_level_0,training loss
epoch,Unnamed: 1_level_1
1,0.207
2,0.125
3,0.081


In [None]:
SAVE = True

if SAVE:
    torch.save(colbert.state_dict(), path_base+'data/EPOCH_3_FAQ')

In [None]:
# ===================================
# ✨ test load model from checkpoint
# ===================================
colbert = ColBERT.from_pretrained(
    path_model,
    query_maxlen=query_maxlen,
    doc_maxlen=doc_maxlen,
    dim=128,
    similarity_metric='cosine',
    mask_punctuation=False).to(DEVICE)

colbert.load_state_dict(torch.load(path_base+'data/EPOCH_3_FAQ'))

Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing ColBERT: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing ColBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ColBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ColBERT were not initialized from the model checkpoint at bert-base-multilingual-uncased and are newly initialized: ['linear.weight'

<All keys matched successfully>