From 90a4fd0102ff4ab8a72b653195cb3ef335584720 Mon Sep 17 00:00:00 2001 From: chenyongzhao Date: Sun, 12 May 2024 21:28:37 +0800 Subject: [PATCH 1/3] fix organize agent's history messages without recalculating tokens --- api/core/agent/base_agent_runner.py | 112 ++---------- .../prompt/agent_history_prompt_transform.py | 165 ++++++++++++++++++ 2 files changed, 175 insertions(+), 102 deletions(-) create mode 100644 api/core/prompt/agent_history_prompt_transform.py diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 485633cab1b3d..8fb15f50a879f 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -42,6 +42,7 @@ from extensions.ext_database import db from models.model import Conversation, Message, MessageAgentThought from models.tools import ToolConversationVariables +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform logger = logging.getLogger(__name__) @@ -84,9 +85,14 @@ def __init__(self, tenant_id: str, self.message = message self.user_id = user_id self.memory = memory - self.history_prompt_messages = self.organize_agent_history( - prompt_messages=prompt_messages or [] - ) + self.history_prompt_messages = AgentHistoryPromptTransform( + tenant_id=tenant_id, + app_config=app_config, + model_config=model_config, + message=message, + prompt_messages=prompt_messages or [], + memory=memory + ).get_prompt() self.variables_pool = variables_pool self.db_variables_pool = db_variables self.model_instance = model_instance @@ -445,102 +451,4 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() db.session.close() - - def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - Organize agent history - """ - result = [] - # check if there is a system message in the beginning of the conversation - for prompt_message in prompt_messages: - if isinstance(prompt_message, SystemPromptMessage): - result.append(prompt_message) - - messages: list[Message] = db.session.query(Message).filter( - Message.conversation_id == self.message.conversation_id, - ).order_by(Message.created_at.asc()).all() - - for message in messages: - if message.id == self.message.id: - continue - - result.append(self.organize_agent_user_prompt(message)) - agent_thoughts: list[MessageAgentThought] = message.agent_thoughts - if agent_thoughts: - for agent_thought in agent_thoughts: - tools = agent_thought.tool - if tools: - tools = tools.split(';') - tool_calls: list[AssistantPromptMessage.ToolCall] = [] - tool_call_response: list[ToolPromptMessage] = [] - try: - tool_inputs = json.loads(agent_thought.tool_input) - except Exception as e: - tool_inputs = { tool: {} for tool in tools } - try: - tool_responses = json.loads(agent_thought.observation) - except Exception as e: - tool_responses = { tool: agent_thought.observation for tool in tools } - - for tool in tools: - # generate a uuid for tool call - tool_call_id = str(uuid.uuid4()) - tool_calls.append(AssistantPromptMessage.ToolCall( - id=tool_call_id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool, - arguments=json.dumps(tool_inputs.get(tool, {})), - ) - )) - tool_call_response.append(ToolPromptMessage( - content=tool_responses.get(tool, agent_thought.observation), - name=tool, - tool_call_id=tool_call_id, - )) - - result.extend([ - AssistantPromptMessage( - content=agent_thought.thought, - tool_calls=tool_calls, - ), - *tool_call_response - ]) - if not tools: - result.append(AssistantPromptMessage(content=agent_thought.thought)) - else: - if message.answer: - result.append(AssistantPromptMessage(content=message.answer)) - - db.session.close() - - return result - - def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: - message_file_parser = MessageFileParser( - tenant_id=self.tenant_id, - app_id=self.app_config.app_id, - ) - - files = message.message_files - if files: - file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) - - if file_extra_config: - file_objs = message_file_parser.transform_message_files( - files, - file_extra_config - ) - else: - file_objs = [] - - if not file_objs: - return UserPromptMessage(content=message.query) - else: - prompt_message_contents = [TextPromptMessageContent(data=message.query)] - for file_obj in file_objs: - prompt_message_contents.append(file_obj.prompt_message_content) - - return UserPromptMessage(content=prompt_message_contents) - else: - return UserPromptMessage(content=message.query) + \ No newline at end of file diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py new file mode 100644 index 0000000000000..3a4a37f825188 --- /dev/null +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -0,0 +1,165 @@ +import json +import uuid +from typing import Optional + +from core.app.entities.app_invoke_entities import ( + AgentChatAppGenerateEntity, + ModelConfigWithCredentialsEntity, +) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from extensions.ext_database import db +from core.model_runtime.model_providers import model_provider_factory +from core.model_runtime.entities.model_entities import ModelType +from core.prompt.prompt_transform import PromptTransform +from core.file.message_file_parser import MessageFileParser +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.memory.token_buffer_memory import TokenBufferMemory +from models.model import Conversation, Message, MessageAgentThought + + +class AgentHistoryPromptTransform(PromptTransform): + """ + History Prompt Transform for Agent App + """ + def __init__(self, + tenant_id: str, + app_config: AgentChatAppConfig, + model_config: ModelConfigWithCredentialsEntity, + message: Message, + prompt_messages: Optional[list[PromptMessage]] = None, + memory: Optional[TokenBufferMemory] = None, + ): + self.tenant_id = tenant_id + self.app_config = app_config + self.model_config = model_config + self.message = message + self.prompt_messages = prompt_messages or [] + self.memory = memory + + def get_prompt(self) -> list[PromptMessage]: + prompt_messages = [] + # check if there is a system message in the beginning of the conversation + for prompt_message in self.prompt_messages: + if isinstance(prompt_message, SystemPromptMessage): + prompt_messages.append(prompt_message) + + if not self.memory: + return prompt_messages + + max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config) + + provider_instance = model_provider_factory.get_provider_instance(self.memory.model_instance.provider) + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + + messages: list[Message] = db.session.query(Message).filter( + Message.conversation_id == self.memory.conversation.id, + ).order_by(Message.created_at.desc()).all() + + for message in messages: + if message.id == self.message.id: + continue + + prompt_messages.append(self._organize_agent_user_prompt(message)) + # number of appended prompts + num_prompt = 1 + agent_thoughts: list[MessageAgentThought] = message.agent_thoughts + if agent_thoughts: + for agent_thought in agent_thoughts: + tools = agent_thought.tool + if tools: + tools = tools.split(';') + tool_calls: list[AssistantPromptMessage.ToolCall] = [] + tool_call_response: list[ToolPromptMessage] = [] + try: + tool_inputs = json.loads(agent_thought.tool_input) + except Exception as e: + tool_inputs = {tool: {} for tool in tools} + try: + tool_responses = json.loads(agent_thought.observation) + except Exception as e: + tool_responses = {tool: agent_thought.observation for tool in tools} + + for tool in tools: + # generate a uuid for tool call + tool_call_id = str(uuid.uuid4()) + tool_calls.append(AssistantPromptMessage.ToolCall( + id=tool_call_id, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool, + arguments=json.dumps(tool_inputs.get(tool, {})), + ) + )) + tool_call_response.append(ToolPromptMessage( + content=tool_responses.get(tool, agent_thought.observation), + name=tool, + tool_call_id=tool_call_id, + )) + + prompt_messages.extend([ + AssistantPromptMessage( + content=agent_thought.thought, + tool_calls=tool_calls, + ), + *tool_call_response + ]) + num_prompt += 1 + len(tool_call_response) + if not tools: + prompt_messages.append(AssistantPromptMessage(content=agent_thought.thought)) + num_prompt += 1 + else: + if message.answer: + prompt_messages.append(AssistantPromptMessage(content=message.answer)) + num_prompt += 1 + + curr_message_tokens = model_type_instance.get_num_tokens( + self.memory.model_instance.model, + self.memory.model_instance.credentials, + prompt_messages + ) + # If tokens is overflow, drop all appended prompts in current message and break + if curr_message_tokens > max_token_limit: + prompt_messages = prompt_messages[:-num_prompt] + break + + db.session.close() + + return prompt_messages + + def _organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: + message_file_parser = MessageFileParser( + tenant_id=self.tenant_id, + app_id=self.app_config.app_id, + ) + + files = message.message_files + if files: + file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + + if file_extra_config: + file_objs = message_file_parser.transform_message_files( + files, + file_extra_config + ) + else: + file_objs = [] + + if not file_objs: + return UserPromptMessage(content=message.query) + else: + prompt_message_contents = [TextPromptMessageContent(data=message.query)] + for file_obj in file_objs: + prompt_message_contents.append(file_obj.prompt_message_content) + + return UserPromptMessage(content=prompt_message_contents) + else: + return UserPromptMessage(content=message.query) From 1c1cdaef7fd3f377511fa1277b71400816bef306 Mon Sep 17 00:00:00 2001 From: chenyongzhao Date: Mon, 20 May 2024 23:51:04 +0800 Subject: [PATCH 2/3] fix recalculate token before every iteration --- api/core/agent/base_agent_runner.py | 113 +++++++++++- api/core/agent/cot_agent_runner.py | 10 +- api/core/agent/cot_chat_agent_runner.py | 12 +- api/core/agent/cot_completion_agent_runner.py | 4 +- api/core/agent/fc_agent_runner.py | 72 ++++---- .../prompt/agent_history_prompt_transform.py | 173 +++++------------- .../test_agent_history_prompt_transform.py | 77 ++++++++ 7 files changed, 284 insertions(+), 177 deletions(-) create mode 100644 api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 8fb15f50a879f..6d6f326dc5833 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -42,7 +42,6 @@ from extensions.ext_database import db from models.model import Conversation, Message, MessageAgentThought from models.tools import ToolConversationVariables -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform logger = logging.getLogger(__name__) @@ -85,14 +84,9 @@ def __init__(self, tenant_id: str, self.message = message self.user_id = user_id self.memory = memory - self.history_prompt_messages = AgentHistoryPromptTransform( - tenant_id=tenant_id, - app_config=app_config, - model_config=model_config, - message=message, - prompt_messages=prompt_messages or [], - memory=memory - ).get_prompt() + self.history_prompt_messages = self.organize_agent_history( + prompt_messages=prompt_messages or [] + ) self.variables_pool = variables_pool self.db_variables_pool = db_variables self.model_instance = model_instance @@ -451,4 +445,103 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() db.session.close() - \ No newline at end of file + + def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Organize agent history + """ + result = [] + # check if there is a system message in the beginning of the conversation + for prompt_message in prompt_messages: + if isinstance(prompt_message, SystemPromptMessage): + result.append(prompt_message) + + messages: list[Message] = db.session.query(Message).filter( + Message.conversation_id == self.message.conversation_id, + ).order_by(Message.created_at.asc()).all() + + for message in messages: + if message.id == self.message.id: + continue + + result.append(self.organize_agent_user_prompt(message)) + agent_thoughts: list[MessageAgentThought] = message.agent_thoughts + if agent_thoughts: + for agent_thought in agent_thoughts: + tools = agent_thought.tool + if tools: + tools = tools.split(';') + tool_calls: list[AssistantPromptMessage.ToolCall] = [] + tool_call_response: list[ToolPromptMessage] = [] + try: + tool_inputs = json.loads(agent_thought.tool_input) + except Exception as e: + tool_inputs = { tool: {} for tool in tools } + try: + tool_responses = json.loads(agent_thought.observation) + except Exception as e: + tool_responses = { tool: agent_thought.observation for tool in tools } + + for tool in tools: + # generate a uuid for tool call + tool_call_id = str(uuid.uuid4()) + tool_calls.append(AssistantPromptMessage.ToolCall( + id=tool_call_id, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool, + arguments=json.dumps(tool_inputs.get(tool, {})), + ) + )) + tool_call_response.append(ToolPromptMessage( + content=tool_responses.get(tool, agent_thought.observation), + name=tool, + tool_call_id=tool_call_id, + )) + + result.extend([ + AssistantPromptMessage( + content=agent_thought.thought, + tool_calls=tool_calls, + ), + *tool_call_response + ]) + if not tools: + result.append(AssistantPromptMessage(content=agent_thought.thought)) + else: + if message.answer: + result.append(AssistantPromptMessage(content=message.answer)) + + db.session.close() + + return result + + def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: + message_file_parser = MessageFileParser( + tenant_id=self.tenant_id, + app_id=self.app_config.app_id, + ) + + files = message.message_files + if files: + file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + + if file_extra_config: + file_objs = message_file_parser.transform_message_files( + files, + file_extra_config + ) + else: + file_objs = [] + + if not file_objs: + return UserPromptMessage(content=message.query) + else: + prompt_message_contents = [TextPromptMessageContent(data=message.query)] + for file_obj in file_objs: + prompt_message_contents.append(file_obj.prompt_message_content) + + return UserPromptMessage(content=prompt_message_contents) + else: + return UserPromptMessage(content=message.query) + \ No newline at end of file diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 12554f42b31ae..8d25fba01dae9 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -15,6 +15,7 @@ ToolPromptMessage, UserPromptMessage, ) +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool.tool import Tool from core.tools.tool_engine import ToolEngine @@ -373,7 +374,7 @@ def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) return message - def _organize_historic_prompt_messages(self) -> list[PromptMessage]: + def _organize_historic_prompt_messages(self, prompt_messages: list[PromptMessage] = []) -> list[PromptMessage]: """ organize historic prompt messages """ @@ -381,6 +382,13 @@ def _organize_historic_prompt_messages(self) -> list[PromptMessage]: scratchpad: list[AgentScratchpadUnit] = [] current_scratchpad: AgentScratchpadUnit = None + self.history_prompt_messages = AgentHistoryPromptTransform( + model_config=self.model_config, + prompt_messages=prompt_messages, + history_messages=self.history_prompt_messages, + memory=self.memory + ).get_prompt() + for message in self.history_prompt_messages: if isinstance(message, AssistantPromptMessage): current_scratchpad = AgentScratchpadUnit( diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index a904f3e64175c..e8b05373ab1d9 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -32,9 +32,6 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: # organize system prompt system_message = self._organize_system_prompt() - # organize historic prompt messages - historic_messages = self._historic_prompt_messages - # organize current assistant messages agent_scratchpad = self._agent_scratchpad if not agent_scratchpad: @@ -57,6 +54,13 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: query_messages = UserPromptMessage(content=self._query) if assistant_messages: + # organize historic prompt messages + historic_messages = self._organize_historic_prompt_messages([ + system_message, + query_messages, + *assistant_messages, + UserPromptMessage(content='continue') + ]) messages = [ system_message, *historic_messages, @@ -65,6 +69,8 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: UserPromptMessage(content='continue') ] else: + # organize historic prompt messages + historic_messages = self._organize_historic_prompt_messages([system_message, query_messages]) messages = [system_message, *historic_messages, query_messages] # join all messages diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 3f0298d5a3639..a43d60cd83b12 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -19,11 +19,11 @@ def _organize_instruction_prompt(self) -> str: return system_prompt - def _organize_historic_prompt(self) -> str: + def _organize_historic_prompt(self, prompt_message: list[PromptMessage] = []) -> str: """ Organize historic prompt """ - historic_prompt_messages = self._historic_prompt_messages + historic_prompt_messages = self._organize_historic_prompt_messages(prompt_message) historic_prompt = "" for message in historic_prompt_messages: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index a9b3a80073446..1750bd60b8d84 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -17,6 +17,7 @@ ToolPromptMessage, UserPromptMessage, ) +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine from models.model import Message @@ -24,21 +25,21 @@ logger = logging.getLogger(__name__) class FunctionCallAgentRunner(BaseAgentRunner): + _query: str = None + _current_thoughts: list[PromptMessage] = [] + + def run(self, message: Message, query: str, **kwargs: Any ) -> Generator[LLMResultChunk, None, None]: """ Run FunctionCall agent application """ + self._query = query app_generate_entity = self.application_generate_entity app_config = self.app_config - prompt_template = app_config.prompt_template.simple_prompt_template or '' - prompt_messages = self.history_prompt_messages - prompt_messages = self._init_system_message(prompt_template, prompt_messages) - prompt_messages = self._organize_user_query(query, prompt_messages) - # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() @@ -81,6 +82,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): ) # recalc llm max tokens + prompt_messages = self._organize_prompt_messages() self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( @@ -203,7 +205,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): else: assistant_message.content = response - prompt_messages.append(assistant_message) + self._current_thoughts.append(assistant_message) # save thought self.save_agent_thought( @@ -265,12 +267,14 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): } tool_responses.append(tool_response) - prompt_messages = self._organize_assistant_message( - tool_call_id=tool_call_id, - tool_call_name=tool_call_name, - tool_response=tool_response['tool_response'], - prompt_messages=prompt_messages, - ) + if tool_response['tool_response'] is not None: + self._current_thoughts.append( + ToolPromptMessage( + content=tool_response['tool_response'], + tool_call_id=tool_call_id, + name=tool_call_name, + ) + ) if len(tool_responses) > 0: # save agent thought @@ -300,8 +304,6 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): iteration_step += 1 - prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) - self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( @@ -393,24 +395,6 @@ def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = No return prompt_messages - def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, - prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: - """ - Organize assistant message - """ - prompt_messages = deepcopy(prompt_messages) - - if tool_response is not None: - prompt_messages.append( - ToolPromptMessage( - content=tool_response, - tool_call_id=tool_call_id, - name=tool_call_name, - ) - ) - - return prompt_messages - def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ As for now, gpt supports both fc and vision at the first iteration. @@ -428,4 +412,26 @@ def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage] for content in prompt_message.content ]) - return prompt_messages \ No newline at end of file + return prompt_messages + + def _organize_prompt_messages(self): + prompt_template = self.app_config.prompt_template.simple_prompt_template or '' + self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) + query_prompt_messages = self._organize_user_query(self._query, []) + + self.history_prompt_messages = AgentHistoryPromptTransform( + model_config=self.model_config, + prompt_messages=[*query_prompt_messages, *self._current_thoughts], + history_messages=self.history_prompt_messages, + memory=self.memory + ).get_prompt() + + prompt_messages = [ + *self.history_prompt_messages, + *query_prompt_messages, + *self._current_thoughts + ] + if len(self._current_thoughts) != 0: + # clear messages after the first iteration + prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) + return prompt_messages diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 3a4a37f825188..af0075ea9154f 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,29 +1,16 @@ -import json -import uuid -from typing import Optional +from typing import Optional, cast from core.app.entities.app_invoke_entities import ( - AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity, ) +from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, PromptMessage, - PromptMessageTool, SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, UserPromptMessage, ) -from extensions.ext_database import db -from core.model_runtime.model_providers import model_provider_factory -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_transform import PromptTransform -from core.file.message_file_parser import MessageFileParser -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig -from core.memory.token_buffer_memory import TokenBufferMemory -from models.model import Conversation, Message, MessageAgentThought class AgentHistoryPromptTransform(PromptTransform): @@ -31,135 +18,65 @@ class AgentHistoryPromptTransform(PromptTransform): History Prompt Transform for Agent App """ def __init__(self, - tenant_id: str, - app_config: AgentChatAppConfig, model_config: ModelConfigWithCredentialsEntity, - message: Message, - prompt_messages: Optional[list[PromptMessage]] = None, + prompt_messages: list[PromptMessage], + history_messages: list[PromptMessage], memory: Optional[TokenBufferMemory] = None, ): - self.tenant_id = tenant_id - self.app_config = app_config self.model_config = model_config - self.message = message - self.prompt_messages = prompt_messages or [] + self.prompt_messages = prompt_messages + self.history_messages = history_messages self.memory = memory def get_prompt(self) -> list[PromptMessage]: prompt_messages = [] - # check if there is a system message in the beginning of the conversation - for prompt_message in self.prompt_messages: + num_system = 0 + for prompt_message in self.history_messages: if isinstance(prompt_message, SystemPromptMessage): prompt_messages.append(prompt_message) + num_system += 1 if not self.memory: return prompt_messages max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config) - provider_instance = model_provider_factory.get_provider_instance(self.memory.model_instance.provider) - model_type_instance = provider_instance.get_model_instance(ModelType.LLM) - - messages: list[Message] = db.session.query(Message).filter( - Message.conversation_id == self.memory.conversation.id, - ).order_by(Message.created_at.desc()).all() - - for message in messages: - if message.id == self.message.id: - continue - - prompt_messages.append(self._organize_agent_user_prompt(message)) - # number of appended prompts - num_prompt = 1 - agent_thoughts: list[MessageAgentThought] = message.agent_thoughts - if agent_thoughts: - for agent_thought in agent_thoughts: - tools = agent_thought.tool - if tools: - tools = tools.split(';') - tool_calls: list[AssistantPromptMessage.ToolCall] = [] - tool_call_response: list[ToolPromptMessage] = [] - try: - tool_inputs = json.loads(agent_thought.tool_input) - except Exception as e: - tool_inputs = {tool: {} for tool in tools} - try: - tool_responses = json.loads(agent_thought.observation) - except Exception as e: - tool_responses = {tool: agent_thought.observation for tool in tools} - - for tool in tools: - # generate a uuid for tool call - tool_call_id = str(uuid.uuid4()) - tool_calls.append(AssistantPromptMessage.ToolCall( - id=tool_call_id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool, - arguments=json.dumps(tool_inputs.get(tool, {})), - ) - )) - tool_call_response.append(ToolPromptMessage( - content=tool_responses.get(tool, agent_thought.observation), - name=tool, - tool_call_id=tool_call_id, - )) - - prompt_messages.extend([ - AssistantPromptMessage( - content=agent_thought.thought, - tool_calls=tool_calls, - ), - *tool_call_response - ]) - num_prompt += 1 + len(tool_call_response) - if not tools: - prompt_messages.append(AssistantPromptMessage(content=agent_thought.thought)) - num_prompt += 1 - else: - if message.answer: - prompt_messages.append(AssistantPromptMessage(content=message.answer)) - num_prompt += 1 - - curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, - self.memory.model_instance.credentials, - prompt_messages - ) - # If tokens is overflow, drop all appended prompts in current message and break - if curr_message_tokens > max_token_limit: - prompt_messages = prompt_messages[:-num_prompt] - break + model_type_instance = self.model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) - db.session.close() - - return prompt_messages - - def _organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: - message_file_parser = MessageFileParser( - tenant_id=self.tenant_id, - app_id=self.app_config.app_id, + curr_message_tokens = model_type_instance.get_num_tokens( + self.memory.model_instance.model, + self.memory.model_instance.credentials, + self.history_messages ) + if curr_message_tokens <= max_token_limit: + return self.history_messages - files = message.message_files - if files: - file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) - - if file_extra_config: - file_objs = message_file_parser.transform_message_files( - files, - file_extra_config + # number of prompt has been appended in current message + num_prompt = 0 + # append prompt messages in desc order + for prompt_message in self.history_messages[::-1]: + if isinstance(prompt_message, SystemPromptMessage): + continue + prompt_messages.append(prompt_message) + num_prompt += 1 + # a message is start with UserPromptMessage + if isinstance(prompt_message, UserPromptMessage): + curr_message_tokens = model_type_instance.get_num_tokens( + self.memory.model_instance.model, + self.memory.model_instance.credentials, + prompt_messages ) - else: - file_objs = [] - - if not file_objs: - return UserPromptMessage(content=message.query) - else: - prompt_message_contents = [TextPromptMessageContent(data=message.query)] - for file_obj in file_objs: - prompt_message_contents.append(file_obj.prompt_message_content) - - return UserPromptMessage(content=prompt_message_contents) - else: - return UserPromptMessage(content=message.query) + # if current message token is overflow, drop all the prompts in current message and break + if curr_message_tokens > max_token_limit: + prompt_messages = prompt_messages[:-num_prompt] + break + num_prompt = 0 + # return prompt messages in asc order + message_prompts = prompt_messages[num_system:] + message_prompts.reverse() + + # merge system and message prompt + prompt_messages = prompt_messages[:num_system] + prompt_messages.extend(message_prompts) + return prompt_messages diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py new file mode 100644 index 0000000000000..9de268d762474 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -0,0 +1,77 @@ +from unittest.mock import MagicMock + +from core.app.entities.app_invoke_entities import ( + ModelConfigWithCredentialsEntity, +) +from core.entities.provider_configuration import ProviderModelBundle +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from models.model import Conversation + + +def test_get_prompt(): + prompt_messages = [ + SystemPromptMessage(content='System Template'), + UserPromptMessage(content='User Query'), + ] + history_messages = [ + SystemPromptMessage(content='System Prompt 1'), + UserPromptMessage(content='User Prompt 1'), + AssistantPromptMessage(content='Assistant Thought 1'), + ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'), + ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'), + SystemPromptMessage(content='System Prompt 2'), + UserPromptMessage(content='User Prompt 2'), + AssistantPromptMessage(content='Assistant Thought 2'), + ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'), + ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'), + UserPromptMessage(content='User Prompt 3'), + AssistantPromptMessage(content='Assistant Thought 3'), + ] + + # use message number instead of token for testing + def side_effect_get_num_tokens(*args): + return len(args[2]) + large_language_model_mock = MagicMock(spec=LargeLanguageModel) + large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens) + + provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) + provider_model_bundle_mock.model_type_instance = large_language_model_mock + + model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_mock.model = 'openai' + model_config_mock.credentials = {} + model_config_mock.provider_model_bundle = provider_model_bundle_mock + + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + transform = AgentHistoryPromptTransform( + model_config=model_config_mock, + prompt_messages=prompt_messages, + history_messages=history_messages, + memory=memory + ) + + max_token_limit = 5 + transform._calculate_rest_token = MagicMock(return_value=max_token_limit) + result = transform.get_prompt() + + assert len(result) <= max_token_limit + assert len(result) == 4 + + max_token_limit = 20 + transform._calculate_rest_token = MagicMock(return_value=max_token_limit) + result = transform.get_prompt() + + assert len(result) <= max_token_limit + assert len(result) == 12 From 0870ab6b35aa9bb01d07c5da98d8416f3e3d89c1 Mon Sep 17 00:00:00 2001 From: chenyongzhao Date: Sat, 25 May 2024 19:20:28 +0800 Subject: [PATCH 3/3] fix optimize code --- api/core/agent/base_agent_runner.py | 2 ++ api/core/agent/cot_agent_runner.py | 4 ++-- api/core/agent/cot_completion_agent_runner.py | 4 ++-- api/core/agent/fc_agent_runner.py | 7 ++----- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 6d6f326dc5833..8baeeccd5ea1d 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -128,6 +128,8 @@ def __init__(self, tenant_id: str, self.files = application_generate_entity.files else: self.files = [] + self.query = None + self._current_thoughts: list[PromptMessage] = [] def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ -> AgentChatAppGenerateEntity: diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 8d25fba01dae9..3d3209b5509d2 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -374,7 +374,7 @@ def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) return message - def _organize_historic_prompt_messages(self, prompt_messages: list[PromptMessage] = []) -> list[PromptMessage]: + def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]: """ organize historic prompt messages """ @@ -384,7 +384,7 @@ def _organize_historic_prompt_messages(self, prompt_messages: list[PromptMessage self.history_prompt_messages = AgentHistoryPromptTransform( model_config=self.model_config, - prompt_messages=prompt_messages, + prompt_messages=current_session_messages or [], history_messages=self.history_prompt_messages, memory=self.memory ).get_prompt() diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index a43d60cd83b12..9e6eb54f4fe51 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -19,11 +19,11 @@ def _organize_instruction_prompt(self) -> str: return system_prompt - def _organize_historic_prompt(self, prompt_message: list[PromptMessage] = []) -> str: + def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str: """ Organize historic prompt """ - historic_prompt_messages = self._organize_historic_prompt_messages(prompt_message) + historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages) historic_prompt = "" for message in historic_prompt_messages: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 1750bd60b8d84..d416a319a4022 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -25,9 +25,6 @@ logger = logging.getLogger(__name__) class FunctionCallAgentRunner(BaseAgentRunner): - _query: str = None - _current_thoughts: list[PromptMessage] = [] - def run(self, message: Message, query: str, **kwargs: Any @@ -35,7 +32,7 @@ def run(self, """ Run FunctionCall agent application """ - self._query = query + self.query = query app_generate_entity = self.application_generate_entity app_config = self.app_config @@ -417,7 +414,7 @@ def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage] def _organize_prompt_messages(self): prompt_template = self.app_config.prompt_template.simple_prompt_template or '' self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) - query_prompt_messages = self._organize_user_query(self._query, []) + query_prompt_messages = self._organize_user_query(self.query, []) self.history_prompt_messages = AgentHistoryPromptTransform( model_config=self.model_config,