Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix organize agent's history messages without recalculating tokens #4324

Merged
merged 3 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion api/core/agent/base_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def __init__(self, tenant_id: str,
self.files = application_generate_entity.files
else:
self.files = []
self.query = None
self._current_thoughts: list[PromptMessage] = []

def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> AgentChatAppGenerateEntity:
Expand Down Expand Up @@ -463,7 +465,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
for message in messages:
if message.id == self.message.id:
continue

result.append(self.organize_agent_user_prompt(message))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
Expand Down Expand Up @@ -544,3 +546,4 @@ def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
return UserPromptMessage(content=prompt_message_contents)
else:
return UserPromptMessage(content=message.query)

10 changes: 9 additions & 1 deletion api/core/agent/cot_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ToolPromptMessage,
UserPromptMessage,
)
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
Expand Down Expand Up @@ -373,14 +374,21 @@ def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit])

return message

def _organize_historic_prompt_messages(self) -> list[PromptMessage]:
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpad: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None

self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=self.history_prompt_messages,
memory=self.memory
).get_prompt()

for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
current_scratchpad = AgentScratchpadUnit(
Expand Down
12 changes: 9 additions & 3 deletions api/core/agent/cot_chat_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def _organize_prompt_messages(self) -> list[PromptMessage]:
# organize system prompt
system_message = self._organize_system_prompt()

# organize historic prompt messages
historic_messages = self._historic_prompt_messages

# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
if not agent_scratchpad:
Expand All @@ -57,6 +54,13 @@ def _organize_prompt_messages(self) -> list[PromptMessage]:
query_messages = UserPromptMessage(content=self._query)

if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([
system_message,
query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
])
messages = [
system_message,
*historic_messages,
Expand All @@ -65,6 +69,8 @@ def _organize_prompt_messages(self) -> list[PromptMessage]:
UserPromptMessage(content='continue')
]
else:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([system_message, query_messages])
messages = [system_message, *historic_messages, query_messages]

# join all messages
Expand Down
4 changes: 2 additions & 2 deletions api/core/agent/cot_completion_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def _organize_instruction_prompt(self) -> str:

return system_prompt

def _organize_historic_prompt(self) -> str:
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
"""
Organize historic prompt
"""
historic_prompt_messages = self._historic_prompt_messages
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
historic_prompt = ""

for message in historic_prompt_messages:
Expand Down
69 changes: 36 additions & 33 deletions api/core/agent/fc_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,26 @@
ToolPromptMessage,
UserPromptMessage,
)
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message

logger = logging.getLogger(__name__)

class FunctionCallAgentRunner(BaseAgentRunner):

def run(self,
message: Message, query: str, **kwargs: Any
) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
self.query = query
app_generate_entity = self.application_generate_entity

app_config = self.app_config

prompt_template = app_config.prompt_template.simple_prompt_template or ''
prompt_messages = self.history_prompt_messages
prompt_messages = self._init_system_message(prompt_template, prompt_messages)
prompt_messages = self._organize_user_query(query, prompt_messages)

# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()

Expand Down Expand Up @@ -81,6 +79,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
)

# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
Expand Down Expand Up @@ -203,7 +202,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
else:
assistant_message.content = response

prompt_messages.append(assistant_message)
self._current_thoughts.append(assistant_message)

# save thought
self.save_agent_thought(
Expand Down Expand Up @@ -265,12 +264,14 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
}

tool_responses.append(tool_response)
prompt_messages = self._organize_assistant_message(
tool_call_id=tool_call_id,
tool_call_name=tool_call_name,
tool_response=tool_response['tool_response'],
prompt_messages=prompt_messages,
)
if tool_response['tool_response'] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=tool_response['tool_response'],
tool_call_id=tool_call_id,
name=tool_call_name,
)
)

if len(tool_responses) > 0:
# save agent thought
Expand Down Expand Up @@ -300,8 +301,6 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):

iteration_step += 1

prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)

self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
Expand Down Expand Up @@ -393,24 +392,6 @@ def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = No

return prompt_messages

def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize assistant message
"""
prompt_messages = deepcopy(prompt_messages)

if tool_response is not None:
prompt_messages.append(
ToolPromptMessage(
content=tool_response,
tool_call_id=tool_call_id,
name=tool_call_name,
)
)

return prompt_messages

def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
Expand All @@ -428,4 +409,26 @@ def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]
for content in prompt_message.content
])

return prompt_messages
return prompt_messages

def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, [])

self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory
).get_prompt()

prompt_messages = [
*self.history_prompt_messages,
*query_prompt_messages,
*self._current_thoughts
]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
return prompt_messages
82 changes: 82 additions & 0 deletions api/core/prompt/agent_history_prompt_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Optional, cast

from core.app.entities.app_invoke_entities import (
ModelConfigWithCredentialsEntity,
)
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
PromptMessage,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_transform import PromptTransform


class AgentHistoryPromptTransform(PromptTransform):
"""
History Prompt Transform for Agent App
"""
def __init__(self,
model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage],
history_messages: list[PromptMessage],
memory: Optional[TokenBufferMemory] = None,
):
self.model_config = model_config
self.prompt_messages = prompt_messages
self.history_messages = history_messages
self.memory = memory

def get_prompt(self) -> list[PromptMessage]:
prompt_messages = []
num_system = 0
for prompt_message in self.history_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_messages.append(prompt_message)
num_system += 1

if not self.memory:
return prompt_messages

max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)

model_type_instance = self.model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)

curr_message_tokens = model_type_instance.get_num_tokens(
self.memory.model_instance.model,
self.memory.model_instance.credentials,
self.history_messages
)
if curr_message_tokens <= max_token_limit:
return self.history_messages

# number of prompt has been appended in current message
num_prompt = 0
# append prompt messages in desc order
for prompt_message in self.history_messages[::-1]:
if isinstance(prompt_message, SystemPromptMessage):
continue
prompt_messages.append(prompt_message)
num_prompt += 1
# a message is start with UserPromptMessage
if isinstance(prompt_message, UserPromptMessage):
curr_message_tokens = model_type_instance.get_num_tokens(
self.memory.model_instance.model,
self.memory.model_instance.credentials,
prompt_messages
)
# if current message token is overflow, drop all the prompts in current message and break
if curr_message_tokens > max_token_limit:
prompt_messages = prompt_messages[:-num_prompt]
break
num_prompt = 0
# return prompt messages in asc order
message_prompts = prompt_messages[num_system:]
message_prompts.reverse()

# merge system and message prompt
prompt_messages = prompt_messages[:num_system]
prompt_messages.extend(message_prompts)
return prompt_messages
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from unittest.mock import MagicMock

from core.app.entities.app_invoke_entities import (
ModelConfigWithCredentialsEntity,
)
from core.entities.provider_configuration import ProviderModelBundle
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from models.model import Conversation


def test_get_prompt():
prompt_messages = [
SystemPromptMessage(content='System Template'),
UserPromptMessage(content='User Query'),
]
history_messages = [
SystemPromptMessage(content='System Prompt 1'),
UserPromptMessage(content='User Prompt 1'),
AssistantPromptMessage(content='Assistant Thought 1'),
ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'),
ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'),
SystemPromptMessage(content='System Prompt 2'),
UserPromptMessage(content='User Prompt 2'),
AssistantPromptMessage(content='Assistant Thought 2'),
ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'),
ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'),
UserPromptMessage(content='User Prompt 3'),
AssistantPromptMessage(content='Assistant Thought 3'),
]

# use message number instead of token for testing
def side_effect_get_num_tokens(*args):
return len(args[2])
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)

provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
provider_model_bundle_mock.model_type_instance = large_language_model_mock

model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.model = 'openai'
model_config_mock.credentials = {}
model_config_mock.provider_model_bundle = provider_model_bundle_mock

memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)

transform = AgentHistoryPromptTransform(
model_config=model_config_mock,
prompt_messages=prompt_messages,
history_messages=history_messages,
memory=memory
)

max_token_limit = 5
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
result = transform.get_prompt()

assert len(result) <= max_token_limit
assert len(result) == 4

max_token_limit = 20
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
result = transform.get_prompt()

assert len(result) <= max_token_limit
assert len(result) == 12