## 의미 검색 유사도

### 데이터셋 준비

In [1]:
from datasets import load_dataset

# https://huggingface.co/datasets/klue/klue/viewer/mrc
klue_mrc_dataset = load_dataset('klue','mrc', split='train')
klue_mrc_dataset

README.md:   0%|          | 0.00/22.5k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/8.68M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/17554 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5841 [00:00<?, ? examples/s]

Dataset({
    features: ['title', 'context', 'news_category', 'source', 'guid', 'is_impossible', 'question_type', 'question', 'answers'],
    num_rows: 17554
})

In [5]:
# 일부 데이터만 사용 row 1000개 정도
row_size = 1000
# klue_mrc_dataset[:1000]
klue_mrc_dataset_train = klue_mrc_dataset.train_test_split(train_size = row_size
                                 , shuffle=False)['train']
klue_mrc_dataset_train

Dataset({
    features: ['title', 'context', 'news_category', 'source', 'guid', 'is_impossible', 'question_type', 'question', 'answers'],
    num_rows: 1000
})

In [7]:
klue_mrc_dataset_train.shape

(1000, 9)

In [9]:
# klue_mrc_dataset_train[10]
klue_mrc_dataset_train[10].keys()

dict_keys(['title', 'context', 'news_category', 'source', 'guid', 'is_impossible', 'question_type', 'question', 'answers'])

### 데이터 임베딩

In [11]:
from sentence_transformers import SentenceTransformer
sentence_model = SentenceTransformer('snunlp/KR-SBERT-V40K-klueNLI-augSTS')
sentence_model

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/4.02k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/467M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/394 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/336k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/967k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

In [12]:
context_embedded = sentence_model.encode(klue_mrc_dataset_train['context'])
context_embedded.shape

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

(1000, 768)

### 검색 인덱스 생성(KNN 알고리즘 사용)

In [14]:
!pip install faiss-cpu faiss-gpu

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m17.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [18]:
import faiss # 메타 API로 벡터 거리 계산용

index_knn = faiss.IndexFlatL2(context_embedded.shape[1]) # KNN 알고리즘 초기화, 768에 벡터공간 할당
index_knn

<faiss.swigfaiss_avx2.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x7851192aed60> >

In [19]:
# 인덱스에 임베딩 저장 : 테이블 생성 유사(메모리용 벡터데이터베이스 유사)
index_knn.add(context_embedded)

In [20]:
type(index_knn), index_knn

(faiss.swigfaiss_avx2.IndexFlatL2,
 <faiss.swigfaiss_avx2.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x7851192aed60> >)

In [24]:
## 활용
query = '이번 연도에는 언제 비가 많이 올까?'
query_embedded = sentence_model.encode([query])
query_embedded.shape

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

(1, 768)

In [27]:
# index_knn.search(query_embedded, 3)
distances, indices = index_knn.search(query_embedded, 3)

for idx in indices[0]:
    print(klue_mrc_dataset_train['context'][idx][:30])

올여름 장마가 17일 제주도에서 시작됐다. 서울 등 중
연구 결과에 따르면, 오리너구리의 눈은 대부분의 포유류
연구 결과에 따르면, 오리너구리의 눈은 대부분의 포유류
