## Sentence Transformer
* https://huggingface.co/jhgan/ko-sroberta-multitask
* https://github.com/jhgan00/ko-sentence-transformers

In [1]:
import torch
from sentence_transformers import SentenceTransformer

device = torch.device('mps') if torch.cuda.is_available() else torch.device('cpu')
sbert_model_name = "jhgan/ko-sroberta-multitask"
model = SentenceTransformer(sbert_model_name).to(device=device)

sentences = ["안녕하세요?", "한국어 문장 임베딩을 위한 버트 모델입니다."]
embedding = model.encode(sentences, convert_to_numpy=True)

embedding

array([[-0.37510464, -0.7733839 ,  0.5927711 , ...,  0.57923526,
         0.32683483, -0.6508965 ],
       [-0.09361704, -0.18191524, -0.19230816, ..., -0.03165802,
         0.30412534, -0.2679362 ]], dtype=float32)

## HuggingFace Transformers

In [2]:
from transformers import AutoTokenizer, AutoModel
import torch


#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


# Sentences we want sentence embeddings for
sentences = ['This is an example sentence', 'Each sentence is converted']

# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('jhgan/ko-sroberta-multitask')
model = AutoModel.from_pretrained('jhgan/ko-sroberta-multitask')

# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

# Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input)

# Perform pooling. In this case, mean pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
sentence_embeddings

tensor([[ 0.3026, -0.5754,  0.2507,  ...,  0.4219, -0.0588, -0.5399],
        [-0.2494, -0.2275, -0.2123,  ..., -0.1167, -0.0520, -0.4261]])