Skip to content

Commit

Permalink
fix organize agent's history messages without recalculating tokens (l…
Browse files Browse the repository at this point in the history
…anggenius#4324)

Co-authored-by: chenyongzhao <chenyz@mama.cn>
  • Loading branch information
2 people authored and dengpeng committed Jun 16, 2024
1 parent 268cd05 commit 61a19a4
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 40 deletions.
5 changes: 4 additions & 1 deletion api/core/agent/base_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def __init__(self, tenant_id: str,
self.files = application_generate_entity.files
else:
self.files = []
self.query = None
self._current_thoughts: list[PromptMessage] = []

def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> AgentChatAppGenerateEntity:
Expand Down Expand Up @@ -464,7 +466,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
for message in messages:
if message.id == self.message.id:
continue

result.append(self.organize_agent_user_prompt(message))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
Expand Down Expand Up @@ -545,3 +547,4 @@ def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
return UserPromptMessage(content=prompt_message_contents)
else:
return UserPromptMessage(content=message.query)

10 changes: 9 additions & 1 deletion api/core/agent/cot_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ToolPromptMessage,
UserPromptMessage,
)
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
Expand Down Expand Up @@ -373,14 +374,21 @@ def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit])

return message

def _organize_historic_prompt_messages(self) -> list[PromptMessage]:
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpad: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None

self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=self.history_prompt_messages,
memory=self.memory
).get_prompt()

for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
current_scratchpad = AgentScratchpadUnit(
Expand Down
12 changes: 9 additions & 3 deletions api/core/agent/cot_chat_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def _organize_prompt_messages(self) -> list[PromptMessage]:
# organize system prompt
system_message = self._organize_system_prompt()

# organize historic prompt messages
historic_messages = self._historic_prompt_messages

# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
if not agent_scratchpad:
Expand All @@ -57,6 +54,13 @@ def _organize_prompt_messages(self) -> list[PromptMessage]:
query_messages = UserPromptMessage(content=self._query)

if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([
system_message,
query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
])
messages = [
system_message,
*historic_messages,
Expand All @@ -65,6 +69,8 @@ def _organize_prompt_messages(self) -> list[PromptMessage]:
UserPromptMessage(content='continue')
]
else:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([system_message, query_messages])
messages = [system_message, *historic_messages, query_messages]

# join all messages
Expand Down
4 changes: 2 additions & 2 deletions api/core/agent/cot_completion_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def _organize_instruction_prompt(self) -> str:

return system_prompt

def _organize_historic_prompt(self) -> str:
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
"""
Organize historic prompt
"""
historic_prompt_messages = self._historic_prompt_messages
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
historic_prompt = ""

for message in historic_prompt_messages:
Expand Down
69 changes: 36 additions & 33 deletions api/core/agent/fc_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,26 @@
ToolPromptMessage,
UserPromptMessage,
)
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message

logger = logging.getLogger(__name__)

class FunctionCallAgentRunner(BaseAgentRunner):

def run(self,
message: Message, query: str, **kwargs: Any
) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
self.query = query
app_generate_entity = self.application_generate_entity

app_config = self.app_config

prompt_template = app_config.prompt_template.simple_prompt_template or ''
prompt_messages = self.history_prompt_messages
prompt_messages = self._init_system_message(prompt_template, prompt_messages)
prompt_messages = self._organize_user_query(query, prompt_messages)

# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()

Expand Down Expand Up @@ -81,6 +79,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
)

# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
Expand Down Expand Up @@ -203,7 +202,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
else:
assistant_message.content = response

prompt_messages.append(assistant_message)
self._current_thoughts.append(assistant_message)

# save thought
self.save_agent_thought(
Expand Down Expand Up @@ -265,12 +264,14 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
}

tool_responses.append(tool_response)
prompt_messages = self._organize_assistant_message(
tool_call_id=tool_call_id,
tool_call_name=tool_call_name,
tool_response=tool_response['tool_response'],
prompt_messages=prompt_messages,
)
if tool_response['tool_response'] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=tool_response['tool_response'],
tool_call_id=tool_call_id,
name=tool_call_name,
)
)

if len(tool_responses) > 0:
# save agent thought
Expand Down Expand Up @@ -300,8 +301,6 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):

iteration_step += 1

prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)

self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
Expand Down Expand Up @@ -393,24 +392,6 @@ def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = No

return prompt_messages

def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize assistant message
"""
prompt_messages = deepcopy(prompt_messages)

if tool_response is not None:
prompt_messages.append(
ToolPromptMessage(
content=tool_response,
tool_call_id=tool_call_id,
name=tool_call_name,
)
)

return prompt_messages

def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
Expand All @@ -428,4 +409,26 @@ def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]
for content in prompt_message.content
])

return prompt_messages
return prompt_messages

def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, [])

self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory
).get_prompt()

prompt_messages = [
*self.history_prompt_messages,
*query_prompt_messages,
*self._current_thoughts
]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
return prompt_messages
82 changes: 82 additions & 0 deletions api/core/prompt/agent_history_prompt_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Optional, cast

from core.app.entities.app_invoke_entities import (
ModelConfigWithCredentialsEntity,
)
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
PromptMessage,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_transform import PromptTransform


class AgentHistoryPromptTransform(PromptTransform):
"""
History Prompt Transform for Agent App
"""
def __init__(self,
model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage],
history_messages: list[PromptMessage],
memory: Optional[TokenBufferMemory] = None,
):
self.model_config = model_config
self.prompt_messages = prompt_messages
self.history_messages = history_messages
self.memory = memory

def get_prompt(self) -> list[PromptMessage]:
prompt_messages = []
num_system = 0
for prompt_message in self.history_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_messages.append(prompt_message)
num_system += 1

if not self.memory:
return prompt_messages

max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)

model_type_instance = self.model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)

curr_message_tokens = model_type_instance.get_num_tokens(
self.memory.model_instance.model,
self.memory.model_instance.credentials,
self.history_messages
)
if curr_message_tokens <= max_token_limit:
return self.history_messages

# number of prompt has been appended in current message
num_prompt = 0
# append prompt messages in desc order
for prompt_message in self.history_messages[::-1]:
if isinstance(prompt_message, SystemPromptMessage):
continue
prompt_messages.append(prompt_message)
num_prompt += 1
# a message is start with UserPromptMessage
if isinstance(prompt_message, UserPromptMessage):
curr_message_tokens = model_type_instance.get_num_tokens(
self.memory.model_instance.model,
self.memory.model_instance.credentials,
prompt_messages
)
# if current message token is overflow, drop all the prompts in current message and break
if curr_message_tokens > max_token_limit:
prompt_messages = prompt_messages[:-num_prompt]
break
num_prompt = 0
# return prompt messages in asc order
message_prompts = prompt_messages[num_system:]
message_prompts.reverse()

# merge system and message prompt
prompt_messages = prompt_messages[:num_system]
prompt_messages.extend(message_prompts)
return prompt_messages
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from unittest.mock import MagicMock

from core.app.entities.app_invoke_entities import (
ModelConfigWithCredentialsEntity,
)
from core.entities.provider_configuration import ProviderModelBundle
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from models.model import Conversation


def test_get_prompt():
prompt_messages = [
SystemPromptMessage(content='System Template'),
UserPromptMessage(content='User Query'),
]
history_messages = [
SystemPromptMessage(content='System Prompt 1'),
UserPromptMessage(content='User Prompt 1'),
AssistantPromptMessage(content='Assistant Thought 1'),
ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'),
ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'),
SystemPromptMessage(content='System Prompt 2'),
UserPromptMessage(content='User Prompt 2'),
AssistantPromptMessage(content='Assistant Thought 2'),
ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'),
ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'),
UserPromptMessage(content='User Prompt 3'),
AssistantPromptMessage(content='Assistant Thought 3'),
]

# use message number instead of token for testing
def side_effect_get_num_tokens(*args):
return len(args[2])
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)

provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
provider_model_bundle_mock.model_type_instance = large_language_model_mock

model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.model = 'openai'
model_config_mock.credentials = {}
model_config_mock.provider_model_bundle = provider_model_bundle_mock

memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)

transform = AgentHistoryPromptTransform(
model_config=model_config_mock,
prompt_messages=prompt_messages,
history_messages=history_messages,
memory=memory
)

max_token_limit = 5
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
result = transform.get_prompt()

assert len(result) <= max_token_limit
assert len(result) == 4

max_token_limit = 20
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
result = transform.get_prompt()

assert len(result) <= max_token_limit
assert len(result) == 12

0 comments on commit 61a19a4

Please sign in to comment.