# Using Our Margin-MSE trained ColBERT Checkpoint

The next step is to download our checkpoint and initialize the tokenizer and models:


In [18]:
from transformers import AutoTokenizer,AutoModel, PreTrainedModel,PretrainedConfig
from typing import Dict
import torch

class ColBERTConfig(PretrainedConfig):
    model_type = "ColBERT"
    bert_model: str
    compression_dim: int = 768
    dropout: float = 0.0
    return_vecs: bool = False
    trainable: bool = True

class ColBERT(PreTrainedModel):
    """
    ColBERT model from: https://arxiv.org/pdf/2004.12832.pdf
    We use a dot-product instead of cosine per term (slightly better)
    """
    config_class = ColBERTConfig
    base_model_prefix = "bert_model"

    def __init__(self,
                 cfg) -> None:
        super().__init__(cfg)
        
        self.bert_model = AutoModel.from_pretrained(cfg.bert_model)

        for p in self.bert_model.parameters():
            p.requires_grad = cfg.trainable

        self.compressor = torch.nn.Linear(self.bert_model.config.hidden_size, cfg.compression_dim)

    def forward(self,
                query: Dict[str, torch.LongTensor],
                document: Dict[str, torch.LongTensor]):

        query_vecs = self.forward_representation(query)
        document_vecs = self.forward_representation(document)

        score = self.forward_aggregation(query_vecs,document_vecs,query["attention_mask"],document["attention_mask"])
        return score

    def forward_representation(self,
                               tokens,
                               sequence_type=None) -> torch.Tensor:
        
        vecs = self.bert_model(**tokens)[0] # assuming a distilbert model here
        vecs = self.compressor(vecs)

        # if encoding only, zero-out the mask values so we can compress storage
        if sequence_type == "doc_encode" or sequence_type == "query_encode": 
            vecs = vecs * tokens["tokens"]["mask"].unsqueeze(-1)

        return vecs

    def forward_aggregation(self,query_vecs, document_vecs,query_mask,document_mask):
        
        # create initial term-x-term scores (dot-product)
        score = torch.bmm(query_vecs, document_vecs.transpose(2,1))

        # mask out padding on the doc dimension (mask by -1000, because max should not select those, setting it to 0 might select them)
        exp_mask = document_mask.bool().unsqueeze(1).expand(-1,score.shape[1],-1)
        score[~exp_mask] = - 10000

        # max pooling over document dimension
        score = score.max(-1).values

        # mask out paddding query values
        score[~(query_mask.bool())] = 0

        # sum over query values
        score = score.sum(-1)

        return score

#
# init the model & tokenizer (using the distilbert tokenizer)
#
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # honestly not sure if that is the best way to go, but it works :)
model = ColBERT.from_pretrained("sebastian-hofstaetter/colbert-distilbert-margin_mse-T2-msmarco")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Now we are ready to use the model to encode two sample passage and query pairs (this would be the re-ranking mode, where we have a candidate list - for indexing or pre-compute mode you need to call forward_representation and forward_aggregation independently):

In [31]:
# our relevant example
passage1_input = tokenizer("We are very happy to show you the 🤗 Transformers library for pre-trained language models. We are helping the community work together towards the goal of advancing NLP 🔥.",return_tensors="pt")
# a non-relevant example
passage2_input = tokenizer("Hmm I don't like this new movie about transformers that i got from my local library. Those transformers are robots?",return_tensors="pt")

# the user query -> which should give us a better score for the first passage
query_input = tokenizer("what is the transformers library")
# adding the mask augmentation, we used 8 as the fixed number for training regardless of batch-size
# it has a somewhat (although not huge) positive impact on effectiveness, we hypothesize that might be due to the increased
# capacity of the query encoding, not so much because of the [MASK] pre-training, but who knows :)
query_input.input_ids += [103] * 8 # [MASK]
query_input.attention_mask += [1] * 8

query_input["input_ids"] = torch.LongTensor(query_input.input_ids).unsqueeze(0)
query_input["attention_mask"] = torch.LongTensor(query_input.attention_mask).unsqueeze(0)

#print("Passage 1 Tokenized:",passage1_input)
#print("Passage 2 Tokenized:",passage2_input)
#print("Query Tokenized:",query_input)

# note how we call the bert model for pairs, can be changed to: forward_representation and forward_aggregation
score_for_p1 = model.forward(query_input,[passage1_input, passage2_input]).squeeze(0)
# score_for_p2 = model.forward(query_input,passage2_input).squeeze(0)

print("---")
print("Score passage 1 <-> query: ",float(score_for_p1))
# print("Score passage 2 <-> query: ",float(score_for_p2))

TypeError: DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (activation): GELUActivation()
        )
        (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
    )
  )
) argument after ** must be a mapping, not list

In [25]:
query_input["input_ids"].size()

torch.Size([1, 15])

In [28]:
tokenizer("hi there", return_tensors = "pt")

{'input_ids': tensor([[ 101, 7632, 2045,  102]]), 'attention_mask': tensor([[1, 1, 1, 1]])}

In [30]:
score_for_p1.cpu()

tensor(106.4638, grad_fn=<SqueezeBackward1>)