# 02. 状态管理和 StateGraph

## 课程目标
- 深入理解 LangGraph 的状态管理机制
- 掌握 StateGraph 的使用方法
- 学习 Reducer 函数的应用
- 理解状态注解和类型系统
- 实现复杂的状态更新逻辑

## 核心概念

StateGraph 是 LangGraph 中最常用的图类型，它提供了：
1. **类型安全的状态管理**
2. **自动状态合并机制**
3. **内置的消息处理**
4. **灵活的状态更新策略**

## 1. 环境准备

In [None]:
# 导入必要的库
from typing import TypedDict, Annotated, Sequence, Literal
from typing import Optional, Union, Any
import operator
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages, Messages
from langgraph.prebuilt import ToolNode
from IPython.display import Image, display
import json
from datetime import datetime

## 2. StateGraph 基础

### 2.1 定义状态结构

In [None]:
# 基础状态定义
class BasicState(TypedDict):
    """最简单的状态定义"""
    count: int
    messages: list[str]
    status: str

# 创建基础 StateGraph
basic_graph = StateGraph(BasicState)

# 定义节点函数
def increment_counter(state: BasicState) -> BasicState:
    """递增计数器"""
    print(f"当前计数: {state.get('count', 0)}")
    return {
        "count": state.get("count", 0) + 1,
        "messages": state.get("messages", []) + ["计数器已递增"],
        "status": "processing"
    }

def check_limit(state: BasicState) -> BasicState:
    """检查计数限制"""
    count = state.get("count", 0)
    if count >= 5:
        status = "limit_reached"
        message = "已达到限制"
    else:
        status = "under_limit"
        message = f"当前计数 {count}，未达限制"
    
    return {
        "messages": state.get("messages", []) + [message],
        "status": status
    }

In [None]:
# 构建图
basic_graph.add_node("increment", increment_counter)
basic_graph.add_node("check", check_limit)

# 设置边
basic_graph.set_entry_point("increment")
basic_graph.add_edge("increment", "check")
basic_graph.add_edge("check", END)

# 编译
basic_app = basic_graph.compile()

# 执行
result = basic_app.invoke({"count": 3})
print("\n执行结果:")
print(json.dumps(result, ensure_ascii=False, indent=2))

## 3. 使用 Reducer 函数进行状态合并

Reducer 函数定义了如何合并状态更新。这是 StateGraph 的核心特性之一。

In [None]:
# 定义带有 Reducer 的状态
class ReducerState(TypedDict):
    # 使用 Annotated 和 operator.add 作为 reducer
    messages: Annotated[list[str], operator.add]
    # 使用自定义 reducer
    total: Annotated[int, lambda x, y: x + y]
    # 不使用 reducer（默认覆盖）
    current_value: int
    # 列表累积
    history: Annotated[list[dict], operator.add]

# 创建带 Reducer 的图
reducer_graph = StateGraph(ReducerState)

def add_values(state: ReducerState) -> dict:
    """添加值到状态"""
    print("添加新值...")
    return {
        "messages": ["添加了新值"],  # 会被追加到列表
        "total": 10,  # 会被累加
        "current_value": 42,  # 会覆盖
        "history": [{"action": "add", "value": 10, "time": datetime.now().isoformat()}]
    }

def multiply_values(state: ReducerState) -> dict:
    """乘法操作"""
    current = state.get("current_value", 1)
    print(f"当前值 {current} 乘以 2")
    return {
        "messages": [f"执行乘法: {current} * 2"],
        "total": current * 2,
        "current_value": current * 2,
        "history": [{"action": "multiply", "value": current * 2, "time": datetime.now().isoformat()}]
    }

In [None]:
# 构建和执行 Reducer 图
reducer_graph.add_node("add", add_values)
reducer_graph.add_node("multiply", multiply_values)

reducer_graph.set_entry_point("add")
reducer_graph.add_edge("add", "multiply")
reducer_graph.add_edge("multiply", END)

reducer_app = reducer_graph.compile()

# 执行并观察 reducer 的效果
initial_state = {
    "messages": ["开始"],
    "total": 5,
    "current_value": 1,
    "history": []
}

print("初始状态:")
print(json.dumps(initial_state, ensure_ascii=False, indent=2))

result = reducer_app.invoke(initial_state)

print("\n最终状态:")
print(json.dumps(result, ensure_ascii=False, indent=2, default=str))

## 4. 自定义 Reducer 函数

除了使用内置的 reducer，我们还可以创建自定义的 reducer 函数：

In [None]:
# 自定义 reducer 函数
def merge_dicts(existing: dict, new: dict) -> dict:
    """合并两个字典，深度合并"""
    result = existing.copy()
    for key, value in new.items():
        if key in result and isinstance(result[key], dict) and isinstance(value, dict):
            result[key] = merge_dicts(result[key], value)
        else:
            result[key] = value
    return result

def keep_last_n(existing: list, new: list, n: int = 5) -> list:
    """只保留最后 n 个元素"""
    combined = existing + new
    return combined[-n:] if len(combined) > n else combined

# 使用自定义 reducer 的状态
class CustomReducerState(TypedDict):
    # 深度合并字典
    config: Annotated[dict, merge_dicts]
    # 只保留最后5条消息
    recent_messages: Annotated[list, lambda x, y: keep_last_n(x, y, 5)]
    # 取最大值
    max_value: Annotated[int, max]
    # 集合去重
    unique_tags: Annotated[set, lambda x, y: x.union(y) if x else y]

In [None]:
# 创建使用自定义 reducer 的图
custom_graph = StateGraph(CustomReducerState)

def process_data_1(state: CustomReducerState) -> dict:
    return {
        "config": {"level1": {"setting1": "value1"}},
        "recent_messages": ["消息1", "消息2"],
        "max_value": 10,
        "unique_tags": {"tag1", "tag2"}
    }

def process_data_2(state: CustomReducerState) -> dict:
    return {
        "config": {"level1": {"setting2": "value2"}, "level2": {"option": "enabled"}},
        "recent_messages": ["消息3", "消息4", "消息5", "消息6"],
        "max_value": 15,
        "unique_tags": {"tag2", "tag3", "tag4"}
    }

# 构建图
custom_graph.add_node("process1", process_data_1)
custom_graph.add_node("process2", process_data_2)

custom_graph.set_entry_point("process1")
custom_graph.add_edge("process1", "process2")
custom_graph.add_edge("process2", END)

custom_app = custom_graph.compile()

# 执行
result = custom_app.invoke({})
print("自定义 Reducer 结果:")
print(json.dumps(result, ensure_ascii=False, indent=2, default=str))

## 5. 消息处理和 add_messages

LangGraph 提供了专门的消息处理机制：

In [None]:
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage

# 使用内置的消息 reducer
class MessageState(TypedDict):
    messages: Annotated[Sequence[Messages], add_messages]
    context: str
    turn_count: int

# 创建消息处理图
message_graph = StateGraph(MessageState)

def process_user_input(state: MessageState) -> dict:
    """处理用户输入"""
    messages = state.get("messages", [])
    
    # 获取最后一条消息
    last_message = messages[-1] if messages else None
    
    print(f"处理用户消息: {last_message.content if last_message else '无'}")
    
    return {
        "messages": [SystemMessage(content="正在处理您的请求...")],
        "context": "user_input_processed",
        "turn_count": state.get("turn_count", 0) + 1
    }

def generate_response(state: MessageState) -> dict:
    """生成AI响应"""
    messages = state.get("messages", [])
    context = state.get("context", "")
    
    # 简单的响应生成逻辑
    response = f"基于上下文 '{context}'，我的回复是：理解您的需求，正在处理。"
    
    print(f"生成响应: {response}")
    
    return {
        "messages": [AIMessage(content=response)],
        "context": "response_generated"
    }

def log_conversation(state: MessageState) -> dict:
    """记录对话"""
    messages = state.get("messages", [])
    turn_count = state.get("turn_count", 0)
    
    print(f"\n对话轮次: {turn_count}")
    print(f"消息总数: {len(messages)}")
    
    return {
        "messages": [SystemMessage(content=f"对话已记录，共{turn_count}轮")]
    }

In [None]:
# 构建消息处理图
message_graph.add_node("process_input", process_user_input)
message_graph.add_node("generate", generate_response)
message_graph.add_node("log", log_conversation)

message_graph.set_entry_point("process_input")
message_graph.add_edge("process_input", "generate")
message_graph.add_edge("generate", "log")
message_graph.add_edge("log", END)

message_app = message_graph.compile()

# 测试消息处理
test_state = {
    "messages": [
        HumanMessage(content="你好，我需要帮助")
    ],
    "context": "",
    "turn_count": 0
}

result = message_app.invoke(test_state)

print("\n最终消息列表:")
for msg in result["messages"]:
    print(f"- [{msg.__class__.__name__}]: {msg.content}")

## 6. 复杂状态管理示例：任务管理系统

让我们构建一个更复杂的例子，展示状态管理的实际应用：

In [None]:
from enum import Enum
from typing import List, Dict

# 定义任务状态枚举
class TaskStatus(str, Enum):
    PENDING = "pending"
    IN_PROGRESS = "in_progress"
    COMPLETED = "completed"
    FAILED = "failed"

# 定义任务结构
class Task(TypedDict):
    id: str
    title: str
    status: TaskStatus
    priority: int
    assigned_to: Optional[str]
    created_at: str
    updated_at: str

# 复杂的状态管理
class TaskManagementState(TypedDict):
    # 任务列表
    tasks: Annotated[List[Task], operator.add]
    # 当前处理的任务ID
    current_task_id: Optional[str]
    # 统计信息（使用自定义 reducer）
    stats: Annotated[Dict[str, int], lambda x, y: {**x, **y}]
    # 操作日志
    logs: Annotated[List[str], operator.add]
    # 错误信息
    errors: Annotated[List[str], operator.add]
    # 系统状态
    system_status: str

In [None]:
# 创建任务管理图
task_graph = StateGraph(TaskManagementState)

def create_task(state: TaskManagementState) -> dict:
    """创建新任务"""
    import uuid
    
    task_id = str(uuid.uuid4())[:8]
    new_task = Task(
        id=task_id,
        title=f"任务 {task_id}",
        status=TaskStatus.PENDING,
        priority=1,
        assigned_to=None,
        created_at=datetime.now().isoformat(),
        updated_at=datetime.now().isoformat()
    )
    
    print(f"✅ 创建任务: {new_task['title']}")
    
    # 更新统计
    current_stats = state.get("stats", {})
    pending_count = current_stats.get("pending", 0) + 1
    
    return {
        "tasks": [new_task],
        "current_task_id": task_id,
        "stats": {"pending": pending_count, "total": current_stats.get("total", 0) + 1},
        "logs": [f"任务 {task_id} 已创建"],
        "system_status": "task_created"
    }

def assign_task(state: TaskManagementState) -> dict:
    """分配任务"""
    current_task_id = state.get("current_task_id")
    
    if not current_task_id:
        return {
            "errors": ["没有当前任务可分配"],
            "system_status": "error"
        }
    
    # 模拟分配
    assignee = "Agent-001"
    
    print(f"👤 分配任务 {current_task_id} 给 {assignee}")
    
    return {
        "logs": [f"任务 {current_task_id} 已分配给 {assignee}"],
        "system_status": "task_assigned"
    }

def process_task(state: TaskManagementState) -> dict:
    """处理任务"""
    current_task_id = state.get("current_task_id")
    
    if not current_task_id:
        return {
            "errors": ["没有当前任务可处理"],
            "system_status": "error"
        }
    
    print(f"⚙️ 处理任务 {current_task_id}")
    
    # 更新统计
    current_stats = state.get("stats", {})
    
    return {
        "stats": {
            "in_progress": current_stats.get("in_progress", 0) + 1,
            "pending": max(0, current_stats.get("pending", 0) - 1)
        },
        "logs": [f"任务 {current_task_id} 正在处理"],
        "system_status": "processing"
    }

def complete_task(state: TaskManagementState) -> dict:
    """完成任务"""
    current_task_id = state.get("current_task_id")
    
    if not current_task_id:
        return {
            "errors": ["没有当前任务可完成"],
            "system_status": "error"
        }
    
    print(f"✔️ 完成任务 {current_task_id}")
    
    # 更新统计
    current_stats = state.get("stats", {})
    
    return {
        "stats": {
            "completed": current_stats.get("completed", 0) + 1,
            "in_progress": max(0, current_stats.get("in_progress", 0) - 1)
        },
        "logs": [f"任务 {current_task_id} 已完成"],
        "current_task_id": None,  # 清空当前任务
        "system_status": "task_completed"
    }

def generate_report(state: TaskManagementState) -> dict:
    """生成报告"""
    stats = state.get("stats", {})
    logs = state.get("logs", [])
    
    report = f"""
    📊 任务管理报告
    ================
    总任务数: {stats.get('total', 0)}
    待处理: {stats.get('pending', 0)}
    进行中: {stats.get('in_progress', 0)}
    已完成: {stats.get('completed', 0)}
    
    最近操作:
    {chr(10).join(f'  - {log}' for log in logs[-5:])}
    """
    
    print(report)
    
    return {
        "logs": ["报告已生成"],
        "system_status": "report_generated"
    }

In [None]:
# 构建任务管理流程
task_graph.add_node("create", create_task)
task_graph.add_node("assign", assign_task)
task_graph.add_node("process", process_task)
task_graph.add_node("complete", complete_task)
task_graph.add_node("report", generate_report)

# 定义流程
task_graph.set_entry_point("create")
task_graph.add_edge("create", "assign")
task_graph.add_edge("assign", "process")
task_graph.add_edge("process", "complete")
task_graph.add_edge("complete", "report")
task_graph.add_edge("report", END)

# 编译
task_app = task_graph.compile()

# 可视化
try:
    display(Image(task_app.get_graph().draw_mermaid_png()))
except:
    print(task_app.get_graph().draw_mermaid())

In [None]:
# 执行任务管理流程
initial_task_state = {
    "tasks": [],
    "current_task_id": None,
    "stats": {"total": 0, "pending": 0, "in_progress": 0, "completed": 0},
    "logs": ["系统启动"],
    "errors": [],
    "system_status": "ready"
}

print("开始执行任务管理流程...\n")
result = task_app.invoke(initial_task_state)

print("\n最终状态:")
print(f"系统状态: {result['system_status']}")
print(f"统计信息: {result['stats']}")
print(f"错误数量: {len(result['errors'])}")

## 7. 并行状态更新

StateGraph 支持并行执行节点并合并它们的状态更新：

In [None]:
# 定义并行处理状态
class ParallelState(TypedDict):
    data: list[str]
    results: Annotated[dict, lambda x, y: {**x, **y}]
    processed_count: Annotated[int, operator.add]
    timestamps: Annotated[list, operator.add]

# 创建并行处理图
parallel_graph = StateGraph(ParallelState)

def processor_a(state: ParallelState) -> dict:
    """处理器 A"""
    import time
    time.sleep(0.1)  # 模拟处理时间
    
    print("🅰️ 处理器 A 执行")
    
    return {
        "results": {"processor_a": "完成"},
        "processed_count": 1,
        "timestamps": [f"A: {datetime.now().isoformat()}"]
    }

def processor_b(state: ParallelState) -> dict:
    """处理器 B"""
    import time
    time.sleep(0.1)  # 模拟处理时间
    
    print("🅱️ 处理器 B 执行")
    
    return {
        "results": {"processor_b": "完成"},
        "processed_count": 1,
        "timestamps": [f"B: {datetime.now().isoformat()}"]
    }

def processor_c(state: ParallelState) -> dict:
    """处理器 C"""
    import time
    time.sleep(0.1)  # 模拟处理时间
    
    print("©️ 处理器 C 执行")
    
    return {
        "results": {"processor_c": "完成"},
        "processed_count": 1,
        "timestamps": [f"C: {datetime.now().isoformat()}"]
    }

def merge_results(state: ParallelState) -> dict:
    """合并结果"""
    results = state.get("results", {})
    count = state.get("processed_count", 0)
    
    print(f"\n📊 合并结果: 共处理 {count} 个任务")
    print(f"结果: {results}")
    
    return {
        "results": {"final": f"合并了 {len(results)} 个处理器的结果"}
    }

In [None]:
# 构建并行处理图
def router(state: ParallelState) -> list[str]:
    """路由到多个处理器"""
    return ["processor_a", "processor_b", "processor_c"]

# 添加节点
parallel_graph.add_node("start", lambda x: {"data": ["开始处理"]})
parallel_graph.add_node("processor_a", processor_a)
parallel_graph.add_node("processor_b", processor_b)
parallel_graph.add_node("processor_c", processor_c)
parallel_graph.add_node("merge", merge_results)

# 设置并行边
parallel_graph.set_entry_point("start")
parallel_graph.add_conditional_edges(
    "start",
    router,
    {
        "processor_a": "processor_a",
        "processor_b": "processor_b",
        "processor_c": "processor_c"
    }
)

# 所有处理器都连接到合并节点
parallel_graph.add_edge("processor_a", "merge")
parallel_graph.add_edge("processor_b", "merge")
parallel_graph.add_edge("processor_c", "merge")
parallel_graph.add_edge("merge", END)

# 编译
parallel_app = parallel_graph.compile()

# 执行
print("开始并行处理...\n")
result = parallel_app.invoke({})

print("\n最终结果:")
print(json.dumps(result, ensure_ascii=False, indent=2, default=str))

## 8. 练习题

### 练习 1：实现一个带优先级的任务队列
创建一个 StateGraph，管理带优先级的任务队列，高优先级任务优先处理。

In [None]:
# 在这里实现你的优先级任务队列
class PriorityTaskState(TypedDict):
    # 定义你的状态结构
    pass

# 创建并测试你的图

### 练习 2：实现状态回滚机制
创建一个支持状态回滚的 StateGraph，可以撤销最近的操作。

In [None]:
# 在这里实现状态回滚机制


## 总结

在本课中，我们深入学习了：

1. **StateGraph 的基础使用**
   - 状态定义和类型注解
   - 节点函数的编写规范

2. **Reducer 函数**
   - 内置 reducer（operator.add, max 等）
   - 自定义 reducer 函数
   - 状态合并策略

3. **消息处理**
   - add_messages reducer
   - 消息类型系统

4. **复杂状态管理**
   - 任务管理系统示例
   - 并行状态更新
   - 状态统计和日志

5. **最佳实践**
   - 使用 TypedDict 定义清晰的状态结构
   - 合理选择 reducer 函数
   - 处理错误和边界情况

## 下一课预告

在下一课中，我们将学习：
- 条件边的详细使用
- 复杂的决策逻辑实现
- 多路径选择和动态路由
- 循环和递归结构