In [10]:
import os
from typing import Dict
import torch
from transformers import AutoModel, AutoTokenizer
from src.constants import SEPARATOR_TOKEN, CLS_TOKEN
from src.utils import pool_and_normalize
from src.datasets_loader import prepare_tokenizer
from src.preprocessing_utils import truncate_sentences
from mteb import MTEB
from abc import ABC, abstractmethod

In [2]:
OUTPUT_FOLDER = "/mnt/colab_public/datasets/joao/mteb_results/"
DEVICE = "cuda:0"
BATCH_SIZE=32
MAX_INPUT_LEN = 10000
MAX_TOKEN_LEN = 1024

In [3]:
def set_device(inputs: Dict[str, torch.Tensor], device: str) -> Dict[str, torch.Tensor]:
    output_data = {}
    for k, v in inputs.items():
        output_data[k] = v.to(device)
    
    return output_data

In [4]:
class BaseEncoder(torch.nn.Module, ABC):

    def __init__(self, device, max_input_len, maximum_token_len, model_name):
        super().__init__()

        self.model_name = model_name
        self.tokenizer = prepare_tokenizer(model_name)
        self.encoder = AutoModel.from_pretrained(model_name, use_auth_token=True).to(DEVICE).eval()
        self.device = device
        self.max_input_len = max_input_len
        self.maximum_token_len = maximum_token_len
    
    @abstractmethod
    def forward(self,):
        pass
    
    def encode(self, input_sentences, batch_size=32, **kwargs):

        truncated_input_sentences = truncate_sentences(input_sentences, self.max_input_len)

        n_batches = len(truncated_input_sentences) // batch_size + int(len(truncated_input_sentences) % batch_size > 0)

        embedding_batch_list = []

        for i in range(n_batches):
            start_idx = i*batch_size
            end_idx = min((i+1)*batch_size, len(truncated_input_sentences))

            with torch.no_grad():
                embedding_batch_list.append(
                    self.forward(truncated_input_sentences[start_idx:end_idx]).detach().cpu()
                )

        input_sentences_embedding = torch.cat(embedding_batch_list)

        return [emb.squeeze().numpy() for emb in input_sentences_embedding]

class BigCodeEncoder(BaseEncoder):

    def __init__(self, device, max_input_len, maximum_token_len):
        super().__init__(device, max_input_len, maximum_token_len, model_name = "bigcode/bigcode-encoder")
    
    def forward(self, input_sentences):

        inputs = self.tokenizer(
            [f"{CLS_TOKEN}{sentence}{SEPARATOR_TOKEN}" for sentence in input_sentences], 
            padding="longest",
            max_length=self.maximum_token_len,
            truncation=True,
            return_tensors="pt",
            )

        outputs = self.encoder(**set_device(inputs, self.device))
        embedding = pool_and_normalize(outputs.hidden_states[-1], inputs.attention_mask)

        return embedding

class CodeBERT(BaseEncoder):

    def __init__(self, device, max_input_len, maximum_token_len):
        super().__init__(device, max_input_len, maximum_token_len, model_name = "microsoft/codebert-base")

        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
    
    def forward(self, input_sentences):

        inputs = self.tokenizer(
            [sentence for sentence in input_sentences], 
            padding="longest",
            max_length=self.maximum_token_len,
            truncation=True,
            return_tensors="pt",
            )

        inputs = set_device(inputs, self.device)

        outputs = self.encoder(inputs["input_ids"], inputs["attention_mask"])

        embedding = outputs["pooler_output"]

        return embedding



In [5]:
codebert = CodeBERT(DEVICE, MAX_INPUT_LEN, MAX_TOKEN_LEN)
bigcode_model = BigCodeEncoder(DEVICE, MAX_INPUT_LEN, MAX_TOKEN_LEN)

Some weights of the model checkpoint at bigcode/bigcode-encoder were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
input_sentences = [
    "Hello world!!",
    "def my_sum(a, b): return a+b"
]

codebert_embeddings = codebert.encode(input_sentences)
bigcode_model_embeddings = bigcode_model.encode(input_sentences)


In [12]:
# evaluation = MTEB(task_types=['Clustering', 'Retrieval'], task_categories=['s2s'])
evaluation = MTEB(tasks=['BiorxivClusteringS2S'])

In [8]:
results_bigcode_encoder = evaluation.run(
    bigcode_model, 
    output_folder=os.path.join(OUTPUT_FOLDER, "bigcode_encoder"), 
    batch_size=BATCH_SIZE, 
    overwrite_results=True,)

results_bigcode_encoder

{'BiorxivClusteringS2S': {'mteb_version': '1.0.1',
  'dataset_revision': '258694dd0231531bc1fd9de6ceb52a0853c6d908',
  'mteb_dataset_name': 'BiorxivClusteringS2S',
  'test': {'v_measure': 0.15253719531046367,
   'v_measure_std': 0.006991505947132362,
   'evaluation_time': 127.73}}}

In [None]:
results_codebert = evaluation.run(
    codebert, 
    output_folder=os.path.join(OUTPUT_FOLDER, "codebert"), 
    batch_size=BATCH_SIZE, 
    overwrite_results=True,)

results_codebert