In [62]:
import requests
import json
import weaviate

from transformers import AutoModel, AutoTokenizer
import torch
import weaviate.classes as wvc
import numpy as np

In [23]:
model_name = "BAAI/bge-m3"  # 모델 이름
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()

XLMRobertaModel(
  (embeddings): XLMRobertaEmbeddings(
    (word_embeddings): Embedding(250002, 1024, padding_idx=1)
    (position_embeddings): Embedding(8194, 1024, padding_idx=1)
    (token_type_embeddings): Embedding(1, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): XLMRobertaEncoder(
    (layer): ModuleList(
      (0-23): 24 x XLMRobertaLayer(
        (attention): XLMRobertaAttention(
          (self): XLMRobertaSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): XLMRobertaSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-05, elementw

In [78]:
sentences=["안녕"]

In [79]:
print("Sentence embeddings:", sentence_embeddings)

Sentence embeddings: tensor([[-0.0160,  0.0335, -0.0276,  ...,  0.0172, -0.0421,  0.0203],
        [-0.0199,  0.0241, -0.0336,  ..., -0.0021, -0.0322,  0.0461],
        [-0.0133,  0.0142, -0.0392,  ...,  0.0147, -0.0407,  0.0138]])


In [160]:
def do_embedding(sentences: list, model=model):
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        model_output = model(**encoded_input)
        # Perform pooling. In this case, cls pooling.
        sentence_embeddings = model_output[0][:, 0]
    sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
    return sentence_embeddings[0]

In [162]:
do_embedding(sentences)

tensor([-0.0160,  0.0335, -0.0276,  ...,  0.0172, -0.0421,  0.0203])

In [93]:
client = weaviate.connect_to_local(
    host="127.0.0.1",  # Use a string to specify the host
    port=8080,
    grpc_port=50051,
)

In [259]:
client.collections.delete("Question")
my_vecs = client.collections.create(
        "Question",
        vectorizer_config=wvc.config.Configure.Vectorizer.none(),
        vector_index_config=wvc.config.Configure.VectorIndex.hnsw(
            distance_metric=wvc.config.VectorDistances.COSINE # select prefered distance metric
        ),
    )

In [260]:
my_new_questions = [
    {"answer": "투썸플레이스", "question": "비싼 카페는?", "category": "카페"},
    {"answer": "백다방", "question": "저렴하고 간얼음 파는 카페는?", "category": "카페"},
    {"answer": "메가커피", "question": "쫌 싸지만 손흥민이 광고하는 카페", "category": "카페"},
    {"answer": "플래그원", "question": "두가지 원두 커피를 제공하는 공유오피스", "category": "공유오피스"},
    {"answer": "스타벅스", "question": "플래그원 1층에 있는 카페", "category": "카페"},
    {"answer": "에그드랍", "question": "커피랑 토스트 같이 파는 곳", "category": "음식점"},
    {"answer": "토스트식당", "question": "토스트는 파는데 커피는 안먹고 싶을 때 가는 가게", "category": "음식점"},
    {"answer": "이삭토스트", "question": "토스트랑 커피랑 생과일쥬스도 파는 곳", "category": "음식점"},
    {"answer": "공차", "question": "카페인데 버블티가 주력인 곳", "category": "카페"},
    {"answer": "커피빈", "question": "잘잘한 얼음을 넣은 커피를 파는 곳", "category": "카페"},
]

In [261]:
questions = client.collections.get("Question")
question_list = list()
for my_question in my_new_questions:
    question_list.append(wvc.data.DataObject(
        properties=my_question, 
        vector=do_embedding(my_question['question']).numpy().tolist()
    ))
questions.data.insert_many(question_list)

BatchObjectReturn(_all_responses=[UUID('616be6a6-6c4f-44ea-b088-60fbf0719bbd'), UUID('d1c4fee2-5e9d-4020-bca4-66a6ecc11e4a'), UUID('18e00cc2-059b-468a-a127-e902ee40a811'), UUID('5847a2fe-af09-4eb8-862b-e356162a6df8'), UUID('865893ab-71ae-487d-b961-acc4927bb684'), UUID('c336dcd3-a26e-499e-b0d5-f66c2195b667'), UUID('0479c5c9-3f40-431e-b196-06b5dc233f5b'), UUID('2125d8cf-b457-475b-b596-9f8bf500523f'), UUID('d115a267-3ae0-4f15-b7d9-de089fea98e7'), UUID('624141fe-ea8f-4423-882e-157dbc5c2cc0')], elapsed_seconds=0.0068950653076171875, errors={}, uuids={0: UUID('616be6a6-6c4f-44ea-b088-60fbf0719bbd'), 1: UUID('d1c4fee2-5e9d-4020-bca4-66a6ecc11e4a'), 2: UUID('18e00cc2-059b-468a-a127-e902ee40a811'), 3: UUID('5847a2fe-af09-4eb8-862b-e356162a6df8'), 4: UUID('865893ab-71ae-487d-b961-acc4927bb684'), 5: UUID('c336dcd3-a26e-499e-b0d5-f66c2195b667'), 6: UUID('0479c5c9-3f40-431e-b196-06b5dc233f5b'), 7: UUID('2125d8cf-b457-475b-b596-9f8bf500523f'), 8: UUID('d115a267-3ae0-4f15-b7d9-de089fea98e7'), 9: UUID

In [262]:
question = "커피도 먹고 토스트도 먹을 수 있는 곳은?"

questions = client.collections.get("Question")
query_vector = do_embedding(question).numpy().tolist()
resps = questions.query.near_vector(
    near_vector=query_vector,
    limit=2,
    return_metadata=wvc.query.MetadataQuery(certainty=True)
)

for obj in resps.objects:
    print(obj.properties)

{'answer': '에그드랍', 'question': '커피랑 토스트 같이 파는 곳', 'category': '음식점'}
{'answer': '\x08이삭토스트', 'question': '토스트랑 커피랑 생과일쥬스도 파는 곳', 'category': '음식점'}


In [263]:
question = "잘잘한 얼음이 먹고싶을 때 가는 카페는?"

questions = client.collections.get("Question")
query_vector = do_embedding(question).numpy().tolist()
resps = questions.query.near_vector(
    near_vector=query_vector,
    limit=2,
    return_metadata=wvc.query.MetadataQuery(certainty=True)
)

for obj in resps.objects:
    print(obj.properties)

{'answer': '커피빈', 'question': '잘잘한 얼음을 넣은 커피를 파는 곳', 'category': '카페'}
{'answer': '백다방', 'question': '저렴하고 간얼음 파는 카페는?', 'category': '카페'}


In [264]:
question = "플래그원에서 커피 먹을 수 있는 공간은 어디야?"

questions = client.collections.get("Question")
query_vector = do_embedding(question).numpy().tolist()
resps = questions.query.near_vector(
    near_vector=query_vector,
    limit=2,
    return_metadata=wvc.query.MetadataQuery(certainty=True)
)

for obj in resps.objects:
    print(obj.properties)

{'answer': '\x08스타벅스', 'question': '플래그원 1층에 있는 카페', 'category': '카페'}
{'answer': '공차', 'question': '카페인데 버블티가 주력인 곳', 'category': '카페'}
