In [5]:
from transformers import AutoTokenizer, AutoModel
import torch
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
    
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask


#Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/bert-base-nli-mean-tokens")
model = AutoModel.from_pretrained("sentence-transformers/bert-base-nli-mean-tokens").to(DEVICE)


In [15]:
# =============================
#  Input should be a list of sentences.
# =============================

def get_doc_emb(doc_text):
    global tokenizer
    global model
    global DEVICE
    
    encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=512, return_tensors='pt').to(DEVICE)
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling. In this case, mean pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    # Take a mean of the sentences 
    return torch.mean(sentence_embeddings,dim=-2)


In [16]:

#Sentences we want sentence embeddings for
sentences = ['This framework generates embeddings for each input sentence',
             'Sentences are passed as a list of string.',
             'The quick brown fox jumps over the lazy dog.']


get_doc_emb(sentences)

tensor([-3.2809e-01,  3.5750e-01,  4.1518e-01,  1.6900e-01,  8.3606e-02,
         7.3211e-02, -1.5879e-01, -1.9684e-02, -1.9021e-01, -2.7730e-01,
        -6.0024e-01,  3.9147e-01,  1.7148e-01, -1.0057e-01,  2.3102e-02,
        -2.5747e-01, -8.4040e-02, -2.2876e-01,  1.7119e-01, -4.9106e-01,
        -2.7618e-01,  6.3767e-03, -9.7387e-01, -7.8169e-01,  7.6970e-01,
        -3.0025e-01,  1.9407e-01, -1.6629e-01, -8.6983e-01,  2.1038e-01,
        -3.9116e-01,  9.4829e-02,  8.0652e-01, -2.1120e-01, -2.6125e-01,
        -6.1115e-02,  5.6831e-01, -5.0122e-02,  1.3548e-01, -6.2303e-01,
         4.1839e-01, -1.0829e-01,  2.7848e-01, -6.3659e-02, -1.1550e+00,
        -3.1414e-01, -6.4800e-01, -3.0956e-02, -1.4968e-02, -1.5882e-01,
        -8.7936e-01,  2.1938e-01,  2.3919e-01,  1.1233e-01, -6.9157e-01,
         1.7533e-01,  5.2621e-01, -5.2849e-01, -6.3219e-02,  7.1167e-02,
        -2.8902e-01, -2.2216e-01,  2.6226e-01,  3.8673e-01, -4.7657e-01,
         1.8633e-01,  4.0936e-01, -2.1928e-01, -9.4

In [3]:
DEVICE

<torch.cuda.device at 0x7f524114a610>

In [20]:
mean_pool_layer = torch.nn.AvgPool1d(model.embeddings.word_embeddings.embedding_dim)

In [37]:
torch.mean(torch.stack( [sentence_embeddings,sentence_embeddings], dim=0),dim=-2).shape

torch.Size([2, 768])