## 01. RAG를 위한 기본정보와 Prompt Template & 함수 셋팅

In [2]:
import os
from dotenv import load_dotenv
# .env 파일 로드
load_dotenv()

from lib.opensearch import getOpenSearchClient
aoss_client = getOpenSearchClient()
vector_index_name = os.getenv("AOSS_VECTOR_INDEX")

from lib.bedrock import get_embedding_output, get_llm_output


In [3]:
prompt_template = """
You're a helpful assistant to answer the question.
Use the following pieces of <CONTEXT> to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

<CONTEXT>
{context}
</CONTEXT>

Question: {question}
Helpful Answer:"""

## 02. Sementic Search

In [4]:
def get_semantic_rag(user_query):
    vector = get_embedding_output(user_query)
    vector_query = {
      "query": {
        "knn": {
          "content_embeddings": {
            "vector": vector,
            "k": 5
          }
        }
      }
    }
    
    response = aoss_client.search(index=vector_index_name, body=vector_query, size=5)
    vector_search_results = [result["_source"]["content"] for result in response["hits"]["hits"]]
    
    context_data = "\n\n".join(vector_search_results)
    
    llm_input = prompt_template.format(context=context_data, question=user_query)
    
    llm_output = get_llm_output(llm_input)
    
    return {"llm_input": llm_input, "llm_output": llm_output}

In [7]:
output = get_semantic_rag("요우커지정일의 코드가 어떻게 되나요?")
print(output["llm_output"])

제공된 컨텍스트에서 "요우커지정일"의 코드는 "9Y"입니다. 테이블에 나와 있는 데이터에 따르면 "Sales" 카테고리에서 "CD_VAL"이 "요우커지정일"인 경우 "CMM_CD"가 "9Y"입니다.


## 03. Hybrid Search

In [25]:

def get_normalized_result(search_results, add_meta, weight=1.0):
    hits = search_results["hits"]["hits"]
    if len(hits) == 0:
        return []
    
    max_score = float(search_results["hits"]["max_score"])
    
    results = []
    for hit in hits:
        normalized_score = float(hit["_score"]) / max_score
        weight_score = normalized_score if weight == 1.0 else normalized_score * weight
        results.append({
            "doc_id": hit["_id"],
            "score": weight_score,
            "content": hit["_source"]["content"],
            "meta": add_meta,
            "metadata": hit["_source"]["metadata"],

        })
        
    return results

def get_hybrid_rag(user_query):
    result_limit = 5
    vec_weight = 0.6
    lex_weight = 0.55
    threshold = 0.05
    
    # Get vector search result
    vector = get_embedding_output(user_query)
    vector_query = {
      "query": {
        "knn": {
          "content_embeddings": {
            "vector": vector,
            "k": 5
          }
        }
      }
    }
    vector_response = aoss_client.search(index=vector_index_name, body=vector_query, size=10)
    vector_result = get_normalized_result(vector_response, "vector", vec_weight)
    
    # Get lexical search result
    keyword_query = {"query": {"match": {"content": user_query}}}
    keyword_response = aoss_client.search(index=vector_index_name, body=keyword_query, size=10)
    keyword_result = get_normalized_result(keyword_response, "lexical", lex_weight)
    
    vector_ids = [vec["doc_id"] for vec in vector_result]
    for keyword in keyword_result:
        if keyword["doc_id"] not in vector_ids:
            vector_result.append(keyword)
    
    items = vector_result
    sorted_items = list(filter(lambda val: val["score"] > threshold, items))
    
    if len(sorted_items) > result_limit:
        sorted_items = sorted_items[:result_limit]
    
    context_data = "\n\n".join([item["content"] for item in sorted_items])
    llm_input = prompt_template.format(context=context_data, question=user_query)
    llm_output = get_llm_output(llm_input)
    return {"llm_input": llm_input, "llm_output": llm_output, "sorted_items": sorted_items}

In [26]:
output = get_hybrid_rag("Return 카테고리에 03 코드는 어떤 값인가요?")
print(output["llm_output"])

Return 카테고리에서 CD_VAL 값이 '03'인 행의 내용은 "수거불필요:택배사분실"입니다.


In [29]:
for sorted_item in output["sorted_items"]:
    print(sorted_item["metadata"]["document_metadata"])


{'title': '페이지내 이미지 테스트', 'id': '63897603', 'source': 'https://lge-web-project.atlassian.net/wiki/spaces/SD/pages/63897603', 'when': '2024-06-02T01:32:45.382Z'}
{'title': '페이지내 이미지 테스트', 'id': '63897603', 'source': 'https://lge-web-project.atlassian.net/wiki/spaces/SD/pages/63897603', 'when': '2024-06-02T01:32:45.382Z'}
{'title': '페이지내 이미지 테스트', 'id': '63897603', 'source': 'https://lge-web-project.atlassian.net/wiki/spaces/SD/pages/63897603', 'when': '2024-06-02T01:32:45.382Z'}
{'title': '페이지내 이미지 테스트', 'id': '63897603', 'source': 'https://lge-web-project.atlassian.net/wiki/spaces/SD/pages/63897603', 'when': '2024-06-02T01:32:45.382Z'}
