In [1]:
import json
import re
import qianfan
import json
import qianfan
import functools

from typing import List, Tuple, Union, Any

from langchain.agents import BaseSingleActionAgent, AgentExecutor
from langchain.callbacks.base import Callbacks
from langchain.schema import AgentAction, BaseMessage, HumanMessage, AIMessage, AgentFinish
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from langchain.tools import StructuredTool
from langchain.pydantic_v1 import BaseModel, Field
from langchain.chat_models import QianfanChatEndpoint
from langchain.embeddings import CacheBackedEmbeddings, QianfanEmbeddingsEndpoint
from langchain.schema import Document
from langchain.storage import LocalFileStore
from langchain.vectorstores.faiss import FAISS

In [2]:
INITIAL_PROMPT = """在接下来的所有对话中，你可以使用外部的工具来回答问题。
你必须按照规定的格式来使用工具，当你使用工具时，我会在下一轮对话给你工具调用结果，然后你应该根据实际结果判断是否需要进一步使用工具，或给出你的回答。
工具可能有多个，每个工具由名称、描述、参数组成，参数符合标准的json schema。

下面是工具列表:
{tool_list}
如果你需要使用外部工具，那么你的输出必须按照如下格式，只包含2行，不需要输出任何解释或其他无关内容:
Action: 使用的工具名称
Action Input: 使用工具的参数，json格式

如果你不需要使用外部工具，不需要输出Action和Action Input，请输出你的回答。

如果你明白了，请直接回答"好的"，然后让我们开始。"""


class CustomAgent(BaseSingleActionAgent):
    tools: List[BaseTool]
    llm: BaseLanguageModel

    def _generate_initial_prompt(self):
        tool_string = ""
        for tool in self.tools:
            tool_string += "名称: {tool_name}\n" \
                           "描述: {tool_description}\n" \
                           "参数: {{\"type\": \"object\", \"properties\": {tool_args_dict}}}\n" \
                           "-\n".format(tool_name=tool.name,
                                        tool_description=tool.description,
                                        tool_args_dict=json.dumps(tool.args, ensure_ascii=False))
        return INITIAL_PROMPT.format(tool_list=tool_string[:-2])

    def _convert_intermediate_steps_into_message(self, steps: List[Tuple[AgentAction, str]], user_input: str) \
            -> List[BaseMessage]:
        messages = [
            HumanMessage(content=self._generate_initial_prompt()),
            AIMessage(content="好的"),
            HumanMessage(content=user_input),
        ]
        for action, tool_result in steps:
            messages += [
                AIMessage(content="Action: {}\nAction Input: {}".format(action.tool, action.tool_input)),
                HumanMessage(content=tool_result),
            ]
        return messages

    def _parse_output(self, return_message: BaseMessage) -> Union[AgentAction, AgentFinish]:
        exp = re.match(r"^Action: (.*?)\nAction Input: (.*)$", return_message.content)
        if not exp or len(exp.groups()) == 0:
            return AgentFinish(return_values={"output": return_message.content},
                               log=str(return_message.additional_kwargs))
        if len(exp.groups()) != 2:
            raise ValueError("incorrect group counter: " + str(len(exp.groups())))

        print(f"\ntool used: {exp.group(1)}\ntool input: {exp.group(2)}")
        return AgentAction(tool=exp.group(1), tool_input=json.loads(exp.group(2).replace("\'", "\"")), log="")

    def plan(self, intermediate_steps: List[Tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any) -> \
            Union[AgentAction, AgentFinish]:
        messages = self._convert_intermediate_steps_into_message(intermediate_steps, **kwargs)
        return self._parse_output(self.llm.predict_messages(messages))

    async def aplan(self, intermediate_steps: List[Tuple[AgentAction, str]], callbacks: Callbacks = None,
                    **kwargs: Any) -> Union[AgentAction, AgentFinish]:
        raise NotImplementedError("not implementation")

    @property
    def input_keys(self) -> List[str]:
        return ["user_input"]

In [3]:
def json_dump(func):
    @functools.wraps(func) 
    def inner(*args, **kwargs):
        return json.dumps({
            "result": func(*args, **kwargs)
        }, ensure_ascii=False)
    return inner

class TextTranslateSchema(BaseModel):
    text: str = Field(description="需要翻译的文本")
    target_language: str = Field(description="目标语言，格式为ISO 639-1语言代码")

TextTranslateTool = StructuredTool.from_function(
    func=json_dump(qianfan.Tools.translate),
    name="text_translate",
    description="将指定的文本翻译到目标语言",
    args_schema=TextTranslateSchema
)

class OCRSchema(BaseModel):
    image_url: str = Field(description="图像url")

OCRTool = StructuredTool.from_function(
    func=json_dump(qianfan.Tools.ocr),
    name="ocr",
    description="识别图像链接中的文字",
    args_schema=OCRSchema
)

class TTSSchema(BaseModel):
    text: str = Field(description="文本")

TTSTool = StructuredTool.from_function(
    func=json_dump(qianfan.Tools.tts),
    name="tts",
    description="将文本转换为语音",
    args_schema=TTSSchema
)

class TextSimilaritySchema(BaseModel):
    text1: str = Field(description="文本1")
    text2: str = Field(description="文本2")

TextSimilarityTool = StructuredTool.from_function(
    func=json_dump(qianfan.Tools.text_similarity),
    name="text_similarity",
    description="比较两个文本的相似度",
    args_schema=TextSimilaritySchema
)

class TextCorrectionSchema(BaseModel):
    text: str = Field(description="原始文本")

TextCorrectionTool = StructuredTool.from_function(
    func=json_dump(qianfan.Tools.text_correction),
    name="text_correction",
    description="为给定的原始文本进行纠错，支持字词、标点、语法、专名、地址纠错",
    args_schema=TextCorrectionSchema
)

class TextToImageSchema(BaseModel):
    text: str = Field(description="图片的文本描述")

TextToImageTool = StructuredTool.from_function(
    func=lambda text: qianfan.Text2Image().do(text)['data'][0]['b64_image'],
    name="text_to_image",
    description="基于文本描述生成图片",
    args_schema=TextToImageSchema
)

class AddressSchema(BaseModel):
    text: str = Field(description="文本")

AddressTool = StructuredTool.from_function(
    func=lambda text: json.dumps(qianfan.Tools.extract_address(text), ensure_ascii=False),
    name="address",
    description="识别文本中的地址信息，包括省、市、区县、街道、详细地址、姓名、电话等",
    args_schema=AddressSchema
)

In [9]:
full_tools = [TextToImageTool, TTSTool, TextSimilarityTool, OCRTool, TextCorrectionTool, AddressTool, TextTranslateTool]
tool_map = {tool.name: tool for tool in full_tools}
tool_docs = []
for tool in full_tools:
    tool_docs.append(Document(page_content=(tool.name + "\n" + tool.description), metadata={"name": tool.name}))

cached_embedding = CacheBackedEmbeddings.from_bytes_store(
    QianfanEmbeddingsEndpoint(), LocalFileStore("./cache/"), namespace=QianfanEmbeddingsEndpoint().model
)
vector_store = FAISS.from_documents(tool_docs, cached_embedding)
retriever = vector_store.as_retriever()


def run(query):
    relevant_tools = retriever.get_relevant_documents(query)
    relevant_tools = relevant_tools[:3]
    relevant_tools = [tool_map[tool.metadata['name']] for tool in relevant_tools]

    llm_qianfan = QianfanChatEndpoint(model="ERNIE-Bot-4", streaming=False)
    agent_qianfan = CustomAgent(tools=relevant_tools, llm=llm_qianfan)
    agent = AgentExecutor(agent=agent_qianfan, tools=relevant_tools, verbose=True)
    result = agent.run(query)
    print("==========")
    print("Query: {}".format(query))
    print("Tools: {}".format([tool.name for tool in relevant_tools]))
    print("Result: {}".format(result))
    print("==========")
    return result

In [10]:
run("为下面的文本纠错：我想区吃饭")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m{'id': 'as-pwahizr8tx', 'object': 'chat.completion', 'created': 1699344782, 'result': '文本纠错后的正确句子应该是：“我想去吃饭。”\n\n原句中的错别字是“区”，正确的字应该是“去”，表示想要做某件事情的意思。因此，将“区”改为“去”即可纠正该文本的错误。', 'is_truncated': False, 'need_clear_history': False, 'usage': {'prompt_tokens': 460, 'completion_tokens': 54, 'total_tokens': 514}}[0m

[1m> Finished chain.[0m
Query: 为下面的文本纠错：我想区吃饭
Tools: ['text_correction', 'text_translate', 'address']
Result: 文本纠错后的正确句子应该是：“我想去吃饭。”

原句中的错别字是“区”，正确的字应该是“去”，表示想要做某件事情的意思。因此，将“区”改为“去”即可纠正该文本的错误。


'文本纠错后的正确句子应该是：“我想去吃饭。”\n\n原句中的错别字是“区”，正确的字应该是“去”，表示想要做某件事情的意思。因此，将“区”改为“去”即可纠正该文本的错误。'