In [1]:
import operator
import os
import sys

current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
from utils.env_util import *
from langgraph_utils.common_util import gen_mermaid
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langchain_openai import ChatOpenAI
from typing import Annotated
from langgraph.types import Send
from pydantic import BaseModel, Field

In [2]:
# 定义 Prompt
# subjects_prompt = """随机生成4个与 {topic} 相关的关键字"""
# joke_prompt = """生成一条关于 {subject} 的笑话"""
# best_joke_prompt = """下面是一些关于 {topic} 的笑话，选择最好的一个，返回其ID（ID 从0开始）。

# {jokes}"""

subjects_prompt = """Generate a comma separated list of between 2 and 5 examples related to: {topic}."""
joke_prompt = """Generate a joke about {subject}"""
best_joke_prompt = """Below are a bunch of jokes about {topic}. Select the best one! Return the ID of the best one.

{jokes}"""


class Subjects(BaseModel):
    subjects: list[str]

class Joke(BaseModel):
    joke: str

class BestJoke(BaseModel):
    id: int = Field(description="Index of the best joke, starting with 0", ge=0)

model = ChatOpenAI(
    openai_api_key=get_openai_api_key(),
    model_name='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
    base_url=get_openai_base_url(),
)

class OverallState(TypedDict):
    topic: str
    subjects: list
    jokes: Annotated[list, operator.add]
    best_selected_joke: str


# 笑话的 subject
class JokeState(TypedDict):
    subject: str

# 通过一个 topic 生成多个 subject
def generate_topics(state: OverallState):
    prompt = subjects_prompt.format(topic=state["topic"])
    response = model.with_structured_output(Subjects).invoke(prompt)
    print(f"⚙️生成主题：{response.subjects}")
    return {"subjects": response.subjects}


# 生成一条笑话
def generate_joke(state: JokeState):
    subject = state["subject"]
    prompt = joke_prompt.format(subject=subject)
    response = model.with_structured_output(Joke).invoke(prompt)
    print(f"⚙️生成[{subject}]笑话：{response.joke}")
    return {"jokes": [response.joke]}


def continue_to_jokes(state: OverallState):
    """
        返回一个 `Send` 对象列表
        每个 `Send` 对象由图中节点的名称组成
        以及发送到该节点的状态

        这里是把所有生成的 subject 都发送给 `generate_joke`
    """
    return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]


def best_joke(state: OverallState):
    """
        从多个笑话中找出1个最好的
    """
    jokes = "\n\n".join(state["jokes"])
    prompt = best_joke_prompt.format(topic=state["topic"], jokes=jokes)

    print("=" * 80)
    print(prompt)
    print("=" * 80)

    response = model.with_structured_output(BestJoke).invoke(prompt)
    return {"best_selected_joke": state["jokes"][response.id]}


graph = StateGraph(OverallState)
graph.add_node("generate_topics", generate_topics)
graph.add_node("generate_joke", generate_joke)
graph.add_node("best_joke", best_joke)
graph.add_edge(START, "generate_topics")
graph.add_conditional_edges("generate_topics", continue_to_jokes, ["generate_joke"])
graph.add_edge("generate_joke", "best_joke")
graph.add_edge("best_joke", END)
app = graph.compile()

gen_mermaid(app, "map_reduce.mmd")

🙈 OPENAI_API_KEY: sk-hybehtt*******************************lpkkvcvojw
👀 DEFAULT_MODEL: Qwen/QwQ-32B
👀 OPENAI_BASE_URL: https://api.siliconflow.cn/v1
✏️ 已生成 mermaid 文件 /Users/yuki/codes/pythonProject/Agent/langgraph_demo/resources/map_reduce.mmd


In [None]:
for s in app.stream({"topic": "animal"}):
    print(s)