In [1]:
import json
from typing import TypedDict, List
from collections import defaultdict
from langchain_core.documents import Document
from langchain_core.runnables import RunnableParallel, RunnableLambda
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.retrievers import BM25Retriever
from langgraph.graph import StateGraph, END
from langchain_milvus import Milvus
from langchain_naver import ChatClovaX, ClovaXEmbeddings
from sentence_transformers import CrossEncoder
from kiwipiepy import Kiwi

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
llm = ChatClovaX(
    model="HCX-DASH-002", 
    max_tokens=1024,
    temperature=0,
    api_key=""
)

embeddings = ClovaXEmbeddings(
    model="bge-m3",
    api_key=""
)


field_selector_system = """
당신은 감사 보고서 기반 RAG 시스템을 위한 필드 선택기(Field Selector)입니다.
사용자 질문을 읽고 아래 7개의 필드 중 관련 있는 필드를 판단해
반드시 ["outline", "problems", ...] 형태로 출력하세요.

중요 규칙:
- outline과 problems는 모든 질문에서 항상 기본적으로 포함합니다.
- 나머지 필드(title, standards, opinion, criteria, action)는
  질문의 의도와 직접 관련이 있을 때만 추가로 포함합니다.
- 출력은 중복 없이 소문자 문자열 리스트로만 제공합니다.
- 설명하지 말고 리스트만 출력하세요.

필드 설명:
- title: 사건 제목, 사안명
- standards: 법령·규정·기준 관련 질문
- outline: 사건의 개요·배경·전체 상황 (항상 포함)
- problems: 위반 사항·문제점·부적정 사례 (항상 포함)
- opinion: 관계기관 의견·평가·입장
- criteria: 향후 내부통제·절차 보완·개선 기준
- action: 처분·제재·후속 조치

출력 규칙:
1. 기본 출력: ["outline", "problems"]
2. 질문에 따라 아래와 같이 추가:
   - 규정 위반 / 법령 관련 → "standards"
   - 기관 입장 / 의견 → "opinion"
   - 개선 방안 / 내부통제 보완 → "criteria"
   - 처분 / 제재 / 조치 → "action"
   - 제목 / 사건명 → "title"

예시:
사용자 질문: "이 사건에서 어떤 규정이 위반되었는지 알려줘."
출력: ["outline", "problems", "standards"]

사용자 질문: "감사 결과에서 기관은 어떤 입장을 보였어?"
출력: ["outline", "problems", "opinion"]

사용자 질문: "사건의 배경과 전체적인 진행 과정을 설명해줘."
출력: ["outline", "problems"]

사용자 질문: "최종 처분은 무엇이었어?"
출력: ["outline", "problems", "action"]

사용자 질문: "앞으로 내부통제를 어떻게 보완해야 할까?"
출력: ["outline", "problems", "criteria"]
"""

field_selector_user = """
outline과 problems는 모든 질문에서 항상 기본적으로 포함합니다.

[Question]
{question}
"""

field_selector_template = ChatPromptTemplate.from_messages([
    ("system", field_selector_system),
    ("human", field_selector_user)
])

field_selector_chain = (
    field_selector_template
    | llm
    | StrOutputParser()
)

In [3]:
kiwi = Kiwi()

def kiwi_tokenize(text):
    return [token.form for token in kiwi.tokenize(text) if token.tag.startswith('N')]

def get_documents_from_milvus(vector_db):
    collection = vector_db.col
    res = collection.query(
        expr="", 
        output_fields=["text", "idx"], 
        limit=16384
    )
    
    docs = []
    for item in res:
        text = item.get('text')
        idx = item.get('idx') 
        
        if text:
            docs.append(Document(
                page_content=text,
                metadata={"idx": idx} 
            ))
            
    return docs

Quantization is not supported for ArchType::neon. Fall back to non-quantized model.


#### Hybrid Retriever Logic

In [4]:
class ManualEnsembleRetriever:
    def __init__(self, retrievers, weights=None, c=60):
        self.retrievers = retrievers
        self.weights = weights if weights else [0.5] * len(retrievers)
        self.c = c 

    def invoke(self, query):
        results = [r.invoke(query) for r in self.retrievers]
        rrf_score = defaultdict(float)
        doc_map = {}

        for r_idx, docs in enumerate(results):
            weight = self.weights[r_idx]
            for rank, doc in enumerate(docs):
                doc_key = doc.page_content
                if doc_key not in doc_map:
                    doc_map[doc_key] = doc
                
                score = weight * (1 / (rank + 1 + self.c))
                rrf_score[doc_key] += score

        sorted_keys = sorted(rrf_score.keys(), key=lambda k: rrf_score[k], reverse=True)
        return [doc_map[k] for k in sorted_keys]

def create_hybrid_retriever(vector_db, k=20):
    raw_docs = get_documents_from_milvus(vector_db)
    vector_retriever = vector_db.as_retriever(search_kwargs={"k": k})
    
    if not raw_docs:
        return vector_retriever
        
    bm25_retriever = BM25Retriever.from_documents(
        raw_docs, 
        preprocess_func=kiwi_tokenize
    )
    bm25_retriever.k = k
    
    return ManualEnsembleRetriever(
        retrievers=[bm25_retriever, vector_retriever],
        weights=[0.5, 0.5]
    )

#### Initialize Retrievers

In [None]:
title_db = Milvus(
    embedding_function=embeddings,
    connection_args={
        "uri": "https://in03-6919b557b41d797.serverless.gcp-us-west1.cloud.zilliz.com",   
        "token": ""          
    },
    collection_name="title",                   
    auto_id=True
)

standards_db = Milvus(
    embedding_function=embeddings,
    connection_args={
        "uri": "https://in03-6919b557b41d797.serverless.gcp-us-west1.cloud.zilliz.com",   
        "token": ""          
    },
    collection_name="standards",                   
    auto_id=True
)

outline_db = Milvus(
    embedding_function=embeddings,
    connection_args={
        "uri": "https://in03-6919b557b41d797.serverless.gcp-us-west1.cloud.zilliz.com",   
        "token": ""          
    },
    collection_name="outline",                   
    auto_id=True
)

problems_db = Milvus(
    embedding_function=embeddings,
    connection_args={
        "uri": "https://in03-6919b557b41d797.serverless.gcp-us-west1.cloud.zilliz.com",   
        "token": ""          
    },
    collection_name="problems",                   
    auto_id=True
)

opinion_db = Milvus(
    embedding_function=embeddings,
    connection_args={
        "uri": "https://in03-6919b557b41d797.serverless.gcp-us-west1.cloud.zilliz.com",   
        "token": ""          
    },
    collection_name="opinion",                   
    auto_id=True
)

criteria_db = Milvus(
    embedding_function=embeddings,
    connection_args={
        "uri": "https://in03-6919b557b41d797.serverless.gcp-us-west1.cloud.zilliz.com",   
        "token": ""          
    },
    collection_name="criteria",                   
    auto_id=True
)

action_db = Milvus(
    embedding_function=embeddings,
    connection_args={
        "uri": "https://in03-6919b557b41d797.serverless.gcp-us-west1.cloud.zilliz.com",   
        "token": ""          
    },
    collection_name="action",                   
    auto_id=True
)

retrievers = {
    "title": create_hybrid_retriever(title_db),
    "standards": create_hybrid_retriever(standards_db),
    "outline": create_hybrid_retriever(outline_db),
    "problems": create_hybrid_retriever(problems_db),
    "opinion": create_hybrid_retriever(opinion_db),
    "criteria": create_hybrid_retriever(criteria_db),
    "action": create_hybrid_retriever(action_db),
}

In [6]:
encoder_model = CrossEncoder("BAAI/bge-reranker-v2-m3")

In [None]:
class GraphState(TypedDict):
    question: str
    selected_fields: List[str]
    documents: List[Document]

In [None]:
def field_selector(state: GraphState) -> dict:
    question = state["question"]
    result_str = field_selector_chain.invoke({"question": question})
    cleaned_result = result_str.replace("```json", "").replace("```", "").strip()
    
    try:
        selected_fields_list = json.loads(cleaned_result)
    except:
        selected_fields_list = ["outline", "problems"]
        
    return {"selected_fields": selected_fields_list}


def retrieve_documents(state: GraphState) -> dict:
    question = state["question"]
    target_fields = state.get("selected_fields", [])
    
    selected_retrievers = {
        k: RunnableLambda(retrievers[k].invoke) 
        for k in target_fields if k in retrievers
    }
    
    if not selected_retrievers:
        return {"documents": []}

    results = RunnableParallel(**selected_retrievers).invoke(question)
    
    flat_docs = []
    for field, docs in results.items():
        for doc in docs:
            doc.metadata["source_field"] = field
            flat_docs.append(doc)
            
    return {"documents": flat_docs}


def merge_documents(state: GraphState) -> dict:
    raw_docs = state.get("documents", [])
    grouped = defaultdict(list)
    
    for doc in raw_docs:
        idx = doc.metadata.get("idx")
        if idx is not None:
            grouped[idx].append((doc.metadata.get("source_field", "UNKNOWN"), doc.page_content))

    merged_docs = []
    for idx, items in grouped.items():
        items.sort(key=lambda x: x[0])
        full_text = "\n\n".join([f"[{field.upper()}]\n{text}" for field, text in items])
        
        merged_docs.append(Document(
            page_content=full_text,
            metadata={"idx": idx, "fields": [x[0] for x in items]}
        ))
        
    return {"documents": merged_docs}


def rerank_documents(state: GraphState) -> dict:
    question = state["question"]
    documents = state.get("documents", [])
    
    if not documents:
        return {"documents": []}
        
    pairs = [[question, doc.page_content] for doc in documents]
    scores = encoder_model.predict(pairs)
    
    scored_docs = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
    
    return {"documents": [doc for doc, _ in scored_docs[:5]]}

In [None]:
workflow = StateGraph(GraphState)


workflow.add_node("field_selector", field_selector)
workflow.add_node("retrieve_documents", retrieve_documents)
workflow.add_node("merge_documents", merge_documents)
workflow.add_node("rerank_documents", rerank_documents)

workflow.set_entry_point("field_selector")
workflow.add_edge("field_selector", "retrieve_documents")
workflow.add_edge("retrieve_documents", "merge_documents")
workflow.add_edge("merge_documents", "rerank_documents")
workflow.add_edge("rerank_documents", END)


app = workflow.compile()

In [1]:
# generation은 빠져있습니다.