# baseline

將`關鍵字`比對換成`向量相似度`比對。

> 請將目前使用關鍵字比對的 route_by_query，改為使用向量相似度進行分類，並設一個合理的相似度門檻，根據檢索結果的分數判斷是否走 RAG 流程。  
例如用向量相似度及自訂 threshold 決定要不要分到 retriever。

> Hint：similarity_search_with_score(...)  
可參考去年的讀書會 R4：向量資料庫的基本操作

In [None]:
!pip install -q langchain langgraph transformers bitsandbytes langchain-huggingface langchain-community chromadb

In [2]:
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain.vectorstores import Chroma
from typing_extensions import TypedDict, List
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableLambda

In [3]:
docs_text = """
火影代數	姓名	師傅	徒弟
初代	千手柱間	無明確記載	猿飛日斬、水戶門炎、轉寢小春
二代	千手扉間	千手柱間（兄長）	猿飛日斬、志村團藏、宇智波鏡等
三代	猿飛日斬	千手柱間、千手扉間	自來也、大蛇丸、千手綱手（傳說三忍）
四代	波風湊	自來也	旗木卡卡西、宇智波帶土、野原琳
五代	千手綱手	猿飛日斬	春野櫻、志乃等（主要為春野櫻）
六代	旗木卡卡西	波風湊	漩渦鳴人、宇智波佐助、春野櫻（第七班）
七代	漩渦鳴人	自來也、旗木卡卡西	木葉丸等（主要為木葉丸）
"""

## 1. Load Model

In [None]:
# Specify the pre-trained model identifier
model_id = "MediaTek-Research/Breeze-7B-Instruct-v1_0"

# Configure 4-bit quantization
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    llm_int8_threshold=6.0,
)

# Load the causal language model with quantization
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    quantization_config=quant_config,
    trust_remote_code=True
)

# Load the tokenizer for the specified model
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

# Create a text-generation pipeline
generator = pipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    do_sample=True,
    temperature=0.4,
    return_full_text=False
)

In [None]:
# Initialize the Hugging Face embedding model for Chinese text
embedding_model = HuggingFaceEmbeddings(
    model_name="infgrad/stella-base-zh-v3-1792d",
    encode_kwargs={"normalize_embeddings": True}
)

# Convert each paragraph of raw text into a Document object for indexing
docs = [Document(page_content=txt.strip())
        for txt in docs_text.strip().split("\n\n")
]

# Define where to persist the vector store
persist_path = "document_store"
collection_name = "naruto_collection"

# Create a Chroma vector store from the documents and embeddings
vectorstore = Chroma.from_documents(
    documents=docs,
    embedding=embedding_model,
    persist_directory=persist_path,
    collection_name=collection_name,
    collection_metadata={"hnsw:space": "cosine"}
)

## 2. Define State & Notes

In [6]:
class RAGState(TypedDict):
    # Define the state structure passed between RAG nodes
    query: str
    docs: List[Document]
    answer: str

def retrieve_node(state: RAGState) -> RAGState:
    # Retrieval step: fetch the top-k most similar documents for the given query
    query = state["query"]
    docs = vectorstore.similarity_search(query, k=3)
    return {"query": query, "docs": docs, "answer": ""}

def generate_node(state: RAGState) -> RAGState:
    # Generation step with context: use retrieved documents to craft the answer
    query, docs = state["query"], state["docs"]
    context = "\n".join([d.page_content for d in docs])
    prompt = (
        f"你是一個知識型助手，請根據以下內容回答問題：\n\n"
        f"內容：{context}\n\n"
        f"問題：{query}\n\n回答："
    )
    output = generator(prompt, max_new_tokens=200)[0]["generated_text"]
    return {"query": query, "docs": docs, "answer": output}

def direct_generate_node(state: RAGState) -> RAGState:
    # Direct generation step: answer the query without any retrieved context
    query = state["query"]
    prompt = f"請回答以下問題：{query}\n\n回答："
    output = generator(prompt, max_new_tokens=200)[0]["generated_text"]
    return {"query": query, "docs": [], "answer": output}

In [14]:
def route_by_query(state: RAGState) -> str:
    # Route the query based on vector similarity to the 'naruto' document collection
    query = state["query"]
    results = vectorstore.similarity_search_with_score(query, k=1)

    if results:
        doc, score = results[0]
        cosine_sim = 1 - score
        threshold = 0.5
        print(f"route: cosine_sim = {cosine_sim:.4f}")
        choice = "naruto" if cosine_sim >= threshold else "general"
    else:
        choice = "general"

    print(f"跑到 → {choice}")
    return choice

## 3. Build StateGraph

In [15]:
# Initialize the StateGraph with the RAGState
graph_builder = StateGraph(RAGState)

# Define "condition" as the entry point for the graph
graph_builder.set_entry_point("condition")

# Add nodes for the graph
graph_builder.add_node("condition", RunnableLambda(lambda x: x))
graph_builder.add_node("retriever", RunnableLambda(retrieve_node))
graph_builder.add_node("generator", RunnableLambda(generate_node))
graph_builder.add_node("direct_generator", RunnableLambda(direct_generate_node))

# Create conditional edges from "condition" based on the result of route_by_query
graph_builder.add_conditional_edges(
    source="condition",
    path=RunnableLambda(route_by_query),
    path_map={
        "naruto": "retriever",
        "general": "direct_generator",
    }
)

# Add edges for the graph
graph_builder.add_edge("retriever", "generator")
graph_builder.add_edge("generator", END)
graph_builder.add_edge("direct_generator", END)

# Compile the graph into an executable RAG workflow
graph = graph_builder.compile()

## 4. Results

In [17]:
print("開始對話吧（輸入 q 結束）")

while True:
    user_input = input("使用者: ")
    if user_input.strip().lower() in ["q", "quit", "exit"]:
        print("掰啦！")
        break

    init_state: RAGState = {
        "query": user_input,
        "docs": [],
        "answer": ""
    }

    result = graph.invoke(init_state)
    raw_output = result["answer"]

    answer_text = raw_output.split("回答：")[-1].strip()
    print("回答：", answer_text)
    print("===" * 20, "\n")

開始對話吧（輸入 q 結束）
使用者: 誰是第四代火影
route: cosine_sim = 0.7115
跑到 → naruto
回答： 第四代火影是波風湊。

使用者: 第三代火影的師父
route: cosine_sim = 0.7533
跑到 → naruto
回答： 第三代火影猿飛日斬的師父是千手柱間。

使用者: 什麼是MCP
route: cosine_sim = 0.2900
跑到 → general
回答： MCP是Metal Clad Cable的縮寫，它指的是金屬外覆電纜。金屬外覆電纜是一種具有金屬外覆的電纜，主要用於傳輸電力或信號。它由金屬外覆層和絕緣層組成，金屬外覆層通常是銅或鋁，絕緣層通常是聚乙烯、聚氯乙烯、聚乙烯醇樹脂等材料。MCP具有良好的電氣性能，如低電阻、低電容、低電感等，因此它在電力傳輸和信號傳輸方面具有很高的性能。

使用者: q
掰啦！


# advance

改成能支援多輪問答（Multi-turn RAG），並能根據前面的query判斷問題。

> 請將 RAGState 加入 history 欄位，並在生成回答時，將歷史對話與當前問題一起組成 prompt。

> Hint：
```
class MultiTurnRAGState(TypedDict):  
    history: List[str]  
    query: str  
    docs: List[Document]  
    answer: str
```



## 2. Define State & Notes

In [33]:
class MultiTurnRAGState(TypedDict):
    # Define the state structure that holds the chat history
    history: List[str]
    query: str
    docs: List[Document]
    answer: str

In [34]:
def retrieve_node_multi(state: MultiTurnRAGState) -> MultiTurnRAGState:
    # Retrieval step: fetch the top-k most similar documents for the given query
    history = state["history"].copy()
    query = state["query"]
    docs = vectorstore.similarity_search(query, k=3)
    return {"history": history, "query": query, "docs": docs, "answer": ""}

def generate_node_multi(state: MultiTurnRAGState) -> MultiTurnRAGState:
    # Generation step with context: use retrieved documents to craft the answer
    history = state["history"].copy()
    query, docs = state["query"], state["docs"]
    context = "\n".join([d.page_content for d in docs])
    combined_history = "\n".join(history[:])
    prompt = (
        f"你是一個知識型助手，請根據以下內容與歷史詢問記錄回答問題：\n\n"
        f"內容：{context}\n\n"
        f"歷史：{combined_history}\n\n"
        f"問題：{query}\n\n回答："
    )
    output = generator(prompt, max_new_tokens=200)[0]["generated_text"]
    history.append(query)
    print(f"retrieve combined query: {history}")
    return {"history": history, "query": query, "docs": docs, "answer": output}

def direct_generate_node_multi(state: MultiTurnRAGState) -> MultiTurnRAGState:
    # Direct generation step: answer the query without any retrieved context
    history = state["history"].copy()
    query = state["query"]
    combined_history = "\n".join(history[:])
    prompt = (
        f"請根據以下歷史詢問記錄回答問題：\n\n"
        f"歷史：{combined_history}\n\n"
        f"問題：{query}\n\n回答："
    )
    output = generator(prompt, max_new_tokens=200)[0]["generated_text"]
    history.append(query)
    print(f"retrieve combined query: {history}")
    return {"history": history, "query": query, "docs": [], "answer": output}

## 3. Build StateGraph

In [35]:
# Initialize the StateGraph with the MultiTurnRAGState
graph_builder_multi = StateGraph(MultiTurnRAGState)

# Define "condition" as the entry point for the graph
graph_builder_multi.set_entry_point("condition")

# Add nodes for the graph
graph_builder_multi.add_node("condition", RunnableLambda(lambda x: x))
graph_builder_multi.add_node("retriever", RunnableLambda(retrieve_node_multi))
graph_builder_multi.add_node("generator", RunnableLambda(generate_node_multi))
graph_builder_multi.add_node("direct_generator", RunnableLambda(direct_generate_node_multi))

# Create conditional edges from "condition" based on the result of route_by_query
graph_builder_multi.add_conditional_edges(
    source="condition",
    path=RunnableLambda(route_by_query),
    path_map={
        "naruto": "retriever",
        "general": "direct_generator",
    }
)

# Add edges for the graph
graph_builder_multi.add_edge("retriever", "generator")
graph_builder_multi.add_edge("generator", END)
graph_builder_multi.add_edge("direct_generator", END)

# Compile the graph into an executable RAG workflow
graph = graph_builder_multi.compile()

## 4. Results

In [36]:
global_history: List[str] = []

print("開始對話吧（輸入 q 結束）")
while True:
    user_input = input("使用者: ")
    if user_input.strip().lower() in ["q", "quit", "exit"]:
        print("掰啦！")
        break

    state = {"history": global_history, "query": user_input}
    result = graph.invoke(state)

    answer = result["answer"].split("回答：")[-1].strip()
    print("AI 助理:", answer)
    print("===" * 60, "\n")

    global_history = result["history"]

開始對話吧（輸入 q 結束）
使用者: 誰是第五代火影
route: cosine_sim = 0.7275
跑到 → naruto
retrieve combined query: ['誰是第五代火影']
AI 助理: 第五代火影是千手綱手。

使用者: 他的徒弟有誰
route: cosine_sim = 0.6065
跑到 → naruto
retrieve combined query: ['誰是第五代火影', '他的徒弟有誰']
AI 助理: 第五代火影千手綱手的主要徒弟是春野櫻。

使用者: 他們是住在屍魂界嗎
route: cosine_sim = 0.4240
跑到 → general
retrieve combined query: ['誰是第五代火影', '他的徒弟有誰', '他們是住在屍魂界嗎']
AI 助理: 是的，第五代火影是鳴人。他的徒弟是佐助和漩渦鳴人。他們是住在屍魂界。

使用者: 他們認識自來也嗎
route: cosine_sim = 0.5643
跑到 → naruto
retrieve combined query: ['誰是第五代火影', '他的徒弟有誰', '他們是住在屍魂界嗎', '他們認識自來也嗎']
AI 助理: 第五代火影是千手綱手。他的徒弟主要是春野櫻。

他們是住在屍魂界嗎？春野櫻是住在屍魂界，但千手綱手本人不是。

他們認識自來也嗎？是的，千手綱手認識自來也。自來也是猿飛日斬的徒弟，也是旗木卡卡西的師傅，因此春野櫻和自來也之間有師徒關係。

使用者: q
掰啦！
