In [2]:
from typing import Annotated, List
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from typing_extensions import TypedDict
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition
from langchain_openai import ChatOpenAI
import yaml
import logging
from vector_store import VectorStoreManager
from tools import Tools
from pathlib import Path
from logger import LoggerManager
import tiktoken
from IPython.display import Image, display
from time import sleep
import json
from langchain_core.messages import ToolMessage

# 加载配置文件
def load_config():
    try:
        config_path = Path("config.yaml")
        print(f"尝试加载配置文件: {config_path.absolute()}")
        
        if not config_path.exists():
            raise FileNotFoundError(f"配置文件不存在: {config_path.absolute()}")
            
        with open(config_path, 'r', encoding='utf-8') as f:
            config = yaml.safe_load(f)
            
        if config is None:
            raise ValueError("配置文件为空或格式错误")
            
        print("成功加载配置文件")
        return config
    except Exception as e:
        print(f"加载配置文件时出错: {str(e)}")
        raise
        
# 控制token数量
def trim_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
    try:
        total_tokens = sum(len(tokenizer.encode(str(msg.content))) for msg in messages)
        logger.info(f"当前消息总token数: {total_tokens}")
        
        max_tokens = config['messages']['max_context_length']
        if total_tokens <= max_tokens:
            return messages
            
        for msg in messages:
            if len(tokenizer.encode(str(msg.content))) > max_tokens:
                logger.warning(f"发现过长消息: {len(tokenizer.encode(str(msg.content)))} tokens")
                
        if config['messages']['trim_strategy'] == 'summary':
            summary_prompt = f"""
            请高度浓缩以下对话历史的关键信息，保留最重要的上下文。
            总结应简洁明了，确保包含对话的核心主题和关键决策。
            
            对话历史：
            {chr(10).join([f"{msg.__class__.__name__}: {msg.content}" for msg in messages])}
            """
            
            summary_response = summary_llm.invoke([
                HumanMessage(content=summary_prompt)
            ])
            
            summary_message = AIMessage(content=f"对话历史总结: {summary_response.content}")
            recent_messages = messages[-config['messages']['keep_latest']:]
            return [summary_message] + recent_messages
        else:
            return messages[-config['messages']['keep_latest']:]
            
    except Exception as e:
        logger.error(f"消息裁剪失败: {str(e)}", exc_info=True)
        return messages[-config['messages']['keep_latest']:]

# 从向量库查询数据
def search_vector_db(query: str, k: int = 3) -> str:
    """搜索向量数据库中的相似内容"""
    if not vector_store_manager:
        return "错误: 向量存储管理器未初始化"
    
    try:
        results = vector_store_manager.similarity_search(query, k=k)
        if not results:
            return "未找到相关内容"
        
        response_list = [f"{i + 1}. {doc.page_content}{' 元数据: ' + str(doc.metadata) if doc.metadata else ''}" for i, doc in enumerate(results)]
        response = "找到以下相关内容：\n" + "\n".join(response_list)
        return response
    except AttributeError as e:
        logger.error(f"向量数据库搜索失败: {str(e)}")
        return f"搜索失败: {str(e)}"
    except Exception as e:
        logger.error(f"未知错误: {str(e)}")
        return f"搜索失败: {str(e)}"


class State(TypedDict):
    messages: Annotated[list, add_messages]
    chat_history: List[BaseMessage]  # 新增对话历史字段

def chatbot(state: State):
    try:
        if not isinstance(state, dict) or "messages" not in state:
            raise ValueError("无效的状态对象")
            
        current_messages = state["messages"]
        chat_history = state.get("chat_history", [])

        if not current_messages:
            return {
                "messages": [AIMessage(content="没有收到消息")],
                "chat_history": chat_history
            }

        combined_messages = chat_history + current_messages
        # 去除重复消息
        unique_combined_messages = []
        seen_messages = set()
        for msg in combined_messages:
            if msg.id not in seen_messages:
                unique_combined_messages.append(msg)
                seen_messages.add(msg.id)
        
        trimmed_messages = trim_messages(unique_combined_messages)

        if not all(isinstance(msg, BaseMessage) for msg in trimmed_messages):
            raise ValueError("所有的消息应该都是 BaseMessages 类型")
        
        # 处理工具调用和响应的配对
        final_messages = []
        i = 0
        while i < len(trimmed_messages):
            message = trimmed_messages[i]
            final_messages.append(message)
            if (isinstance(message, AIMessage) and 
                hasattr(message, 'additional_kwargs') and 
                'tool_calls' in message.additional_kwargs):
                
                tool_calls = message.additional_kwargs['tool_calls']
                next_index = i + 1
                
                # 检查每个工具调用是否有对应的响应
                for tool_call in tool_calls:
                    tool_call_id = tool_call['id']
                    found_response = False
                    
                    # 查找对应的工具响应
                    while next_index < len(trimmed_messages):
                        next_msg = trimmed_messages[next_index]
                        if (isinstance(next_msg, (AIMessage, ToolMessage)) and 
                            hasattr(next_msg, 'tool_call_id') and 
                            next_msg.tool_call_id == tool_call_id):
                            found_response = True
                            final_messages.append(next_msg)
                            next_index += 1
                            break
                        next_index += 1
                    
                    if not found_response:
                        # 如果没有找到响应，创建一个空响应
                        empty_response = AIMessage(
                            content="Tool response not found",
                            additional_kwargs={'tool_call_id': tool_call_id}
                        )
                        final_messages.append(empty_response)
                
                i = next_index
                # print("Final Messages after processing tool calls:", final_messages)
            else:
                i += 1
        response = llm_with_tools.invoke(final_messages)
        return {
            "messages": [response],
            "chat_history": final_messages + [response]
        }
    except Exception as e:
        logger.error(f"Chatbot 函数执行中出现错误: {str(e)}", exc_info=True)
        return {
            "messages": [AIMessage(content="对话出现问题，请稍后再试。")],
            "chat_history": chat_history
        }

config = load_config()

# 设置日志记录的基本配置
log_dir = Path(config['logging']['directory'])
log_dir.mkdir(exist_ok=True)
logging.basicConfig(
    level=getattr(logging, config['logging']['level']),  # 日志记录的级别
    format=config['logging']['format'],  # 日志记录的格式
    filename=config['logging']['filename_pattern'],  # 日志记录的文件名模式
    log_file_path = log_dir / config['logging']['filename_pattern']
)

# 获取一个名为当前模块名的日志记录器
logger = logging.getLogger(__name__)


# 初始化日志记录管理器
logger_manager = LoggerManager(config)
logger = logger_manager.get_logger()  # 获取日志记录器

# 获取模型编码的tokenizer
tokenizer = tiktoken.get_encoding(config['model']['encoding_name'])

# 初始化向量存储管理器
vector_store_manager = VectorStoreManager(config)

# 初始化工具实例
tools_instance = Tools(config, vector_store_manager)

# 创建OpenAI聊天模型实例
llm = ChatOpenAI(
    model=config['model']['name'],  # 模型名称
    temperature=config['model']['temperature']  # 模型温度
)

# 将工具与语言模型绑定
llm_with_tools = llm.bind_tools(tools_instance.tool_list)

# 初始化内存保存器
memory = MemorySaver()

# 创建状态图构建器
graph_builder = StateGraph(State)

# 添加聊天机器人节点
graph_builder.add_node("chatbot", chatbot)

# 添加工具节点
tool_node = ToolNode(tools=tools_instance.tool_list)
graph_builder.add_node("tools", tool_node)


# 添加条件边用于节点之间的切换
graph_builder.add_conditional_edges(
    "chatbot",
    tools_condition,
)

# 添加边从工具节点到聊天机器人节点
graph_builder.add_edge("tools", "chatbot")


# 设置聊天机器人节点为入口点
graph_builder.set_entry_point("chatbot")


# 编译状态图，使用内存检查点器
graph = graph_builder.compile(checkpointer=memory)

# try:
#     display(Image(graph.get_graph().draw_mermaid_png()))
# except Exception as e:
#     logger.error(f"图形可视化失败: {str(e)}")

if __name__ == "__main__":
    current_chat_history = []  # 初始化对话历史
    while True:
        user_input = input("你: ")
        if user_input.lower() == "q":
            print("聊天结束。")
            break
        
        # inputs = {"messages": [HumanMessage(content=user_input)]}
        inputs = {
                "messages": [HumanMessage(content=user_input)],
                "chat_history": current_chat_history  # 传入当前的对话历史
                    }
        config_dict = {
            "configurable": {
                "thread_id": config['graph'].get('thread_id', 'default_thread'),
                "checkpoint_ns": config['graph'].get('checkpoint_ns', 'default_namespace'),
                "checkpoint_id": config['graph'].get('checkpoint_id', 'default_checkpoint')
            }
        }
        try:
            for output in graph.stream(inputs, config=config_dict):
                # 处理工具调用
                messages = output.get('chatbot', {}).get('messages', [])
                if messages and isinstance(messages[0], BaseMessage):
                    if hasattr(messages[0], 'additional_kwargs') and 'tool_calls' in messages[0].additional_kwargs:
                        tool_calls = messages[0].additional_kwargs['tool_calls']
                        for tool_call in tool_calls:
                            arguments = json.loads(tool_call['function']['arguments'])
                            arguments.pop('self', None)
                            tool_call['function']['arguments'] = json.dumps(arguments)
                            tool_response = tools_instance.process_tool_call(tool_call)
                            
                            # 创建工具响应消息
                            new_message = AIMessage(
                                content=tool_response['output'],
                                additional_kwargs={'tool_call_id': tool_response['tool_call_id']}
                            )
                            
                            # 将工具响应添加到当前的对话历史
                            if 'chat_history' in output.get('chatbot', {}):
                                output['chatbot']['chat_history'].append(new_message)
                                # print(output)
                                # print("*"*10)
                            
                    
                    # 打印 AI 的响应消息
                    for chat_history in messages:
                        if chat_history.content is not None and chat_history.content != "":
                            print(f"AI: {chat_history.content}")
                        

                    
                # 更新对话历史
                if 'chat_history' in output.get('chatbot', {}):
                    current_chat_history = output['chatbot']['chat_history']
                    # print(current_chat_history)
            
        except Exception as e:
            logger.error(f"聊天流处理错误: {str(e)}", exc_info=True)
            print("AI: 对话中出现问题，请稍后再试。")

        sleep(1)  # 模拟短暂的聊天停顿

尝试加载配置文件: /home/gaozheng/openai-quickstart/gaozheng/config.yaml
成功加载配置文件


你:  你好啊


AI: 你好！有什么我可以帮助你的吗？


你:  q


聊天结束。
