In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForPreTraining
from src.constants import SEPARATOR_TOKEN, CLS_TOKEN
from src.utils import pool_and_normalize
from src.datasets_loader import prepare_tokenizer
from mteb import MTEB
# from huggingface_hub import notebook_login

In [3]:
OUTPUT_FOLDER = "/mnt/colab_public/datasets/joao/mteb_results/"

In [4]:
class embedding_model(torch.nn.Module):

    def __init__(self, ):
        super().__init__()

        _model_name = "bigcode/bigcode-encoder"
        self.tokenizer = prepare_tokenizer(AutoTokenizer.from_pretrained(_model_name, use_auth_token=True))
        self.encoder = AutoModelForPreTraining.from_pretrained(_model_name, use_auth_token=True)
    
    def forward(self, input_sentences):

        inputs = self.tokenizer([f"{CLS_TOKEN}{sentence}{SEPARATOR_TOKEN}" for sentence in input_sentences], return_tensors="pt", padding=True)
        outputs = self.encoder(**inputs)
        embedding = pool_and_normalize(outputs.hidden_states[-1], inputs.attention_mask)

        return embedding
    
    def encode(self, input_sentences, batch_size=32, **kwargs):

        n_batches = len(input_sentences) // batch_size + int(len(input_sentences) % batch_size > 0)

        embedding_batch_list = []

        for i in range(n_batches):
            start_idx = i*batch_size
            end_idx = min((i+1)*batch_size, len(input_sentences))

            with torch.no_grad():
                embedding_batch_list.append(
                    self.forward(input_sentences[start_idx:end_idx]).detach().cpu()
                )

        input_sentences_embedding = torch.cat(embedding_batch_list)

        return [emb.squeeze().numpy() for emb in input_sentences_embedding]


In [5]:
model = embedding_model()

In [6]:
input_sentences = [
    "Hello world!!",
    "def my_sum(a, b): return a+b"
]

embeddings = model(input_sentences)
embeddings.size()

torch.Size([2, 768])

In [7]:
embeddings_2 = model.encode(input_sentences)
embeddings_2 = torch.cat([torch.from_numpy(el)[None, :] for el in embeddings_2])

assert torch.all(embeddings == embeddings_2).item()

In [None]:
evaluation = MTEB(task_types=['Clustering', 'Retrieval'], task_categories=['s2s'])
results = evaluation.run(model, output_folder=OUTPUT_FOLDER)