# 07. 并行执行和Map-Reduce

## 课程目标
- 掌握并行节点执行机制
- 学习Map-Reduce模式实现
- 使用Send API创建动态并行任务
- 实现结果聚合和同步策略
- 优化并行处理性能

## 核心概念

LangGraph支持强大的并行处理能力：
1. **并行节点**：同时执行多个节点
2. **Map-Reduce**：分布式处理模式
3. **Send API**：动态创建并行任务
4. **结果聚合**：收集和合并并行结果
5. **负载均衡**：优化资源利用

In [None]:
# 环境准备
from typing import TypedDict, Annotated, List, Dict, Any
from langgraph.graph import StateGraph, END, START
from langgraph.constants import Send
import asyncio
import time
import json
from concurrent.futures import ThreadPoolExecutor
import random

print("环境准备完成")

## 1. 基础并行执行

让我们从简单的并行执行开始：

In [None]:
# 并行处理状态
class ParallelState(TypedDict):
    input_data: List[str]
    results: List[Dict[str, Any]]
    processing_time: float
    total_items: int

# 并行工作节点
def worker_a(state: ParallelState) -> ParallelState:
    """工作节点A"""
    print("🔧 Worker A 开始处理...")
    time.sleep(1)  # 模拟处理时间
    
    result = {
        "worker": "A",
        "processed_items": 3,
        "status": "completed",
        "result": "A处理完成"
    }
    
    return {
        "results": state.get("results", []) + [result]
    }

def worker_b(state: ParallelState) -> ParallelState:
    """工作节点B"""
    print("⚙️ Worker B 开始处理...")
    time.sleep(1.5)  # 不同的处理时间
    
    result = {
        "worker": "B",
        "processed_items": 5,
        "status": "completed",
        "result": "B处理完成"
    }
    
    return {
        "results": state.get("results", []) + [result]
    }

def worker_c(state: ParallelState) -> ParallelState:
    """工作节点C"""
    print("🛠️ Worker C 开始处理...")
    time.sleep(0.8)
    
    result = {
        "worker": "C",
        "processed_items": 2,
        "status": "completed",
        "result": "C处理完成"
    }
    
    return {
        "results": state.get("results", []) + [result]
    }

def aggregate_results(state: ParallelState) -> ParallelState:
    """聚合结果"""
    results = state.get("results", [])
    total_processed = sum(r.get("processed_items", 0) for r in results)
    
    print(f"📊 聚合完成: {len(results)} 个工作节点, 总处理 {total_processed} 项")
    
    return {
        "total_items": total_processed
    }

# 创建并行执行图
def create_parallel_graph():
    graph = StateGraph(ParallelState)
    
    # 添加并行工作节点
    graph.add_node("worker_a", worker_a)
    graph.add_node("worker_b", worker_b)
    graph.add_node("worker_c", worker_c)
    graph.add_node("aggregate", aggregate_results)
    
    # 设置并行入口点
    graph.set_entry_point("worker_a")
    graph.set_entry_point("worker_b")
    graph.set_entry_point("worker_c")
    
    # 所有工作节点完成后聚合
    graph.add_edge("worker_a", "aggregate")
    graph.add_edge("worker_b", "aggregate")
    graph.add_edge("worker_c", "aggregate")
    graph.add_edge("aggregate", END)
    
    return graph.compile()

# 测试并行执行
parallel_app = create_parallel_graph()

print("=== 并行执行演示 ===")
start_time = time.time()

result = parallel_app.invoke({
    "input_data": ["data1", "data2", "data3"],
    "results": []
})

execution_time = time.time() - start_time
print(f"\n执行时间: {execution_time:.2f} 秒")
print(f"处理结果: {len(result.get('results', []))} 个工作节点完成")
print(f"总处理项目: {result.get('total_items', 0)} 项")

# 显示详细结果
for r in result.get('results', []):
    print(f"- {r['worker']}: {r['result']} ({r['processed_items']} 项)")

## 2. Map-Reduce 模式

实现经典的Map-Reduce处理模式：

In [None]:
# Map-Reduce 状态
class MapReduceState(TypedDict):
    input_data: List[int]
    map_results: List[Dict[str, Any]]
    reduce_result: Dict[str, Any]
    chunk_size: int

def split_data(state: MapReduceState) -> MapReduceState:
    """分割数据"""
    input_data = state.get("input_data", [])
    chunk_size = state.get("chunk_size", 3)
    
    # 将数据分成块
    chunks = []
    for i in range(0, len(input_data), chunk_size):
        chunk = input_data[i:i + chunk_size]
        chunks.append(chunk)
    
    print(f"📦 数据分割: {len(input_data)} 项分成 {len(chunks)} 块")
    
    return {"chunks": chunks}

# 创建动态Map节点
def create_map_tasks(state: MapReduceState) -> List[Send]:
    """创建Map任务"""
    chunks = state.get("chunks", [])
    
    # 为每个数据块创建一个Map任务
    tasks = []
    for i, chunk in enumerate(chunks):
        task = Send("map_worker", {
            "chunk_id": i,
            "chunk_data": chunk
        })
        tasks.append(task)
    
    print(f"🚀 创建 {len(tasks)} 个Map任务")
    return tasks

def map_worker(state: Dict[str, Any]) -> Dict[str, Any]:
    """Map工作节点"""
    chunk_id = state["chunk_id"]
    chunk_data = state["chunk_data"]
    
    print(f"🔄 Map {chunk_id}: 处理 {len(chunk_data)} 个数据项")
    
    # 模拟Map操作：计算平方和
    squared_sum = sum(x * x for x in chunk_data)
    max_value = max(chunk_data) if chunk_data else 0
    min_value = min(chunk_data) if chunk_data else 0
    
    # 模拟处理时间
    time.sleep(0.5)
    
    result = {
        "chunk_id": chunk_id,
        "squared_sum": squared_sum,
        "max_value": max_value,
        "min_value": min_value,
        "count": len(chunk_data)
    }
    
    print(f"✅ Map {chunk_id} 完成: 平方和={squared_sum}")
    
    return {
        "map_results": [result]
    }

def reduce_worker(state: MapReduceState) -> MapReduceState:
    """Reduce工作节点"""
    map_results = state.get("map_results", [])
    
    print(f"🔄 Reduce: 聚合 {len(map_results)} 个Map结果")
    
    # 聚合所有Map结果
    total_squared_sum = sum(r["squared_sum"] for r in map_results)
    global_max = max(r["max_value"] for r in map_results) if map_results else 0
    global_min = min(r["min_value"] for r in map_results) if map_results else 0
    total_count = sum(r["count"] for r in map_results)
    
    reduce_result = {
        "total_squared_sum": total_squared_sum,
        "global_max": global_max,
        "global_min": global_min,
        "total_count": total_count,
        "average_squared": total_squared_sum / total_count if total_count > 0 else 0
    }
    
    print(f"✅ Reduce 完成: 总平方和={total_squared_sum}, 全局最大值={global_max}")
    
    return {"reduce_result": reduce_result}

# 创建Map-Reduce图
def create_mapreduce_graph():
    graph = StateGraph(MapReduceState)
    
    # 添加节点
    graph.add_node("split", split_data)
    graph.add_node("map_worker", map_worker)
    graph.add_node("reduce", reduce_worker)
    
    # 设置流程
    graph.set_entry_point("split")
    graph.add_conditional_edges("split", create_map_tasks, ["map_worker"])
    graph.add_edge("map_worker", "reduce")
    graph.add_edge("reduce", END)
    
    return graph.compile()

# 测试Map-Reduce
mapreduce_app = create_mapreduce_graph()

print("\n=== Map-Reduce 演示 ===")
test_data = list(range(1, 21))  # 1到20的数字
print(f"输入数据: {test_data}")

start_time = time.time()

mapreduce_result = mapreduce_app.invoke({
    "input_data": test_data,
    "chunk_size": 5
})

execution_time = time.time() - start_time

print(f"\n=== Map-Reduce 结果 ===")
print(f"执行时间: {execution_time:.2f} 秒")

reduce_result = mapreduce_result.get("reduce_result", {})
print(f"总平方和: {reduce_result.get('total_squared_sum', 0)}")
print(f"全局最大值: {reduce_result.get('global_max', 0)}")
print(f"全局最小值: {reduce_result.get('global_min', 0)}")
print(f"平均平方值: {reduce_result.get('average_squared', 0):.2f}")
print(f"处理总数: {reduce_result.get('total_count', 0)} 项")

## 3. 实践案例：大规模文本处理

构建一个大规模文本处理系统：

In [None]:
# 文本处理状态
class TextProcessingState(TypedDict):
    documents: List[str]
    processing_results: List[Dict[str, Any]]
    final_statistics: Dict[str, Any]
    batch_size: int

def prepare_documents(state: TextProcessingState) -> TextProcessingState:
    """准备文档批次"""
    # 模拟大量文档
    documents = [
        "This is document 1. It contains some text for analysis.",
        "Document 2 has different content and more words to process.",
        "The third document discusses various topics and concepts.",
        "Document 4 contains technical information about systems.",
        "Fifth document explores advanced algorithmic approaches.",
        "Document 6 covers machine learning and data science.",
        "The seventh document examines distributed computing.",
        "Document 8 focuses on parallel processing techniques.",
        "Ninth document analyzes performance optimization methods.",
        "The final document summarizes key findings and conclusions."
    ]
    
    print(f"📚 准备了 {len(documents)} 个文档进行处理")
    
    return {
        "documents": documents,
        "batch_size": 3
    }

def create_processing_tasks(state: TextProcessingState) -> List[Send]:
    """创建文本处理任务"""
    documents = state.get("documents", [])
    batch_size = state.get("batch_size", 3)
    
    # 将文档分批
    tasks = []
    for i in range(0, len(documents), batch_size):
        batch = documents[i:i + batch_size]
        task = Send("process_batch", {
            "batch_id": i // batch_size,
            "batch_documents": batch
        })
        tasks.append(task)
    
    print(f"🚀 创建 {len(tasks)} 个处理批次")
    return tasks

def process_batch(state: Dict[str, Any]) -> Dict[str, Any]:
    """处理文档批次"""
    batch_id = state["batch_id"]
    documents = state["batch_documents"]
    
    print(f"📄 批次 {batch_id}: 处理 {len(documents)} 个文档")
    
    # 模拟文本分析
    batch_stats = {
        "batch_id": batch_id,
        "document_count": len(documents),
        "total_words": 0,
        "total_characters": 0,
        "word_frequencies": {},
        "avg_doc_length": 0
    }
    
    all_words = []
    total_chars = 0
    
    for doc in documents:
        # 简单的文本处理
        words = doc.lower().replace('.', '').replace(',', '').split()
        all_words.extend(words)
        total_chars += len(doc)
    
    # 计算统计信息
    word_freq = {}
    for word in all_words:
        word_freq[word] = word_freq.get(word, 0) + 1
    
    batch_stats.update({
        "total_words": len(all_words),
        "total_characters": total_chars,
        "word_frequencies": dict(list(word_freq.items())[:10]),  # 只保留前10个
        "avg_doc_length": total_chars / len(documents) if documents else 0
    })
    
    # 模拟处理时间
    time.sleep(0.8)
    
    print(f"✅ 批次 {batch_id} 完成: {len(all_words)} 个单词, {total_chars} 个字符")
    
    return {
        "processing_results": [batch_stats]
    }

def aggregate_statistics(state: TextProcessingState) -> TextProcessingState:
    """聚合统计信息"""
    results = state.get("processing_results", [])
    
    print(f"📊 聚合 {len(results)} 个批次的结果")
    
    # 合并所有统计信息
    total_docs = sum(r["document_count"] for r in results)
    total_words = sum(r["total_words"] for r in results)
    total_chars = sum(r["total_characters"] for r in results)
    
    # 合并词频
    combined_freq = {}
    for result in results:
        for word, freq in result["word_frequencies"].items():
            combined_freq[word] = combined_freq.get(word, 0) + freq
    
    # 获取最高频词汇
    top_words = sorted(combined_freq.items(), key=lambda x: x[1], reverse=True)[:10]
    
    final_stats = {
        "total_documents": total_docs,
        "total_words": total_words,
        "total_characters": total_chars,
        "avg_words_per_doc": total_words / total_docs if total_docs > 0 else 0,
        "avg_chars_per_doc": total_chars / total_docs if total_docs > 0 else 0,
        "top_10_words": top_words,
        "processing_batches": len(results)
    }
    
    print(f"✅ 聚合完成: {total_docs} 文档, {total_words} 单词")
    
    return {"final_statistics": final_stats}

# 创建文本处理图
def create_text_processing_graph():
    graph = StateGraph(TextProcessingState)
    
    graph.add_node("prepare", prepare_documents)
    graph.add_node("process_batch", process_batch)
    graph.add_node("aggregate", aggregate_statistics)
    
    graph.set_entry_point("prepare")
    graph.add_conditional_edges("prepare", create_processing_tasks, ["process_batch"])
    graph.add_edge("process_batch", "aggregate")
    graph.add_edge("aggregate", END)
    
    return graph.compile()

# 测试文本处理系统
text_app = create_text_processing_graph()

print("\n=== 大规模文本处理演示 ===")
start_time = time.time()

text_result = text_app.invoke({})

execution_time = time.time() - start_time

print(f"\n=== 文本处理结果 ===")
print(f"总执行时间: {execution_time:.2f} 秒")

stats = text_result.get("final_statistics", {})
print(f"处理文档数: {stats.get('total_documents', 0)}")
print(f"总单词数: {stats.get('total_words', 0)}")
print(f"总字符数: {stats.get('total_characters', 0)}")
print(f"平均每文档单词数: {stats.get('avg_words_per_doc', 0):.1f}")
print(f"处理批次数: {stats.get('processing_batches', 0)}")

print("\n高频词汇:")
for word, freq in stats.get('top_10_words', [])[:5]:
    print(f"  {word}: {freq} 次")

## 4. 练习题

### 练习1：并行图像处理系统
创建一个并行图像处理系统，支持多种滤镜效果：

In [None]:
# 练习1：请实现并行图像处理系统
# TODO: 创建多个图像滤镜节点
# TODO: 实现并行处理多张图片
# TODO: 支持不同的处理效果组合

print("请实现并行图像处理系统")

### 练习2：分布式数据分析
构建一个分布式数据分析系统：

In [None]:
# 练习2：请实现分布式数据分析系统
# TODO: 实现数据分片和分发
# TODO: 并行统计分析
# TODO: 结果聚合和报告生成

print("请实现分布式数据分析系统")

## 总结

在本课中，我们学习了并行执行和Map-Reduce模式：

### 关键要点：
1. **并行执行**：同时运行多个节点提高效率
2. **Map-Reduce**：分而治之的处理模式
3. **Send API**：动态创建并行任务
4. **结果聚合**：合并并行处理结果
5. **性能优化**：合理设计并行度和批次大小

### 最佳实践：
- **合理分片**：平衡负载和通信开销
- **错误处理**：处理并行任务中的异常
- **资源管理**：控制并发度避免资源耗尽
- **监控调试**：跟踪并行任务执行状态

### 应用场景：
- 大数据处理
- 图像/视频处理
- 科学计算
- 机器学习训练
- 爬虫和数据采集

## 下一课预告

在下一课《Human-in-the-Loop》中，我们将学习：
- 人机交互节点设计
- 中断机制实现
- 人工审批流程
- 输入验证和确认
- 交互式对话系统构建