diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 0c41b0f74..87094528b 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -5,7 +5,7 @@ #include "utils/cpuid/cpu_info.h" #include "utils/engine_constants.h" #include "utils/file_manager_utils.h" - +#include "utils/function_calling/common.h" using namespace inferences; using json = nlohmann::json; namespace inferences { @@ -168,6 +168,7 @@ void server::ProcessStreamRes(std::function cb, void server::ProcessNonStreamRes(std::function cb, services::SyncQueue& q) { auto [status, res] = q.wait_and_pop(); + function_calling_utils::PostProcessResponse(res); auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); resp->setStatusCode( static_cast(status["status_code"].asInt())); diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 08dbc33aa..eebc7e2ee 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -2,6 +2,7 @@ #include "utils/cpuid/cpu_info.h" #include "utils/engine_constants.h" #include "utils/file_manager_utils.h" +#include "utils/function_calling/common.h" namespace services { @@ -41,11 +42,17 @@ cpp::result InferenceService::HandleChatCompletion( LOG_WARN << "Engine is not loaded yet"; return cpp::fail(std::make_pair(stt, res)); } + + function_calling_utils::PreprocessRequest(json_body); + Json::Value tool_choice = json_body->get("tool_choice", Json::Value::null); std::get(engines_[ne].engine) - ->HandleChatCompletion(json_body, - [q](Json::Value status, Json::Value res) { - q->push(std::make_pair(status, res)); - }); + ->HandleChatCompletion( + json_body, [q, tool_choice](Json::Value status, Json::Value res) { + if (!tool_choice.isNull()) { + res["tool_choice"] = tool_choice; + } + q->push(std::make_pair(status, res)); + }); return {}; } diff --git a/engine/test/components/test_function_calling.cc b/engine/test/components/test_function_calling.cc new file mode 100644 index 000000000..7a4810b29 --- /dev/null +++ b/engine/test/components/test_function_calling.cc @@ -0,0 +1,157 @@ +#include +#include "gtest/gtest.h" +#include "json/json.h" +#include "utils/function_calling/common.h" + +class FunctionCallingUtilsTest : public ::testing::Test { + protected: + std::shared_ptr createTestRequest() { + auto request = std::make_shared(); + (*request)["tools"] = Json::Value(Json::arrayValue); + return request; + } +}; + +TEST_F(FunctionCallingUtilsTest, ReplaceCustomFunctions) { + std::string original = "Test placeholder"; + std::string replacement = "Custom function"; + std::string result = + function_calling_utils::ReplaceCustomFunctions(original, replacement); + EXPECT_EQ(result, "Test Custom function placeholder"); +} + +TEST_F(FunctionCallingUtilsTest, HasTools) { + auto request = createTestRequest(); + EXPECT_FALSE(function_calling_utils::HasTools(request)); + + (*request)["tools"].append(Json::Value()); + EXPECT_TRUE(function_calling_utils::HasTools(request)); + + (*request)["tools"] = "random"; + EXPECT_FALSE(function_calling_utils::HasTools(request)); + + (*request)["tools"] = Json::Value::null; + EXPECT_FALSE(function_calling_utils::HasTools(request)); +} + +TEST_F(FunctionCallingUtilsTest, ProcessTools) { + auto request = createTestRequest(); + Json::Value tool; + tool["type"] = "function"; + tool["function"]["name"] = "test_function"; + tool["function"]["description"] = "Test description"; + (*request)["tools"].append(tool); + + std::string result = function_calling_utils::ProcessTools(request); + EXPECT_TRUE( + result.find("Use the function 'test_function' to: Test description") != + std::string::npos); +} + +TEST_F(FunctionCallingUtilsTest, ParseMultipleFunctionStrings) { + std::string input = + "{\"arg\":\"value1\"}{\"arg\":\"value2\"}"; + Json::Value result = + function_calling_utils::ParseMultipleFunctionStrings(input); + + ASSERT_EQ(result.size(), 2); + EXPECT_EQ(result[0]["function"]["name"].asString(), "func1"); + EXPECT_EQ(result[0]["function"]["arguments"].asString(), + "{\"arg\":\"value1\"}"); + EXPECT_EQ(result[1]["function"]["name"].asString(), "func2"); + EXPECT_EQ(result[1]["function"]["arguments"].asString(), + "{\"arg\":\"value2\"}"); +} + +TEST_F(FunctionCallingUtilsTest, ConvertJsonToFunctionStrings) { + Json::Value jsonArray(Json::arrayValue); + Json::Value function1, function2; + function1["function"]["name"] = "func1"; + function1["function"]["arguments"] = "{\"arg\":\"value1\"}"; + function2["function"]["name"] = "func2"; + function2["function"]["arguments"] = "{\"arg\":\"value2\"}"; + jsonArray.append(function1); + jsonArray.append(function2); + + std::string result = + function_calling_utils::ConvertJsonToFunctionStrings(jsonArray); + EXPECT_EQ(result, + "{\"arg\":\"value1\"}{\"arg\":\"value2\"}"); +} + +TEST_F(FunctionCallingUtilsTest, CreateCustomFunctionsString) { + auto request = createTestRequest(); + Json::Value tool; + tool["type"] = "function"; + tool["function"]["name"] = "test_function"; + tool["function"]["description"] = "Test description"; + (*request)["tools"].append(tool); + + std::string result = + function_calling_utils::CreateCustomFunctionsString(request); + EXPECT_TRUE(result.find("```") != std::string::npos); + EXPECT_TRUE( + result.find("Use the function 'test_function' to: Test description") != + std::string::npos); +} + +TEST_F(FunctionCallingUtilsTest, IsValidToolChoiceFormat) { + Json::Value validTool; + validTool["type"] = "function"; + validTool["function"]["name"] = "test_function"; + EXPECT_TRUE(function_calling_utils::IsValidToolChoiceFormat(validTool)); + + Json::Value invalidTool; + EXPECT_FALSE(function_calling_utils::IsValidToolChoiceFormat(invalidTool)); +} + +TEST_F(FunctionCallingUtilsTest, UpdateMessages) { + auto request = createTestRequest(); + std::string system_prompt = "Original prompt"; + (*request)["messages"] = Json::Value(Json::arrayValue); + + function_calling_utils::UpdateMessages(system_prompt, request); + + ASSERT_TRUE((*request)["messages"].isArray()); + EXPECT_EQ((*request)["messages"][0]["role"].asString(), "system"); + EXPECT_EQ((*request)["messages"][0]["content"].asString(), system_prompt); +} + +TEST_F(FunctionCallingUtilsTest, PreprocessRequest) { + auto request = createTestRequest(); + Json::Value tool; + tool["type"] = "function"; + tool["function"]["name"] = "test_function"; + tool["function"]["description"] = "Test description"; + (*request)["tools"].append(tool); + + function_calling_utils::PreprocessRequest(request); + + ASSERT_TRUE((*request)["messages"].isArray()); + EXPECT_TRUE((*request)["messages"][0]["content"].asString().find( + "Test description") != std::string::npos); +} + +TEST_F(FunctionCallingUtilsTest, PostProcessResponse) { + Json::Value response; + response["choices"] = Json::Value(Json::arrayValue); + Json::Value choice; + choice["message"]["content"] = + "{\"arg\":\"value\"}"; + response["choices"].append(choice); + + function_calling_utils::PostProcessResponse(response); + + EXPECT_EQ(response["choices"][0]["message"]["content"].asString(), ""); + EXPECT_TRUE(response["choices"][0]["message"]["tool_calls"].isArray()); + EXPECT_EQ( + response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] + .asString(), + "test_function"); + EXPECT_EQ(response["choices"][0]["message"]["tool_calls"][0]["function"] + ["arguments"] + .asString(), + "{\"arg\":\"value\"}"); +} \ No newline at end of file diff --git a/engine/utils/function_calling/common.h b/engine/utils/function_calling/common.h new file mode 100644 index 000000000..cd47ab529 --- /dev/null +++ b/engine/utils/function_calling/common.h @@ -0,0 +1,264 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include "llama3.1.h" +#include "utils/logging_utils.h" +namespace function_calling_utils { +constexpr auto custom_template_function = ""; + +constexpr auto gamma_json = R"( +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\\x7F\x00-\x1F] | + "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= | " " | "\n" [ \t]{0,20})"; + +inline std::string ReplaceCustomFunctions(const std::string& original, + const std::string& replacement) { + std::string result = original; + + size_t pos = result.find(custom_template_function); + if (pos != std::string::npos) { + result.replace(pos, std::string(custom_template_function).length(), + replacement); + } + + return result; +} + +inline bool HasTools(const std::shared_ptr& request) { + return request->isMember("tools") && (*request)["tools"].isArray() && + (*request)["tools"].size() > 0; +} + +inline std::string ProcessTools(const std::shared_ptr& request) { + if (!HasTools(request)) { + return ""; + } + + std::ostringstream result; + result << "\n"; + + const Json::Value& tools = (*request)["tools"]; + for (const auto& tool : tools) { + if (tool["type"] == "function") { + const Json::Value& function = tool["function"]; + result << "Use the function '" << function["name"].asString() + << "' to: " << function["description"].asString() << "\n"; + + Json::FastWriter writer; + std::string jsonString = writer.write(tool); + result << jsonString << "\n"; + } + } + + return result.str(); +} + +inline Json::Value ParseMultipleFunctionStrings(const std::string& input) { + Json::Value results(Json::arrayValue); + + // Regular expression to match the function name and arguments + std::regex functionRegex("]+)>(.+?)"); + + // Iterator for regex matches + auto words_begin = + std::sregex_iterator(input.begin(), input.end(), functionRegex); + auto words_end = std::sregex_iterator(); + + for (std::sregex_iterator i = words_begin; i != words_end; ++i) { + std::smatch match = *i; + if (match.size() == 3) { + Json::Value function; + function["type"] = "function"; + function["function"]["name"] = match[1].str(); + function["function"]["arguments"] = match[2].str(); + results.append(function); + } + } + + return results; +} + +inline std::string ConvertJsonToFunctionStrings(const Json::Value& jsonArray) { + if (!jsonArray.isArray()) { + return ""; // Return empty string if input is not an array + } + + std::ostringstream result; + + for (const auto& function : jsonArray) { + auto function_json = function.get("function", {}); + if (function_json.isMember("name") && function_json.isMember("arguments")) { + result << "" + << function_json["arguments"].asString() << ""; + } + } + return result.str(); +} + +// Helper function to parse a JSON string to Json +inline Json::Value ParseJsonString(const std::string& jsonString) { + Json::Value root; + Json::Reader reader; + reader.parse(jsonString, root); + return root; +} + +inline std::string CreateCustomFunctionsString( + std::shared_ptr request) { + std::string customFunctions = ProcessTools(request); + if (customFunctions.empty()) { + return ""; // No custom functions found + } + + return "```\n" + customFunctions + "```"; +} +inline bool IsValidToolChoiceFormat(const Json::Value& root) { + return root.isObject() && root.isMember("type") && root["type"].isString() && + root["type"].asString() == "function" && root.isMember("function") && + root["function"].isObject() && root["function"].isMember("name") && + root["function"]["name"].isString(); +} +inline void UpdateMessages(std::string& system_prompt, + std::shared_ptr request) { + Json::Value tool_choice = request->get("tool_choice", "auto"); + if (tool_choice.isString() && tool_choice.asString() == "required") { + system_prompt += + "\n\nYou must use a function to answer the user's question."; + } else if (!tool_choice.isString()) { + + system_prompt += + "\n\nNow this is your first priority: You must call the function '" + + tool_choice["function"]["name"].asString() + + "' to answer the user's question."; + } + + bool original_stream_config = (*request).get("stream", false).asBool(); + // (*request)["grammar"] = function_calling_utils::gamma_json; + (*request)["stream"] = + false; //when using function calling, disable stream automatically because we need to parse the response to get function name and params + if (!request->isMember("messages") || !(*request)["messages"].isArray() || + (*request)["messages"].empty()) { + // If no messages, add the system prompt as the first message + Json::Value systemMessage; + systemMessage["role"] = "system"; + systemMessage["content"] = system_prompt; + (*request)["messages"].append(systemMessage); + } else { + Json::Value& firstMessage = (*request)["messages"][0]; + if (firstMessage["role"] == "system") { + bool addCustomPrompt = + request->get("add_custom_system_prompt", true).asBool(); + if (addCustomPrompt) { + firstMessage["content"] = + system_prompt + "\n" + firstMessage["content"].asString(); + } + } else { + // If the first message is not a system message, prepend the system prompt + Json::Value systemMessage; + systemMessage["role"] = "system"; + systemMessage["content"] = system_prompt; + (*request)["messages"].insert(0, systemMessage); + } + Json::Value& lastMessage = + (*request)["messages"][(*request)["messages"].size() - 1]; + if (lastMessage.get("role", "") == "tool") { + lastMessage["role"] = function_calling_llama3_1_utils::tool_role; + (*request)["stream"] = + original_stream_config; // if role is tool then should restore stream config to original value + } + } + for (Json::Value& message : (*request)["messages"]) { + if (message["role"] == "assistant" && message.isMember("tool_calls")) { + const Json::Value& tool_calls = message["tool_calls"]; + if (!tool_calls.isNull() && tool_calls.isArray() && + tool_calls.size() > 0) { + message["content"] = ConvertJsonToFunctionStrings(tool_calls); + message["tool_calls"] = {}; + } + } + } +} +inline void PreprocessRequest(std::shared_ptr request) { + if (!function_calling_utils::HasTools(request)) { + return; // Exit if no tools present + } + if (request->get("tool_choice", "auto").isString()) { + std::string tool_choice = request->get("tool_choice", "auto").asString(); + if (tool_choice == "none") { + return; // Exit if tool_choice is none + } + } + std::string customFunctionsString = + function_calling_utils::CreateCustomFunctionsString(request); + std::string new_system_prompt = + function_calling_utils::ReplaceCustomFunctions( + function_calling_llama3_1_utils::system_prompt, + customFunctionsString); + UpdateMessages(new_system_prompt, request); +} + +inline void PostProcessResponse(Json::Value& response) { + if (!response.isMember("choices") || !response["choices"].isArray() || + response["choices"].empty()) { + // If there are no choices or the structure is incorrect, do nothing + return; + } + + // Get a reference to the first choice + Json::Value& firstChoice = response["choices"][0]; + + // Check if the choice has a message with content + if (firstChoice.isMember("message") && + firstChoice["message"].isMember("content")) { + std::string content = firstChoice["message"]["content"].asString(); + + // Create a new structure for tool_calls + Json::Value toolCall = ParseMultipleFunctionStrings(content); + if (toolCall.size() > 0) { + // Add tool_calls to the message + if (response.get("tool_choice", "auto").isString()) { + std::string tool_choice = + response.get("tool_choice", "auto").asString(); + if (tool_choice == "auto") { + firstChoice["finish_reason"] = "tool_calls"; + } else { + firstChoice["finish_reason"] = "stop"; + } + } + + firstChoice["message"]["tool_calls"] = toolCall; + + // Clear the content as it's now represented in tool_calls + firstChoice["message"]["content"] = ""; + } + } + + // Add any additional post-processing logic here +} +} // namespace function_calling_utils diff --git a/engine/utils/function_calling/llama3.1.h b/engine/utils/function_calling/llama3.1.h new file mode 100644 index 000000000..5c2e6ffdb --- /dev/null +++ b/engine/utils/function_calling/llama3.1.h @@ -0,0 +1,43 @@ +#pragma once + +namespace function_calling_llama3_1_utils { +constexpr auto system_prompt = R"( +Environment: ipython +Tools: brave_search, wolfram_alpha +Cutting Knowledge Date: December 2023 +Today Date: 20 September 2024 + +# Tool Instructions +- Always execute python code in messages that you share. +- When looking for real time information use relevant functions if available else fallback to brave_search + +You have access to the following CUSTOM functions: + + + + +If a you choose to call a function ONLY reply in the following format: +<{start_tag}={function_name}>{parameters}{end_tag} +where + +start_tag => ` a JSON dict with the function argument name as key and function argument value as value. +end_tag => `` + +Here is an example, +{"example_name": "example_value"} + +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- Always add your sources when using search results to answer the user query +- If can not find correct parameters corresponding to function, ask user again to provide. +- No explanation are needed when calling a function. + +You are a helpful assistant. +)"; + +constexpr auto tool_role = "<|eot_id|>\n<|start_header_id|>ipython<|end_header_id|>\n"; +} // namespace function_calling_llama3_1_utils