In [None]:
# 对话型 RAG
# https://mp.weixin.qq.com/s/aJzrSobkbxJPEfD9BUd-WA

In [3]:
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams


load_dotenv()

llm = ChatOpenAI(
    model="qwen3-32b",
    temperature=0.5,
    base_url=os.environ.get("COMPATIBLE_BASE_URL"),
    api_key=os.environ.get("COMPATIBLE_API_KEY"),
    streaming=True,
    extra_body={"enable_thinking": False},
)

embeddings = DashScopeEmbeddings(
    model="text-embedding-v4",
    dashscope_api_key=os.environ.get("OPENAI_API_KEY"),)



client = QdrantClient(host="localhost", port=6333)

vector_store = QdrantVectorStore(
    client=client,
    collection_name="rag_from_scratch",
    embedding=embeddings,
)

retriever = vector_store.as_retriever(
    search_type="similarity",
    search_kwargs={"k": 2},
)

In [4]:
from langgraph.graph import MessagesState, StateGraph


graph_builder = StateGraph(MessagesState)

In [5]:
from langchain_core.tools import tool


@tool(response_format="content_and_artifact")
def retrieve(query: str):
    """Retrieve information related to a query."""
    retrieved_docs = retriever.invoke(query)
    serialized = "\n\n".join(
        (f"Source: {doc.metadata}\nContent: {doc.page_content}")
        for doc in retrieved_docs
    )
    return serialized, retrieved_docs

In [6]:
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import ToolNode


llm_with_tools = llm.bind_tools([retrieve])


# Step 1: Generate an AIMessage that may include a tool-call to be sent.
def query_or_respond(state: MessagesState):
    """Generate tool call for retrieval or respond."""
    response = llm_with_tools.invoke(state["messages"])
    # MessagesState appends messages to state instead of overwriting
    return {"messages": [response]}


# Step 2: Execute the retrieval.
tools = ToolNode([retrieve])


# Step 3: Generate a response using the retrieved content.
def generate(state: MessagesState):
    """Generate answer."""
    # Get generated ToolMessages
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]

    # Format into prompt
    docs_content = "\n\n".join(doc.content for doc in tool_messages)
    system_message_content = (
        "You are an assistant for question-answering tasks. "
        "Use the following pieces of retrieved context to answer "
        "the question. If you don't know the answer, say that you "
        "don't know. Use three sentences maximum and keep the "
        "answer concise."
        "\n\n"
        f"{docs_content}"
    )
    conversation_messages = [
        message
        for message in state["messages"]
        if message.type in ("human", "system")
        or (message.type == "ai" and not message.tool_calls)
    ]
    prompt = [SystemMessage(system_message_content)] + conversation_messages

    # Run
    response = llm.invoke(prompt)
    return {"messages": [response]}

In [7]:
from langgraph.graph import END
from langgraph.prebuilt import ToolNode, tools_condition

graph_builder.add_node(query_or_respond)
graph_builder.add_node(tools)
graph_builder.add_node(generate)

graph_builder.set_entry_point("query_or_respond")
graph_builder.add_conditional_edges(
    "query_or_respond",
    tools_condition,
    {END: END, "tools": "tools"},
)
graph_builder.add_edge("tools", "generate")
graph_builder.add_edge("generate", END)

graph = graph_builder.compile()

In [9]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

ValueError: Failed to reach https://mermaid.ink/ API while trying to render your graph after 1 retries. To resolve this issue:
1. Check your internet connection and try again
2. Try with higher retry settings: `draw_mermaid_png(..., max_retries=5, retry_delay=2.0)`
3. Use the Pyppeteer rendering method which will render your graph locally in a browser: `draw_mermaid_png(..., draw_method=MermaidDrawMethod.PYPPETEER)`

In [18]:
input_message = "你好"

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()


你好

你好！有什么我可以帮你的吗？


In [19]:
input_message = "耐克，包括匡威在美国有多少个配送中心？"

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()


耐克，包括匡威在美国有多少个配送中心？
Tool Calls:
  retrieve (call_cdbac307a38c4b2fb2bc0c)
 Call ID: call_cdbac307a38c4b2fb2bc0c
  Args:
    query: 耐克，包括匡威在美国有多少个配送中心？
Name: retrieve

Source: {'producer': 'Wdesk Fidelity Content Translations Version 008.001.016', 'creator': 'Workiva', 'creationdate': '2023-07-20T22:09:22+00:00', 'source': '/Users/gavinyao/Downloads/nke-10k-2023.pdf', 'file_path': '/Users/gavinyao/Downloads/nke-10k-2023.pdf', 'total_pages': 106, 'format': 'PDF 1.7', 'title': 'Nike 2023 Proxy', 'author': 'anonymous', 'subject': '', 'keywords': '', 'moddate': '2023-07-26T15:13:52+08:00', 'trapped': '', 'modDate': "D:20230726151352+08'00'", 'creationDate': 'D:20230720220922Z', 'page': 27, 'start_index': 1064, '_id': '4844d61d-07f0-4d17-9802-aa5b4417fe87', '_collection_name': 'rag_from_scratch'}
Content: In the United States, NIKE has eight significant distribution centers. Five are located in or near Memphis, Tennessee, two of
which are owned and three of which are leased. Two other distri

In [20]:
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)

In [21]:
# Specify an ID for the thread
config = {"configurable": {"thread_id": "abc123"}}


input_message = "耐克，包括匡威在美国有多少个配送中心？"

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
    config=config,
):
    step["messages"][-1].pretty_print()


耐克，包括匡威在美国有多少个配送中心？
Tool Calls:
  retrieve (call_3a62085e843f482c854e05)
 Call ID: call_3a62085e843f482c854e05
  Args:
    query: 耐克，包括匡威在美国有多少个配送中心？
Name: retrieve

Source: {'producer': 'Wdesk Fidelity Content Translations Version 008.001.016', 'creator': 'Workiva', 'creationdate': '2023-07-20T22:09:22+00:00', 'source': '/Users/gavinyao/Downloads/nke-10k-2023.pdf', 'file_path': '/Users/gavinyao/Downloads/nke-10k-2023.pdf', 'total_pages': 106, 'format': 'PDF 1.7', 'title': 'Nike 2023 Proxy', 'author': 'anonymous', 'subject': '', 'keywords': '', 'moddate': '2023-07-26T15:13:52+08:00', 'trapped': '', 'modDate': "D:20230726151352+08'00'", 'creationDate': 'D:20230720220922Z', 'page': 27, 'start_index': 1064, '_id': '4844d61d-07f0-4d17-9802-aa5b4417fe87', '_collection_name': 'rag_from_scratch'}
Content: In the United States, NIKE has eight significant distribution centers. Five are located in or near Memphis, Tennessee, two of
which are owned and three of which are leased. Two other distri

In [22]:
input_message = "在美国之外有哪些？"

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
    config=config,
):
    step["messages"][-1].pretty_print()


在美国之外有哪些？

耐克在美国之外的重要配送中心包括：

1. 比利时的Laakdal
2. 中国的太仓（Taicang）
3. 日本的富士见泽（Tomisato）
4. 韩国的议政府（Icheon）

这些配送中心均为耐克所拥有。
