#### Sentence Transformation and Embeddings

#### Acknowledgements

- https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2

#### Packages

In [1]:
import sentence_transformers as pkg_sentence_transformers
import transformers as pkg_transformers
import torch as pkg_torch

#### Play

In [2]:
model_path = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'

In [3]:
model = pkg_sentence_transformers.SentenceTransformer(model_path)

In [4]:
example_setences = [
    "This is an example sentence. Each sentence is converted.",
    "Female variant of a King is Queen. Male variant of a Queen is King."
]

In [5]:
example_embeddings = model.encode(example_setences)
print(example_embeddings)

[[ 0.11049473 -0.2600527  -0.01425929 ...  0.03306267 -0.0403426
  -0.14017496]
 [ 0.23600417 -0.12095904 -0.0118967  ...  0.02033678  0.12211031
  -0.0387219 ]]


In [6]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return pkg_torch.sum(token_embeddings * input_mask_expanded, 1) / pkg_torch.clamp(input_mask_expanded.sum(1), min=1e-9)


In [7]:
# Sentences we want sentence embeddings for
sentences = ['This is an example sentence', 'Each sentence is converted']

# Load model from HuggingFace Hub
tokenizer = pkg_transformers.AutoTokenizer.from_pretrained(model_path)
model = pkg_transformers.AutoModel.from_pretrained(model_path)

# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

# Compute token embeddings
with pkg_torch.no_grad():
    model_output = model(**encoded_input)

# Perform pooling. In this case, max pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

print("Sentence embeddings")
print(sentence_embeddings)

Sentence embeddings
tensor([[ 0.1432, -0.2308, -0.0139,  ...,  0.0399,  0.1009, -0.1994],
        [ 0.0400, -0.2041, -0.0131,  ...,  0.0263, -0.2010, -0.1452]])
