forked from langgenius/dify
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix organize agent's history messages without recalculating tokens (l…
…anggenius#4324) Co-authored-by: chenyongzhao <chenyz@mama.cn>
- Loading branch information
Showing
7 changed files
with
219 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from typing import Optional, cast | ||
|
||
from core.app.entities.app_invoke_entities import ( | ||
ModelConfigWithCredentialsEntity, | ||
) | ||
from core.memory.token_buffer_memory import TokenBufferMemory | ||
from core.model_runtime.entities.message_entities import ( | ||
PromptMessage, | ||
SystemPromptMessage, | ||
UserPromptMessage, | ||
) | ||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||
from core.prompt.prompt_transform import PromptTransform | ||
|
||
|
||
class AgentHistoryPromptTransform(PromptTransform): | ||
""" | ||
History Prompt Transform for Agent App | ||
""" | ||
def __init__(self, | ||
model_config: ModelConfigWithCredentialsEntity, | ||
prompt_messages: list[PromptMessage], | ||
history_messages: list[PromptMessage], | ||
memory: Optional[TokenBufferMemory] = None, | ||
): | ||
self.model_config = model_config | ||
self.prompt_messages = prompt_messages | ||
self.history_messages = history_messages | ||
self.memory = memory | ||
|
||
def get_prompt(self) -> list[PromptMessage]: | ||
prompt_messages = [] | ||
num_system = 0 | ||
for prompt_message in self.history_messages: | ||
if isinstance(prompt_message, SystemPromptMessage): | ||
prompt_messages.append(prompt_message) | ||
num_system += 1 | ||
|
||
if not self.memory: | ||
return prompt_messages | ||
|
||
max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config) | ||
|
||
model_type_instance = self.model_config.provider_model_bundle.model_type_instance | ||
model_type_instance = cast(LargeLanguageModel, model_type_instance) | ||
|
||
curr_message_tokens = model_type_instance.get_num_tokens( | ||
self.memory.model_instance.model, | ||
self.memory.model_instance.credentials, | ||
self.history_messages | ||
) | ||
if curr_message_tokens <= max_token_limit: | ||
return self.history_messages | ||
|
||
# number of prompt has been appended in current message | ||
num_prompt = 0 | ||
# append prompt messages in desc order | ||
for prompt_message in self.history_messages[::-1]: | ||
if isinstance(prompt_message, SystemPromptMessage): | ||
continue | ||
prompt_messages.append(prompt_message) | ||
num_prompt += 1 | ||
# a message is start with UserPromptMessage | ||
if isinstance(prompt_message, UserPromptMessage): | ||
curr_message_tokens = model_type_instance.get_num_tokens( | ||
self.memory.model_instance.model, | ||
self.memory.model_instance.credentials, | ||
prompt_messages | ||
) | ||
# if current message token is overflow, drop all the prompts in current message and break | ||
if curr_message_tokens > max_token_limit: | ||
prompt_messages = prompt_messages[:-num_prompt] | ||
break | ||
num_prompt = 0 | ||
# return prompt messages in asc order | ||
message_prompts = prompt_messages[num_system:] | ||
message_prompts.reverse() | ||
|
||
# merge system and message prompt | ||
prompt_messages = prompt_messages[:num_system] | ||
prompt_messages.extend(message_prompts) | ||
return prompt_messages |
77 changes: 77 additions & 0 deletions
77
api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from unittest.mock import MagicMock | ||
|
||
from core.app.entities.app_invoke_entities import ( | ||
ModelConfigWithCredentialsEntity, | ||
) | ||
from core.entities.provider_configuration import ProviderModelBundle | ||
from core.memory.token_buffer_memory import TokenBufferMemory | ||
from core.model_runtime.entities.message_entities import ( | ||
AssistantPromptMessage, | ||
SystemPromptMessage, | ||
ToolPromptMessage, | ||
UserPromptMessage, | ||
) | ||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | ||
from models.model import Conversation | ||
|
||
|
||
def test_get_prompt(): | ||
prompt_messages = [ | ||
SystemPromptMessage(content='System Template'), | ||
UserPromptMessage(content='User Query'), | ||
] | ||
history_messages = [ | ||
SystemPromptMessage(content='System Prompt 1'), | ||
UserPromptMessage(content='User Prompt 1'), | ||
AssistantPromptMessage(content='Assistant Thought 1'), | ||
ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'), | ||
ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'), | ||
SystemPromptMessage(content='System Prompt 2'), | ||
UserPromptMessage(content='User Prompt 2'), | ||
AssistantPromptMessage(content='Assistant Thought 2'), | ||
ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'), | ||
ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'), | ||
UserPromptMessage(content='User Prompt 3'), | ||
AssistantPromptMessage(content='Assistant Thought 3'), | ||
] | ||
|
||
# use message number instead of token for testing | ||
def side_effect_get_num_tokens(*args): | ||
return len(args[2]) | ||
large_language_model_mock = MagicMock(spec=LargeLanguageModel) | ||
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens) | ||
|
||
provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) | ||
provider_model_bundle_mock.model_type_instance = large_language_model_mock | ||
|
||
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) | ||
model_config_mock.model = 'openai' | ||
model_config_mock.credentials = {} | ||
model_config_mock.provider_model_bundle = provider_model_bundle_mock | ||
|
||
memory = TokenBufferMemory( | ||
conversation=Conversation(), | ||
model_instance=model_config_mock | ||
) | ||
|
||
transform = AgentHistoryPromptTransform( | ||
model_config=model_config_mock, | ||
prompt_messages=prompt_messages, | ||
history_messages=history_messages, | ||
memory=memory | ||
) | ||
|
||
max_token_limit = 5 | ||
transform._calculate_rest_token = MagicMock(return_value=max_token_limit) | ||
result = transform.get_prompt() | ||
|
||
assert len(result) <= max_token_limit | ||
assert len(result) == 4 | ||
|
||
max_token_limit = 20 | ||
transform._calculate_rest_token = MagicMock(return_value=max_token_limit) | ||
result = transform.get_prompt() | ||
|
||
assert len(result) <= max_token_limit | ||
assert len(result) == 12 |