diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 06c746990d2a3b..6ec3e55a40ebc2 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -35,6 +35,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.utils.extract_thread_messages import extract_thread_messages +from core.prompt.utils.image_detail_config import image_detail_config_for_prompt_file from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ( ToolParameter, @@ -120,7 +121,8 @@ def __init__( model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials) features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features - self.files = application_generate_entity.files if ModelFeature.VISION in features else [] + supports_file_context = ModelFeature.VISION in features or ModelFeature.DOCUMENT in features + self.files = application_generate_entity.files if supports_file_context else [] self.query: str | None = "" self._current_thoughts: list[PromptMessage] = [] @@ -542,7 +544,7 @@ def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: prompt_message_contents.append( file_manager.to_prompt_message_content( file, - image_detail_config=image_detail_config, + image_detail_config=image_detail_config_for_prompt_file(file, image_detail_config), ) ) prompt_message_contents.append(TextPromptMessageContent(data=message.query)) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 2b2e26987e88f7..dcd2ae71eced85 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -12,6 +12,7 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder from core.agent.cot_agent_runner import CotAgentRunner +from core.prompt.utils.image_detail_config import image_detail_config_for_prompt_file class CotChatAgentRunner(CotAgentRunner): @@ -56,7 +57,7 @@ def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> l prompt_message_contents.append( file_manager.to_prompt_message_content( file, - image_detail_config=image_detail_config, + image_detail_config=image_detail_config_for_prompt_file(file, image_detail_config), ) ) prompt_message_contents.append(TextPromptMessageContent(data=query)) diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index fdffde85d01a84..2bad14c0353722 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -21,6 +21,7 @@ from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from core.agent.base_agent_runner import BaseAgentRunner +from core.prompt.utils.image_detail_config import image_detail_config_for_prompt_file from core.agent.errors import AgentMaxIterationError from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent @@ -415,7 +416,7 @@ def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) prompt_message_contents.append( file_manager.to_prompt_message_content( file, - image_detail_config=image_detail_config, + image_detail_config=image_detail_config_for_prompt_file(file, image_detail_config), ) ) prompt_message_contents.append(TextPromptMessageContent(data=query)) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 5809d6f74a7bb9..3d973d1546e63b 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -17,6 +17,7 @@ from core.app.file_access import DatabaseFileAccessController from core.model_manager import ModelInstance from core.prompt.utils.extract_thread_messages import extract_thread_messages +from core.prompt.utils.image_detail_config import image_detail_config_for_prompt_file from extensions.ext_database import db from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile @@ -111,7 +112,7 @@ def _build_prompt_message_with_files( for file in file_objs: prompt_message = file_manager.to_prompt_message_content( file, - image_detail_config=detail, + image_detail_config=image_detail_config_for_prompt_file(file, detail), ) prompt_message_contents.append(prompt_message) prompt_message_contents.append(TextPromptMessageContent(data=text_content)) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 19b5e9223a8939..55fc333f64df69 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -19,6 +19,7 @@ from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.image_detail_config import image_detail_config_for_prompt_file from core.prompt.utils.prompt_template_parser import PromptTemplateParser @@ -133,7 +134,10 @@ def _get_completion_model_prompt_messages( prompt_message_contents: list[PromptMessageContentUnionTypes] = [] for file in files: prompt_message_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config_for_prompt_file(file, image_detail_config), + ) ) prompt_message_contents.append(TextPromptMessageContent(data=prompt)) @@ -215,7 +219,10 @@ def _get_chat_model_prompt_messages( if files and query is not None: for file in files: prompt_message_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config_for_prompt_file(file, image_detail_config), + ) ) prompt_message_contents.append(TextPromptMessageContent(data=query)) @@ -230,7 +237,10 @@ def _get_chat_model_prompt_messages( # get last user message content and add files for file in files: prompt_message_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config_for_prompt_file(file, image_detail_config), + ) ) prompt_message_contents.append(TextPromptMessageContent(data=cast(str, last_message.content))) @@ -238,7 +248,10 @@ def _get_chat_model_prompt_messages( else: for file in files: prompt_message_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config_for_prompt_file(file, image_detail_config), + ) ) prompt_message_contents.append(TextPromptMessageContent(data="")) @@ -246,7 +259,10 @@ def _get_chat_model_prompt_messages( else: for file in files: prompt_message_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config_for_prompt_file(file, image_detail_config), + ) ) prompt_message_contents.append(TextPromptMessageContent(data=query)) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index dc8391a6a56708..d9c84a8ed56944 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -19,6 +19,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.image_detail_config import image_detail_config_for_prompt_file from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode @@ -297,12 +298,18 @@ def _get_last_user_message( if files: for file in files: prompt_message_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config_for_prompt_file(file, image_detail_config), + ) ) if context_files: for file in context_files: prompt_message_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config_for_prompt_file(file, image_detail_config), + ) ) if prompt_message_contents: prompt_message_contents.append(TextPromptMessageContent(data=prompt)) diff --git a/api/core/prompt/utils/image_detail_config.py b/api/core/prompt/utils/image_detail_config.py new file mode 100644 index 00000000000000..863cc782cea72a --- /dev/null +++ b/api/core/prompt/utils/image_detail_config.py @@ -0,0 +1,11 @@ +from graphon.file import File, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent + + +def image_detail_config_for_prompt_file( + file: File, + image_detail_config: ImagePromptMessageContent.DETAIL | None, +) -> ImagePromptMessageContent.DETAIL | None: + if file.type == FileType.IMAGE: + return image_detail_config + return None diff --git a/api/tests/unit_tests/core/agent/test_base_agent_runner.py b/api/tests/unit_tests/core/agent/test_base_agent_runner.py index db4b293b163640..6092681a944d10 100644 --- a/api/tests/unit_tests/core/agent/test_base_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_base_agent_runner.py @@ -660,6 +660,82 @@ def test_init_sets_stream_tool_call_and_files(self, mocker): assert runner.dataset_tools == ["ds_tool"] assert runner.agent_thought_count == 2 + def test_init_keeps_files_for_document_models(self, mocker): + session = mocker.MagicMock() + session.scalar.return_value = 0 + mocker.patch.object(module.db, "session", session) + + mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[]) + mocker.patch.object(module.DatasetRetrieverTool, "get_dataset_tools", return_value=[]) + + llm = mocker.MagicMock() + llm.get_model_schema.return_value = mocker.MagicMock( + features=[module.ModelFeature.DOCUMENT] + ) + model_instance = mocker.MagicMock(model_type_instance=llm, model="m", credentials="c") + + app_config = mocker.MagicMock() + app_config.app_id = "app1" + app_config.agent = None + app_config.dataset = None + app_config.additional_features = None + + app_generate = mocker.MagicMock(invoke_from="test", inputs={}, files=["doc-file"]) + message = mocker.MagicMock(id="msg1", conversation_id="conv1") + + runner = BaseAgentRunner( + tenant_id="tenant", + application_generate_entity=app_generate, + conversation=mocker.MagicMock(), + app_config=app_config, + model_config=mocker.MagicMock(), + config=mocker.MagicMock(), + queue_manager=mocker.MagicMock(), + message=message, + user_id="user", + model_instance=model_instance, + ) + + assert runner.files == ["doc-file"] + + def test_init_drops_files_when_model_has_no_file_features(self, mocker): + session = mocker.MagicMock() + session.scalar.return_value = 0 + mocker.patch.object(module.db, "session", session) + + mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[]) + mocker.patch.object(module.DatasetRetrieverTool, "get_dataset_tools", return_value=[]) + + llm = mocker.MagicMock() + llm.get_model_schema.return_value = mocker.MagicMock( + features=[module.ModelFeature.STREAM_TOOL_CALL] + ) + model_instance = mocker.MagicMock(model_type_instance=llm, model="m", credentials="c") + + app_config = mocker.MagicMock() + app_config.app_id = "app1" + app_config.agent = None + app_config.dataset = None + app_config.additional_features = None + + app_generate = mocker.MagicMock(invoke_from="test", inputs={}, files=["file1"]) + message = mocker.MagicMock(id="msg1", conversation_id="conv1") + + runner = BaseAgentRunner( + tenant_id="tenant", + application_generate_entity=app_generate, + conversation=mocker.MagicMock(), + app_config=app_config, + model_config=mocker.MagicMock(), + config=mocker.MagicMock(), + queue_manager=mocker.MagicMock(), + message=message, + user_id="user", + model_instance=model_instance, + ) + + assert runner.files == [] + class TestBaseAgentRunnerCoverage: def test_convert_tool_skips_non_llm_param(self, runner, mocker): diff --git a/api/tests/unit_tests/core/prompt/utils/test_image_detail_config.py b/api/tests/unit_tests/core/prompt/utils/test_image_detail_config.py new file mode 100644 index 00000000000000..c7d09dc7d2f65b --- /dev/null +++ b/api/tests/unit_tests/core/prompt/utils/test_image_detail_config.py @@ -0,0 +1,18 @@ +from unittest.mock import MagicMock + +from graphon.file import FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent + +from core.prompt.utils.image_detail_config import image_detail_config_for_prompt_file + + +def test_image_detail_only_for_image_files(): + image_file = MagicMock() + image_file.type = FileType.IMAGE + doc_file = MagicMock() + doc_file.type = FileType.DOCUMENT + detail = ImagePromptMessageContent.DETAIL.HIGH + + assert image_detail_config_for_prompt_file(image_file, detail) is detail + assert image_detail_config_for_prompt_file(doc_file, detail) is None + assert image_detail_config_for_prompt_file(doc_file, None) is None diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx index 56345890ff5c3d..ce89e4c219c95e 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx @@ -87,6 +87,7 @@ const ChatItem: FC = ({ const currentProvider = textGenerationModelList.find(item => item.provider === modelAndParameter.provider) const currentModel = currentProvider?.models.find(model => model.model === modelAndParameter.model) const supportVision = currentModel?.features?.includes(ModelFeatureEnum.vision) + const supportDocument = currentModel?.features?.includes(ModelFeatureEnum.document) const configData = { ...config, @@ -105,7 +106,7 @@ const ChatItem: FC = ({ parent_message_id: getLastAnswer(chatList)?.id || null, } - if ((config.file_upload as any).enabled && files?.length && supportVision) + if ((config.file_upload as any).enabled && files?.length && (supportVision || supportDocument)) data.files = files handleSend( diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx index a9f9f1116bdd20..148d28b8b59d51 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx @@ -96,6 +96,7 @@ const DebugWithSingleModel = ( const currentProvider = textGenerationModelList.find(item => item.provider === modelConfig.provider) const currentModel = currentProvider?.models.find(model => model.model === modelConfig.model_id) const supportVision = currentModel?.features?.includes(ModelFeatureEnum.vision) + const supportDocument = currentModel?.features?.includes(ModelFeatureEnum.document) const configData = { ...config, @@ -114,7 +115,7 @@ const DebugWithSingleModel = ( parent_message_id: (isRegenerate ? parentAnswer?.id : getLastAnswer(chatList)?.id) || null, } - if ((config.file_upload as any)?.enabled && files?.length && supportVision) + if ((config.file_upload as any)?.enabled && files?.length && (supportVision || supportDocument)) data.files = files handleSend( diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx index 88bca7111ce69f..d514094612bc72 100644 --- a/web/app/components/app/configuration/debug/index.tsx +++ b/web/app/components/app/configuration/debug/index.tsx @@ -344,11 +344,13 @@ const Debug: FC = ({ const handleVisionConfigInMultipleModel = useCallback(() => { if (debugWithMultipleModel && mode) { - const supportedVision = multipleModelConfigs.some((modelConfig) => { + const supportedFileContext = multipleModelConfigs.some((modelConfig) => { const currentProvider = textGenerationModelList.find(modelItem => modelItem.provider === modelConfig.provider) const currentModel = currentProvider?.models.find(model => model.model === modelConfig.model) - return currentModel?.features?.includes(ModelFeatureEnum.vision) + return !!currentModel?.features?.some(f => + f === ModelFeatureEnum.vision || f === ModelFeatureEnum.document, + ) }) const { features, @@ -358,7 +360,7 @@ const Debug: FC = ({ const newFeatures = produce(features, (draft) => { draft.file = { ...draft.file, - enabled: supportedVision, + enabled: supportedFileContext, } }) setFeatures(newFeatures)