In [2]:
import os
from dotenv import load_dotenv

load_dotenv()

from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate

from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import AnyMessage, add_messages
from langgraph.checkpoint.memory import MemorySaver

from pydantic import BaseModel, Field, ConfigDict, field_validator
from typing import Annotated, List, Optional, Dict, Any, Literal
import json

In [3]:
llm = llm = ChatOpenAI(model="deepseek-chat", openai_api_key=os.getenv("DEEPSEEK_API_KEY"), openai_api_base='https://api.deepseek.com')

## 知识图谱检索类

In [None]:
# import os
# from typing import Dict, List, Any, Optional, Tuple
# import numpy as np
# from neo4j import GraphDatabase
# from dotenv import load_dotenv
# # import hnswlib  # 假设使用HNSW作为向量索引库
# from transformers import AutoTokenizer, AutoModel
# import torch



class KnowledgeGraphRetriever:
    def __init__(self):
        # Neo4j连接
        self.neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
        self.neo4j_user = os.getenv("NEO4J_USER", "neo4j")
        self.neo4j_password = os.getenv("NEO4J_PASSWORD", "password")
        self.neo4j_driver = GraphDatabase.driver(
            self.neo4j_uri, 
            auth=(self.neo4j_user, self.neo4j_password)
        )
        
        # 向量数据库连接
        self.vector_dim = 128  # 降维后的向量维度
        self.vector_index = self._load_vector_index()
        
        # 加载模型用于文本向量化
        self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b")
        self.model = AutoModel.from_pretrained("THUDM/chatglm3-6b")
        self.model.eval()
        
    def _load_vector_index(self) -> hnswlib.Index:
        """加载HNSW向量索引"""
        index_path = os.getenv("VECTOR_INDEX_PATH", "./vector_index.bin")
        
        # 初始化索引
        index = hnswlib.Index(space='cosine', dim=self.vector_dim)
        
        # 如果索引文件存在，则加载
        try:
            index.load_index(index_path)
            print(f"向量索引已加载，包含 {index.get_current_count()} 个向量")
        except:
            print("向量索引加载失败，请确保已构建索引")
        
        return index
    
    def _get_text_embedding(self, text: str) -> np.ndarray:
        """获取文本的向量表示"""
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        with torch.no_grad():
            outputs = self.model(**inputs)
            # 使用最后一层隐藏状态的平均值作为文本表示
            embeddings = outputs.last_hidden_state.mean(dim=1).numpy()
        return embeddings[0]  # 返回第一个样本的嵌入向量
    
    def retrieve_by_entity_name(self, entity_name: str, limit: int = 10) -> Dict[str, Any]:
        """
        根据实体名称检索知识图谱
        
        Args:
            entity_name: 实体名称，如"二叉树深度优先遍历"
            limit: 返回的相关实体数量上限
            
        Returns:
            包含查询实体和相关实体的字典
        """
        # 1. 在Neo4j中查询实体信息
        query_entity = self._get_entity_by_name(entity_name)
        
        if not query_entity:
            # 如果找不到精确匹配的实体，尝试使用向量相似度查找最相似的实体
            similar_entities = self._find_similar_entities_by_vector(entity_name, top_k=1)
            if similar_entities:
                query_entity = self._get_entity_by_id(similar_entities[0]["id"])
            else:
                return {"error": f"找不到与'{entity_name}'相关的实体"}
        
        # 2. 查询与该实体直接相关的其他实体及其关系（包含RS值）
        related_entities = self._get_related_entities(query_entity["id"], limit)
        
        # 3. 构建知识路径
        knowledge_paths = self._build_knowledge_paths(query_entity["id"], related_entities)
        
        # 4. 统计相关实体的关联度分布
        relevance_stats = {
            "high_relevance_count": 0,
            "medium_relevance_count": 0,
            "low_relevance_count": 0
        }
        
        for related in related_entities:
            rs_score = related["relation"]["rs_score"]
            if rs_score >= 0.6:
                relevance_stats["high_relevance_count"] += 1
            elif rs_score >= 0.3:
                relevance_stats["medium_relevance_count"] += 1
            else:
                relevance_stats["low_relevance_count"] += 1
        
        # 5. 构建返回结果
        result = {
            "query_entity": query_entity,
            "related_entities": related_entities,
            "knowledge_paths": knowledge_paths,
            "metadata": {
                "total_related_entities": len(related_entities),
                **relevance_stats,
                "query_timestamp": self._get_current_timestamp()
            }
        }
        
        return result
    
    def _get_entity_by_name(self, entity_name: str) -> Optional[Dict[str, Any]]:
        """根据实体名称在Neo4j中查询实体信息"""
        with self.neo4j_driver.session() as session:
            result = session.run(
                """
                MATCH (e:Entity)
                WHERE e.name = $name
                RETURN e.id AS id, e.name AS name, e.type AS type, 
                       e.definition AS definition, e.attributes AS attributes
                LIMIT 1
                """,
                name=entity_name
            )
            
            record = result.single()
            if record:
                return {
                    "id": record["id"],
                    "name": record["name"],
                    "type": record["type"],
                    "definition": record["definition"],
                    "attributes": record["attributes"]
                }
            return None
    
    def _get_entity_by_id(self, entity_id: str) -> Optional[Dict[str, Any]]:
        """根据实体ID在Neo4j中查询实体信息"""
        with self.neo4j_driver.session() as session:
            result = session.run(
                """
                MATCH (e:Entity {id: $id})
                RETURN e.id AS id, e.name AS name, e.type AS type, 
                       e.definition AS definition, e.attributes AS attributes
                LIMIT 1
                """,
                id=entity_id
            )
            
            record = result.single()
            if record:
                return {
                    "id": record["id"],
                    "name": record["name"],
                    "type": record["type"],
                    "definition": record["definition"],
                    "attributes": record["attributes"]
                }
            return None
    
    def _find_similar_entities_by_vector(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """使用向量相似度查找与文本语义相似的实体"""
        # 获取文本的向量表示
        query_vector = self._get_text_embedding(text)
        
        # 在向量索引中查找最相似的向量
        labels, distances = self.vector_index.knn_query(query_vector, k=top_k)
        
        # 获取对应的实体ID
        similar_entities = []
        for i, (label, distance) in enumerate(zip(labels[0], distances[0])):
            entity_id = f"E{label:03d}"  # 假设ID格式为E001, E002等
            similar_entities.append({
                "id": entity_id,
                "similarity": 1 - distance  # 转换距离为相似度
            })
        
        return similar_entities
    
    def _get_related_entities(self, entity_id: str, limit: int = 10) -> List[Dict[str, Any]]:
        """获取与指定实体直接相关的其他实体及其关系（包含RS值）"""
        with self.neo4j_driver.session() as session:
            result = session.run(
                """
                MATCH (source:Entity {id: $id})-[r]->(target:Entity)
                RETURN source.id AS source_id, target.id AS target_id, 
                       target.name AS target_name, target.type AS target_type,
                       target.definition AS target_definition, target.attributes AS target_attributes,
                       type(r) AS relation_type, r.description AS relation_description,
                       r.rs_score AS rs_score, r.md_value AS md_value, 
                       r.dw_value AS dw_value, r.ss_value AS ss_value
                ORDER BY r.rs_score DESC
                LIMIT $limit
                """,
                id=entity_id,
                limit=limit
            )
            
            related_entities = []
            for record in result:
                related_entities.append({
                    "entity": {
                        "id": record["target_id"],
                        "name": record["target_name"],
                        "type": record["target_type"],
                        "definition": record["target_definition"],
                        "attributes": record["target_attributes"]
                    },
                    "relation": {
                        "type": record["relation_type"],
                        "description": record["relation_description"],
                        "rs_score": record["rs_score"],
                        "md_value": record["md_value"],
                        "dw_value": record["dw_value"],
                        "ss_value": record["ss_value"]
                    }
                })
            
            return related_entities
    
    def _build_knowledge_paths(self, entity_id: str, related_entities: List[Dict[str, Any]]) -> List[List[str]]:
        """构建知识路径"""
        paths = []
        entity_name = None
        
        # 获取查询实体的名称
        with self.neo4j_driver.session() as session:
            result = session.run(
                "MATCH (e:Entity {id: $id}) RETURN e.name AS name",
                id=entity_id
            )
            record = result.single()
            if record:
                entity_name = record["name"]
        
        if not entity_name:
            return paths
        
        # 为每个相关实体构建一个简单的路径
        for related in related_entities:
            target_name = related["entity"]["name"]
            relation_type = related["relation"]["type"]
            paths.append([entity_name, relation_type, target_name])
        
        return paths
    
    def _get_current_timestamp(self) -> str:
        """获取当前时间戳"""
        from datetime import datetime
        return datetime.now().isoformat()
    
    def close(self):
        """关闭数据库连接"""
        self.neo4j_driver.close()

        return "teacher_agent_summary"  # 路由到教师智能体，进行总结评价

## 智能体间的通讯状态

知识图谱检索返回示例如下：

```json
{
  "query_entity": {
    "id": "E042",
    "name": "二叉树深度优先遍历",
    "type": "算法",
    "definition": "一种遍历二叉树的算法，从根节点开始，尽可能深地搜索树的分支，直到到达叶子节点，然后回溯到前一个节点继续搜索其他分支",
    "attributes": {
      "实现方式": ["递归实现", "非递归实现(栈)"],
      "时间复杂度": "O(n)",
      "空间复杂度": "最坏情况O(h)，h为树的高度"
    },
    "vector_embedding": [0.23, 0.45, ..., 0.12]  // 128维向量
  },
  "related_entities": [
    {
      "entity": {
        "id": "E043",
        "name": "前序遍历",
        "type": "算法",
        "definition": "一种特殊的深度优先遍历，访问顺序为：根节点、左子树、右子树"
      },
      "relation": {
        "type": "分类关系",
        "description": "前序遍历是深度优先遍历的一种特定实现",
        "rs_score": 0.82,  //高关联度
        "md_value": 0,
        "dw_value": 0.15,
        "ss_value": 0.91
      }
    },
    {
      "entity": {
        "id": "E044",
        "name": "广度优先遍历",
        "type": "算法",
        "definition": "一种遍历二叉树的算法，按层次从上到下、从左到右访问所有节点"
      },
      "relation": {
        "type": "比较关系",
        "description": "与深度优先遍历形成对比的另一种主要遍历方式",
        "rs_score": 0.58,  //中关联度
        "md_value": 0,
        "dw_value": 0.14,
        "ss_value": 0.65
      }
    },
    // 更多相关实体...
  ],
  "knowledge_path": [
    ["二叉树深度优先遍历", "是...的一种", "树的遍历算法"],
    ["二叉树深度优先遍历", "可以通过", "递归"],
    ["二叉树深度优先遍历", "可以通过", "栈"]
  ],
  "metadata": {
    "total_related_entities": 12,
    "high_relevance_count": 4,
    "medium_relevance_count": 5,
    "low_relevance_count": 3,
    "query_timestamp": "2024-05-15T08:30:45Z"
  }
}
```

In [5]:
class State(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    messages: Annotated[List[AnyMessage], add_messages] = Field(default_factory=list, title="对话列表")
    knowledge_results: list = Field(default=[], title="从存储数据结构知识图谱的向量数据库搜索到的相关知识点信息")
    current_knowledge_point: str = Field(default="", title="当前知识点")
    is_right: bool = Field(default=False, title="用户回复是否正确")
    next_node: str = Field(default="teacher_agent", title="下一个节点")
    success: bool = Field(default=False, title="节点执行是否成功")
    log: str = Field(default="", title="节点执行日志")

    @field_validator("messages", mode="before")
    def validate_messages(cls, v, info):
        if 'messages' in info.data:
            return add_messages(info.data['messages'], v)
        else:
            return v if isinstance(v, list) else [v]
    

## Router Agent

接收历史对话信息以及最新的用户回复进行路由选择

- 如果用户回复内容错误
    - 直接转向 Teacher Agent

- 如果用户回复内容正确，通过 rag 技术，在知识图谱中检索，进行知识点关联度判断
    - 关联度低，转向 Teacher Agent，进行总结
    - 关联度高，转向 Student Agent，进行追问
    - 关联度中，随机选择

In [6]:
class ReplyEvaluation(BaseModel):
    """
    用户回复评估结果
    - is_right: 用户回复是否正确
    - current_knowledge_point: 涉及的知识点
    """
    is_right: bool = Field(..., title="用户回复是否正确")
    current_knowledge_point: str = Field(..., title="涉及的主要知识点")


def router_agent(state: State) -> State:
    """根据当前状态进行路由"""
    try:
        conversation = ""
        for message in state.messages:
            if isinstance(message, HumanMessage):
                conversation += f"user: {message.content}\n"
            elif isinstance(message, AIMessage):
                conversation += f"assistant: {message.content}\n"

        # 首先进行用户回复正确性判断
        prompt = ChatPromptTemplate([
            ("system", "你是一名数据结构领域的专家，现在需要判断用户回复是否正确，并给出涉及的主要知识点"),
            ("human", "{conversation}"),
        ])
        reply_evaluation_chain = prompt | llm.with_structured_output(ReplyEvaluation)
        reply_evaluation = reply_evaluation_chain.invoke({"conversation": conversation})

        # 根据用户回复正确性判断结果进行路由
        if reply_evaluation.is_right:
            # 用户回复正确，进行知识图谱检索
            kg_result = retrieve_knowledge_for_routing(reply_evaluation.current_knowledge_point)
            if "error" in kg_result:
                return State(current_knowledge_point=reply_evaluation.current_knowledge_point, is_right=True, next_node="teacher_agent", success=False, log=f"知识图谱检索失败，转向TeacherAgent")
            
            # 根据关联度判断结果进行路由
            if kg_result["metadata"]["total_related_entities"] == 0:
                return State(current_knowledge_point=reply_evaluation.current_knowledge_point, is_right=True, next_node="teacher_agent", success=False, log=f"无相关知识点，转向TeacherAgent")
            
            if kg_result["metadata"]["high_relevance_count"] > 0:
                high_relevance_entities = []
                for re in kg_result["related_entities"]:
                    if re["relation"]["rs_score"] > 0.6:
                        high_relevance_entities.append(re)
                    return State(current_knowledge_point=reply_evaluation.current_knowledge_point, knowledge_results=high_relevance_entities, is_right=True, next_node="stutent_agent", success=False, log=f"关联度高，转向StudentAgent")

            if kg_result["metadata"]["medium_relevance_count"] > 0:
                medium_relevance_entities = []
                for re in kg_result["related_entities"]:
                    if re["relation"]["rs_score"] > 0.3:
                        medium_relevance_entities.append(re)
                return State(current_knowledge_point=reply_evaluation.current_knowledge_point, knowledge_results=medium_relevance_entities, is_right=True, next_node="random", success=False, log=f"关联度中,随机选择")
                
        else:
            return State(current_knowledge_point=reply_evaluation.current_knowledge_point, is_right=False, next_node="teacher_agent", success=True, log=f"用户回复错误，转向TeacherAgent")
            
            
    except Exception as e:
        return State(prompt="", success=False, log=str(e))

## Student Agent

根据从 Router Agent 转来的状态进行回复

In [None]:
def student_agent(state: State) -> State:
    """根据当前状态进行回复"""
    try:
        conversation = ""
        for message in state.messages:
            if isinstance(message, HumanMessage):
                conversation += f"user: {message.content}\n"
            elif isinstance(message, AIMessage):
                conversation += f"assistant: {message.content}\n"
        
        if state.knowledge_results:
            student_agent_prompt = PromptTemplate.from_template("""
            你是一名正在学习数据结构的学生， 针对当前知识点：{current_knowledge_point}，用户已经做了正确的讲解，请结合可能涉及的相关知识点：{knowledge_results}，进一步追问。
            """)
        else:
            student_agent_prompt = PromptTemplate.from_template("""
            你是一名正在学习数据结构的学生， 针对当前知识点：{current_knowledge_point}，用户已经做了正确的讲解，进一步追问。
            """)

        prompt = ChatPromptTemplate([
            ("system", student_agent_prompt),
            ("human", "{conversation}"),
        ])
        student_agent_chain = prompt | llm
        guidance = student_agent_chain.invoke({"conversation": conversation})
        return State(route="student_agent", success=True, log=f"StudentAgent生成成功", messages=[AIMessage(content=guidance)])
    
    except Exception as e:
        return State(route="student_agent", success=False, log=str(e))


## Teacher Agent

根据从路由智能体转来的状态进行回复

- 如果 is_right == False，即用户回答错误，teacher_agent 进行错误分类与讲解建议
- 如果 is_right == True，那根据 log 的情况进行回复
    - 知识图谱搜索失败，根据历史对话信息适当回复
    - 无关联知识点，总结评价
    - 中关联度，结合关联知识点，总结评价

In [9]:

def teacher_agent(state: State) -> State:
    """根据当前状态进行回复"""
    try:
        conversation = ""
        for message in state.messages:
            if isinstance(message, HumanMessage):
                conversation += f"user: {message.content}\n"
            elif isinstance(message, AIMessage):
                conversation += f"assistant: {message.content}\n"
        
        if state.is_right == False:
            teacher_agent_prompt = PromptTemplate.from_template("""
            你是一名数据结构领域的专家， 针对当前知识点：{current_knowledge_point}，用户回答错误，请给出讲解建议。
            """)    
        elif state.knowledge_results:
            teacher_agent_prompt = PromptTemplate.from_template("""
            你是一名数据结构领域的专家， 针对当前知识点：{current_knowledge_point}，用户回答正确，请结合可能涉及的相关知识点：{knowledge_results}，给出总结评价。
            """)
        else:
            teacher_agent_prompt = PromptTemplate.from_template("""
            你是一名数据结构领域的专家， 针对当前知识点：{current_knowledge_point}，用户回答正确，请给出总结评价。
            """)

        prompt = ChatPromptTemplate([
            ("system", teacher_agent_prompt),
            ("human", "{conversation}"),
        ])
        teacher_agent_chain = prompt | llm
        guidance = teacher_agent_chain.invoke({"conversation": conversation})
        return State(route="teacher_agent", success=True, log=f"TeacherAgent生成成功", messages=[AIMessage(content=guidance)])
    except Exception as e:
        return State(route="teacher_agent", success=False, log=str(e))


## Router

In [10]:
def router(state: State) -> State:
    """根据RouterAgent返回的状态进行路由"""
    if state.next_node == "teacher_agent":
        return "teacher_agent"
    elif state.next_node == "student_agent":
        return "student_agent"
    elif state.next_node == "random":
        # 从 stutent_agent 和 teacher_agent 中随机选择
        return random.choice(["student_agent", "teacher_agent"])
    else:
        return "__end__"


## MainGraph

In [11]:
workflow = StateGraph(State)

workflow.add_node("router", router)
workflow.add_node("student_agent", student_agent)
workflow.add_node("teacher_agent", teacher_agent)

workflow.add_conditional_edges(
    "router",
    router,
    {
        "student_agent": "student_agent",
        "teacher_agent": "teacher_agent",
        "__end__": END
    }
)

workflow.add_edge(START, "router")
workflow.add_edge("student_agent", "teacher_agent")
workflow.add_edge("teacher_agent", END)

memory = MemorySaver()

graph = workflow.compile(checkpointer=memory)


In [None]:
from IPython.display import Image, display

# Setting xray to 1 will show the internal structure of the nested graph
display(Image(graph.get_graph(xray=2).draw_mermaid_png()))