diff --git a/engine/controllers/swagger.cc b/engine/controllers/swagger.cc index 81d5fdf45..94f2af402 100644 --- a/engine/controllers/swagger.cc +++ b/engine/controllers/swagger.cc @@ -630,10 +630,21 @@ Json::Value SwaggerController::generateOpenAPISpec() { "#/components/schemas/ChatMessage"; schemas["ChatCompletionRequest"]["properties"]["stream"]["type"] = "boolean"; schemas["ChatCompletionRequest"]["properties"]["engine"]["type"] = "string"; + schemas["ChatCompletionRequest"]["properties"]["tools"]["type"] = "array"; + schemas["ChatCompletionRequest"]["properties"]["tools"]["items"]["$ref"] = + "#/components/schemas/ToolsCall"; + schemas["ChatCompletionRequest"]["properties"]["tools_call_in_user_message"] + ["type"] = "boolean"; + schemas["ChatCompletionRequest"]["properties"]["tools_call_in_user_message"] + ["default"] = false; + schemas["ToolsCall"]["type"] = "object"; schemas["ChatMessage"]["type"] = "object"; schemas["ChatMessage"]["properties"]["role"]["type"] = "string"; schemas["ChatMessage"]["properties"]["content"]["type"] = "string"; + schemas["ChatMessage"]["properties"]["tools"]["type"] = "array"; + schemas["ChatMessage"]["properties"]["tools"]["items"]["$ref"] = + "#/components/schemas/ToolsCall"; schemas["ChatCompletionResponse"]["type"] = "object"; // Add properties based on your implementation diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index eebc7e2ee..a8d9a3166 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -389,4 +389,4 @@ bool InferenceService::HasFieldInReq(std::shared_ptr json_body, } return true; } -} // namespace services \ No newline at end of file +} // namespace services diff --git a/engine/utils/function_calling/common.h b/engine/utils/function_calling/common.h index cd47ab529..d01f22423 100644 --- a/engine/utils/function_calling/common.h +++ b/engine/utils/function_calling/common.h @@ -51,8 +51,9 @@ inline std::string ReplaceCustomFunctions(const std::string& original, } inline bool HasTools(const std::shared_ptr& request) { - return request->isMember("tools") && (*request)["tools"].isArray() && - (*request)["tools"].size() > 0; + return (request->isMember("tools") && (*request)["tools"].isArray() && + (*request)["tools"].size() > 0) || + request->get("tools_call_in_user_message", false).asBool(); } inline std::string ProcessTools(const std::shared_ptr& request) { @@ -149,7 +150,7 @@ inline void UpdateMessages(std::string& system_prompt, Json::Value tool_choice = request->get("tool_choice", "auto"); if (tool_choice.isString() && tool_choice.asString() == "required") { system_prompt += - "\n\nYou must use a function to answer the user's question."; + "\n\nYou must call a function to answer the user's question."; } else if (!tool_choice.isString()) { system_prompt += @@ -158,10 +159,14 @@ inline void UpdateMessages(std::string& system_prompt, "' to answer the user's question."; } + bool tools_call_in_user_message = + request->get("tools_call_in_user_message", false).asBool(); + bool original_stream_config = (*request).get("stream", false).asBool(); // (*request)["grammar"] = function_calling_utils::gamma_json; (*request)["stream"] = false; //when using function calling, disable stream automatically because we need to parse the response to get function name and params + if (!request->isMember("messages") || !(*request)["messages"].isArray() || (*request)["messages"].empty()) { // If no messages, add the system prompt as the first message @@ -170,21 +175,34 @@ inline void UpdateMessages(std::string& system_prompt, systemMessage["content"] = system_prompt; (*request)["messages"].append(systemMessage); } else { - Json::Value& firstMessage = (*request)["messages"][0]; - if (firstMessage["role"] == "system") { - bool addCustomPrompt = - request->get("add_custom_system_prompt", true).asBool(); - if (addCustomPrompt) { - firstMessage["content"] = - system_prompt + "\n" + firstMessage["content"].asString(); + + if (tools_call_in_user_message) { + for (Json::Value& message : (*request)["messages"]) { + if (message["role"] == "user" && message.isMember("tools") && + message["tools"].isArray() && message["tools"].size() > 0) { + message["content"] = system_prompt + "\n User question: " + + message["content"].asString(); + } } } else { - // If the first message is not a system message, prepend the system prompt - Json::Value systemMessage; - systemMessage["role"] = "system"; - systemMessage["content"] = system_prompt; - (*request)["messages"].insert(0, systemMessage); + Json::Value& firstMessage = (*request)["messages"][0]; + if (firstMessage["role"] == "system") { + bool addCustomPrompt = + request->get("add_custom_system_prompt", true).asBool(); + if (addCustomPrompt) { + firstMessage["content"] = + system_prompt + "\n" + firstMessage["content"].asString(); + } + } else { + // If the first message is not a system message, prepend the system prompt + Json::Value systemMessage; + systemMessage["role"] = "system"; + systemMessage["content"] = system_prompt; + (*request)["messages"].insert(0, systemMessage); + } } + + // transform last message role to tool if it is a function call Json::Value& lastMessage = (*request)["messages"][(*request)["messages"].size() - 1]; if (lastMessage.get("role", "") == "tool") {