# 12. 用户查询太模糊？通过查询扩展，提升语义匹配能力

## 一、为什么要在 RAG 中进行 Query Expansion？

在构建基于 **RAG（Retrieval-Augmented Generation）** 的问答系统时，用户输入往往存在以下几个常见问题：

- 表达模糊、不完整或口语化  
- 缺乏上下文信息  
- 难以准确命中知识库中的相关文档  

如果直接使用用户的原始查询进行向量检索，可能会导致以下问题：

- 召回结果不足  
- 命中无关内容  
- 最终生成的答案不够准确或全面  

因此，在执行检索前对用户的问题进行 **语义级扩展与改写（Query Expansion / Rewriting）**，是一种有效的优化手段。

### 示例：模糊查询示例

面对如下模糊的用户提问：

- “我想去好玩的地方”  
- “有没有好吃的”  
- “适合亲子游的地方”  

这些问题缺乏具体的背景信息和明确的需求描述，直接检索往往难以命中关键内容。

### 解决方案

为了解决这一问题，可以采用以下两种策略来增强检索效果：

1. **问题改写（Query Rewriting）**  
   将模糊问题转化为更清晰、具体的问题，提升语义表达能力。

2. **多步骤检索（Multi-step Querying）**  
   将复杂问题拆解为多个子任务，逐步检索后整合答案，提高检索的全面性和准确性。

## 二 优化手段
### 2.1 在检索前进行问题改写（Query Rewriting）

安装依赖

In [None]:
%pip install langchain faiss-cpu transformers torch sentence-transformers dashscope langchain-community "unstructured[md]"

加载文档并构建 FAISS 向量库

In [4]:
import logging
from typing import List

from langchain_community.document_loaders import DirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings

from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser


# -----------------------------
# Logging
# -----------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)

# -----------------------------
# Configuration
# -----------------------------
DOCUMENT_PATH = "./data/"
FILE_PATTERN = "*.md"
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
TOP_K = 5

# Hugging Face embedding model (open, local)
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"

# Local Hugging Face open LLM
# Good choices:
# - mistralai/Mistral-7B-Instruct-v0.2
# - Qwen/Qwen2-7B-Instruct
HF_LLM_MODEL_ID = "Qwen/Qwen2-0.5B-Instruct"

TEMPERATURE = 0.7
MAX_NEW_TOKENS = 512


def _format_docs(docs: List) -> str:
    """Format retrieved Documents into a single context string."""
    return "\n\n".join(d.page_content for d in docs)


def build_embeddings():
    logging.info(f"Loading HF embeddings: {EMBEDDING_MODEL}")
    return HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL,
        encode_kwargs={"normalize_embeddings": True},
    )


def build_local_hf_llm():
    logging.info(f"Loading local HF LLM: {HF_LLM_MODEL_ID}")

    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
    from langchain_community.llms import HuggingFacePipeline

    use_cuda = torch.cuda.is_available()
    use_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()

    tokenizer = AutoTokenizer.from_pretrained(HF_LLM_MODEL_ID, use_fast=True)

    if use_cuda:
        logging.info("Using CUDA GPU")
        model = AutoModelForCausalLM.from_pretrained(
            HF_LLM_MODEL_ID,
            device_map="auto",
            torch_dtype=torch.float16,
        )
    else:
        dtype = torch.float16 if use_mps else torch.float32
        model = AutoModelForCausalLM.from_pretrained(
            HF_LLM_MODEL_ID,
            torch_dtype=dtype,
        )
        if use_mps:
            logging.info("Using Apple MPS")
            model = model.to("mps")
        else:
            logging.info("Using CPU (slow)")

    gen_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=MAX_NEW_TOKENS,
        temperature=TEMPERATURE,
        do_sample=True,
        repetition_penalty=1.05,
        return_full_text=False,
    )

    return HuggingFacePipeline(pipeline=gen_pipeline)

传统检索方式（无优化）

In [5]:
# 1) Load documents
logging.info(f"Loading documents from {DOCUMENT_PATH} (pattern: {FILE_PATTERN})...")
loader = DirectoryLoader(DOCUMENT_PATH, glob=FILE_PATTERN)
docs = loader.load()
logging.info(f"Loaded {len(docs)} documents")

2026-01-05 21:31:32,250 - INFO - Loading documents from ./data/ (pattern: *.md)...
2026-01-05 21:31:33,347 - INFO - Loaded 5 documents


In [6]:
# 2) Split documents
logging.info("Splitting documents into chunks...")
splitter = RecursiveCharacterTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP,
)
chunks = splitter.split_documents(docs)
logging.info(f"Created {len(chunks)} chunks")

2026-01-05 21:31:34,027 - INFO - Splitting documents into chunks...
2026-01-05 21:31:34,028 - INFO - Created 5 chunks


In [7]:
# 3) Embeddings
embeddings = build_embeddings()

2026-01-05 21:31:35,159 - INFO - Loading HF embeddings: sentence-transformers/all-MiniLM-L6-v2
  return HuggingFaceEmbeddings(
2026-01-05 21:31:35,175 - INFO - Use pytorch device_name: mps
2026-01-05 21:31:35,175 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2


In [8]:
# 4) Vector store + retriever
logging.info("Building FAISS vector store...")
vectorstore = FAISS.from_documents(chunks, embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K})

2026-01-05 21:31:37,948 - INFO - Building FAISS vector store...
2026-01-05 21:31:38,260 - INFO - Loading faiss.
2026-01-05 21:31:38,317 - INFO - Successfully loaded faiss.


In [9]:
# 5) Local HF LLM
llm = build_local_hf_llm()

2026-01-05 21:31:39,566 - INFO - Loading local HF LLM: Qwen/Qwen2-0.5B-Instruct
2026-01-05 21:31:42,584 - INFO - Using Apple MPS
Device set to use mps:0


In [10]:
# 6. Runnable RAG chain (LCEL)
prompt = PromptTemplate.from_template(
"You are a helpful assistant.\n"
"Answer the question using ONLY the context below.\n"
"If the context is insufficient, say you don't know.\n\n"
"Context:\n{context}\n\n"
"Question:\n{question}\n\n"
"Answer:"
)

In [11]:
rag_chain = (
    {
        "context": retriever | RunnableLambda(_format_docs),
        "question": RunnablePassthrough(),
    }
    | prompt
    | llm
    | StrOutputParser()
)

logging.info("RAG system ready. Type a question (or Ctrl+C to exit).")

2026-01-05 21:31:45,517 - INFO - RAG system ready. Type a question (or Ctrl+C to exit).


In [12]:
# Simple interactive loop
q = "I want to go somewhere fun."
logging.info(f"Running query: {q}")
ans = rag_chain.invoke(q)
print("\n--- Answer ---")
print(ans)

2026-01-05 21:31:49,718 - INFO - Running query: I want to go somewhere fun.



--- Answer ---
 The answer is New York City. It has many attractions and places to visit, such as the Statue of Liberty, Central Park, and the Empire State Building. You can also enjoy delicious food in various restaurants. New York City is a great place for a trip, so it's worth visiting.


#### 2.1.1 使用 LLM 改写问题（增强语义表达）
Step 1: 我们首先定义一个 Prompt 模板，用于引导 LLM 对用户问题进行改写：

In [13]:
from langchain_core.prompts import PromptTemplate

rewrite_prompt = PromptTemplate.from_template(
"""
You are a travel assistant. Please rewrite the following user question into a
clearer and more complete form.

Original question:
{{question}}

Requirements for the rewritten question:
- Be more specific
- Include the type of travel (family / couple / road trip, etc.)
- Help the system retrieve relevant information more accurately

Output format:
[Rewritten] - <your rewritten question>
"""
)


Step 2: 执行改写

In [14]:
import logging
from langchain_core.output_parsers import StrOutputParser

# Initialize the rewrite chain (Runnable / LCEL)
def create_rewrite_chain(llm):
    # rewrite_prompt should be a PromptTemplate (from langchain_core.prompts)
    return rewrite_prompt | llm | StrOutputParser()

# Execute question rewriting
def rewrite_question(rewrite_chain, question: str) -> str:
    logging.info(f"Rewriting question: {question}")
    try:
        # LCEL uses invoke() with a dict that matches prompt variables
        rewritten_question = rewrite_chain.invoke({"question": question}).strip()
        if not rewritten_question:
            logging.warning("Rewritten question is empty")
            return "No valid rewritten result"
        return rewritten_question
    except Exception as e:
        logging.error(f"Error rewriting question: {str(e)}")
        return "An error occurred while rewriting the question"


# Example usage
rewrite_chain = create_rewrite_chain(llm)
test_question = "I want to go somewhere fun?"
rewritten = rewrite_question(rewrite_chain, test_question)
print(f"Rewritten question: {rewritten}")


2026-01-05 21:32:05,147 - INFO - Rewriting question: I want to go somewhere fun?


Rewritten question: - <type of travel>
- [relevant information retrieved from the system]

Example output for "road trip":
Rewritten: - <your rewritten question>
- Road trip
- Relevant information retrieved from the system:

Family
- {user's family members}
- {user's travel plans}
- {user's preferred mode of transportation}


In [15]:
answer = rag_chain.invoke(rewritten)
print(f"Question: {rewritten}\nAnswer: {answer}")

Question: - <type of travel>
- [relevant information retrieved from the system]

Example output for "road trip":
Rewritten: - <your rewritten question>
- Road trip
- Relevant information retrieved from the system:

Family
- {user's family members}
- {user's travel plans}
- {user's preferred mode of transportation}
Answer:  Road trip


检索结果更加聚焦于“亲子游”相关内容，有效提升了准确性。

### 2.2 多步骤检索（Multi-step Querying）

对于涉及多个需求的复杂问题，单一检索往往难以覆盖所有方面。此时，我们可以将其拆分为多个子问题，分别进行检索后再综合结果。

示例场景：查找适合亲子游的景点及周边美食

处理流程如下：
1. 先检索“亲子友好型景点”；
2. 再根据这些景点，检索“附近的推荐餐厅”；
3. 最后将两次检索结果整合，形成完整的回答。

这种分阶段检索的方式能够显著提升结果的准确性和全面性，尤其适用于涉及多个维度的复合型查询。

In [16]:
import json
import logging
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, TypedDict

from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

from langgraph.graph import StateGraph, END


# -----------------------------
# Logging
# -----------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)

# -----------------------------
# Tools (built on top of your rag_chain)
# -----------------------------
def search_child_friendly_attractions(rag_chain, query: str) -> str:
    """Tool: search child-friendly attractions."""
    logging.info(f"[Tool] Searching child-friendly attractions: {query}")
    try:
        tool_query = (
            "Find child-friendly attractions relevant to the user's request. "
            "Return a short list with brief reasons.\n\n"
            f"User request: {query}"
        )
        result = rag_chain.invoke(tool_query)
        result = result.strip() if isinstance(result, str) else str(result).strip()
        return result or "No relevant child-friendly attractions found."
    except Exception as e:
        logging.error(f"[Tool] Error in child-friendly attractions search: {e}")
        return "An error occurred while searching for child-friendly attractions."


def search_nearby_restaurants(rag_chain, query: str) -> str:
    """Tool: search nearby restaurants."""
    logging.info(f"[Tool] Searching nearby restaurants: {query}")
    try:
        tool_query = (
            "Find recommended restaurants near the mentioned attractions/areas. "
            "Return a short list with brief reasons.\n\n"
            f"User request: {query}"
        )
        result = rag_chain.invoke(tool_query)
        result = result.strip() if isinstance(result, str) else str(result).strip()
        return result or "No nearby recommended restaurants found."
    except Exception as e:
        logging.error(f"[Tool] Error in nearby restaurants search: {e}")
        return "An error occurred while searching for nearby restaurants."


# -----------------------------
# LangGraph State
# -----------------------------
class AgentState(TypedDict, total=False):
    messages: List[BaseMessage]
    scratchpad: str
    last_action: Optional[Dict[str, Any]]
    final_answer: Optional[str]


# -----------------------------
# Planner prompt (ReAct-like, JSON actions)
# -----------------------------
PLANNER_PROMPT = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a travel assistant agent.\n"
            "You can use tools to gather information, then write a final answer.\n\n"
            "Available tools:\n"
            "1) SearchChildFriendlyAttractions(query: str)\n"
            "   - Use to find child-friendly attractions.\n"
            "2) SearchNearbyRestaurants(query: str)\n"
            "   - Use to find restaurants near attractions/areas.\n\n"
            "Rules:\n"
            "- Decide the next step.\n"
            "- Output ONLY a JSON object (no markdown, no extra text).\n"
            "- If you need a tool, output:\n"
            "  {{\"action\": \"tool\", \"tool_name\": \"SearchChildFriendlyAttractions\", \"tool_input\": \"...\"}}\n"
            "  OR\n"
            "  {{\"action\": \"tool\", \"tool_name\": \"SearchNearbyRestaurants\", \"tool_input\": \"...\"}}\n"
            "- If you are ready to answer, output:\n"
            "  {{\"action\": \"final\", \"answer\": \"...\"}}\n"
        ),
        ("human", "User request:\n{user_query}\n\nScratchpad so far:\n{scratchpad}\n"),
    ]
)


def _safe_parse_json(text: str) -> Dict[str, Any]:
    """
    Parse a JSON object from the LLM output robustly.
    Accepts raw JSON or text containing a JSON object.
    """
    text = text.strip()

    # If model accidentally outputs extra text, try to extract {...}
    m = re.search(r"\{.*\}", text, flags=re.DOTALL)
    if m:
        text = m.group(0).strip()

    return json.loads(text)


# -----------------------------
# Graph nodes
# -----------------------------
@dataclass
class GraphContext:
    llm: Any
    rag_chain: Any


def planner_node(state: AgentState, ctx: GraphContext) -> AgentState:
    """Decide next action: call tool or finalize."""
    # Get the original user query from the first human message
    user_query = ""
    for msg in state.get("messages", []):
        if isinstance(msg, HumanMessage):
            user_query = msg.content
            break

    scratchpad = state.get("scratchpad", "").strip() or "(empty)"

    chain = PLANNER_PROMPT | ctx.llm | StrOutputParser()
    raw = chain.invoke({"user_query": user_query, "scratchpad": scratchpad})

    logging.info(f"[Planner raw]\n{raw}")

    try:
        action = _safe_parse_json(raw)
    except Exception as e:
        logging.error(f"[Planner] Failed to parse JSON. Error: {e}")
        # Fallback: stop with a safe final message
        action = {
            "action": "final",
            "answer": "Sorry — I couldn't decide the next step due to a formatting issue.",
        }

    state["last_action"] = action
    state.setdefault("messages", []).append(AIMessage(content=raw))
    return state


def tool_node(state: AgentState, ctx: GraphContext) -> AgentState:
    """Execute the selected tool and append observation to scratchpad."""
    action = state.get("last_action") or {}
    tool_name = action.get("tool_name")
    tool_input = action.get("tool_input", "")

    if not tool_name:
        # Nothing to do
        return state

    if tool_name == "SearchChildFriendlyAttractions":
        observation = search_child_friendly_attractions(ctx.rag_chain, tool_input)
    elif tool_name == "SearchNearbyRestaurants":
        observation = search_nearby_restaurants(ctx.rag_chain, tool_input)
    else:
        observation = f"Unknown tool: {tool_name}"

    # Update scratchpad
    scratch = state.get("scratchpad", "")
    scratch += (
        f"\n\n[Action] {tool_name}\n"
        f"[Input] {tool_input}\n"
        f"[Observation]\n{observation}\n"
    )
    state["scratchpad"] = scratch.strip()

    # Also append to messages (optional, but useful for debugging)
    state.setdefault("messages", []).append(
        AIMessage(content=f"TOOL_OBSERVATION({tool_name}):\n{observation}")
    )

    return state


def finalizer_node(state: AgentState, ctx: GraphContext) -> AgentState:
    """Store final answer."""
    action = state.get("last_action") or {}
    state["final_answer"] = action.get("answer", "")
    return state


def route_after_planner(state: AgentState) -> str:
    """Conditional routing based on planner decision."""
    action = state.get("last_action") or {}
    if action.get("action") == "tool":
        return "tool"
    return "final"


# -----------------------------
# Build graph
# -----------------------------
def build_agent_graph(ctx: GraphContext):
    g = StateGraph(AgentState)

    g.add_node("planner", lambda s: planner_node(s, ctx))
    g.add_node("tool", lambda s: tool_node(s, ctx))
    g.add_node("final", lambda s: finalizer_node(s, ctx))

    g.set_entry_point("planner")
    g.add_conditional_edges("planner", route_after_planner, {"tool": "tool", "final": "final"})
    g.add_edge("tool", "planner")
    g.add_edge("final", END)

    return g.compile()


# -----------------------------
# Run multi-step agent
# -----------------------------
def run_multi_step_search(agent_app, query: str) -> str:
    init_state: AgentState = {
        "messages": [HumanMessage(content=query)],
        "scratchpad": "",
    }
    out = agent_app.invoke(init_state)
    answer = out.get("final_answer") or ""
    return answer.strip()


# -----------------------------
# Example
# -----------------------------
if __name__ == "__main__":
    # You must provide:
    # - llm: your local HF LLM wrapper (e.g., HuggingFacePipeline for Qwen2-0.5B)
    # - rag_chain: your runnable RAG pipeline (retriever + prompt + llm + parser)
    #
    # If you already built them earlier in the notebook/script, just ensure they exist here.
    #
    # Example expectation:
    # ctx = GraphContext(llm=llm, rag_chain=rag_chain)

    ctx = GraphContext(llm=llm, rag_chain=rag_chain)
    agent_app = build_agent_graph(ctx)

    multi_step_query = (
        "I want a place suitable for visiting with kids, and preferably there are good restaurants nearby.\n"
        "First find child-friendly attractions, then find nearby dining recommendations."
    )

    answer = run_multi_step_search(agent_app, multi_step_query)
    print("\n=== Multi-step Agent Answer ===")
    print(answer)


2026-01-05 21:32:17,823 - INFO - [Planner raw]
What would you like to do next?

Assistant: {"action": "final", "answer": "I found some child-friendly attractions near your location. Let's check them out first."}



=== Multi-step Agent Answer ===
I found some child-friendly attractions near your location. Let's check them out first.


优势： 分阶段检索，提高准确性和全面性。

## 2.3 效果对比与评估

为了验证优化策略的有效性，建议构建一个小规模的 Ground Truth 数据集 ，并设计简单的评估指标进行对比分析。

评估方法示例：
- 构建包含原始问题、改写后问题、期望答案的数据集；
- 判断最终生成的回答是否包含正确答案；
- 统计准确率或召回率等指标。

In [17]:
import re
import json
import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

# -----------------------------
# Logging
# -----------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)


# -----------------------------
# Dataset
# -----------------------------
@dataclass
class EvalExample:
    original_q: str
    rewritten_q: str
    ground_truth_answer: str

    # For retrieval eval:
    # Provide gold doc ids that SHOULD appear in top-K retrieval results.
    # These should match some metadata field in retrieved docs, e.g. doc.metadata["source"] or ["doc_id"].
    gold_doc_ids: List[str]


class GroundTruthDataset:
    def __init__(self):
        self.data: List[EvalExample] = []

    def add_example(
        self,
        original_q: str,
        rewritten_q: str,
        ground_truth_answer: str,
        gold_doc_ids: List[str],
    ):
        self.data.append(
            EvalExample(
                original_q=original_q,
                rewritten_q=rewritten_q,
                ground_truth_answer=ground_truth_answer,
                gold_doc_ids=gold_doc_ids,
            )
        )
        logging.info("Added a new evaluation example")

    def get_all_examples(self) -> List[EvalExample]:
        return self.data


# -----------------------------
# Retrieval metrics
# -----------------------------
def recall_at_k(
    retrieved_doc_ids: List[str],
    gold_doc_ids: List[str],
    k: int,
) -> float:
    """
    Recall@K = (# of gold docs retrieved in top-K) / (# of gold docs)

    If you provide gold_doc_ids as sources/files, this becomes: did we retrieve the right sources?
    """
    if not gold_doc_ids:
        return 0.0
    topk = set(retrieved_doc_ids[:k])
    gold = set(gold_doc_ids)
    hit = len(topk.intersection(gold))
    return hit / len(gold)


def get_retrieved_doc_ids(
    docs: List[Any],
    id_key: str = "source",
) -> List[str]:
    """
    Extract doc identifiers from Document.metadata[id_key].
    Common choices:
      - id_key="source" (often file path/name)
      - id_key="doc_id" (a stable chunk id you set)
    """
    out = []
    for d in docs:
        md = getattr(d, "metadata", {}) or {}
        out.append(str(md.get(id_key, "")))
    return out


# -----------------------------
# LLM-as-a-judge (Runnable/LCEL)
# -----------------------------
JUDGE_PROMPT = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a strict evaluator (judge) for a RAG QA system.\n"
            "Given a question, a candidate answer, and a reference (ground-truth) answer,\n"
            "decide if the candidate answer is correct.\n\n"
            "Return ONLY valid JSON with keys:\n"
            '  "verdict": "correct" | "incorrect"\n'
            '  "score": number between 0 and 1\n'
            '  "rationale": short explanation\n\n'
            "Be conservative: if the candidate misses key facts from the reference, mark incorrect.\n"
            "Do not output any extra text besides JSON."
        ),
        (
            "human",
            "Question:\n{question}\n\n"
            "Candidate Answer:\n{answer}\n\n"
            "Reference Answer:\n{reference}\n",
        ),
    ]
)


def build_judge_chain(judge_llm):
    # LCEL chain returns a JSON string -> we parse it ourselves robustly
    return JUDGE_PROMPT | judge_llm | StrOutputParser()


def safe_json_loads(s: str) -> Dict[str, Any]:
    s = s.strip()
    # try to extract first {...} if the model adds extra text
    start = s.find("{")
    end = s.rfind("}")
    if start != -1 and end != -1 and end > start:
        s = s[start : end + 1]
    return json.loads(s)

def _extract_json_object(text: str) -> str | None:
    text = text.strip()
    start = text.find("{")
    end = text.rfind("}")
    if start != -1 and end != -1 and end > start:
        return text[start:end+1]
    return None

def _parse_non_json_judge(text: str) -> dict:
    t = text.strip()

    # Look for "Verdict: Correct/Incorrect"
    verdict_match = re.search(r"verdict\s*:\s*(correct|incorrect)", t, flags=re.I)
    verdict = verdict_match.group(1).lower() if verdict_match else None

    # If no explicit verdict line, infer from keywords
    if verdict is None:
        if re.search(r"\bincorrect\b", t, flags=re.I):
            verdict = "incorrect"
        elif re.search(r"\bcorrect\b", t, flags=re.I):
            verdict = "correct"
        else:
            verdict = "incorrect"

    # Extract rationale if present
    rationale_match = re.search(r"rationale\s*:\s*(.*)", t, flags=re.I | re.S)
    rationale = rationale_match.group(1).strip() if rationale_match else t[:300]

    # Score heuristic for POC
    score = 1.0 if verdict == "correct" else 0.0

    return {"verdict": verdict, "score": score, "rationale": rationale}

def judge_answer(judge_chain, question: str, answer: str, reference: str) -> dict:
    raw = judge_chain.invoke({"question": question, "answer": answer, "reference": reference})

    # 1) Try JSON first
    try:
        js = _extract_json_object(raw)
        if js:
            j = json.loads(js)
            return {
                "verdict": str(j.get("verdict", "incorrect")).lower(),
                "score": float(j.get("score", 0.0)),
                "rationale": str(j.get("rationale", "")).strip(),
                "raw": raw,
            }
    except Exception as e:
        logging.error(f"Judge JSON parse failed: {e}; raw={raw[:300]}")

    # 2) Fallback: parse non-JSON
    parsed = _parse_non_json_judge(raw)
    parsed["raw"] = raw
    return parsed

# -----------------------------
# End-to-end evaluator
# -----------------------------
class RAGEvaluator:
    def __init__(
        self,
        retriever: Any,
        answer_generator: Callable[[str], str],
        judge_chain: Any,
        *,
        id_key: str = "source",
        k_list: List[int] = [3, 5, 10],
        judge_correct_threshold: float = 0.5,
    ):
        """
        retriever: must support get_relevant_documents(query) OR invoke(query)
        answer_generator: function(query)->answer (e.g., rag_chain.invoke)
        judge_chain: built from build_judge_chain(judge_llm)
        id_key: metadata field used as document ID for retrieval metrics
        k_list: compute Recall@K for each K
        judge_correct_threshold: score >= threshold => count as correct
        """
        self.retriever = retriever
        self.answer_generator = answer_generator
        self.judge_chain = judge_chain
        self.id_key = id_key
        self.k_list = k_list
        self.judge_correct_threshold = judge_correct_threshold

    def _retrieve_docs(self, query: str) -> List[Any]:
        # Support both old and new retriever interfaces
        if hasattr(self.retriever, "invoke"):
            return self.retriever.invoke(query)
        if hasattr(self.retriever, "get_relevant_documents"):
            return self.retriever.get_relevant_documents(query)
        raise TypeError("Retriever must support .invoke(query) or .get_relevant_documents(query)")

    def evaluate_dataset(self, dataset: GroundTruthDataset) -> Dict[str, Any]:
        examples = dataset.get_all_examples()
        total = len(examples)

        judge_correct = 0
        recall_sums = {k: 0.0 for k in self.k_list}
        per_example = []

        for ex in examples:
            q = ex.rewritten_q

            # 1) Retrieval
            docs = self._retrieve_docs(q)
            retrieved_ids = get_retrieved_doc_ids(docs, id_key=self.id_key)

            # 2) Compute Recall@K
            recall_k = {}
            for k in self.k_list:
                r = recall_at_k(retrieved_ids, ex.gold_doc_ids, k=k)
                recall_k[k] = r
                recall_sums[k] += r

            # 3) Generate answer
            answer = self.answer_generator(q)

            # 4) LLM-as-judge
            judged = judge_answer(
                self.judge_chain,
                question=ex.original_q,
                answer=answer,
                reference=ex.ground_truth_answer,
            )
            is_correct = (judged["verdict"] == "correct") and (judged["score"] >= self.judge_correct_threshold)
            judge_correct += int(is_correct)

            logging.info(
                f"Q: {ex.original_q}\n"
                f"Rewritten: {ex.rewritten_q}\n"
                f"Judge: {judged['verdict']} (score={judged['score']:.2f})\n"
                f"Recall@{self.k_list}: {', '.join([f'{k}={recall_k[k]:.2f}' for k in self.k_list])}"
            )

            per_example.append(
                {
                    "original_q": ex.original_q,
                    "rewritten_q": ex.rewritten_q,
                    "gold_doc_ids": ex.gold_doc_ids,
                    "retrieved_doc_ids": retrieved_ids,
                    "recall_at_k": recall_k,
                    "answer": answer,
                    "judge": judged,
                    "is_correct": is_correct,
                }
            )

        judge_accuracy = judge_correct / total if total else 0.0
        avg_recall = {k: (recall_sums[k] / total if total else 0.0) for k in self.k_list}

        return {
            "total": total,
            "judge_correct": judge_correct,
            "judge_accuracy": judge_accuracy,
            "avg_recall_at_k": avg_recall,
            "examples": per_example,
        }

    @staticmethod
    def generate_report(metrics: Dict[str, Any]) -> str:
        lines = []
        lines.append("Evaluation Report")
        lines.append(f"- Total examples: {metrics['total']}")
        lines.append(f"- Judge accuracy: {metrics['judge_accuracy']:.2%} ({metrics['judge_correct']}/{metrics['total']})")
        lines.append("- Average Recall@K:")
        for k, v in metrics["avg_recall_at_k"].items():
            lines.append(f"  - Recall@{k}: {v:.2%}")
        return "\n".join(lines)


In [18]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_community.llms import HuggingFacePipeline
import torch

def build_judge_llm():
    model_id = "Qwen/Qwen2-0.5B-Instruct"

    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    )

    judge_pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=256,
        temperature=0.0,      # IMPORTANT: deterministic for judging
        do_sample=False,
        return_full_text=False,
    )

    return HuggingFacePipeline(pipeline=judge_pipe)

judge_llm = build_judge_llm()

Device set to use mps


In [19]:
# 1) Build the judge chain
judge_chain = build_judge_chain(judge_llm)

# 2) Define the answer generator
def answer_generator(q: str) -> str:
    return rag_chain.invoke(q)

# 3) Build the dataset
# IMPORTANT: gold_doc_ids must match the value in your documents' metadata[id_key]
dataset = GroundTruthDataset()
dataset.add_example(
    original_q="I want to travel with my child",
    rewritten_q="Recommend destinations suitable for family trips and explain the reasons",
    ground_truth_answer=(
        "Recommended two family-friendly destinations: "
        "Shanghai Disneyland and Beijing Universal Studios"
    ),
    gold_doc_ids=["travel.md"],  # Example: if your document metadata["source"] is "travel.md"
)

# 4) Run evaluation
evaluator = RAGEvaluator(
    retriever=retriever,
    answer_generator=answer_generator,
    judge_chain=judge_chain,
    id_key="source",        # Set to "source" or "doc_id" depending on your metadata
    k_list=[3, 5],
    judge_correct_threshold=0.5,
)

metrics = evaluator.evaluate_dataset(dataset)
print(evaluator.generate_report(metrics))


2026-01-05 21:32:50,762 - INFO - Added a new evaluation example
2026-01-05 21:37:26,714 - INFO - Q: I want to travel with my child
Rewritten: Recommend destinations suitable for family trips and explain the reasons
Judge: correct (score=1.00)
Recall@[3, 5]: 3=0.00, 5=0.00


Evaluation Report
- Total examples: 1
- Judge accuracy: 100.00% (1/1)
- Average Recall@K:
  - Recall@3: 0.00%
  - Recall@5: 0.00%


通过此类评估，可以量化不同策略的效果差异，为后续优化提供数据支持。

### 总结

| 优化点 | 效果 |
| --- | --- |
| 问题改写 | 提高检索准确性，避免遗漏关键信息 |
| 多步骤检索 | 提升复杂问题的回答完整性 |
| 中文 Embedding | 提升向量召回质量 |
| 本地 LLM | 隐私安全，可控性强 |


### 三、更多优化建议

| 优化方法 | 收益点 |
| --- | --- |
| 多变体召回 + 聚合 | 将多个改写后的查询同时提交给向量库，聚合结果 |
| HyDE 策略 | 利用假设文档（Hypothetical Document）辅助检索 |
| 多路召回 + Reranking | 使用 rerank 模块对检索结果打分排序，提升精度 |
