## 라이브러리

In [None]:
import re
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

import torch
import pandas as pd
from transformers import AutoModel, AutoTokenizer
from google.cloud import storage

## 파라미터

In [None]:
MODEL_NAME='BM-K/KoSimCSE-roberta-multitask'
BUCKET_NAME='law-search'  # ==> GCS 저장소 bucket name

## 모듈

In [None]:
def convert_text_to_embedding(sub_df: pd.DataFrame) -> pd.DataFrame:
    """
    Text 데이터 embedding 데이터 변환
    """
    # 컬럼 변수
    col_embedding='embedding'

    # 모델 & 토크나이저
    model=AutoModel.from_pretrained(MODEL_NAME)
    tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME)

    #Mean Pooling - Take attention mask into account for correct averaging
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    embedding_list=[]
    for _, row in sub_df.iterrows():
        law_text=row.law
        
        # 전처리 pattern 및 전처리 진행
        pattern=r'\<(.*?)\>'
        prepro_text=re.sub(pattern, '', law_text).strip()

        # 특수기호 제거 및 띄어쓰기 반복 변형
        prepro_text=re.sub(r' +', ' ',re.sub(r'[^가-힣a-zA-Z0-9]', ' ', prepro_text)).strip()

        # Tokenize sentences
        encoded_input = tokenizer(prepro_text, padding=True, truncation=True, return_tensors='pt')

        # Compute token embeddings
        with torch.no_grad():
            model_output = model(**encoded_input)
        
        # Perform pooling. In this case, mean pooling.
        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
        # 데이터 정립
        embedding_list.append(sentence_embeddings[0].tolist())

    prepro_sub_df=sub_df.copy()
    prepro_sub_df[col_embedding]=embedding_list
    
    return prepro_sub_df