## Embedding model 의 문장 유사도 테스트
### [KoSimCSE-supervised-roberta-large](https://huggingface.co/daekeun-ml/KoSimCSE-supervised-roberta-large)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from transformers import AutoModel, AutoTokenizer, logging

class SimCSEConfig(PretrainedConfig):
    def __init__(self, version=1.0, **kwargs):
        self.version = version
        super().__init__(**kwargs)

class SimCSEModel(PreTrainedModel):
    config_class = SimCSEConfig

    def __init__(self, config):
        super().__init__(config)
        self.backbone = AutoModel.from_pretrained(config.base_model)
        self.hidden_size: int = self.backbone.config.hidden_size
        self.dense = nn.Linear(self.hidden_size, self.hidden_size)
        self.activation = nn.Tanh()

    def forward(
        self,
        input_ids: Tensor,
        attention_mask: Tensor = None,
        # RoBERTa variants don't have token_type_ids, so this argument is optional
        token_type_ids: Tensor = None,
    ) -> Tensor:
        # shape of input_ids: (batch_size, seq_len)
        # shape of attention_mask: (batch_size, seq_len)
        outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )

        emb = outputs.last_hidden_state[:, 0]

        if self.training:
            emb = self.dense(emb)
            emb = self.activation(emb)

        return emb

# Load pre-trained model
model = SimCSEModel.from_pretrained("daekeun-ml/KoSimCSE-supervised-roberta-large")
tokenizer = AutoTokenizer.from_pretrained("daekeun-ml/KoSimCSE-supervised-roberta-large")


  from .autonotebook import tqdm as notebook_tqdm
config.json: 100%|██████████| 582/582 [00:00<00:00, 1.06MB/s]
model.safetensors: 100%|██████████| 1.35G/1.35G [01:04<00:00, 20.9MB/s]
config.json: 100%|██████████| 547/547 [00:00<00:00, 1.35MB/s]
model.safetensors: 100%|██████████| 1.35G/1.35G [00:24<00:00, 54.2MB/s]
Some weights of RobertaModel were not initialized from the model checkpoint at klue/roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tokenizer_config.json: 100%|██████████| 415/415 [00:00<00:00, 561kB/s]
vocab.txt: 100%|██████████| 248k/248k [00:00<00:00, 507kB/s]
tokenizer.json: 100%|██████████| 752k/752k [00:00<00:00, 998kB/s] 
special_tokens_map.json: 100%|██████████| 173/173 [00:00<00:00, 797kB/s]


tensor([[92.9861]], grad_fn=<MulBackward0>) tensor([[83.0575]], grad_fn=<MulBackward0>)


In [3]:
def show_embedding_score(tokenizer, model, sentences):
    inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
    embeddings = model(**inputs)
    score01 = cal_score(embeddings[0,:], embeddings[1,:])
    score02 = cal_score(embeddings[0,:], embeddings[2,:])
    score03 = cal_score(embeddings[1,:], embeddings[2,:])
    
    score = [score01, score02, score03 ]
    max_score = max(score).item()
    max_score_idx = score.index(max(score))
    print(score01.item(), score02.item(), score03.item())
    if max_score_idx == 0:
        print(f"1,2번 문장이 {max_score}로 가장 유사합니다.")
    elif max_score_idx == 1:  
        print(f"1,3번 문장이 {max_score}로 가장 유사합니다.")
    else:    
        print(f"2,3번 문장이 {max_score}로 가장 유사합니다.") 

def cal_score(a, b):
    if len(a.shape) == 1: a = a.unsqueeze(0)
    if len(b.shape) == 1: b = b.unsqueeze(0)
    a_norm = a / a.norm(dim=1)[:, None]
    b_norm = b / b.norm(dim=1)[:, None]
    return torch.mm(a_norm, b_norm.transpose(0, 1)) * 100 

In [4]:
# Inference example
sentences = ['이번 주 일요일에 분당 이마트 점은 문을 여나요?',
             '일요일에 분당 이마트는 문 열어요?',
             '분당 이마트 점은 토요일에 몇 시까지 하나요']

show_embedding_score(tokenizer, model.cpu(), sentences)


92.98614501953125 83.0574722290039 77.81117248535156
1,2번 문장이 92.98614501953125로 가장 유사합니다.
