In [9]:
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel

def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def embed_string(text: str):
    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-large-v2')
    model = AutoModel.from_pretrained('intfloat/e5-large-v2')

    # Prefix the text with 'query: '
    text = 'query: ' + text

    # Tokenize the input text
    inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')

    # Generate model outputs
    outputs = model(**inputs)

    # Average pool the last hidden states and apply the attention mask
    embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])

    # Normalize the embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)

    # Convert tensor to list
    embeddings_list = embeddings.tolist()

    return embeddings_list

text = "How much protein should a female eat?"
embeddings_list = embed_string(text)

print("Embeddings: ", embeddings_list)


Embeddings:  [[0.016367772594094276, -0.07768592238426208, 0.025325743481516838, -0.006067469250410795, -0.036973737180233, 0.017542056739330292, -0.032574061304330826, -0.07031966000795364, -0.031784724444150925, 0.0404752679169178, 0.04285208880901337, -0.02163163386285305, 0.03059859201312065, 0.018854863941669464, 0.004938546568155289, -0.0019983206875622272, -0.026725780218839645, -0.009781244210898876, 0.004824492614716291, -0.04963250458240509, 0.03962915018200874, -0.05389925092458725, -0.016336366534233093, 0.0413457490503788, -0.03457082808017731, -0.00037096842424944043, -0.011879738420248032, -0.0014438929501920938, -0.015483500435948372, -0.011211621575057507, 0.01299666054546833, -0.002892272314056754, 0.06803121417760849, -0.03911842778325081, 0.048065416514873505, -0.007100287359207869, -0.05535821616649628, -0.02251799777150154, 0.037919607013463974, 0.029768234118819237, 0.00014554608787875623, 0.029410487040877342, -0.05231502279639244, 0.005632266402244568, 0.001865

In [10]:
text = "How much protein should a female eat?"
embeddings = embed_string(text)


In [8]:
print(embeddings)

[[0.016367772594094276, -0.07768592238426208, 0.025325743481516838, -0.006067469250410795, -0.036973737180233, 0.017542056739330292, -0.032574061304330826, -0.07031966000795364, -0.031784724444150925, 0.0404752679169178, 0.04285208880901337, -0.02163163386285305, 0.03059859201312065, 0.018854863941669464, 0.004938546568155289, -0.0019983206875622272, -0.026725780218839645, -0.009781244210898876, 0.004824492614716291, -0.04963250458240509, 0.03962915018200874, -0.05389925092458725, -0.016336366534233093, 0.0413457490503788, -0.03457082808017731, -0.00037096842424944043, -0.011879738420248032, -0.0014438929501920938, -0.015483500435948372, -0.011211621575057507, 0.01299666054546833, -0.002892272314056754, 0.06803121417760849, -0.03911842778325081, 0.048065416514873505, -0.007100287359207869, -0.05535821616649628, -0.02251799777150154, 0.037919607013463974, 0.029768234118819237, 0.00014554608787875623, 0.029410487040877342, -0.05231502279639244, 0.005632266402244568, 0.0018655768362805247