## 의미 검색 유사도 

### 데이터셋 준비

In [1]:
from datasets import load_dataset

# https://huggingface.co/datasets/klue/klue/viewer/sts
klue_mrc_dataset = load_dataset('klue', 'mrc', split='train')
klue_mrc_dataset

README.md:   0%|          | 0.00/22.5k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/8.68M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/17554 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5841 [00:00<?, ? examples/s]

Dataset({
    features: ['title', 'context', 'news_category', 'source', 'guid', 'is_impossible', 'question_type', 'question', 'answers'],
    num_rows: 17554
})

In [2]:
# 일부 데이터만 사용 row 1000 정도 
row_size = 1000
# kue_mrc_dataset[:1000]
klue_mrc_dataset_train = klue_mrc_dataset.train_test_split(train_size=row_size
                                 , shuffle=False)['train']

klue_mrc_dataset_train

Dataset({
    features: ['title', 'context', 'news_category', 'source', 'guid', 'is_impossible', 'question_type', 'question', 'answers'],
    num_rows: 1000
})

In [3]:
klue_mrc_dataset_train.shape

(1000, 9)

In [4]:
klue_mrc_dataset_train[10].keys()


dict_keys(['title', 'context', 'news_category', 'source', 'guid', 'is_impossible', 'question_type', 'question', 'answers'])

### 데이터 임베딩

In [5]:
from sentence_transformers import SentenceTransformer
sentence_model = SentenceTransformer('snunlp/KR-SBERT-V40K-klueNLI-augSTS')
sentence_model

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/4.02k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/467M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/394 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/336k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/967k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

In [6]:
context_embedded = sentence_model.encode(klue_mrc_dataset_train['context'])
context_embedded.shape

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

(1000, 768)

### 검색 인덱스 생성(KNN 알고리즘 사용)

In [7]:
!pip install faiss-cpu faiss-gpu -qqq

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m52.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [8]:
import faiss # 메타 API로 벡터 거리 계산용

index_knn = faiss.IndexFlatL2(context_embedded.shape[1])  #KNN 알고리즘 초기화 , 768에 벡터공간 할당 
index_knn

<faiss.swigfaiss_avx2.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x783d12b846f0> >

In [9]:
# 인덱스에 임베딩 저장 : 테이블 생성 유사(메모리용 벡터 데이터 베이스 유사)
index_knn.add(context_embedded)

In [10]:
type(index_knn), index_knn

(faiss.swigfaiss_avx2.IndexFlatL2,
 <faiss.swigfaiss_avx2.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x783d12b846f0> >)

In [11]:
## 활용 
query = '이번 년도에는 언제 비가 많이 올까?'
query_embedded = sentence_model.encode([query])
query_embedded.shape

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

(1, 768)

In [12]:
#index_knn.search(query_embedded, 3)
distances, indices = index_knn.search(query_embedded, 3)

for idx in indices[0]:
    print(klue_mrc_dataset_train['context'][idx][:30])


올여름 장마가 17일 제주도에서 시작됐다. 서울 등 중
써머스플랫폼(대표 김기범)이 운영하는 '에누리 가격비교
연구 결과에 따르면, 오리너구리의 눈은 대부분의 포유류
