Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 21d77bc

Browse files
Feat/function calling (#1472)
* chore: change update to patch * fix: swagger * fix: pull api * chore: refactor server controller * fix: update status * feat: mimic openai function calling api with llama3.1 * chore: remove unnecessary cout * feat: add tool choice option to api * feat: add unitest * chore: format code --------- Co-authored-by: vansangpfiev <vansangpfiev@gmail.com>
1 parent 77fb12f commit 21d77bc

File tree

5 files changed

+477
-5
lines changed

5 files changed

+477
-5
lines changed

engine/controllers/server.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "utils/cpuid/cpu_info.h"
66
#include "utils/engine_constants.h"
77
#include "utils/file_manager_utils.h"
8-
8+
#include "utils/function_calling/common.h"
99
using namespace inferences;
1010
using json = nlohmann::json;
1111
namespace inferences {
@@ -168,6 +168,7 @@ void server::ProcessStreamRes(std::function<void(const HttpResponsePtr&)> cb,
168168
void server::ProcessNonStreamRes(std::function<void(const HttpResponsePtr&)> cb,
169169
services::SyncQueue& q) {
170170
auto [status, res] = q.wait_and_pop();
171+
function_calling_utils::PostProcessResponse(res);
171172
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
172173
resp->setStatusCode(
173174
static_cast<drogon::HttpStatusCode>(status["status_code"].asInt()));

engine/services/inference_service.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "utils/cpuid/cpu_info.h"
33
#include "utils/engine_constants.h"
44
#include "utils/file_manager_utils.h"
5+
#include "utils/function_calling/common.h"
56

67
namespace services {
78

@@ -41,11 +42,17 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
4142
LOG_WARN << "Engine is not loaded yet";
4243
return cpp::fail(std::make_pair(stt, res));
4344
}
45+
46+
function_calling_utils::PreprocessRequest(json_body);
47+
Json::Value tool_choice = json_body->get("tool_choice", Json::Value::null);
4448
std::get<EngineI*>(engines_[ne].engine)
45-
->HandleChatCompletion(json_body,
46-
[q](Json::Value status, Json::Value res) {
47-
q->push(std::make_pair(status, res));
48-
});
49+
->HandleChatCompletion(
50+
json_body, [q, tool_choice](Json::Value status, Json::Value res) {
51+
if (!tool_choice.isNull()) {
52+
res["tool_choice"] = tool_choice;
53+
}
54+
q->push(std::make_pair(status, res));
55+
});
4956
return {};
5057
}
5158

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#include <memory>
2+
#include "gtest/gtest.h"
3+
#include "json/json.h"
4+
#include "utils/function_calling/common.h"
5+
6+
class FunctionCallingUtilsTest : public ::testing::Test {
7+
protected:
8+
std::shared_ptr<Json::Value> createTestRequest() {
9+
auto request = std::make_shared<Json::Value>();
10+
(*request)["tools"] = Json::Value(Json::arrayValue);
11+
return request;
12+
}
13+
};
14+
15+
TEST_F(FunctionCallingUtilsTest, ReplaceCustomFunctions) {
16+
std::string original = "Test <CUSTOM_FUNCTIONS> placeholder";
17+
std::string replacement = "Custom function";
18+
std::string result =
19+
function_calling_utils::ReplaceCustomFunctions(original, replacement);
20+
EXPECT_EQ(result, "Test Custom function placeholder");
21+
}
22+
23+
TEST_F(FunctionCallingUtilsTest, HasTools) {
24+
auto request = createTestRequest();
25+
EXPECT_FALSE(function_calling_utils::HasTools(request));
26+
27+
(*request)["tools"].append(Json::Value());
28+
EXPECT_TRUE(function_calling_utils::HasTools(request));
29+
30+
(*request)["tools"] = "random";
31+
EXPECT_FALSE(function_calling_utils::HasTools(request));
32+
33+
(*request)["tools"] = Json::Value::null;
34+
EXPECT_FALSE(function_calling_utils::HasTools(request));
35+
}
36+
37+
TEST_F(FunctionCallingUtilsTest, ProcessTools) {
38+
auto request = createTestRequest();
39+
Json::Value tool;
40+
tool["type"] = "function";
41+
tool["function"]["name"] = "test_function";
42+
tool["function"]["description"] = "Test description";
43+
(*request)["tools"].append(tool);
44+
45+
std::string result = function_calling_utils::ProcessTools(request);
46+
EXPECT_TRUE(
47+
result.find("Use the function 'test_function' to: Test description") !=
48+
std::string::npos);
49+
}
50+
51+
TEST_F(FunctionCallingUtilsTest, ParseMultipleFunctionStrings) {
52+
std::string input =
53+
"<function=func1>{\"arg\":\"value1\"}</"
54+
"function><function=func2>{\"arg\":\"value2\"}</function>";
55+
Json::Value result =
56+
function_calling_utils::ParseMultipleFunctionStrings(input);
57+
58+
ASSERT_EQ(result.size(), 2);
59+
EXPECT_EQ(result[0]["function"]["name"].asString(), "func1");
60+
EXPECT_EQ(result[0]["function"]["arguments"].asString(),
61+
"{\"arg\":\"value1\"}");
62+
EXPECT_EQ(result[1]["function"]["name"].asString(), "func2");
63+
EXPECT_EQ(result[1]["function"]["arguments"].asString(),
64+
"{\"arg\":\"value2\"}");
65+
}
66+
67+
TEST_F(FunctionCallingUtilsTest, ConvertJsonToFunctionStrings) {
68+
Json::Value jsonArray(Json::arrayValue);
69+
Json::Value function1, function2;
70+
function1["function"]["name"] = "func1";
71+
function1["function"]["arguments"] = "{\"arg\":\"value1\"}";
72+
function2["function"]["name"] = "func2";
73+
function2["function"]["arguments"] = "{\"arg\":\"value2\"}";
74+
jsonArray.append(function1);
75+
jsonArray.append(function2);
76+
77+
std::string result =
78+
function_calling_utils::ConvertJsonToFunctionStrings(jsonArray);
79+
EXPECT_EQ(result,
80+
"<function=func1>{\"arg\":\"value1\"}</"
81+
"function><function=func2>{\"arg\":\"value2\"}</function>");
82+
}
83+
84+
TEST_F(FunctionCallingUtilsTest, CreateCustomFunctionsString) {
85+
auto request = createTestRequest();
86+
Json::Value tool;
87+
tool["type"] = "function";
88+
tool["function"]["name"] = "test_function";
89+
tool["function"]["description"] = "Test description";
90+
(*request)["tools"].append(tool);
91+
92+
std::string result =
93+
function_calling_utils::CreateCustomFunctionsString(request);
94+
EXPECT_TRUE(result.find("```") != std::string::npos);
95+
EXPECT_TRUE(
96+
result.find("Use the function 'test_function' to: Test description") !=
97+
std::string::npos);
98+
}
99+
100+
TEST_F(FunctionCallingUtilsTest, IsValidToolChoiceFormat) {
101+
Json::Value validTool;
102+
validTool["type"] = "function";
103+
validTool["function"]["name"] = "test_function";
104+
EXPECT_TRUE(function_calling_utils::IsValidToolChoiceFormat(validTool));
105+
106+
Json::Value invalidTool;
107+
EXPECT_FALSE(function_calling_utils::IsValidToolChoiceFormat(invalidTool));
108+
}
109+
110+
TEST_F(FunctionCallingUtilsTest, UpdateMessages) {
111+
auto request = createTestRequest();
112+
std::string system_prompt = "Original prompt";
113+
(*request)["messages"] = Json::Value(Json::arrayValue);
114+
115+
function_calling_utils::UpdateMessages(system_prompt, request);
116+
117+
ASSERT_TRUE((*request)["messages"].isArray());
118+
EXPECT_EQ((*request)["messages"][0]["role"].asString(), "system");
119+
EXPECT_EQ((*request)["messages"][0]["content"].asString(), system_prompt);
120+
}
121+
122+
TEST_F(FunctionCallingUtilsTest, PreprocessRequest) {
123+
auto request = createTestRequest();
124+
Json::Value tool;
125+
tool["type"] = "function";
126+
tool["function"]["name"] = "test_function";
127+
tool["function"]["description"] = "Test description";
128+
(*request)["tools"].append(tool);
129+
130+
function_calling_utils::PreprocessRequest(request);
131+
132+
ASSERT_TRUE((*request)["messages"].isArray());
133+
EXPECT_TRUE((*request)["messages"][0]["content"].asString().find(
134+
"Test description") != std::string::npos);
135+
}
136+
137+
TEST_F(FunctionCallingUtilsTest, PostProcessResponse) {
138+
Json::Value response;
139+
response["choices"] = Json::Value(Json::arrayValue);
140+
Json::Value choice;
141+
choice["message"]["content"] =
142+
"<function=test_function>{\"arg\":\"value\"}</function>";
143+
response["choices"].append(choice);
144+
145+
function_calling_utils::PostProcessResponse(response);
146+
147+
EXPECT_EQ(response["choices"][0]["message"]["content"].asString(), "");
148+
EXPECT_TRUE(response["choices"][0]["message"]["tool_calls"].isArray());
149+
EXPECT_EQ(
150+
response["choices"][0]["message"]["tool_calls"][0]["function"]["name"]
151+
.asString(),
152+
"test_function");
153+
EXPECT_EQ(response["choices"][0]["message"]["tool_calls"][0]["function"]
154+
["arguments"]
155+
.asString(),
156+
"{\"arg\":\"value\"}");
157+
}

0 commit comments

Comments
 (0)