### 09 从零实现一个角色扮演的聊天机器人

In [None]:
%pip install -Uqq langchain-xai

In [None]:
from google.colab import userdata
from langchain_xai import ChatXAI

chat_model = ChatXAI(
    xai_api_key=userdata.get('xai_api_key'),
    model="grok-beta",
)

In [None]:
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables import RunnableWithMessageHistory
import json
import os
from typing import List
import tiktoken
from langchain_core.messages import SystemMessage, trim_messages, BaseMessage, HumanMessage, AIMessage, ToolMessage
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder

def str_token_counter(text: str) -> int:
    enc = tiktoken.get_encoding("o200k_base")
    return len(enc.encode(text))

def tiktoken_counter(messages: List[BaseMessage]) -> int:
    num_tokens = 3
    tokens_per_message = 3
    tokens_per_name = 1
    for msg in messages:
        if isinstance(msg, HumanMessage):
            role = "user"
        elif isinstance(msg, AIMessage):
            role = "assistant"
        elif isinstance(msg, ToolMessage):
            role = "tool"
        elif isinstance(msg, SystemMessage):
            role = "system"
        else:
            raise ValueError(f"Unsupported messages type {msg.__class__}")
        num_tokens += (
                tokens_per_message
                + str_token_counter(role)
                + str_token_counter(msg.content)
        )
        if msg.name:
            num_tokens += tokens_per_name + str_token_counter(msg.name)
    return num_tokens

trimmer = trim_messages(
    max_tokens=4096,
    strategy="last",
    token_counter=tiktoken_counter,
    include_system=True,
)

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "你现在扮演孔子的角色，尽量按照孔子的风格回复，不要出现‘子曰’",
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
)

# 文件持久化历史记录类
class FileChatMessageHistory(BaseChatMessageHistory):
    def __init__(self, session_id: str, file_path: str = "chat_histories"):
        super().__init__()
        self.session_id = session_id
        os.makedirs(file_path, exist_ok=True)
        self.file_path = os.path.join(file_path, f"{session_id}.json")
        self.messages = self._load_messages()

    def _load_messages(self):
        if not os.path.exists(self.file_path):
            return []
        try:
            with open(self.file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                return [
                    HumanMessage(content=msg['content']) if msg['type'] == 'human'
                    else AIMessage(content=msg['content'])
                    for msg in data
                ]
        except Exception as e:
            print(f"加载历史消息时出错: {str(e)}")
            return []

    def _save_messages(self):
        try:
            data = [
                {
                    'type': 'human' if isinstance(msg, HumanMessage) else 'ai',
                    'content': msg.content
                }
                for msg in self.messages
            ]
            with open(self.file_path, 'w', encoding='utf-8') as f:
                json.dump(data, f, ensure_ascii=False, indent=2)
        except Exception as e:
            print(f"保存消息时出错: {str(e)}")

    def add_message(self, message):
        self.messages.append(message)
        self._save_messages()

    def clear(self):
        self.messages = []
        if os.path.exists(self.file_path):
            os.remove(self.file_path)

session_histories = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in session_histories:
        session_histories[session_id] = FileChatMessageHistory(session_id=session_id)
    return session_histories[session_id]

with_message_history = RunnableWithMessageHistory(
    trimmer | prompt | chat_model,
    get_session_history,
)

config = {"configurable": {"session_id": "confucious"}}

while True:
    user_input = input("You:> ")
    if user_input.lower() == 'exit':
        break
    stream = with_message_history.stream(
        {"messages": [HumanMessage(content=user_input)]},
        config=config
    )
    for chunk in stream:
        print(chunk.content, end='', flush=True)
    print()