In [None]:
pip install sacremoses

In [17]:
# Load AutoModel from huggingface model repository
# modelName = "sentence-transformers/all-MiniLM-L6-v2";
modelName = "neuml/pubmedbert-base-embeddings";
pubmedbert_tokenizer = AutoTokenizer.from_pretrained(modelName)
pubmedbert_model = AutoModel.from_pretrained(modelName)

modelName = "microsoft/biogpt";
biogpt_tokenizer = AutoTokenizer.from_pretrained(modelName)
biogpt_model = AutoModel.from_pretrained(modelName)

In [19]:
from transformers import AutoTokenizer, AutoModel, BioGptModel
from sentence_transformers import SentenceTransformer
import torch


# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

def embeddings(model, tokenizer, sentences):
    # # Tokenize sentences
    encoded_input = tokenizer(
        sentences, padding=True, truncation=True, max_length=128, return_tensors="pt"
    )
    
    # # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)
    
    # # Perform pooling. In this case, mean pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
    return sentence_embeddings


In [20]:
# Sentences we want sentence embeddings for
sentences = [
    "This framework generates embeddings for each input sentence",
    "Sentences are passed as a list of string.",
    "The quick brown fox jumps over the lazy dog.",
]

biogpt_embeddings = embeddings(biogpt_model, biogpt_tokenizer, sentences)
pubmedbert_embeddings = embeddings(pubmedbert_model, pubmedbert_tokenizer, sentences)
print(biogpt_embeddings)
print(pubmedbert_embeddings)

tensor([[ 0.5246,  0.1851,  1.5204,  ...,  0.0856,  0.0222,  0.9674],
        [ 0.0245, -1.2456,  0.5576,  ..., -0.6955,  1.2123,  0.1388],
        [-0.7520,  0.5984,  0.2248,  ..., -0.1149,  0.9079,  0.3904]])
tensor([[-0.6495,  0.5809, -0.4687,  ..., -0.1336, -0.7726, -0.2508],
        [-0.9131,  0.6357,  0.5418,  ..., -0.2846, -0.9001,  0.1210],
        [-0.5454,  0.9630, -0.4064,  ...,  0.1175, -1.1538, -0.1450]])
