# 05. 子图和嵌套图

## 课程目标
- 理解子图（Subgraphs）的概念和优势
- 学习嵌套图结构的设计
- 掌握子图的编译和执行
- 理解状态传递和隔离机制
- 实现模块化设计模式

## 核心概念
子图是LangGraph中用于构建复杂应用的重要工具：
1. **模块化设计**：将复杂系统分解为独立的子系统
2. **状态隔离**：每个子图可以有独立的状态管理
3. **可重用性**：子图可以在多个地方重用
4. **层次结构**：支持多层嵌套的图结构

In [None]:
# 环境准备
from typing import TypedDict, Annotated, List, Dict, Any
from langgraph.graph import StateGraph, END
import json
import random
from datetime import datetime

print("环境准备完成")

## 1. 基础子图示例

让我们从一个简单的子图开始：

In [None]:
# 定义子图状态
class SubGraphState(TypedDict):
    input_data: str
    processed_data: str
    step_count: int
    logs: List[str]

# 子图节点函数
def preprocess_data(state: SubGraphState) -> SubGraphState:
    """预处理数据"""
    input_data = state.get("input_data", "")
    processed = f"preprocessed_{input_data}"
    
    print(f"📝 预处理数据: {input_data} -> {processed}")
    
    return {
        "processed_data": processed,
        "step_count": state.get("step_count", 0) + 1,
        "logs": state.get("logs", []) + [f"预处理完成: {processed}"]
    }

def validate_data(state: SubGraphState) -> SubGraphState:
    """验证数据"""
    processed_data = state.get("processed_data", "")
    is_valid = len(processed_data) > 5  # 简单验证逻辑
    
    status = "有效" if is_valid else "无效"
    print(f"✅ 数据验证: {processed_data} -> {status}")
    
    return {
        "step_count": state.get("step_count", 0) + 1,
        "logs": state.get("logs", []) + [f"验证结果: {status}"]
    }

def transform_data(state: SubGraphState) -> SubGraphState:
    """转换数据"""
    processed_data = state.get("processed_data", "")
    transformed = f"transformed_{processed_data.upper()}"
    
    print(f"🔄 转换数据: {processed_data} -> {transformed}")
    
    return {
        "processed_data": transformed,
        "step_count": state.get("step_count", 0) + 1,
        "logs": state.get("logs", []) + [f"转换完成: {transformed}"]
    }

# 创建子图
def create_data_processing_subgraph():
    subgraph = StateGraph(SubGraphState)
    
    # 添加节点
    subgraph.add_node("preprocess", preprocess_data)
    subgraph.add_node("validate", validate_data)
    subgraph.add_node("transform", transform_data)
    
    # 设置流程
    subgraph.set_entry_point("preprocess")
    subgraph.add_edge("preprocess", "validate")
    subgraph.add_edge("validate", "transform")
    subgraph.add_edge("transform", END)
    
    return subgraph.compile()

# 测试子图
data_subgraph = create_data_processing_subgraph()

# 执行子图
test_input = "sample_data"
result = data_subgraph.invoke({
    "input_data": test_input,
    "step_count": 0,
    "logs": []
})

print("\n子图执行结果:")
print(json.dumps(result, ensure_ascii=False, indent=2))

## 2. 嵌套图结构

创建包含子图的主图：

In [None]:
# 主图状态
class MainGraphState(TypedDict):
    raw_input: str
    batch_data: List[str]
    processing_results: List[Dict[str, Any]]
    final_output: str
    total_steps: int

# 主图节点函数
def prepare_batch(state: MainGraphState) -> MainGraphState:
    """准备批量数据"""
    raw_input = state.get("raw_input", "")
    
    # 模拟将输入分批处理
    batch_data = [f"{raw_input}_batch_{i}" for i in range(3)]
    
    print(f"📦 准备批量数据: {batch_data}")
    
    return {
        "batch_data": batch_data,
        "processing_results": [],
        "total_steps": 1
    }

def process_with_subgraph(state: MainGraphState) -> MainGraphState:
    """使用子图处理每批数据"""
    batch_data = state.get("batch_data", [])
    processing_results = []
    
    print("🔄 开始批量处理...")
    
    # 为每批数据调用子图
    for i, data in enumerate(batch_data):
        print(f"\n--- 处理批次 {i+1}: {data} ---")
        
        # 调用子图
        subgraph_result = data_subgraph.invoke({
            "input_data": data,
            "step_count": 0,
            "logs": []
        })
        
        processing_results.append({
            "batch_id": i,
            "input": data,
            "output": subgraph_result.get("processed_data", ""),
            "steps": subgraph_result.get("step_count", 0),
            "logs": subgraph_result.get("logs", [])
        })
    
    return {
        "processing_results": processing_results,
        "total_steps": state.get("total_steps", 0) + 1
    }

def aggregate_results(state: MainGraphState) -> MainGraphState:
    """聚合处理结果"""
    processing_results = state.get("processing_results", [])
    
    # 聚合所有处理结果
    all_outputs = [result["output"] for result in processing_results]
    final_output = " | ".join(all_outputs)
    
    total_sub_steps = sum(result["steps"] for result in processing_results)
    
    print(f"📊 聚合结果: {final_output}")
    print(f"📈 总子步骤数: {total_sub_steps}")
    
    return {
        "final_output": final_output,
        "total_steps": state.get("total_steps", 0) + 1
    }

# 创建主图
def create_main_graph():
    main_graph = StateGraph(MainGraphState)
    
    # 添加节点
    main_graph.add_node("prepare", prepare_batch)
    main_graph.add_node("process", process_with_subgraph)
    main_graph.add_node("aggregate", aggregate_results)
    
    # 设置流程
    main_graph.set_entry_point("prepare")
    main_graph.add_edge("prepare", "process")
    main_graph.add_edge("process", "aggregate")
    main_graph.add_edge("aggregate", END)
    
    return main_graph.compile()

# 测试嵌套图
main_graph = create_main_graph()

print("\n=== 嵌套图执行测试 ===")
nested_result = main_graph.invoke({
    "raw_input": "user_data"
})

print("\n嵌套图最终结果:")
print(f"输入: {nested_result.get('raw_input', 'N/A')}")
print(f"批量数据: {nested_result.get('batch_data', [])}")
print(f"最终输出: {nested_result.get('final_output', 'N/A')}")
print(f"主图总步骤: {nested_result.get('total_steps', 0)}")

# 显示每个批次的详细结果
print("\n各批次处理详情:")
for result in nested_result.get('processing_results', []):
    print(f"批次 {result['batch_id']}: {result['input']} -> {result['output']} ({result['steps']} 步骤)")

## 3. 状态传递和隔离

演示不同图之间的状态管理：

In [None]:
# 定义不同的状态类型
class ParentState(TypedDict):
    user_id: str
    session_data: Dict[str, Any]
    child_results: List[Dict[str, Any]]
    global_config: Dict[str, Any]

class ChildState(TypedDict):
    task_id: str
    local_data: str
    processing_status: str
    internal_state: Dict[str, Any]
    inherited_config: Dict[str, Any]

# 子图：任务处理器
def initialize_task(state: ChildState) -> ChildState:
    """初始化任务"""
    task_id = state.get("task_id", f"task_{random.randint(1000, 9999)}")
    config = state.get("inherited_config", {})
    
    print(f"🚀 初始化任务: {task_id}")
    print(f"📋 继承配置: {config}")
    
    return {
        "task_id": task_id,
        "processing_status": "initialized",
        "internal_state": {
            "start_time": datetime.now().isoformat(),
            "steps_completed": 0
        }
    }

def execute_task(state: ChildState) -> ChildState:
    """执行任务"""
    task_id = state.get("task_id", "")
    local_data = state.get("local_data", "")
    internal_state = state.get("internal_state", {})
    
    print(f"⚙️ 执行任务: {task_id} 处理数据: {local_data}")
    
    # 模拟任务处理
    processed_data = f"processed_{local_data}"
    
    updated_internal = {
        **internal_state,
        "steps_completed": internal_state.get("steps_completed", 0) + 1,
        "last_processed": processed_data
    }
    
    return {
        "local_data": processed_data,
        "processing_status": "completed",
        "internal_state": updated_internal
    }

def finalize_task(state: ChildState) -> ChildState:
    """完成任务"""
    task_id = state.get("task_id", "")
    internal_state = state.get("internal_state", {})
    
    print(f"✅ 完成任务: {task_id}")
    
    final_internal = {
        **internal_state,
        "end_time": datetime.now().isoformat(),
        "status": "finalized"
    }
    
    return {
        "processing_status": "finalized",
        "internal_state": final_internal
    }

# 创建子图
def create_task_subgraph():
    task_graph = StateGraph(ChildState)
    
    task_graph.add_node("init", initialize_task)
    task_graph.add_node("execute", execute_task)
    task_graph.add_node("finalize", finalize_task)
    
    task_graph.set_entry_point("init")
    task_graph.add_edge("init", "execute")
    task_graph.add_edge("execute", "finalize")
    task_graph.add_edge("finalize", END)
    
    return task_graph.compile()

# 父图节点
def setup_session(state: ParentState) -> ParentState:
    """设置会话"""
    user_id = state.get("user_id", f"user_{random.randint(100, 999)}")
    
    session_data = {
        "session_id": f"session_{random.randint(1000, 9999)}",
        "start_time": datetime.now().isoformat(),
        "tasks_created": 0
    }
    
    global_config = {
        "timeout": 30,
        "retry_count": 3,
        "debug_mode": True
    }
    
    print(f"👤 设置用户会话: {user_id}")
    print(f"🔧 全局配置: {global_config}")
    
    return {
        "user_id": user_id,
        "session_data": session_data,
        "global_config": global_config,
        "child_results": []
    }

def orchestrate_tasks(state: ParentState) -> ParentState:
    """编排任务执行"""
    session_data = state.get("session_data", {})
    global_config = state.get("global_config", {})
    
    # 创建子图实例
    task_subgraph = create_task_subgraph()
    
    # 模拟创建多个任务
    tasks_to_create = ["task_a", "task_b", "task_c"]
    child_results = []
    
    print("\n🎯 开始编排任务执行...")
    
    for i, task_data in enumerate(tasks_to_create):
        print(f"\n--- 执行子任务 {i+1}: {task_data} ---")
        
        # 准备子图输入（状态隔离）
        child_input = {
            "task_id": f"{task_data}_{i}",
            "local_data": task_data,
            "processing_status": "pending",
            "internal_state": {},
            "inherited_config": global_config  # 传递全局配置
        }
        
        # 执行子图
        child_result = task_subgraph.invoke(child_input)
        
        # 收集结果（只保留必要信息）
        child_results.append({
            "task_id": child_result.get("task_id"),
            "status": child_result.get("processing_status"),
            "output_data": child_result.get("local_data"),
            "execution_summary": {
                "steps": child_result.get("internal_state", {}).get("steps_completed", 0),
                "duration": "calculated_duration"  # 实际应用中计算时间差
            }
        })
    
    # 更新会话数据
    updated_session = {
        **session_data,
        "tasks_created": len(tasks_to_create),
        "tasks_completed": len([r for r in child_results if r["status"] == "finalized"])
    }
    
    return {
        "session_data": updated_session,
        "child_results": child_results
    }

def generate_report(state: ParentState) -> ParentState:
    """生成报告"""
    user_id = state.get("user_id", "")
    session_data = state.get("session_data", {})
    child_results = state.get("child_results", [])
    
    print(f"\n📊 生成执行报告...")
    print(f"用户: {user_id}")
    print(f"会话: {session_data.get('session_id', 'N/A')}")
    print(f"创建任务数: {session_data.get('tasks_created', 0)}")
    print(f"完成任务数: {session_data.get('tasks_completed', 0)}")
    
    print("\n任务执行详情:")
    for result in child_results:
        print(f"- {result['task_id']}: {result['status']} -> {result['output_data']}")
    
    return state

# 创建父图
def create_parent_graph():
    parent_graph = StateGraph(ParentState)
    
    parent_graph.add_node("setup", setup_session)
    parent_graph.add_node("orchestrate", orchestrate_tasks)
    parent_graph.add_node("report", generate_report)
    
    parent_graph.set_entry_point("setup")
    parent_graph.add_edge("setup", "orchestrate")
    parent_graph.add_edge("orchestrate", "report")
    parent_graph.add_edge("report", END)
    
    return parent_graph.compile()

# 测试状态隔离
print("\n=== 状态隔离演示 ===")
parent_graph = create_parent_graph()

isolation_result = parent_graph.invoke({
    "user_id": "demo_user"
})

print("\n=== 最终状态摘要 ===")
print(f"用户ID: {isolation_result.get('user_id')}")
print(f"会话信息: {isolation_result.get('session_data')}")
print(f"子任务结果数量: {len(isolation_result.get('child_results', []))}")

## 4. 实践案例：多阶段数据分析系统

构建一个包含多个专门子图的复杂分析系统：

In [None]:
# 数据清理子图
class CleaningState(TypedDict):
    raw_data: List[str]
    cleaned_data: List[str]
    cleaning_report: Dict[str, Any]

def remove_duplicates(state: CleaningState) -> CleaningState:
    raw_data = state.get("raw_data", [])
    cleaned_data = list(set(raw_data))  # 去重
    removed_count = len(raw_data) - len(cleaned_data)
    
    print(f"🧹 去重处理: 移除了 {removed_count} 个重复项")
    
    return {
        "cleaned_data": cleaned_data,
        "cleaning_report": {"duplicates_removed": removed_count}
    }

def filter_invalid(state: CleaningState) -> CleaningState:
    cleaned_data = state.get("cleaned_data", [])
    valid_data = [item for item in cleaned_data if len(item) >= 3]  # 简单过滤
    filtered_count = len(cleaned_data) - len(valid_data)
    
    print(f"🔍 过滤无效数据: 过滤了 {filtered_count} 个无效项")
    
    report = state.get("cleaning_report", {})
    report["invalid_filtered"] = filtered_count
    
    return {
        "cleaned_data": valid_data,
        "cleaning_report": report
    }

# 创建数据清理子图
def create_cleaning_subgraph():
    cleaning_graph = StateGraph(CleaningState)
    cleaning_graph.add_node("deduplicate", remove_duplicates)
    cleaning_graph.add_node("filter", filter_invalid)
    
    cleaning_graph.set_entry_point("deduplicate")
    cleaning_graph.add_edge("deduplicate", "filter")
    cleaning_graph.add_edge("filter", END)
    
    return cleaning_graph.compile()

# 分析子图
class AnalysisState(TypedDict):
    clean_data: List[str]
    statistics: Dict[str, Any]
    insights: List[str]

def calculate_stats(state: AnalysisState) -> AnalysisState:
    clean_data = state.get("clean_data", [])
    
    stats = {
        "total_items": len(clean_data),
        "avg_length": sum(len(item) for item in clean_data) / len(clean_data) if clean_data else 0,
        "unique_chars": len(set(''.join(clean_data))),
        "longest_item": max(clean_data, key=len) if clean_data else ""
    }
    
    print(f"📊 计算统计信息: {stats}")
    
    return {"statistics": stats}

def generate_insights(state: AnalysisState) -> AnalysisState:
    stats = state.get("statistics", {})
    insights = []
    
    if stats.get("total_items", 0) > 10:
        insights.append("数据集规模较大")
    
    if stats.get("avg_length", 0) > 5:
        insights.append("数据项平均长度较长")
    
    if stats.get("unique_chars", 0) > 20:
        insights.append("字符多样性高")
    
    print(f"💡 生成洞察: {insights}")
    
    return {"insights": insights}

# 创建分析子图
def create_analysis_subgraph():
    analysis_graph = StateGraph(AnalysisState)
    analysis_graph.add_node("stats", calculate_stats)
    analysis_graph.add_node("insights", generate_insights)
    
    analysis_graph.set_entry_point("stats")
    analysis_graph.add_edge("stats", "insights")
    analysis_graph.add_edge("insights", END)
    
    return analysis_graph.compile()

# 主分析系统状态
class AnalysisSystemState(TypedDict):
    input_data: List[str]
    cleaning_results: Dict[str, Any]
    analysis_results: Dict[str, Any]
    final_report: Dict[str, Any]

# 主系统节点
def load_data(state: AnalysisSystemState) -> AnalysisSystemState:
    # 模拟加载数据
    input_data = [
        "data1", "data2", "data1", "abc", "xyz", "data3", "ab", "data4", "xyz", "data5",
        "long_data_item", "another_long_item", "short", "x", "medium_data"
    ]
    
    print(f"📁 加载数据: {len(input_data)} 项")
    return {"input_data": input_data}

def execute_cleaning(state: AnalysisSystemState) -> AnalysisSystemState:
    input_data = state.get("input_data", [])
    
    # 调用清理子图
    cleaning_subgraph = create_cleaning_subgraph()
    cleaning_result = cleaning_subgraph.invoke({
        "raw_data": input_data
    })
    
    return {"cleaning_results": cleaning_result}

def execute_analysis(state: AnalysisSystemState) -> AnalysisSystemState:
    cleaning_results = state.get("cleaning_results", {})
    clean_data = cleaning_results.get("cleaned_data", [])
    
    # 调用分析子图
    analysis_subgraph = create_analysis_subgraph()
    analysis_result = analysis_subgraph.invoke({
        "clean_data": clean_data
    })
    
    return {"analysis_results": analysis_result}

def compile_report(state: AnalysisSystemState) -> AnalysisSystemState:
    input_data = state.get("input_data", [])
    cleaning_results = state.get("cleaning_results", {})
    analysis_results = state.get("analysis_results", {})
    
    final_report = {
        "summary": {
            "original_items": len(input_data),
            "cleaned_items": len(cleaning_results.get("cleaned_data", [])),
            "data_quality_score": 85  # 模拟评分
        },
        "cleaning": cleaning_results.get("cleaning_report", {}),
        "statistics": analysis_results.get("statistics", {}),
        "insights": analysis_results.get("insights", []),
        "timestamp": datetime.now().isoformat()
    }
    
    print("\n📄 生成最终报告:")
    print(json.dumps(final_report, ensure_ascii=False, indent=2))
    
    return {"final_report": final_report}

# 创建主分析系统
def create_analysis_system():
    system_graph = StateGraph(AnalysisSystemState)
    
    system_graph.add_node("load", load_data)
    system_graph.add_node("clean", execute_cleaning)
    system_graph.add_node("analyze", execute_analysis)
    system_graph.add_node("report", compile_report)
    
    system_graph.set_entry_point("load")
    system_graph.add_edge("load", "clean")
    system_graph.add_edge("clean", "analyze")
    system_graph.add_edge("analyze", "report")
    system_graph.add_edge("report", END)
    
    return system_graph.compile()

# 测试完整系统
print("\n=== 多阶段数据分析系统演示 ===")
analysis_system = create_analysis_system()

system_result = analysis_system.invoke({})

print("\n=== 系统执行完成 ===")
report = system_result.get("final_report", {})
print(f"数据质量评分: {report.get('summary', {}).get('data_quality_score', 'N/A')}")
print(f"处理效率: {report.get('summary', {}).get('cleaned_items', 0)}/{report.get('summary', {}).get('original_items', 0)}")
print(f"关键洞察: {', '.join(report.get('insights', []))}")

## 5. 练习题

### 练习1：电商订单处理系统
设计一个包含多个子图的电商订单处理系统：
- 订单验证子图
- 库存管理子图  
- 支付处理子图
- 物流安排子图

In [None]:
# 练习1：请实现电商订单处理系统
# TODO: 定义各个子图的状态和节点
# TODO: 实现子图之间的状态传递
# TODO: 处理异常情况和回滚逻辑

print("请实现电商订单处理系统")

### 练习2：文档处理流水线
创建一个文档处理系统，包含：
- 文档解析子图
- 内容提取子图
- 格式转换子图
- 质量检查子图

In [None]:
# 练习2：请实现文档处理流水线
# TODO: 设计模块化的文档处理流程
# TODO: 实现不同格式的处理子图
# TODO: 添加错误恢复和质量保证机制

print("请实现文档处理流水线")

## 总结

在本课中，我们学习了子图和嵌套图的核心概念：

### 关键要点：
1. **模块化设计**：将复杂系统分解为独立的子系统
2. **状态隔离**：每个子图维护独立的状态空间
3. **可重用性**：子图可以在多个地方复用
4. **层次结构**：支持多层嵌套的复杂架构
5. **状态传递**：合理设计父子图之间的信息传递

### 最佳实践：
- **清晰边界**：明确定义子图的职责和边界
- **状态设计**：合理设计状态结构，避免过度耦合
- **错误处理**：在子图级别实现错误处理和恢复
- **性能优化**：考虑子图调用的开销
- **测试策略**：分别测试子图和整体系统

### 应用场景：
- 大型企业级应用
- 微服务架构
- 数据处理流水线
- 工作流引擎
- 多阶段任务系统

## 下一课预告

在下一课《持久化和检查点》中，我们将学习：
- 检查点机制的原理
- MemorySaver 的使用
- 持久化存储配置
- 状态恢复和回溯
- 线程管理和配置
- 生产环境的可靠性保证