From 52d464e0b192378c1bab194b65f5123b7d76cdd7 Mon Sep 17 00:00:00 2001 From: dengyingxu1 Date: Tue, 5 Aug 2025 16:08:44 +0800 Subject: [PATCH 1/9] feat: add support for passing tools parameter to chat template. --- xllm/api_service/api_service.cpp | 18 +- xllm/api_service/chat_service_impl.cpp | 135 +++++++++ xllm/core/chat_template/tools_converter.cpp | 284 ++++++++++++++++++ xllm/core/chat_template/tools_converter.h | 46 +++ .../chat_template/jinja_chat_template.cpp | 33 +- xllm/core/framework/request/request_params.h | 4 + xllm/core/runtime/llm_master.cpp | 9 +- xllm/proto/chat.proto | 16 + 8 files changed, 533 insertions(+), 12 deletions(-) create mode 100644 xllm/core/chat_template/tools_converter.cpp create mode 100644 xllm/core/chat_template/tools_converter.h diff --git a/xllm/api_service/api_service.cpp b/xllm/api_service/api_service.cpp index d632a5cf..feb2295c 100644 --- a/xllm/api_service/api_service.cpp +++ b/xllm/api_service/api_service.cpp @@ -15,7 +15,7 @@ #include "models.pb.h" #include "service_impl_factory.h" #include "xllm_metrics.h" - +#include namespace xllm { APIService::APIService(Master* master, @@ -71,7 +71,7 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller, auto ctrl = reinterpret_cast(controller); std::string attachment = std::move(ctrl->request_attachment().to_string()); - + LOG(INFO) << "attachment:" << attachment; std::string error; auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); if (!st) { @@ -106,10 +106,10 @@ void ChatCompletionsImpl(std::unique_ptr& service, std::string attachment = std::move(ctrl->request_attachment().to_string()); std::string error; - auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); - if (!st) { - ctrl->SetFailed(error); - LOG(ERROR) << "parse json to proto failed: " << error; + auto json_status = google::protobuf::util::JsonStringToMessage(attachment, req_pb); + if (!json_status.ok()) { + ctrl->SetFailed(json_status.ToString()); + LOG(ERROR) << "parse json to proto failed: " << json_status.ToString(); return; } @@ -175,7 +175,7 @@ void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller, auto ctrl = reinterpret_cast(controller); std::string attachment = std::move(ctrl->request_attachment().to_string()); - + LOG(INFO) << "attachment:" << attachment; std::string error; auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); if (!st) { @@ -289,7 +289,7 @@ void APIService::LinkCluster(::google::protobuf::RpcController* controller, auto ctrl = reinterpret_cast(controller); std::string attachment = std::move(ctrl->request_attachment().to_string()); - + LOG(INFO) << "attachment:" << attachment; std::string error; auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); if (!st) { @@ -344,7 +344,7 @@ void APIService::UnlinkCluster(::google::protobuf::RpcController* controller, auto ctrl = reinterpret_cast(controller); std::string attachment = std::move(ctrl->request_attachment().to_string()); - + LOG(INFO) << "attachment:" << attachment; std::string error; auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); if (!st) { diff --git a/xllm/api_service/chat_service_impl.cpp b/xllm/api_service/chat_service_impl.cpp index 571a27e0..aec356b5 100644 --- a/xllm/api_service/chat_service_impl.cpp +++ b/xllm/api_service/chat_service_impl.cpp @@ -18,6 +18,13 @@ #include "core/runtime/vlm_master.h" #include "core/util/utils.h" #include "core/util/uuid.h" +#include "chat_template/chat_template.h" +#include "chat_template/tools_converter.h" +#include "common/instance_name.h" +#include "common/uuid.h" +#include "request/request_params.h" +#include "util/utils.h" +#include namespace xllm { namespace { @@ -46,6 +53,108 @@ void set_logprobs(proto::ChatChoice* choice, } } } +struct ToolsInfo { + std::vector tools; + std::string tool_choice; + bool has_tools = false; +}; + +ToolsInfo extract_tools_info(const proto::ChatRequest& request) { + ToolsInfo info; + + if (request.tools_size() > 0) { + info.has_tools = true; + info.tools.reserve(request.tools_size()); + + for (const auto& proto_tool : request.tools()) { + Tool tool; + tool.type = proto_tool.type(); + tool.function.name = proto_tool.function().name(); + tool.function.description = proto_tool.function().description(); + // tool.function.parameters = proto_tool.function().parameters(); + // std::cerr << "proto_tool.function().parameters():" << proto_tool.function().parameters() << std::endl; + + std::string parameters_json_str; + if (proto_tool.function().has_parameters()) { + google::protobuf::util::JsonPrintOptions options; + options.add_whitespace = false; + options.preserve_proto_field_names = true; + auto status = google::protobuf::util::MessageToJsonString( + proto_tool.function().parameters(), ¶meters_json_str, options); + if (!status.ok()) { + LOG(WARNING) << "Failed to convert parameters Struct to JSON: " + << status.message() << ", tool: " << tool.function.name; + parameters_json_str = "{}"; + } + } else { + parameters_json_str = "{}"; + } + tool.function.parameters = parameters_json_str; + std::cerr << "parameters_json_str:" << parameters_json_str << std::endl; + + info.tools.push_back(std::move(tool)); + } + + if (request.has_tool_choice()) { + info.tool_choice = request.tool_choice(); + } else { + info.tool_choice = "auto"; + } + } + + return info; +} +struct ToolsInfo { + std::vector tools; + std::string tool_choice; + bool has_tools = false; +}; + +ToolsInfo extract_tools_info(const proto::ChatRequest& request) { + ToolsInfo info; + + if (request.tools_size() > 0) { + info.has_tools = true; + info.tools.reserve(request.tools_size()); + + for (const auto& proto_tool : request.tools()) { + Tool tool; + tool.type = proto_tool.type(); + tool.function.name = proto_tool.function().name(); + tool.function.description = proto_tool.function().description(); + // tool.function.parameters = proto_tool.function().parameters(); + // std::cerr << "proto_tool.function().parameters():" << proto_tool.function().parameters() << std::endl; + + std::string parameters_json_str; + if (proto_tool.function().has_parameters()) { + google::protobuf::util::JsonPrintOptions options; + options.add_whitespace = false; + options.preserve_proto_field_names = true; + auto status = google::protobuf::util::MessageToJsonString( + proto_tool.function().parameters(), ¶meters_json_str, options); + if (!status.ok()) { + LOG(WARNING) << "Failed to convert parameters Struct to JSON: " + << status.message() << ", tool: " << tool.function.name; + parameters_json_str = "{}"; + } + } else { + parameters_json_str = "{}"; + } + tool.function.parameters = parameters_json_str; + std::cerr << "parameters_json_str:" << parameters_json_str << std::endl; + + info.tools.push_back(std::move(tool)); + } + + if (request.has_tool_choice()) { + info.tool_choice = request.tool_choice(); + } else { + info.tool_choice = "auto"; + } + } + + return info; +} template bool send_delta_to_client_brpc(std::shared_ptr call, @@ -287,6 +396,27 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { return; } + ToolsInfo tools_info = extract_tools_info(rpc_request); + // 打印所有工具信息 + std::cerr << "Tools Information:" << std::endl; + std::cerr << "Has tools: " << (tools_info.has_tools ? "true" : "false") << std::endl; + std::cerr << "Tool choice: " << tools_info.tool_choice << std::endl; + + if (tools_info.has_tools) { + std::cerr << "Number of tools: " << tools_info.tools.size() << std::endl; + for (size_t i = 0; i < tools_info.tools.size(); ++i) { + const auto& tool = tools_info.tools[i]; + std::cerr << "Tool #" << i + 1 << ":" << std::endl; + std::cerr << " Type: " << tool.type << std::endl; + std::cerr << " Function name: " << tool.function.name << std::endl; + std::cerr << " Function description: " << tool.function.description << std::endl; + std::cerr << " Function parameters: " << tool.function.parameters << std::endl; + } + } else { + std::cerr << "No tools in this request" << std::endl; + } + + RequestParams request_params( rpc_request, call->get_x_request_id(), call->get_x_request_time()); @@ -305,6 +435,11 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { include_usage = rpc_request.stream_options().include_usage(); } + if ((tools_info.has_tools)) { + request_params.tools = std::move(tools_info.tools); + request_params.tool_choice = std::move(tools_info.tool_choice); + } + // schedule the request master_->handle_request( std::move(messages), diff --git a/xllm/core/chat_template/tools_converter.cpp b/xllm/core/chat_template/tools_converter.cpp new file mode 100644 index 00000000..e40b13da --- /dev/null +++ b/xllm/core/chat_template/tools_converter.cpp @@ -0,0 +1,284 @@ +#include "tools_converter.h" + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace llm { + +std::string ToolsConverter::convert_tools_to_json(const std::vector& tools) { + if (tools.empty()) { + return "[]"; + } + + nlohmann::json tools_json = nlohmann::json::array(); + + for (const auto& tool : tools) { + nlohmann::json tool_json; + tool_json["type"] = tool.type; + + nlohmann::json function_json; + function_json["name"] = tool.function.name; + function_json["description"] = tool.function.description; + + try { + if (!tool.function.parameters.empty()) { + function_json["parameters"] = nlohmann::json::parse(tool.function.parameters); + } else { + function_json["parameters"] = nlohmann::json::object(); + } + } catch (const nlohmann::json::exception& e) { + LOG(WARNING) << "Failed to parse tool parameters JSON: " << e.what() + << ", tool: " << tool.function.name; + function_json["parameters"] = nlohmann::json::object(); + } + + tool_json["function"] = function_json; + tools_json.push_back(tool_json); + } + + return tools_json.dump(2); +} + +std::string ToolsConverter::convert_tools_to_prompt( + const std::vector& tools, + const std::string& tool_choice) { + if (tools.empty()) { + return ""; + } + + std::ostringstream prompt; + prompt << "You have access to the following functions:\n\n"; + + for (const auto& tool : tools) { + prompt << "Function: " << tool.function.name << "\n"; + prompt << "Description: " << tool.function.description << "\n"; + + try { + if (!tool.function.parameters.empty()) { + auto params_json = nlohmann::json::parse(tool.function.parameters); + prompt << "Parameters: " << params_json.dump(2) << "\n"; + } + } catch (const nlohmann::json::exception& e) { + LOG(WARNING) << "Failed to parse parameters for tool: " << tool.function.name; + } + + prompt << "\n"; + } + + if (tool_choice == "required") { + prompt << "You MUST call one of the above functions. "; + } else if (tool_choice == "auto") { + prompt << "You may call one of the above functions if needed. "; + } + + prompt << "To call a function, respond with a JSON object in the following format:\n"; + prompt << "{\n"; + prompt << " \"tool_calls\": [\n"; + prompt << " {\n"; + prompt << " \"id\": \"call_\",\n"; + prompt << " \"type\": \"function\",\n"; + prompt << " \"function\": {\n"; + prompt << " \"name\": \"function_name\",\n"; + prompt << " \"arguments\": \"{\\\"param1\\\": \\\"value1\\\"}\"\n"; + prompt << " }\n"; + prompt << " }\n"; + prompt << " ]\n"; + prompt << "}\n\n"; + + return prompt.str(); +} + +std::vector ToolsConverter::parse_tool_calls_from_text( + const std::string& model_output) { + std::vector tool_calls; + + auto json_blocks = extract_json_blocks(model_output); + + for (const auto& json_block : json_blocks) { + auto parsed_calls = parse_tool_calls_from_json(json_block); + tool_calls.insert(tool_calls.end(), parsed_calls.begin(), parsed_calls.end()); + } + + return tool_calls; +} + +std::vector ToolsConverter::parse_tool_calls_from_json( + const std::string& json_str) { + std::vector tool_calls; + + try { + auto json_obj = nlohmann::json::parse(clean_json_string(json_str)); + + if (json_obj.contains("tool_calls") && json_obj["tool_calls"].is_array()) { + for (const auto& call_json : json_obj["tool_calls"]) { + auto tool_call = parse_single_function_call(call_json); + if (tool_call.has_value()) { + tool_calls.push_back(tool_call.value()); + } + } + } + else if (json_obj.contains("function") || json_obj.contains("name")) { + auto tool_call = parse_single_function_call(json_obj); + if (tool_call.has_value()) { + tool_calls.push_back(tool_call.value()); + } + } + } catch (const nlohmann::json::exception& e) { + LOG(WARNING) << "Failed to parse tool calls JSON: " << e.what(); + } + + return tool_calls; +} + +bool ToolsConverter::validate_tool_call_arguments( + const ToolCall& tool_call, + const std::vector& available_tools) { + auto tool_it = std::find_if(available_tools.begin(), available_tools.end(), + [&](const Tool& tool) { + return tool.function.name == tool_call.function_name; + }); + + if (tool_it == available_tools.end()) { + LOG(WARNING) << "Tool not found: " << tool_call.function_name; + return false; + } + + try { + nlohmann::json::parse(tool_call.function_arguments); + } catch (const nlohmann::json::exception& e) { + LOG(WARNING) << "Invalid arguments JSON for tool " << tool_call.function_name + << ": " << e.what(); + return false; + } + + return validate_json_schema(tool_call.function_arguments, tool_it->function.parameters); +} + +std::string ToolsConverter::generate_tool_call_id() { + static std::random_device rd; + static std::mt19937 gen(rd()); + static std::uniform_int_distribution<> dis(100000, 999999); + + return "call_" + std::to_string(dis(gen)); +} + +std::string ToolsConverter::format_tool_choice(const std::string& tool_choice) { + if (tool_choice == "auto" || tool_choice == "none" || tool_choice == "required") { + return tool_choice; + } + return "auto"; + + +std::optional ToolsConverter::parse_single_function_call( + const nlohmann::json& json_obj) { + try { + ToolCall tool_call; + + if (json_obj.contains("id")) { + tool_call.id = json_obj["id"].get(); + } else { + tool_call.id = generate_tool_call_id(); + } + + if (json_obj.contains("type")) { + tool_call.type = json_obj["type"].get(); + } else { + tool_call.type = "function"; + } + + if (json_obj.contains("function")) { + const auto& func_json = json_obj["function"]; + if (func_json.contains("name")) { + tool_call.function_name = func_json["name"].get(); + } + if (func_json.contains("arguments")) { + if (func_json["arguments"].is_string()) { + tool_call.function_arguments = func_json["arguments"].get(); + } else { + tool_call.function_arguments = func_json["arguments"].dump(); + } + } + } + else if (json_obj.contains("name")) { + tool_call.function_name = json_obj["name"].get(); + if (json_obj.contains("arguments")) { + if (json_obj["arguments"].is_string()) { + tool_call.function_arguments = json_obj["arguments"].get(); + } else { + tool_call.function_arguments = json_obj["arguments"].dump(); + } + } + } + + if (!tool_call.function_name.empty()) { + return tool_call; + } + } catch (const nlohmann::json::exception& e) { + LOG(WARNING) << "Failed to parse single function call: " << e.what(); + } + + return std::nullopt; +} + +bool ToolsConverter::validate_json_schema( + const std::string& json_str, + const std::string& schema_str) { + try { + nlohmann::json::parse(json_str); + if (!schema_str.empty()) { + nlohmann::json::parse(schema_str); + } + return true; + } catch (const nlohmann::json::exception& e) { + return false; + } + + +} + +std::string ToolsConverter::clean_json_string(const std::string& raw_json) { + std::string cleaned = raw_json; + + cleaned = std::string(absl::StripAsciiWhitespace(cleaned)); + + cleaned = absl::StrReplaceAll(cleaned, {{"```json", ""}, {"```", ""}}); + + return cleaned; +} + +std::vector ToolsConverter::extract_json_blocks(const std::string& text) { + std::vector json_blocks; + + std::regex json_regex(R"(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})"); + std::sregex_iterator iter(text.begin(), text.end(), json_regex); + std::sregex_iterator end; + + for (; iter != end; ++iter) { + std::string match = iter->str(); + try { + nlohmann::json::parse(match); + json_blocks.push_back(match); + } catch (const nlohmann::json::exception&) { + continue; + } + } + + if (json_blocks.empty()) { + try { + nlohmann::json::parse(clean_json_string(text)); + json_blocks.push_back(clean_json_string(text)); + } catch (const nlohmann::json::exception&) { + } + } + + return json_blocks; +} + +} // namespace llm \ No newline at end of file diff --git a/xllm/core/chat_template/tools_converter.h b/xllm/core/chat_template/tools_converter.h new file mode 100644 index 00000000..650bd670 --- /dev/null +++ b/xllm/core/chat_template/tools_converter.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include +#include "chat_template.h" + +namespace llm { + +class ToolsConverter { + public: + static std::string convert_tools_to_json(const std::vector& tools); + + static std::string convert_tools_to_prompt( + const std::vector& tools, + const std::string& tool_choice = "auto"); + + static std::vector parse_tool_calls_from_text( + const std::string& model_output); + + static std::vector parse_tool_calls_from_json( + const std::string& json_str); + + static bool validate_tool_call_arguments( + const ToolCall& tool_call, + const std::vector& available_tools); + + static std::string generate_tool_call_id(); + + static std::string format_tool_choice(const std::string& tool_choice); + + private: + static std::optional parse_single_function_call( + const nlohmann::json& json_obj); + + static bool validate_json_schema( + const std::string& json_str, + const std::string& schema_str); + + static std::string clean_json_string(const std::string& raw_json); + + static std::vector extract_json_blocks(const std::string& text); +}; + +} // namespace llm \ No newline at end of file diff --git a/xllm/core/framework/chat_template/jinja_chat_template.cpp b/xllm/core/framework/chat_template/jinja_chat_template.cpp index 119b7a2f..4dbd8069 100644 --- a/xllm/core/framework/chat_template/jinja_chat_template.cpp +++ b/xllm/core/framework/chat_template/jinja_chat_template.cpp @@ -23,6 +23,12 @@ JinjaChatTemplate::JinjaChatTemplate(const TokenizerArgs& args) : args_(args) { std::optional JinjaChatTemplate::apply( const ChatMessages& messages) const { + const std::vector empty_tools; + return apply(messages, empty_tools); +} + +std::optional JinjaChatTemplate::apply( + const ChatMessages& messages, const std::vector& tools) const { // convert the messages to json object nlohmann::ordered_json messages_json = nlohmann::json::array(); for (const auto& message : messages) { @@ -38,14 +44,37 @@ std::optional JinjaChatTemplate::apply( messages_json.push_back(message_json); } - // apply the template - return apply(messages_json); + + // convert tools to json object + nlohmann::ordered_json tools_json = nlohmann::json::array(); + if (!tools.empty()) { + try { + // Use ToolsConverter to convert tools to JSON string, then parse it + std::string tools_json_str = ToolsConverter::convert_tools_to_json(tools); + tools_json = nlohmann::json::parse(tools_json_str); + } catch (const std::exception& e) { + LOG(WARNING) << "Failed to convert tools to JSON: " << e.what(); + // Continue with empty tools array + } + } + + // apply the template with tools + return apply(messages_json, tools_json); } std::optional JinjaChatTemplate::apply( nlohmann::ordered_json& messages) const { + // Call the overloaded method with empty tools + nlohmann::ordered_json empty_tools = nlohmann::json::array(); + return apply(messages, empty_tools); +} + +std::optional JinjaChatTemplate::apply( + nlohmann::ordered_json& messages, + const nlohmann::ordered_json& tools) const { minja::chat_template_inputs input; input.messages = messages; + input.tools = tools; input.add_generation_prompt = true; minja::chat_template_options options; diff --git a/xllm/core/framework/request/request_params.h b/xllm/core/framework/request/request_params.h index 225ece91..19330523 100644 --- a/xllm/core/framework/request/request_params.h +++ b/xllm/core/framework/request/request_params.h @@ -100,6 +100,10 @@ struct RequestParams { // decode address. std::string decode_address; + + std::vector tools; + std::string tool_choice = "auto"; + bool has_tools() const { return !tools.empty(); } }; } // namespace xllm diff --git a/xllm/core/runtime/llm_master.cpp b/xllm/core/runtime/llm_master.cpp index ec2af81b..f9a0cdbb 100644 --- a/xllm/core/runtime/llm_master.cpp +++ b/xllm/core/runtime/llm_master.cpp @@ -422,7 +422,14 @@ std::shared_ptr LLMMaster::generate_request( const RequestParams& sp, OutputCallback callback) { Timer timer; - auto prompt = chat_template_->apply(messages); + std::optional prompt; + if (sp.has_tools()) { + auto tools = sp.tools; + prompt = chat_template_->apply(messages, tools); + } else { + prompt = chat_template_->apply(messages); + } + if (!prompt.has_value()) { CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "Failed to construct prompt from messages"); diff --git a/xllm/proto/chat.proto b/xllm/proto/chat.proto index e59ab705..0fbb820f 100644 --- a/xllm/proto/chat.proto +++ b/xllm/proto/chat.proto @@ -4,6 +4,19 @@ option go_package = "jd.com/jd-infer/xllm;xllm"; package xllm.proto; import "common.proto"; +import "google/protobuf/struct.proto"; + +message Function { + string name = 1; + string description = 2; + google.protobuf.Struct parameters = 3; + // string parameters = 3; +} + +message Tool { + string type = 1; // "function" + Function function = 2; +} message ChatMessage { // the role of the messages author. One of "system", "user", "assistant". @@ -107,6 +120,9 @@ message ChatRequest { optional string service_request_id = 25; Routing routing = 26; + + repeated Tool tools = 27; + optional string tool_choice = 28; } message ChatLogProbData { From ff3bc00bd098e2a88c0b82e284e3e4ea2be5dca7 Mon Sep 17 00:00:00 2001 From: dengyingxu1 Date: Wed, 6 Aug 2025 23:45:24 +0800 Subject: [PATCH 2/9] feat: add Qwen3 non-streaming tool call support. --- xllm/api_service/chat_service_impl.cpp | 244 +++++++++++++++- xllm/core/CMakeLists.txt | 1 + xllm/core/common/global_flags.cpp | 7 + xllm/core/common/global_flags.h | 2 + xllm/core/function_call/CMakeLists.txt | 15 + xllm/core/function_call/base_detector.cpp | 47 +++ xllm/core/function_call/base_detector.h | 99 +++++++ xllm/core/function_call/function_call.h | 69 +++++ .../function_call/function_call_parser.cpp | 267 ++++++++++++++++++ .../core/function_call/function_call_parser.h | 83 ++++++ xllm/core/function_call/qwen25_detector.h | 199 +++++++++++++ xllm/core/function_call/types.h | 160 +++++++++++ xllm/core/util/uuid.cpp | 10 + xllm/core/util/uuid.h | 4 + xllm/proto/chat.proto | 13 + xllm/xllm.cpp | 3 +- 16 files changed, 1214 insertions(+), 9 deletions(-) create mode 100644 xllm/core/function_call/CMakeLists.txt create mode 100644 xllm/core/function_call/base_detector.cpp create mode 100644 xllm/core/function_call/base_detector.h create mode 100644 xllm/core/function_call/function_call.h create mode 100644 xllm/core/function_call/function_call_parser.cpp create mode 100644 xllm/core/function_call/function_call_parser.h create mode 100644 xllm/core/function_call/qwen25_detector.h create mode 100644 xllm/core/function_call/types.h diff --git a/xllm/api_service/chat_service_impl.cpp b/xllm/api_service/chat_service_impl.cpp index aec356b5..6fdb97af 100644 --- a/xllm/api_service/chat_service_impl.cpp +++ b/xllm/api_service/chat_service_impl.cpp @@ -25,6 +25,7 @@ #include "request/request_params.h" #include "util/utils.h" #include +#include "function_call/function_call.h" namespace xllm { namespace { @@ -282,6 +283,193 @@ bool send_result_to_client_brpc(std::shared_ptr call, return call->write_and_finish(response); } +std::string generate_tool_call_id() { + return "call_" + llm::generate_uuid(); +} + +void convert_tool_calls_to_proto( + const std::vector& tool_calls, + proto::ChatMessage* message) { + + for (const auto& call : tool_calls) { + if (!call.is_valid()) { + LOG(WARNING) << "Invalid tool call: " << call.to_string(); + continue; + } + + auto* proto_tool_call = message->add_tool_calls(); + proto_tool_call->set_id(call.id.empty() ? generate_tool_call_id() : call.id); + proto_tool_call->set_type(call.type.empty() ? "function" : call.type); + + auto* function = proto_tool_call->mutable_function(); + function->set_name(call.function_name); + function->set_arguments(call.function_arguments); + + LOG(INFO) << "Converted tool call: " << call.function_name + << " with args: " << call.function_arguments; + } +} + + +bool handle_function_call_response( + const RequestOutput& req_output, + std::shared_ptr call_data, + const std::string& request_id, + int64_t created_time, + const std::string& model, + const std::string& parser_format) { + + auto& fc_interface = llm::function_call::FunctionCallInterface::getInstance(); + + fc_interface.setPreferredFormat(parser_format); + + auto parse_result = fc_interface.parse(req_output.outputs[0].text); + + if (parse_result.has_tool_calls()) { + auto& response = call_data->response(); + response.set_object("chat.completion"); + response.set_id(request_id); + response.set_created(created_time); + response.set_model(model); + + auto* choice = response.add_choices(); + choice->set_index(0); + choice->set_finish_reason("tool_calls"); + + auto* message = choice->mutable_message(); + message->set_role("assistant"); + + if (!parse_result.normal_text.empty()) { + std::string cleaned_text = parse_result.normal_text; + boost::algorithm::trim(cleaned_text); + if (!cleaned_text.empty()) { + message->set_content(cleaned_text); + } + } + + convert_tool_calls_to_proto(parse_result.tool_calls, message); + + if (req_output.usage.has_value()) { + const auto& usage = req_output.usage.value(); + auto* proto_usage = response.mutable_usage(); + proto_usage->set_prompt_tokens(static_cast(usage.num_prompt_tokens)); + proto_usage->set_completion_tokens(static_cast(usage.num_generated_tokens)); + proto_usage->set_total_tokens(static_cast(usage.num_total_tokens)); + } + + LOG(INFO) << "Function call detected: " << parse_result.tool_calls.size() << " calls"; + return call_data->write_and_finish(response); + + } else if (parse_result.has_error) { + LOG(WARNING) << "Function call parsing error: " << parse_result.error_message; + return send_result_to_client_brpc(call_data, request_id, created_time, model, req_output); + } else { + return send_result_to_client_brpc(call_data, request_id, created_time, model, req_output); + } +} + + +bool handle_streaming_function_calls( + const RequestOutput& req_output, + std::shared_ptr call_data, + std::unordered_set* first_message_sent, + const std::string& request_id, + int64_t created_time, + const std::string& model, + const std::string& parser_format, + bool include_usage) { + + auto& fc_interface = llm::function_call::FunctionCallInterface::getInstance(); + + fc_interface.setPreferredFormat(parser_format); + + auto stream_result = fc_interface.parseStreaming(req_output.outputs[0].text); + + auto& response = call_data->response(); + + for (const auto& completed_call : stream_result.completed_calls) { + const auto& index = req_output.outputs[0].index; + if (first_message_sent->find(index) == first_message_sent->end()) { + response.Clear(); + response.set_object("chat.completion.chunk"); + response.set_id(request_id); + response.set_created(created_time); + response.set_model(model); + auto* choice = response.add_choices(); + choice->set_index(index); + auto* message = choice->mutable_delta(); + message->set_role("assistant"); + first_message_sent->insert(index); + if (!call_data->write(response)) { + return false; + } + } + + response.Clear(); + response.set_object("chat.completion.chunk"); + response.set_id(request_id); + response.set_created(created_time); + response.set_model(model); + + auto* choice = response.add_choices(); + choice->set_index(index); + + auto* delta = choice->mutable_delta(); + auto* tool_call = delta->add_tool_calls(); + tool_call->set_id(completed_call.id.empty() ? generate_tool_call_id() : completed_call.id); + tool_call->set_type(completed_call.type.empty() ? "function" : completed_call.type); + + auto* function = tool_call->mutable_function(); + function->set_name(completed_call.function_name); + function->set_arguments(completed_call.function_arguments); + + if (!call_data->write(response)) { + return false; + } + } + + if (!stream_result.completed_calls.empty()) { + response.Clear(); + response.set_object("chat.completion.chunk"); + response.set_id(request_id); + response.set_created(created_time); + response.set_model(model); + + auto* choice = response.add_choices(); + choice->set_index(req_output.outputs[0].index); + choice->set_finish_reason("tool_calls"); + choice->mutable_delta(); + + if (!call_data->write(response)) { + return false; + } + } + + if (include_usage && req_output.usage.has_value()) { + response.Clear(); + const auto& usage = req_output.usage.value(); + response.set_object("chat.completion.chunk"); + response.set_id(request_id); + response.set_created(created_time); + response.set_model(model); + auto* proto_usage = response.mutable_usage(); + proto_usage->set_prompt_tokens(static_cast(usage.num_prompt_tokens)); + proto_usage->set_completion_tokens(static_cast(usage.num_generated_tokens)); + proto_usage->set_total_tokens(static_cast(usage.num_total_tokens)); + if (!call_data->write(response)) { + return false; + } + } + + if (req_output.finished || req_output.cancelled) { + response.Clear(); + return call_data->finish(); + } + + return true; +} + + } // namespace ChatServiceImpl::ChatServiceImpl(LLMMaster* master, @@ -452,7 +640,8 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { include_usage = include_usage, first_message_sent = std::unordered_set(), request_id = request_params.request_id, - created_time = absl::ToUnixSeconds(absl::Now())]( + created_time = absl::ToUnixSeconds(absl::Now()), + has_tools = request_params.has_tools()]( const RequestOutput& req_output) mutable -> bool { if (req_output.status.has_value()) { const auto& status = req_output.status.value(); @@ -471,15 +660,54 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { master->get_rate_limiter()->decrease_one_request(); } + std::string parser_format = master->options().tool_call_parser().value_or(""); if (stream) { - return send_delta_to_client_brpc(call, - include_usage, - &first_message_sent, - request_id, - created_time, - model, - req_output); + if (has_tools && !parser_format.empty()) { + LOG(ERROR) << "Tool call does not support streaming output"; + return send_delta_to_client_brpc(call_data, + include_usage, + &first_message_sent, + request_id, + created_time, + model, + req_output); + // return handle_streaming_function_calls( + // req_output, call_data, &first_message_sent, + // request_id, created_time, model, parser_format, include_usage); + } else { + // send delta to client + return send_delta_to_client_brpc(call_data, + include_usage, + &first_message_sent, + request_id, + created_time, + model, + req_output); + } } + + if (has_tools && !parser_format.empty()) { + //debug2 + auto& interface = llm::function_call::FunctionCallInterface::getInstance(); + + if (parser_format != "auto") { + interface.setPreferredFormat(parser_format); + } + auto result = interface.parse(req_output.outputs[0].text); + + std::cerr << "正常文本: " << result.normal_text << std::endl; + + for (const auto& call : result.tool_calls) { + std::cerr << "函数: " << call.function_name << std::endl; + std::cerr << "参数: " << call.function_arguments << std::endl; + } + } + if (has_tools && !parser_format.empty()) { + return handle_function_call_response( + req_output, call_data, + request_id, created_time, model, parser_format); + } + return send_result_to_client_brpc( call, request_id, created_time, model, req_output); }); diff --git a/xllm/core/CMakeLists.txt b/xllm/core/CMakeLists.txt index d0c8f682..ff88a505 100644 --- a/xllm/core/CMakeLists.txt +++ b/xllm/core/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(common) +add_subdirectory(function_call) add_subdirectory(distributed_runtime) add_subdirectory(framework) add_subdirectory(kernels) diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index e858a411..78a3fee7 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -174,6 +174,13 @@ DEFINE_bool(enable_atb_comm_multiprocess, false, "whether to use multiprocess mode."); +// --- function call config --- + +DEFINE_string(tool_call_parser, + "", + "Specify the parser for handling tool-call interactions. " + "Options include: 'qwen25'"); + DEFINE_bool(enable_atb_spec_kernel, false, "whether to use ATB speculative kernel."); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 67d903f2..d69b41cc 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -98,6 +98,8 @@ DECLARE_bool(disable_custom_kernels); DECLARE_bool(enable_atb_comm_multiprocess); +DECLARE_string(tool_call_parser); + DECLARE_bool(enable_atb_spec_kernel); DECLARE_string(etcd_addr); diff --git a/xllm/core/function_call/CMakeLists.txt b/xllm/core/function_call/CMakeLists.txt new file mode 100644 index 00000000..5bd6a6e4 --- /dev/null +++ b/xllm/core/function_call/CMakeLists.txt @@ -0,0 +1,15 @@ +cc_library( + NAME function_call + HDRS + types.h + base_detector.h + qwen25_detector.h + function_call_parser.h + function_call.h + SRCS + base_detector.cpp + function_call_parser.cpp + DEPS + nlohmann_json::nlohmann_json + glog::glog +) \ No newline at end of file diff --git a/xllm/core/function_call/base_detector.cpp b/xllm/core/function_call/base_detector.cpp new file mode 100644 index 00000000..bda66683 --- /dev/null +++ b/xllm/core/function_call/base_detector.cpp @@ -0,0 +1,47 @@ +#include "base_detector.h" +#include "qwen25_detector.h" +#include +#include + +namespace llm { +namespace function_call { + +std::unique_ptr DetectorFactory::create_detector(ModelFormat format) { + switch (format) { + case ModelFormat::QWEN25: + return std::make_unique(); + default: + return nullptr; + } +} + +std::vector> DetectorFactory::create_all_detectors() { + std::vector> detectors; + + detectors.push_back(std::make_unique()); + + return detectors; +} + +ModelFormat DetectorFactory::infer_format_from_model_name(const std::string& model_name) { + std::string lower_name = model_name; + std::transform(lower_name.begin(), lower_name.end(), lower_name.begin(), ::tolower); + + if (lower_name.find("qwen") != std::string::npos) { + return ModelFormat::QWEN25; + } + + return ModelFormat::UNKNOWN; +} + +std::string DetectorFactory::get_format_name(ModelFormat format) { + switch (format) { + case ModelFormat::QWEN25: + return "qwen25"; + default: + return "unknown"; + } +} + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/base_detector.h b/xllm/core/function_call/base_detector.h new file mode 100644 index 00000000..c934085a --- /dev/null +++ b/xllm/core/function_call/base_detector.h @@ -0,0 +1,99 @@ +#pragma once + +#include "types.h" +#include +#include +#include +#include "common/uuid.h" +namespace llm { +namespace function_call { + +class BaseFormatDetector { +public: + virtual ~BaseFormatDetector() = default; + + virtual bool detect(const std::string& text) const = 0; + + virtual FormatDetectionResult detect_format(const std::string& text) const = 0; + + virtual ParseResult parse_calls(const std::string& text) const = 0; + + virtual StreamingParseResult parse_streaming(const std::string& chunk) = 0; + + virtual void reset_streaming_state() = 0; + + virtual ModelFormat get_format() const = 0; + + virtual std::string get_format_name() const = 0; + + virtual StructureInfo get_structure_info() const = 0; + + virtual EBNFGrammar generate_ebnf_grammar(const ConstraintOptions& options = {}) const = 0; + + virtual bool validate_call_format(const ToolCallItem& call) const = 0; + +protected: + std::string generate_call_id() const { + return "call_" + llm::generate_uuid(); + } + + std::string clean_json_string(const std::string& json_str) const { + std::string cleaned = json_str; + size_t start = cleaned.find_first_not_of(" \t\n\r"); + if (start == std::string::npos) return ""; + size_t end = cleaned.find_last_not_of(" \t\n\r"); + cleaned = cleaned.substr(start, end - start + 1); + return cleaned; + } + + bool is_valid_json(const std::string& json_str) const { + try { + int brace_count = 0; + bool in_string = false; + bool escaped = false; + + for (char c : json_str) { + if (escaped) { + escaped = false; + continue; + } + + if (c == '\\') { + escaped = true; + continue; + } + + if (c == '"') { + in_string = !in_string; + continue; + } + + if (!in_string) { + if (c == '{') brace_count++; + else if (c == '}') brace_count--; + } + } + + return brace_count == 0 && !in_string; + } catch (...) { + return false; + } + } +}; + +class DetectorFactory { +public: + static std::unique_ptr create_detector(ModelFormat format); + + static std::vector> create_all_detectors(); + + static ModelFormat infer_format_from_model_name(const std::string& model_name); + + static std::string get_format_name(ModelFormat format); + +private: + DetectorFactory() = default; +}; + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/function_call.h b/xllm/core/function_call/function_call.h new file mode 100644 index 00000000..d789cc0b --- /dev/null +++ b/xllm/core/function_call/function_call.h @@ -0,0 +1,69 @@ +#pragma once + +#include "types.h" +#include "function_call_parser.h" + +namespace llm { +namespace function_call { + +class FunctionCallInterface { +public: + static FunctionCallInterface& getInstance() { + static FunctionCallInterface instance; + return instance; + } + + ParseResult parse(const std::string& text) { + return parser_.parse_auto(text); + } + + ParseResult parse(const std::string& text, const std::string& format) { + ModelFormat model_format = ModelFormat::UNKNOWN; + if (format == "qwen" || format == "qwen25") { + model_format = ModelFormat::QWEN25; + } + return parser_.parse_with_format(text, model_format); + } + + StreamingParseResult parseStreaming(const std::string& chunk) { + return parser_.parse_streaming_auto(chunk); + } + + bool hasFunction(const std::string& text) { + return utils::has_function_calls(text); + } + + std::string detectFormat(const std::string& text) { + return utils::detect_best_format(text); + } + + std::string generateConstraints(const std::vector& function_names, + const std::string& format = "auto") { + return utils::generate_ebnf_constraints(function_names, format); + } + + void setPreferredFormat(const std::string& model_name) { + parser_.set_preferred_format(model_name); + } + + void resetStreamingState() { + parser_.reset_all_streaming_states(); + } + +private: + FunctionCallInterface() = default; + FunctionCallParser parser_; +}; + +} // namespace function_call + +inline function_call::ParseResult parse_function_calls(const std::string& text) { + return function_call::FunctionCallInterface::getInstance().parse(text); +} + +inline bool has_function_calls(const std::string& text) { + auto result = function_call::FunctionCallInterface::getInstance().parse(text); + return result.has_tool_calls(); +} + +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/function_call_parser.cpp b/xllm/core/function_call/function_call_parser.cpp new file mode 100644 index 00000000..1a419311 --- /dev/null +++ b/xllm/core/function_call/function_call_parser.cpp @@ -0,0 +1,267 @@ +#include "function_call_parser.h" +#include "base_detector.h" +#include +#include + +namespace llm { +namespace function_call { + +FunctionCallParser::FunctionCallParser() : preferred_format_(ModelFormat::UNKNOWN) { + initialize_detectors(); +} + +void FunctionCallParser::initialize_detectors() { + detectors_ = DetectorFactory::create_all_detectors(); + + for (auto& detector : detectors_) { + if (detector) { + format_detectors_[detector->get_format()] = std::move(detector); + } + } + detectors_.clear(); +} + +void FunctionCallParser::set_preferred_format(ModelFormat format) { + preferred_format_ = format; +} + +void FunctionCallParser::set_preferred_format(const std::string& model_name) { + preferred_format_ = infer_format_from_model_name(model_name); +} + +ParseResult FunctionCallParser::parse_auto(const std::string& text) { + if (preferred_format_ != ModelFormat::UNKNOWN) { + auto detector = get_detector(preferred_format_); + if (detector && detector->detect(text)) { + return detector->parse_calls(text); + } + } + + for (auto& [format, detector] : format_detectors_) { + if (detector && detector->detect(text)) { + return detector->parse_calls(text); + } + } + + return {}; +} + +ParseResult FunctionCallParser::parse_with_format(const std::string& text, ModelFormat format) { + auto detector = get_detector(format); + if (!detector) { + LOG(WARNING) << "Unsupported format: " << static_cast(format); + return {}; + } + + return detector->parse_calls(text); +} + +StreamingParseResult FunctionCallParser::parse_streaming_auto(const std::string& chunk) { + if (preferred_format_ != ModelFormat::UNKNOWN) { + auto detector = get_detector(preferred_format_); + if (detector) { + return detector->parse_streaming(chunk); + } + } + + for (auto& [format, detector] : format_detectors_) { + if (detector) { + auto result = detector->parse_streaming(chunk); + if (result.has_completed_calls() || result.has_partial_call()) { + return result; + } + } + } + + StreamingParseResult empty_result; + return empty_result; +} + +StreamingParseResult FunctionCallParser::parse_streaming_with_format(const std::string& chunk, ModelFormat format) { + auto detector = get_detector(format); + if (!detector) { + StreamingParseResult result; + result.has_error = true; + result.error_message = "Unsupported format"; + return result; + } + + return detector->parse_streaming(chunk); +} + +std::vector FunctionCallParser::detect_formats(const std::string& text) { + std::vector results; + + for (auto& [format, detector] : format_detectors_) { + if (detector) { + auto result = detector->detect_format(text); + results.push_back(result); + } + } + + std::sort(results.begin(), results.end(), + [](const FormatDetectionResult& a, const FormatDetectionResult& b) { + return a.confidence > b.confidence; + }); + + return results; +} + +FormatDetectionResult FunctionCallParser::get_best_format(const std::string& text) { + auto results = detect_formats(text); + if (!results.empty()) { + return results[0]; + } + + FormatDetectionResult empty_result; + return empty_result; +} + +bool FunctionCallParser::validate_calls(const std::vector& calls, ModelFormat format) { + auto detector = get_detector(format); + if (!detector) { + return false; + } + + for (const auto& call : calls) { + if (!detector->validate_call_format(call)) { + return false; + } + } + + return true; +} + +std::string FunctionCallParser::generate_constraints(const std::vector& function_names, + ModelFormat format, + const ConstraintOptions& options) { + ModelFormat target_format = format; + if (target_format == ModelFormat::UNKNOWN) { + target_format = preferred_format_; + } + if (target_format == ModelFormat::UNKNOWN) { + target_format = ModelFormat::QWEN25; // 默认格式 + } + + auto detector = get_detector(target_format); + if (!detector) { + return ""; + } + + ConstraintOptions modified_options = options; + modified_options.allowed_functions = function_names; + + auto grammar = detector->generate_ebnf_grammar(modified_options); + return grammar.to_string(); +} + +void FunctionCallParser::reset_all_streaming_states() { + for (auto& [format, detector] : format_detectors_) { + if (detector) { + detector->reset_streaming_state(); + } + } +} + +void FunctionCallParser::reset_streaming_state(ModelFormat format) { + auto detector = get_detector(format); + if (detector) { + detector->reset_streaming_state(); + } +} + +std::vector FunctionCallParser::get_supported_formats() const { + std::vector formats; + for (const auto& [format, detector] : format_detectors_) { + if (detector) { + formats.push_back(format); + } + } + return formats; +} + +std::string FunctionCallParser::get_format_name(ModelFormat format) const { + return DetectorFactory::get_format_name(format); +} + +bool FunctionCallParser::is_format_supported(ModelFormat format) const { + return format_detectors_.find(format) != format_detectors_.end(); +} + +BaseFormatDetector* FunctionCallParser::get_detector(ModelFormat format) { + auto it = format_detectors_.find(format); + if (it != format_detectors_.end()) { + return it->second.get(); + } + return nullptr; +} + +ModelFormat FunctionCallParser::infer_format_from_model_name(const std::string& model_name) { + return DetectorFactory::infer_format_from_model_name(model_name); +} + +namespace utils { + +std::vector parse_function_calls(const std::string& text) { + static FunctionCallParser parser; + return parser.parse_auto(text).tool_calls; +} + +std::vector parse_function_calls(const std::string& text, const std::string& format) { + static FunctionCallParser parser; + + ModelFormat model_format = ModelFormat::UNKNOWN; + if (format == "qwen25" || format == "qwen") { + model_format = ModelFormat::QWEN25; + } + + if (model_format == ModelFormat::UNKNOWN) { + return parser.parse_auto(text).tool_calls; + } + + return parser.parse_with_format(text, model_format).tool_calls; +} + +bool has_function_calls(const std::string& text) { + static FunctionCallParser parser; + auto calls = parser.parse_auto(text); + return calls.has_tool_calls(); +} + +std::string detect_best_format(const std::string& text) { + static FunctionCallParser parser; + auto result = parser.get_best_format(text); + return parser.get_format_name(result.format); +} + +std::string generate_ebnf_constraints(const std::vector& function_names, + const std::string& format) { + static FunctionCallParser parser; + + ModelFormat model_format = ModelFormat::UNKNOWN; + if (format == "qwen25" || format == "qwen") { + model_format = ModelFormat::QWEN25; + } + + return parser.generate_constraints(function_names, model_format); +} + +bool validate_function_call_format(const ToolCallItem& call, const std::string& format) { + static FunctionCallParser parser; + + ModelFormat model_format = ModelFormat::UNKNOWN; + if (format == "qwen25" || format == "qwen") { + model_format = ModelFormat::QWEN25; + } + + if (model_format == ModelFormat::UNKNOWN) { + return false; + } + + return parser.validate_calls({call}, model_format); +} + +} // namespace utils + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/function_call_parser.h b/xllm/core/function_call/function_call_parser.h new file mode 100644 index 00000000..67f9d50c --- /dev/null +++ b/xllm/core/function_call/function_call_parser.h @@ -0,0 +1,83 @@ +#pragma once + +#include "types.h" +#include "base_detector.h" +#include +#include +#include +#include + +namespace llm { +namespace function_call { + +class FunctionCallParser { +private: + std::vector> detectors_; + std::unordered_map> format_detectors_; + ModelFormat preferred_format_; + +public: + FunctionCallParser(); + ~FunctionCallParser() = default; + + FunctionCallParser(const FunctionCallParser&) = delete; + FunctionCallParser& operator=(const FunctionCallParser&) = delete; + + void set_preferred_format(ModelFormat format); + void set_preferred_format(const std::string& model_name); + + ParseResult parse_auto(const std::string& text); + + ParseResult parse_with_format(const std::string& text, ModelFormat format); + + StreamingParseResult parse_streaming_auto(const std::string& chunk); + + StreamingParseResult parse_streaming_with_format(const std::string& chunk, ModelFormat format); + + std::vector detect_formats(const std::string& text); + + FormatDetectionResult get_best_format(const std::string& text); + + bool validate_calls(const std::vector& calls, ModelFormat format); + + std::string generate_constraints(const std::vector& function_names, + ModelFormat format = ModelFormat::UNKNOWN, + const ConstraintOptions& options = {}); + + void reset_all_streaming_states(); + + void reset_streaming_state(ModelFormat format); + + std::vector get_supported_formats() const; + + std::string get_format_name(ModelFormat format) const; + + bool is_format_supported(ModelFormat format) const; + +private: + void initialize_detectors(); + + BaseFormatDetector* get_detector(ModelFormat format); + + ModelFormat infer_format_from_model_name(const std::string& model_name); +}; + +namespace utils { + +std::vector parse_function_calls(const std::string& text); + +std::vector parse_function_calls(const std::string& text, const std::string& format); + +bool has_function_calls(const std::string& text); + +std::string detect_best_format(const std::string& text); + +std::string generate_ebnf_constraints(const std::vector& function_names, + const std::string& format = "auto"); + +bool validate_function_call_format(const ToolCallItem& call, const std::string& format); + +} // namespace utils + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/qwen25_detector.h b/xllm/core/function_call/qwen25_detector.h new file mode 100644 index 00000000..ac0c81ea --- /dev/null +++ b/xllm/core/function_call/qwen25_detector.h @@ -0,0 +1,199 @@ +#pragma once + +#include "base_detector.h" +#include +#include + +namespace llm { +namespace function_call { + +class Qwen25Detector : public BaseFormatDetector { +private: + std::regex start_pattern_; + std::regex end_pattern_; + std::regex full_pattern_; + std::string buffer_; + +public: + Qwen25Detector() + : start_pattern_(R"()"), + end_pattern_(R"()"), + full_pattern_(R"(\s*(\{.*?\})\s*)", + std::regex_constants::ECMAScript) {} + + ~Qwen25Detector() override = default; + + bool detect(const std::string& text) const override { + return std::regex_search(text, start_pattern_) && + std::regex_search(text, end_pattern_); + } + + FormatDetectionResult detect_format(const std::string& text) const override { + FormatDetectionResult result; + result.format = ModelFormat::QWEN25; + + bool has_start = std::regex_search(text, start_pattern_); + bool has_end = std::regex_search(text, end_pattern_); + + if (has_start && has_end) { + result.confidence = 0.95; + result.reason = "Found complete tags"; + } else if (has_start) { + result.confidence = 0.7; + result.reason = "Found opening tag"; + } else { + result.confidence = 0.0; + result.reason = "No Qwen2.5 format markers found"; + } + + result.structure_info = get_structure_info(); + return result; + } + + ParseResult parse_calls(const std::string& text) const override { + ParseResult result; + std::string normal_text = text; + + std::sregex_iterator iter(text.begin(), text.end(), full_pattern_); + std::sregex_iterator end; + + std::vector> match_positions; + + for (; iter != end; ++iter) { + const std::smatch& match = *iter; + std::string json_content = match[1].str(); + + match_positions.push_back({match.position(), match.length()}); + + ToolCallItem call; + call.id = generate_call_id(); + call.type = "function"; + + if (parse_json_content(json_content, call)) { + call.state = ParseState::COMPLETED; + result.tool_calls.push_back(call); + } else { + call.state = ParseState::ERROR; + call.error = "Failed to parse JSON content: " + json_content; + result.tool_calls.push_back(call); + } + } + + for (auto it = match_positions.rbegin(); it != match_positions.rend(); ++it) { + normal_text.erase(it->first, it->second); + } + + result.normal_text = normal_text; + return result; + } + + StreamingParseResult parse_streaming(const std::string& chunk) override { + buffer_ += chunk; + StreamingParseResult result; + + auto completed_result = parse_calls(buffer_); + result.completed_calls = completed_result.tool_calls; + result.normal_text = completed_result.normal_text; + + if (std::regex_search(buffer_, start_pattern_) && + !std::regex_search(buffer_, end_pattern_)) { + ToolCallItem partial; + partial.id = "partial_call"; + partial.type = "function"; + partial.state = ParseState::PARSING; + result.partial_call = partial; + } + + if (!completed_result.tool_calls.empty()) { + size_t last_end = buffer_.rfind(""); + if (last_end != std::string::npos) { + last_end += 12; // Length of "" + result.remaining_text = buffer_.substr(last_end); + buffer_ = result.remaining_text; + } + } + + return result; + } + + void reset_streaming_state() override { + buffer_.clear(); + } + + ModelFormat get_format() const override { + return ModelFormat::QWEN25; + } + + std::string get_format_name() const override { + return "qwen25"; + } + + StructureInfo get_structure_info() const override { + StructureInfo info("qwen25", "", ""); + info.patterns["function_call"] = R"(\s*\{.*?\}\s*)"; + info.patterns["json_content"] = R"(\{.*?\})"; + return info; + } + + EBNFGrammar generate_ebnf_grammar(const ConstraintOptions& options) const override { + EBNFGrammar grammar; + grammar.start_rule = "tool_calls"; + + if (options.allow_multiple_calls) { + grammar.add_rule(EBNFRule("tool_calls", "tool_call+")); + } else { + grammar.add_rule(EBNFRule("tool_calls", "tool_call")); + } + + grammar.add_rule(EBNFRule("tool_call", "\"\" ws json_object ws \"\"")); + + grammar.add_rule(EBNFRule("json_object", "\"{\" ws json_members ws \"}\"")); + grammar.add_rule(EBNFRule("json_members", "json_member (\",\" ws json_member)*")); + grammar.add_rule(EBNFRule("json_member", "json_string \":\" ws json_value")); + + grammar.add_rule(EBNFRule("json_value", "json_string | json_object | json_array")); + grammar.add_rule(EBNFRule("json_string", "\"\\\"\" [^\"\\\\]* \"\\\"\"")); + grammar.add_rule(EBNFRule("json_array", "\"[\" ws (json_value (\",\" ws json_value)*)? ws \"]\"")); + + grammar.add_rule(EBNFRule("ws", "[ \\t\\n\\r]*", true)); + + return grammar; + } + + bool validate_call_format(const ToolCallItem& call) const override { + if (call.function_name.empty()) return false; + if (call.function_arguments.empty()) return false; + return is_valid_json(call.function_arguments); + } + +private: + bool parse_json_content(const std::string& json_str, ToolCallItem& call) const { + try { + auto json_obj = nlohmann::json::parse(clean_json_string(json_str)); + + if (json_obj.contains("name") && json_obj["name"].is_string()) { + call.function_name = json_obj["name"].get(); + } else { + return false; + } + + if (json_obj.contains("arguments")) { + if (json_obj["arguments"].is_string()) { + call.function_arguments = json_obj["arguments"].get(); + } else { + call.function_arguments = json_obj["arguments"].dump(); + } + } else { + call.function_arguments = "{}"; + } + + return true; + } catch (const nlohmann::json::exception& e) { + call.error = "JSON parse error: " + std::string(e.what()); + return false; + } + } +}; + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/types.h b/xllm/core/function_call/types.h new file mode 100644 index 00000000..9b14857e --- /dev/null +++ b/xllm/core/function_call/types.h @@ -0,0 +1,160 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace llm { +namespace function_call { + +enum class ParseState { + PENDING, + PARSING, + COMPLETED, + ERROR +}; + +struct ToolCallItem { + std::string id; + std::string type; + std::string function_name; + std::string function_arguments; + ParseState state = ParseState::PENDING; + std::optional error; + + ToolCallItem() = default; + ToolCallItem(const std::string& id, const std::string& type, + const std::string& name, const std::string& args) + : id(id), type(type), function_name(name), function_arguments(args) {} + + bool is_valid() const { + return state == ParseState::COMPLETED && + !function_name.empty() && + !function_arguments.empty() && + !error.has_value(); + } + + std::string to_string() const { + std::string result = "ToolCallItem{"; + result += "id='" + id + "', "; + result += "type='" + type + "', "; + result += "function_name='" + function_name + "', "; + result += "arguments='" + function_arguments + "', "; + result += "state=" + std::to_string(static_cast(state)); + if (error.has_value()) { + result += ", error='" + error.value() + "'"; + } + result += "}"; + return result; + } +}; + +struct StreamingParseResult { + std::string normal_text; + std::vector completed_calls; + std::optional partial_call; + std::string remaining_text; + bool has_error = false; + std::string error_message; + + bool has_completed_calls() const { + return !completed_calls.empty(); + } + + bool has_partial_call() const { + return partial_call.has_value(); + } + + void clear() { + completed_calls.clear(); + partial_call.reset(); + remaining_text.clear(); + has_error = false; + error_message.clear(); + } +}; +struct ParseResult { + std::string normal_text; + std::vector tool_calls; + bool has_error = false; + std::string error_message; + + bool has_tool_calls() const { + return !tool_calls.empty(); + } + + void clear() { + normal_text.clear(); + tool_calls.clear(); + has_error = false; + error_message.clear(); + } +}; + +struct StructureInfo { + std::string format_name; + std::string start_marker; + std::string end_marker; + std::unordered_map patterns; + + StructureInfo() = default; + StructureInfo(const std::string& name, const std::string& start, const std::string& end) + : format_name(name), start_marker(start), end_marker(end) {} +}; + +enum class ModelFormat { + QWEN25, + UNKNOWN +}; + +struct FormatDetectionResult { + ModelFormat format = ModelFormat::UNKNOWN; + double confidence = 0.0; + std::string reason; + StructureInfo structure_info; + + bool is_valid() const { + return format != ModelFormat::UNKNOWN && confidence > 0.5; + } +}; + +struct EBNFRule { + std::string name; + std::string definition; + bool is_terminal = false; + + EBNFRule() = default; + EBNFRule(const std::string& name, const std::string& def, bool terminal = false) + : name(name), definition(def), is_terminal(terminal) {} +}; + +struct EBNFGrammar { + std::vector rules; + std::string start_rule; + + void add_rule(const EBNFRule& rule) { + rules.push_back(rule); + } + + std::string to_string() const { + std::string result; + for (const auto& rule : rules) { + result += rule.name + " ::= " + rule.definition + "\n"; + } + return result; + } +}; + +struct ConstraintOptions { + bool allow_multiple_calls = true; + bool require_arguments = true; + bool strict_json = true; + std::vector allowed_functions; + + ConstraintOptions() = default; +}; + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/util/uuid.cpp b/xllm/core/util/uuid.cpp index 18c917fb..c86fa969 100644 --- a/xllm/core/util/uuid.cpp +++ b/xllm/core/util/uuid.cpp @@ -18,4 +18,14 @@ std::string ShortUUID::random(size_t len) { return uuid; } +std::string generate_uuid(size_t len) { + static thread_local ShortUUID uuid_generator; + return uuid_generator.random(len); +} + +std::string generate_uuid(size_t len) { + static thread_local ShortUUID uuid_generator; + return uuid_generator.random(len); +} + } // namespace xllm \ No newline at end of file diff --git a/xllm/core/util/uuid.h b/xllm/core/util/uuid.h index 0c2604be..42c647c5 100644 --- a/xllm/core/util/uuid.h +++ b/xllm/core/util/uuid.h @@ -18,4 +18,8 @@ class ShortUUID { absl::BitGen gen_; }; +std::string generate_uuid(size_t len = 22); + +std::string generate_uuid(size_t len = 22); + } // namespace xllm \ No newline at end of file diff --git a/xllm/proto/chat.proto b/xllm/proto/chat.proto index 0fbb820f..6329f9dc 100644 --- a/xllm/proto/chat.proto +++ b/xllm/proto/chat.proto @@ -18,6 +18,17 @@ message Tool { Function function = 2; } +message ToolCall { + string id = 1; + string type = 2; // "function" + FunctionCall function = 3; +} + +message FunctionCall { + string name = 1; + string arguments = 2; // JSON string +} + message ChatMessage { // the role of the messages author. One of "system", "user", "assistant". optional string role = 1; @@ -30,6 +41,8 @@ message ChatMessage { // TODO: add function call support // FunctionCall function_call = 4; + repeated ToolCall tool_calls = 3; + optional string tool_call_id = 4; } diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index 018d15eb..1868d9bc 100644 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -110,7 +110,8 @@ int run() { .enable_schedule_overlap(FLAGS_enable_schedule_overlap) .kv_cache_transfer_mode(FLAGS_kv_cache_transfer_mode) .etcd_addr(FLAGS_etcd_addr) - .enable_service_routing(FLAGS_enable_service_routing); + .enable_service_routing(FLAGS_enable_service_routing) + .tool_call_parser(FLAGS_tool_call_parser); InstanceName::name()->set_name(options.instance_name().value_or("")); From 66f405883590ff70eb6de6f09e34ff32a4ac52a6 Mon Sep 17 00:00:00 2001 From: dengyingxu1 Date: Sun, 10 Aug 2025 18:01:40 +0800 Subject: [PATCH 3/9] refactor: optimize tool parameter parsing implementation. --- xllm/api_service/api_service.cpp | 5 +- xllm/api_service/chat_service_impl.cpp | 512 ++++++------------ xllm/core/chat_template/tools_converter.cpp | 284 ---------- xllm/core/chat_template/tools_converter.h | 46 -- .../chat_template/jinja_chat_template.cpp | 125 ++++- xllm/core/framework/request/request_params.h | 8 +- xllm/core/runtime/llm_master.cpp | 5 +- xllm/proto/chat.proto | 7 +- 8 files changed, 301 insertions(+), 691 deletions(-) delete mode 100644 xllm/core/chat_template/tools_converter.cpp delete mode 100644 xllm/core/chat_template/tools_converter.h diff --git a/xllm/api_service/api_service.cpp b/xllm/api_service/api_service.cpp index feb2295c..6539e3af 100644 --- a/xllm/api_service/api_service.cpp +++ b/xllm/api_service/api_service.cpp @@ -1,6 +1,7 @@ #include "api_service.h" #include +#include #include #include @@ -71,7 +72,6 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller, auto ctrl = reinterpret_cast(controller); std::string attachment = std::move(ctrl->request_attachment().to_string()); - LOG(INFO) << "attachment:" << attachment; std::string error; auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); if (!st) { @@ -175,7 +175,6 @@ void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller, auto ctrl = reinterpret_cast(controller); std::string attachment = std::move(ctrl->request_attachment().to_string()); - LOG(INFO) << "attachment:" << attachment; std::string error; auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); if (!st) { @@ -289,7 +288,6 @@ void APIService::LinkCluster(::google::protobuf::RpcController* controller, auto ctrl = reinterpret_cast(controller); std::string attachment = std::move(ctrl->request_attachment().to_string()); - LOG(INFO) << "attachment:" << attachment; std::string error; auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); if (!st) { @@ -344,7 +342,6 @@ void APIService::UnlinkCluster(::google::protobuf::RpcController* controller, auto ctrl = reinterpret_cast(controller); std::string attachment = std::move(ctrl->request_attachment().to_string()); - LOG(INFO) << "attachment:" << attachment; std::string error; auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); if (!st) { diff --git a/xllm/api_service/chat_service_impl.cpp b/xllm/api_service/chat_service_impl.cpp index 6fdb97af..842575d6 100644 --- a/xllm/api_service/chat_service_impl.cpp +++ b/xllm/api_service/chat_service_impl.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -19,17 +20,38 @@ #include "core/util/utils.h" #include "core/util/uuid.h" #include "chat_template/chat_template.h" -#include "chat_template/tools_converter.h" #include "common/instance_name.h" #include "common/uuid.h" +#include "function_call/function_call.h" #include "request/request_params.h" #include "util/utils.h" -#include -#include "function_call/function_call.h" namespace xllm { namespace { + +std::string generate_tool_call_id() { return "call_" + llm::generate_uuid(); } + +void convert_tool_calls_to_proto( + const std::vector& tool_calls, + proto::ChatMessage* message) { + for (const auto& call : tool_calls) { + if (!call.is_valid()) { + LOG(WARNING) << "Invalid tool call: " << call.to_string(); + continue; + } + + auto* proto_tool_call = message->add_tool_calls(); + proto_tool_call->set_id(call.id.empty() ? generate_tool_call_id() + : call.id); + proto_tool_call->set_type(call.type.empty() ? "function" : call.type); + + auto* function = proto_tool_call->mutable_function(); + function->set_name(call.function_name); + function->set_arguments(call.function_arguments); + } +} + void set_logprobs(proto::ChatChoice* choice, const std::optional>& logprobs) { if (!logprobs.has_value() || logprobs.value().empty()) { @@ -54,108 +76,6 @@ void set_logprobs(proto::ChatChoice* choice, } } } -struct ToolsInfo { - std::vector tools; - std::string tool_choice; - bool has_tools = false; -}; - -ToolsInfo extract_tools_info(const proto::ChatRequest& request) { - ToolsInfo info; - - if (request.tools_size() > 0) { - info.has_tools = true; - info.tools.reserve(request.tools_size()); - - for (const auto& proto_tool : request.tools()) { - Tool tool; - tool.type = proto_tool.type(); - tool.function.name = proto_tool.function().name(); - tool.function.description = proto_tool.function().description(); - // tool.function.parameters = proto_tool.function().parameters(); - // std::cerr << "proto_tool.function().parameters():" << proto_tool.function().parameters() << std::endl; - - std::string parameters_json_str; - if (proto_tool.function().has_parameters()) { - google::protobuf::util::JsonPrintOptions options; - options.add_whitespace = false; - options.preserve_proto_field_names = true; - auto status = google::protobuf::util::MessageToJsonString( - proto_tool.function().parameters(), ¶meters_json_str, options); - if (!status.ok()) { - LOG(WARNING) << "Failed to convert parameters Struct to JSON: " - << status.message() << ", tool: " << tool.function.name; - parameters_json_str = "{}"; - } - } else { - parameters_json_str = "{}"; - } - tool.function.parameters = parameters_json_str; - std::cerr << "parameters_json_str:" << parameters_json_str << std::endl; - - info.tools.push_back(std::move(tool)); - } - - if (request.has_tool_choice()) { - info.tool_choice = request.tool_choice(); - } else { - info.tool_choice = "auto"; - } - } - - return info; -} -struct ToolsInfo { - std::vector tools; - std::string tool_choice; - bool has_tools = false; -}; - -ToolsInfo extract_tools_info(const proto::ChatRequest& request) { - ToolsInfo info; - - if (request.tools_size() > 0) { - info.has_tools = true; - info.tools.reserve(request.tools_size()); - - for (const auto& proto_tool : request.tools()) { - Tool tool; - tool.type = proto_tool.type(); - tool.function.name = proto_tool.function().name(); - tool.function.description = proto_tool.function().description(); - // tool.function.parameters = proto_tool.function().parameters(); - // std::cerr << "proto_tool.function().parameters():" << proto_tool.function().parameters() << std::endl; - - std::string parameters_json_str; - if (proto_tool.function().has_parameters()) { - google::protobuf::util::JsonPrintOptions options; - options.add_whitespace = false; - options.preserve_proto_field_names = true; - auto status = google::protobuf::util::MessageToJsonString( - proto_tool.function().parameters(), ¶meters_json_str, options); - if (!status.ok()) { - LOG(WARNING) << "Failed to convert parameters Struct to JSON: " - << status.message() << ", tool: " << tool.function.name; - parameters_json_str = "{}"; - } - } else { - parameters_json_str = "{}"; - } - tool.function.parameters = parameters_json_str; - std::cerr << "parameters_json_str:" << parameters_json_str << std::endl; - - info.tools.push_back(std::move(tool)); - } - - if (request.has_tool_choice()) { - info.tool_choice = request.tool_choice(); - } else { - info.tool_choice = "auto"; - } - } - - return info; -} template bool send_delta_to_client_brpc(std::shared_ptr call, @@ -250,7 +170,9 @@ bool send_result_to_client_brpc(std::shared_ptr call, const std::string& request_id, int64_t created_time, const std::string& model, - const RequestOutput& req_output) { + const RequestOutput& req_output, + bool has_tools = false, + const std::string& parser_format = "") { auto& response = call->response(); response.set_object("chat.completion"); response.set_id(request_id); @@ -264,9 +186,35 @@ bool send_result_to_client_brpc(std::shared_ptr call, set_logprobs(choice, output.logprobs); auto* message = choice->mutable_message(); message->set_role("assistant"); - message->set_content(output.text); - if (output.finish_reason.has_value()) { - choice->set_finish_reason(output.finish_reason.value()); + + auto setOutputAndFinishReason = [&]() { + message->set_content(output.text); + if (output.finish_reason.has_value()) { + choice->set_finish_reason(output.finish_reason.value()); + } + }; + + if (has_tools && !parser_format.empty()) { + auto& fc_interface = + llm::function_call::FunctionCallInterface::getInstance(); + fc_interface.setPreferredFormat(parser_format); + auto parse_result = fc_interface.parse(output.text); + if (parse_result.has_tool_calls()) { + choice->set_finish_reason("tool_calls"); + if (!parse_result.normal_text.empty()) { + std::string cleaned_text = parse_result.normal_text; + boost::algorithm::trim(cleaned_text); + if (!cleaned_text.empty()) { + message->set_content(cleaned_text); + } + } + convert_tool_calls_to_proto(parse_result.tool_calls, message); + } else { + LOG(WARNING) << "Function call parsing error."; + setOutputAndFinishReason(); + } + } else { + setOutputAndFinishReason(); } } @@ -283,190 +231,96 @@ bool send_result_to_client_brpc(std::shared_ptr call, return call->write_and_finish(response); } -std::string generate_tool_call_id() { - return "call_" + llm::generate_uuid(); -} - -void convert_tool_calls_to_proto( - const std::vector& tool_calls, - proto::ChatMessage* message) { - - for (const auto& call : tool_calls) { - if (!call.is_valid()) { - LOG(WARNING) << "Invalid tool call: " << call.to_string(); - continue; - } - - auto* proto_tool_call = message->add_tool_calls(); - proto_tool_call->set_id(call.id.empty() ? generate_tool_call_id() : call.id); - proto_tool_call->set_type(call.type.empty() ? "function" : call.type); - - auto* function = proto_tool_call->mutable_function(); - function->set_name(call.function_name); - function->set_arguments(call.function_arguments); - - LOG(INFO) << "Converted tool call: " << call.function_name - << " with args: " << call.function_arguments; - } -} +} // namespace +ChatServiceImpl::ChatServiceImpl(LLMMaster* master, + const std::vector& models) + : APIServiceImpl(master, models) {} -bool handle_function_call_response( - const RequestOutput& req_output, - std::shared_ptr call_data, - const std::string& request_id, - int64_t created_time, - const std::string& model, - const std::string& parser_format) { - - auto& fc_interface = llm::function_call::FunctionCallInterface::getInstance(); - - fc_interface.setPreferredFormat(parser_format); - - auto parse_result = fc_interface.parse(req_output.outputs[0].text); - - if (parse_result.has_tool_calls()) { - auto& response = call_data->response(); - response.set_object("chat.completion"); - response.set_id(request_id); - response.set_created(created_time); - response.set_model(model); - - auto* choice = response.add_choices(); - choice->set_index(0); - choice->set_finish_reason("tool_calls"); - - auto* message = choice->mutable_message(); - message->set_role("assistant"); - - if (!parse_result.normal_text.empty()) { - std::string cleaned_text = parse_result.normal_text; - boost::algorithm::trim(cleaned_text); - if (!cleaned_text.empty()) { - message->set_content(cleaned_text); - } - } - - convert_tool_calls_to_proto(parse_result.tool_calls, message); - - if (req_output.usage.has_value()) { - const auto& usage = req_output.usage.value(); - auto* proto_usage = response.mutable_usage(); - proto_usage->set_prompt_tokens(static_cast(usage.num_prompt_tokens)); - proto_usage->set_completion_tokens(static_cast(usage.num_generated_tokens)); - proto_usage->set_total_tokens(static_cast(usage.num_total_tokens)); - } - - LOG(INFO) << "Function call detected: " << parse_result.tool_calls.size() << " calls"; - return call_data->write_and_finish(response); - - } else if (parse_result.has_error) { - LOG(WARNING) << "Function call parsing error: " << parse_result.error_message; - return send_result_to_client_brpc(call_data, request_id, created_time, model, req_output); - } else { - return send_result_to_client_brpc(call_data, request_id, created_time, model, req_output); +// chat_async for brpc +void ChatServiceImpl::process_async_impl(std::shared_ptr call) { + const auto& rpc_request = call->request(); + // check if model is supported + const auto& model = rpc_request.model(); + if (!models_.contains(model)) { + call->finish_with_error(StatusCode::UNKNOWN, "Model not supported"); + return; } -} + // Check if the request is being rate-limited. + if (master_->get_rate_limiter()->is_limited()) { + call->finish_with_error( + StatusCode::RESOURCE_EXHAUSTED, + "The number of concurrent requests has reached the limit."); + return; + } -bool handle_streaming_function_calls( - const RequestOutput& req_output, - std::shared_ptr call_data, - std::unordered_set* first_message_sent, - const std::string& request_id, - int64_t created_time, - const std::string& model, - const std::string& parser_format, - bool include_usage) { - - auto& fc_interface = llm::function_call::FunctionCallInterface::getInstance(); - - fc_interface.setPreferredFormat(parser_format); - - auto stream_result = fc_interface.parseStreaming(req_output.outputs[0].text); - - auto& response = call_data->response(); - - for (const auto& completed_call : stream_result.completed_calls) { - const auto& index = req_output.outputs[0].index; - if (first_message_sent->find(index) == first_message_sent->end()) { - response.Clear(); - response.set_object("chat.completion.chunk"); - response.set_id(request_id); - response.set_created(created_time); - response.set_model(model); - auto* choice = response.add_choices(); - choice->set_index(index); - auto* message = choice->mutable_delta(); - message->set_role("assistant"); - first_message_sent->insert(index); - if (!call_data->write(response)) { - return false; - } - } - - response.Clear(); - response.set_object("chat.completion.chunk"); - response.set_id(request_id); - response.set_created(created_time); - response.set_model(model); - - auto* choice = response.add_choices(); - choice->set_index(index); - - auto* delta = choice->mutable_delta(); - auto* tool_call = delta->add_tool_calls(); - tool_call->set_id(completed_call.id.empty() ? generate_tool_call_id() : completed_call.id); - tool_call->set_type(completed_call.type.empty() ? "function" : completed_call.type); - - auto* function = tool_call->mutable_function(); - function->set_name(completed_call.function_name); - function->set_arguments(completed_call.function_arguments); - - if (!call_data->write(response)) { - return false; - } + RequestParams request_params( + rpc_request, call->get_x_request_id(), call->get_x_request_time()); + std::vector messages; + messages.reserve(rpc_request.messages_size()); + for (const auto& message : rpc_request.messages()) { + messages.emplace_back(message.role(), message.content()); } - - if (!stream_result.completed_calls.empty()) { - response.Clear(); - response.set_object("chat.completion.chunk"); - response.set_id(request_id); - response.set_created(created_time); - response.set_model(model); - - auto* choice = response.add_choices(); - choice->set_index(req_output.outputs[0].index); - choice->set_finish_reason("tool_calls"); - choice->mutable_delta(); - - if (!call_data->write(response)) { - return false; - } + + bool include_usage = false; + if (rpc_request.has_stream_options()) { + include_usage = rpc_request.stream_options().include_usage(); } - - if (include_usage && req_output.usage.has_value()) { - response.Clear(); - const auto& usage = req_output.usage.value(); - response.set_object("chat.completion.chunk"); - response.set_id(request_id); - response.set_created(created_time); - response.set_model(model); - auto* proto_usage = response.mutable_usage(); - proto_usage->set_prompt_tokens(static_cast(usage.num_prompt_tokens)); - proto_usage->set_completion_tokens(static_cast(usage.num_generated_tokens)); - proto_usage->set_total_tokens(static_cast(usage.num_total_tokens)); - if (!call_data->write(response)) { - return false; + std::optional> prompt_tokens = std::nullopt; + if (rpc_request.has_routing()) { + prompt_tokens = std::vector{}; + prompt_tokens->reserve(rpc_request.routing().token_ids_size()); + for (int i = 0; i < rpc_request.routing().token_ids_size(); i++) { + prompt_tokens->emplace_back(rpc_request.routing().token_ids(i)); } + + request_params.decode_address = rpc_request.routing().decode_name(); } - - if (req_output.finished || req_output.cancelled) { - response.Clear(); - return call_data->finish(); - } - - return true; + + master_->handle_request( + std::move(messages), + std::move(prompt_tokens), + std::move(request_params), + [call, + model, + master = master_, + stream = request_params.streaming, + include_usage = include_usage, + first_message_sent = std::unordered_set(), + request_id = request_params.request_id, + created_time = absl::ToUnixSeconds(absl::Now())]( + const RequestOutput& req_output) mutable -> bool { + if (req_output.status.has_value()) { + const auto& status = req_output.status.value(); + if (!status.ok()) { + // Reduce the number of concurrent requests when a + // request is finished with error. + master->get_rate_limiter()->decrease_one_request(); + + return call->finish_with_error(status.code(), status.message()); + } + } + + // Reduce the number of concurrent requests when a request + // is finished or canceled. + if (req_output.finished || req_output.cancelled) { + master->get_rate_limiter()->decrease_one_request(); + } + + if (stream) { + // send delta to client + return send_delta_to_client_brpc(call, + include_usage, + &first_message_sent, + request_id, + created_time, + model, + req_output); + } + return send_result_to_client_brpc( + call, request_id, created_time, model, req_output); + }); } @@ -584,27 +438,6 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { return; } - ToolsInfo tools_info = extract_tools_info(rpc_request); - // 打印所有工具信息 - std::cerr << "Tools Information:" << std::endl; - std::cerr << "Has tools: " << (tools_info.has_tools ? "true" : "false") << std::endl; - std::cerr << "Tool choice: " << tools_info.tool_choice << std::endl; - - if (tools_info.has_tools) { - std::cerr << "Number of tools: " << tools_info.tools.size() << std::endl; - for (size_t i = 0; i < tools_info.tools.size(); ++i) { - const auto& tool = tools_info.tools[i]; - std::cerr << "Tool #" << i + 1 << ":" << std::endl; - std::cerr << " Type: " << tool.type << std::endl; - std::cerr << " Function name: " << tool.function.name << std::endl; - std::cerr << " Function description: " << tool.function.description << std::endl; - std::cerr << " Function parameters: " << tool.function.parameters << std::endl; - } - } else { - std::cerr << "No tools in this request" << std::endl; - } - - RequestParams request_params( rpc_request, call->get_x_request_id(), call->get_x_request_time()); @@ -623,9 +456,16 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { include_usage = rpc_request.stream_options().include_usage(); } - if ((tools_info.has_tools)) { - request_params.tools = std::move(tools_info.tools); - request_params.tool_choice = std::move(tools_info.tool_choice); + if (rpc_request.tools_size() > 0) { + request_params.proto_tools.assign(rpc_request.tools().begin(), + rpc_request.tools().end()); + + // TODO: Implement support for 'required' option in tool_choice. + if (rpc_request.has_tool_choice()) { + request_params.tool_choice = rpc_request.tool_choice(); + } else { + request_params.tool_choice = "auto"; + } } // schedule the request @@ -660,52 +500,30 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { master->get_rate_limiter()->decrease_one_request(); } - std::string parser_format = master->options().tool_call_parser().value_or(""); + std::string parser_format = + master->options().tool_call_parser().value_or(""); if (stream) { if (has_tools && !parser_format.empty()) { LOG(ERROR) << "Tool call does not support streaming output"; - return send_delta_to_client_brpc(call_data, - include_usage, - &first_message_sent, - request_id, - created_time, - model, - req_output); - // return handle_streaming_function_calls( - // req_output, call_data, &first_message_sent, - // request_id, created_time, model, parser_format, include_usage); - } else { - // send delta to client - return send_delta_to_client_brpc(call_data, - include_usage, - &first_message_sent, - request_id, - created_time, - model, - req_output); } + // send delta to client + return send_delta_to_client_brpc(call_data, + include_usage, + &first_message_sent, + request_id, + created_time, + model, + req_output); } if (has_tools && !parser_format.empty()) { - //debug2 - auto& interface = llm::function_call::FunctionCallInterface::getInstance(); - - if (parser_format != "auto") { - interface.setPreferredFormat(parser_format); - } - auto result = interface.parse(req_output.outputs[0].text); - - std::cerr << "正常文本: " << result.normal_text << std::endl; - - for (const auto& call : result.tool_calls) { - std::cerr << "函数: " << call.function_name << std::endl; - std::cerr << "参数: " << call.function_arguments << std::endl; - } - } - if (has_tools && !parser_format.empty()) { - return handle_function_call_response( - req_output, call_data, - request_id, created_time, model, parser_format); + return send_result_to_client_brpc(call_data, + request_id, + created_time, + model, + req_output, + has_tools, + parser_format); } return send_result_to_client_brpc( diff --git a/xllm/core/chat_template/tools_converter.cpp b/xllm/core/chat_template/tools_converter.cpp deleted file mode 100644 index e40b13da..00000000 --- a/xllm/core/chat_template/tools_converter.cpp +++ /dev/null @@ -1,284 +0,0 @@ -#include "tools_converter.h" - -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace llm { - -std::string ToolsConverter::convert_tools_to_json(const std::vector& tools) { - if (tools.empty()) { - return "[]"; - } - - nlohmann::json tools_json = nlohmann::json::array(); - - for (const auto& tool : tools) { - nlohmann::json tool_json; - tool_json["type"] = tool.type; - - nlohmann::json function_json; - function_json["name"] = tool.function.name; - function_json["description"] = tool.function.description; - - try { - if (!tool.function.parameters.empty()) { - function_json["parameters"] = nlohmann::json::parse(tool.function.parameters); - } else { - function_json["parameters"] = nlohmann::json::object(); - } - } catch (const nlohmann::json::exception& e) { - LOG(WARNING) << "Failed to parse tool parameters JSON: " << e.what() - << ", tool: " << tool.function.name; - function_json["parameters"] = nlohmann::json::object(); - } - - tool_json["function"] = function_json; - tools_json.push_back(tool_json); - } - - return tools_json.dump(2); -} - -std::string ToolsConverter::convert_tools_to_prompt( - const std::vector& tools, - const std::string& tool_choice) { - if (tools.empty()) { - return ""; - } - - std::ostringstream prompt; - prompt << "You have access to the following functions:\n\n"; - - for (const auto& tool : tools) { - prompt << "Function: " << tool.function.name << "\n"; - prompt << "Description: " << tool.function.description << "\n"; - - try { - if (!tool.function.parameters.empty()) { - auto params_json = nlohmann::json::parse(tool.function.parameters); - prompt << "Parameters: " << params_json.dump(2) << "\n"; - } - } catch (const nlohmann::json::exception& e) { - LOG(WARNING) << "Failed to parse parameters for tool: " << tool.function.name; - } - - prompt << "\n"; - } - - if (tool_choice == "required") { - prompt << "You MUST call one of the above functions. "; - } else if (tool_choice == "auto") { - prompt << "You may call one of the above functions if needed. "; - } - - prompt << "To call a function, respond with a JSON object in the following format:\n"; - prompt << "{\n"; - prompt << " \"tool_calls\": [\n"; - prompt << " {\n"; - prompt << " \"id\": \"call_\",\n"; - prompt << " \"type\": \"function\",\n"; - prompt << " \"function\": {\n"; - prompt << " \"name\": \"function_name\",\n"; - prompt << " \"arguments\": \"{\\\"param1\\\": \\\"value1\\\"}\"\n"; - prompt << " }\n"; - prompt << " }\n"; - prompt << " ]\n"; - prompt << "}\n\n"; - - return prompt.str(); -} - -std::vector ToolsConverter::parse_tool_calls_from_text( - const std::string& model_output) { - std::vector tool_calls; - - auto json_blocks = extract_json_blocks(model_output); - - for (const auto& json_block : json_blocks) { - auto parsed_calls = parse_tool_calls_from_json(json_block); - tool_calls.insert(tool_calls.end(), parsed_calls.begin(), parsed_calls.end()); - } - - return tool_calls; -} - -std::vector ToolsConverter::parse_tool_calls_from_json( - const std::string& json_str) { - std::vector tool_calls; - - try { - auto json_obj = nlohmann::json::parse(clean_json_string(json_str)); - - if (json_obj.contains("tool_calls") && json_obj["tool_calls"].is_array()) { - for (const auto& call_json : json_obj["tool_calls"]) { - auto tool_call = parse_single_function_call(call_json); - if (tool_call.has_value()) { - tool_calls.push_back(tool_call.value()); - } - } - } - else if (json_obj.contains("function") || json_obj.contains("name")) { - auto tool_call = parse_single_function_call(json_obj); - if (tool_call.has_value()) { - tool_calls.push_back(tool_call.value()); - } - } - } catch (const nlohmann::json::exception& e) { - LOG(WARNING) << "Failed to parse tool calls JSON: " << e.what(); - } - - return tool_calls; -} - -bool ToolsConverter::validate_tool_call_arguments( - const ToolCall& tool_call, - const std::vector& available_tools) { - auto tool_it = std::find_if(available_tools.begin(), available_tools.end(), - [&](const Tool& tool) { - return tool.function.name == tool_call.function_name; - }); - - if (tool_it == available_tools.end()) { - LOG(WARNING) << "Tool not found: " << tool_call.function_name; - return false; - } - - try { - nlohmann::json::parse(tool_call.function_arguments); - } catch (const nlohmann::json::exception& e) { - LOG(WARNING) << "Invalid arguments JSON for tool " << tool_call.function_name - << ": " << e.what(); - return false; - } - - return validate_json_schema(tool_call.function_arguments, tool_it->function.parameters); -} - -std::string ToolsConverter::generate_tool_call_id() { - static std::random_device rd; - static std::mt19937 gen(rd()); - static std::uniform_int_distribution<> dis(100000, 999999); - - return "call_" + std::to_string(dis(gen)); -} - -std::string ToolsConverter::format_tool_choice(const std::string& tool_choice) { - if (tool_choice == "auto" || tool_choice == "none" || tool_choice == "required") { - return tool_choice; - } - return "auto"; - - -std::optional ToolsConverter::parse_single_function_call( - const nlohmann::json& json_obj) { - try { - ToolCall tool_call; - - if (json_obj.contains("id")) { - tool_call.id = json_obj["id"].get(); - } else { - tool_call.id = generate_tool_call_id(); - } - - if (json_obj.contains("type")) { - tool_call.type = json_obj["type"].get(); - } else { - tool_call.type = "function"; - } - - if (json_obj.contains("function")) { - const auto& func_json = json_obj["function"]; - if (func_json.contains("name")) { - tool_call.function_name = func_json["name"].get(); - } - if (func_json.contains("arguments")) { - if (func_json["arguments"].is_string()) { - tool_call.function_arguments = func_json["arguments"].get(); - } else { - tool_call.function_arguments = func_json["arguments"].dump(); - } - } - } - else if (json_obj.contains("name")) { - tool_call.function_name = json_obj["name"].get(); - if (json_obj.contains("arguments")) { - if (json_obj["arguments"].is_string()) { - tool_call.function_arguments = json_obj["arguments"].get(); - } else { - tool_call.function_arguments = json_obj["arguments"].dump(); - } - } - } - - if (!tool_call.function_name.empty()) { - return tool_call; - } - } catch (const nlohmann::json::exception& e) { - LOG(WARNING) << "Failed to parse single function call: " << e.what(); - } - - return std::nullopt; -} - -bool ToolsConverter::validate_json_schema( - const std::string& json_str, - const std::string& schema_str) { - try { - nlohmann::json::parse(json_str); - if (!schema_str.empty()) { - nlohmann::json::parse(schema_str); - } - return true; - } catch (const nlohmann::json::exception& e) { - return false; - } - - -} - -std::string ToolsConverter::clean_json_string(const std::string& raw_json) { - std::string cleaned = raw_json; - - cleaned = std::string(absl::StripAsciiWhitespace(cleaned)); - - cleaned = absl::StrReplaceAll(cleaned, {{"```json", ""}, {"```", ""}}); - - return cleaned; -} - -std::vector ToolsConverter::extract_json_blocks(const std::string& text) { - std::vector json_blocks; - - std::regex json_regex(R"(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})"); - std::sregex_iterator iter(text.begin(), text.end(), json_regex); - std::sregex_iterator end; - - for (; iter != end; ++iter) { - std::string match = iter->str(); - try { - nlohmann::json::parse(match); - json_blocks.push_back(match); - } catch (const nlohmann::json::exception&) { - continue; - } - } - - if (json_blocks.empty()) { - try { - nlohmann::json::parse(clean_json_string(text)); - json_blocks.push_back(clean_json_string(text)); - } catch (const nlohmann::json::exception&) { - } - } - - return json_blocks; -} - -} // namespace llm \ No newline at end of file diff --git a/xllm/core/chat_template/tools_converter.h b/xllm/core/chat_template/tools_converter.h deleted file mode 100644 index 650bd670..00000000 --- a/xllm/core/chat_template/tools_converter.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include "chat_template.h" - -namespace llm { - -class ToolsConverter { - public: - static std::string convert_tools_to_json(const std::vector& tools); - - static std::string convert_tools_to_prompt( - const std::vector& tools, - const std::string& tool_choice = "auto"); - - static std::vector parse_tool_calls_from_text( - const std::string& model_output); - - static std::vector parse_tool_calls_from_json( - const std::string& json_str); - - static bool validate_tool_call_arguments( - const ToolCall& tool_call, - const std::vector& available_tools); - - static std::string generate_tool_call_id(); - - static std::string format_tool_choice(const std::string& tool_choice); - - private: - static std::optional parse_single_function_call( - const nlohmann::json& json_obj); - - static bool validate_json_schema( - const std::string& json_str, - const std::string& schema_str); - - static std::string clean_json_string(const std::string& raw_json); - - static std::vector extract_json_blocks(const std::string& text); -}; - -} // namespace llm \ No newline at end of file diff --git a/xllm/core/framework/chat_template/jinja_chat_template.cpp b/xllm/core/framework/chat_template/jinja_chat_template.cpp index 4dbd8069..e2ec1c0a 100644 --- a/xllm/core/framework/chat_template/jinja_chat_template.cpp +++ b/xllm/core/framework/chat_template/jinja_chat_template.cpp @@ -1,6 +1,7 @@ #include "jinja_chat_template.h" #include +#include #include #include @@ -23,8 +24,8 @@ JinjaChatTemplate::JinjaChatTemplate(const TokenizerArgs& args) : args_(args) { std::optional JinjaChatTemplate::apply( const ChatMessages& messages) const { - const std::vector empty_tools; - return apply(messages, empty_tools); + const std::vector empty_tools; + return apply(messages, empty_tools); } std::optional JinjaChatTemplate::apply( @@ -81,6 +82,126 @@ std::optional JinjaChatTemplate::apply( return template_->apply(input, options); } +std::optional JinjaChatTemplate::apply( + const ChatMessages& messages, + const std::vector& proto_tools) const { + // convert the messages to json object + nlohmann::ordered_json messages_json = nlohmann::json::array(); + for (const auto& message : messages) { + nlohmann::ordered_json message_json; + message_json["role"] = message.role; + message_json["content"] = message.content; + messages_json.push_back(message_json); + } + + // convert protobuf tools to json object + nlohmann::ordered_json tools_json = nlohmann::json::array(); + if (!proto_tools.empty()) { + try { + for (const auto& proto_tool : proto_tools) { + nlohmann::ordered_json tool_json; + tool_json["type"] = proto_tool.type(); + + nlohmann::ordered_json function_json; + function_json["name"] = proto_tool.function().name(); + function_json["description"] = proto_tool.function().description(); + + if (proto_tool.function().has_parameters()) { + std::string parameters_json_str; + google::protobuf::util::JsonPrintOptions options; + options.add_whitespace = false; + options.preserve_proto_field_names = true; + auto status = google::protobuf::util::MessageToJsonString( + proto_tool.function().parameters(), + ¶meters_json_str, + options); + if (status.ok()) { + function_json["parameters"] = + nlohmann::json::parse(parameters_json_str); + } else { + LOG(WARNING) << "Failed to convert parameters Struct to JSON: " + << status.message() + << ", tool: " << proto_tool.function().name(); + function_json["parameters"] = nlohmann::json::object(); + } + } else { + function_json["parameters"] = nlohmann::json::object(); + } + + tool_json["function"] = function_json; + tools_json.push_back(tool_json); + } + } catch (const std::exception& e) { + LOG(WARNING) << "Failed to convert protobuf tools to JSON: " << e.what(); + // Continue with empty tools array + tools_json = nlohmann::json::array(); + } + } + + // apply the template with tools + return apply(messages_json, tools_json); +} + +std::optional JinjaChatTemplate::apply( + const ChatMessages& messages, + const std::vector& proto_tools) const { + // convert the messages to json object + nlohmann::ordered_json messages_json = nlohmann::json::array(); + for (const auto& message : messages) { + nlohmann::ordered_json message_json; + message_json["role"] = message.role; + message_json["content"] = message.content; + messages_json.push_back(message_json); + } + + // convert protobuf tools to json object + nlohmann::ordered_json tools_json = nlohmann::json::array(); + if (!proto_tools.empty()) { + try { + for (const auto& proto_tool : proto_tools) { + nlohmann::ordered_json tool_json; + tool_json["type"] = proto_tool.type(); + + nlohmann::ordered_json function_json; + function_json["name"] = proto_tool.function().name(); + function_json["description"] = proto_tool.function().description(); + + if (proto_tool.function().has_parameters()) { + std::string parameters_json_str; + google::protobuf::util::JsonPrintOptions options; + options.add_whitespace = false; + options.preserve_proto_field_names = true; + auto status = google::protobuf::util::MessageToJsonString( + proto_tool.function().parameters(), + ¶meters_json_str, + options); + if (status.ok()) { + function_json["parameters"] = + nlohmann::json::parse(parameters_json_str); + } else { + LOG(WARNING) << "Failed to convert parameters Struct to JSON: " + << status.message() + << ", tool: " << proto_tool.function().name(); + function_json["parameters"] = nlohmann::json::object(); + } + } else { + function_json["parameters"] = nlohmann::json::object(); + } + + tool_json["function"] = function_json; + tools_json.push_back(tool_json); + } + } catch (const std::exception& e) { + LOG(WARNING) << "Failed to convert protobuf tools to JSON: " << e.what(); + // Continue with empty tools array + tools_json = nlohmann::json::array(); + } + } + + // apply the template with tools + return apply(messages_json, tools_json); +} + nlohmann::ordered_json JinjaChatTemplate::get_mm_content( const Message::MMContentVec& vec) const { nlohmann::ordered_json content_json = nlohmann::json::array(); diff --git a/xllm/core/framework/request/request_params.h b/xllm/core/framework/request/request_params.h index 19330523..4fac5def 100644 --- a/xllm/core/framework/request/request_params.h +++ b/xllm/core/framework/request/request_params.h @@ -5,6 +5,7 @@ #include #include "chat.pb.h" +#include "common/macros.h" #include "completion.pb.h" #include "core/common/macros.h" #include "embedding.pb.h" @@ -101,9 +102,12 @@ struct RequestParams { // decode address. std::string decode_address; - std::vector tools; + // decode address. + std::string decode_address; + + std::vector proto_tools; std::string tool_choice = "auto"; - bool has_tools() const { return !tools.empty(); } + bool has_tools() const { return !proto_tools.empty(); } }; } // namespace xllm diff --git a/xllm/core/runtime/llm_master.cpp b/xllm/core/runtime/llm_master.cpp index f9a0cdbb..b42f3953 100644 --- a/xllm/core/runtime/llm_master.cpp +++ b/xllm/core/runtime/llm_master.cpp @@ -422,10 +422,9 @@ std::shared_ptr LLMMaster::generate_request( const RequestParams& sp, OutputCallback callback) { Timer timer; - std::optional prompt; + std::optional prompt; if (sp.has_tools()) { - auto tools = sp.tools; - prompt = chat_template_->apply(messages, tools); + prompt = chat_template_->apply(messages, sp.proto_tools); } else { prompt = chat_template_->apply(messages); } diff --git a/xllm/proto/chat.proto b/xllm/proto/chat.proto index 6329f9dc..cfc03b01 100644 --- a/xllm/proto/chat.proto +++ b/xllm/proto/chat.proto @@ -19,9 +19,10 @@ message Tool { } message ToolCall { - string id = 1; - string type = 2; // "function" - FunctionCall function = 3; + optional uint32 index = 1; + optional string id = 2; + string type = 3; // "function" + FunctionCall function = 4; } message FunctionCall { From 8528e6a8afd49b8682157b2562caa03f6853be47 Mon Sep 17 00:00:00 2001 From: dengyingxu1 Date: Mon, 11 Aug 2025 20:46:13 +0800 Subject: [PATCH 4/9] refactor: optimize function_call structure for better maintainability. --- xllm/api_service/chat_service_impl.cpp | 108 ++++--- xllm/core/function_call/CMakeLists.txt | 26 +- xllm/core/function_call/base_detector.cpp | 47 --- xllm/core/function_call/base_detector.h | 99 ------ .../function_call/base_format_detector.cpp | 66 ++++ .../core/function_call/base_format_detector.h | 64 ++++ xllm/core/function_call/core_types.h | 62 ++++ xllm/core/function_call/function_call.h | 73 ++--- .../function_call/function_call_parser.cpp | 287 ++++-------------- .../core/function_call/function_call_parser.h | 88 +++--- xllm/core/function_call/qwen25_detector.cpp | 68 +++++ xllm/core/function_call/qwen25_detector.h | 197 +----------- xllm/core/function_call/types.h | 160 ---------- 13 files changed, 473 insertions(+), 872 deletions(-) delete mode 100644 xllm/core/function_call/base_detector.cpp delete mode 100644 xllm/core/function_call/base_detector.h create mode 100644 xllm/core/function_call/base_format_detector.cpp create mode 100644 xllm/core/function_call/base_format_detector.h create mode 100644 xllm/core/function_call/core_types.h create mode 100644 xllm/core/function_call/qwen25_detector.cpp delete mode 100644 xllm/core/function_call/types.h diff --git a/xllm/api_service/chat_service_impl.cpp b/xllm/api_service/chat_service_impl.cpp index 842575d6..b116f357 100644 --- a/xllm/api_service/chat_service_impl.cpp +++ b/xllm/api_service/chat_service_impl.cpp @@ -23,6 +23,7 @@ #include "common/instance_name.h" #include "common/uuid.h" #include "function_call/function_call.h" +#include "function_call/core_types.h" #include "request/request_params.h" #include "util/utils.h" @@ -32,26 +33,62 @@ namespace { std::string generate_tool_call_id() { return "call_" + llm::generate_uuid(); } -void convert_tool_calls_to_proto( - const std::vector& tool_calls, - proto::ChatMessage* message) { - for (const auto& call : tool_calls) { - if (!call.is_valid()) { - LOG(WARNING) << "Invalid tool call: " << call.to_string(); - continue; - } +struct ToolCallResult { + std::optional> tool_calls; + std::string text; + std::string finish_reason; +}; + +ToolCallResult process_tool_calls( + const std::string& text, + const std::vector& tools, + const std::string& parser_format, + const std::string& finish_reason) { + + ToolCallResult result; + result.text = text; + result.finish_reason = finish_reason; - auto* proto_tool_call = message->add_tool_calls(); - proto_tool_call->set_id(call.id.empty() ? generate_tool_call_id() - : call.id); - proto_tool_call->set_type(call.type.empty() ? "function" : call.type); + function_call::FunctionCallParser parser(tools, parser_format); - auto* function = proto_tool_call->mutable_function(); - function->set_name(call.function_name); - function->set_arguments(call.function_arguments); + if (!parser.has_tool_call(text)) { + return result; } + + if (result.finish_reason == "stop") { + result.finish_reason = "tool_calls"; + } + + try { + auto [parsed_text, call_info_list] = parser.parse_non_stream(text); + result.text = parsed_text; + + std::vector tool_calls; + tool_calls.reserve(call_info_list.size()); + + for (const auto& call_info : call_info_list) { + proto::ToolCall tool_call; + tool_call.set_id("call_" + generate_uuid()); + tool_call.set_type("function"); + + auto* function = tool_call.mutable_function(); + if (call_info.name) { + function->set_name(*call_info.name); + } + function->set_arguments(call_info.parameters); + + tool_calls.push_back(tool_call); + } + + result.tool_calls = std::move(tool_calls); + } catch (const std::exception& e) { + LOG(ERROR) << "Tool call parsing error: " << e.what(); + } + + return result; } + void set_logprobs(proto::ChatChoice* choice, const std::optional>& logprobs) { if (!logprobs.has_value() || logprobs.value().empty()) { @@ -172,7 +209,8 @@ bool send_result_to_client_brpc(std::shared_ptr call, const std::string& model, const RequestOutput& req_output, bool has_tools = false, - const std::string& parser_format = "") { + const std::string& parser_format = "", + const std::vector& tools = {}) { auto& response = call->response(); response.set_object("chat.completion"); response.set_id(request_id); @@ -195,23 +233,23 @@ bool send_result_to_client_brpc(std::shared_ptr call, }; if (has_tools && !parser_format.empty()) { - auto& fc_interface = - llm::function_call::FunctionCallInterface::getInstance(); - fc_interface.setPreferredFormat(parser_format); - auto parse_result = fc_interface.parse(output.text); - if (parse_result.has_tool_calls()) { - choice->set_finish_reason("tool_calls"); - if (!parse_result.normal_text.empty()) { - std::string cleaned_text = parse_result.normal_text; - boost::algorithm::trim(cleaned_text); - if (!cleaned_text.empty()) { - message->set_content(cleaned_text); - } + std::string finish_reason; + if (output.finish_reason.has_value()) { + finish_reason = output.finish_reason.value(); + } + + auto result = process_tool_calls(output.text, tools, parser_format, finish_reason); + + message->set_content(result.text); + + if (result.tool_calls) { + for (const auto& tool_call : *result.tool_calls) { + *message->add_tool_calls() = tool_call; } - convert_tool_calls_to_proto(parse_result.tool_calls, message); - } else { - LOG(WARNING) << "Function call parsing error."; - setOutputAndFinishReason(); + } + + if (!result.finish_reason.empty()) { + choice->set_finish_reason(result.finish_reason); } } else { setOutputAndFinishReason(); @@ -481,7 +519,8 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { first_message_sent = std::unordered_set(), request_id = request_params.request_id, created_time = absl::ToUnixSeconds(absl::Now()), - has_tools = request_params.has_tools()]( + has_tools = request_params.has_tools(), + proto_tools = request_params.proto_tools]( const RequestOutput& req_output) mutable -> bool { if (req_output.status.has_value()) { const auto& status = req_output.status.value(); @@ -523,7 +562,8 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { model, req_output, has_tools, - parser_format); + parser_format, + proto_tools); } return send_result_to_client_brpc( diff --git a/xllm/core/function_call/CMakeLists.txt b/xllm/core/function_call/CMakeLists.txt index 5bd6a6e4..7b9b5f6d 100644 --- a/xllm/core/function_call/CMakeLists.txt +++ b/xllm/core/function_call/CMakeLists.txt @@ -1,15 +1,21 @@ -cc_library( - NAME function_call +include(cc_library) +include(cc_test) + +cc_library ( + NAME + function_call HDRS - types.h - base_detector.h - qwen25_detector.h - function_call_parser.h - function_call.h + core_types.h + base_format_detector.h + qwen25_detector.h + function_call_parser.h + function_call.h SRCS - base_detector.cpp - function_call_parser.cpp + base_format_detector.cpp + qwen25_detector.cpp + function_call_parser.cpp DEPS nlohmann_json::nlohmann_json glog::glog -) \ No newline at end of file + proto::xllm_proto +) diff --git a/xllm/core/function_call/base_detector.cpp b/xllm/core/function_call/base_detector.cpp deleted file mode 100644 index bda66683..00000000 --- a/xllm/core/function_call/base_detector.cpp +++ /dev/null @@ -1,47 +0,0 @@ -#include "base_detector.h" -#include "qwen25_detector.h" -#include -#include - -namespace llm { -namespace function_call { - -std::unique_ptr DetectorFactory::create_detector(ModelFormat format) { - switch (format) { - case ModelFormat::QWEN25: - return std::make_unique(); - default: - return nullptr; - } -} - -std::vector> DetectorFactory::create_all_detectors() { - std::vector> detectors; - - detectors.push_back(std::make_unique()); - - return detectors; -} - -ModelFormat DetectorFactory::infer_format_from_model_name(const std::string& model_name) { - std::string lower_name = model_name; - std::transform(lower_name.begin(), lower_name.end(), lower_name.begin(), ::tolower); - - if (lower_name.find("qwen") != std::string::npos) { - return ModelFormat::QWEN25; - } - - return ModelFormat::UNKNOWN; -} - -std::string DetectorFactory::get_format_name(ModelFormat format) { - switch (format) { - case ModelFormat::QWEN25: - return "qwen25"; - default: - return "unknown"; - } -} - -} // namespace function_call -} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/base_detector.h b/xllm/core/function_call/base_detector.h deleted file mode 100644 index c934085a..00000000 --- a/xllm/core/function_call/base_detector.h +++ /dev/null @@ -1,99 +0,0 @@ -#pragma once - -#include "types.h" -#include -#include -#include -#include "common/uuid.h" -namespace llm { -namespace function_call { - -class BaseFormatDetector { -public: - virtual ~BaseFormatDetector() = default; - - virtual bool detect(const std::string& text) const = 0; - - virtual FormatDetectionResult detect_format(const std::string& text) const = 0; - - virtual ParseResult parse_calls(const std::string& text) const = 0; - - virtual StreamingParseResult parse_streaming(const std::string& chunk) = 0; - - virtual void reset_streaming_state() = 0; - - virtual ModelFormat get_format() const = 0; - - virtual std::string get_format_name() const = 0; - - virtual StructureInfo get_structure_info() const = 0; - - virtual EBNFGrammar generate_ebnf_grammar(const ConstraintOptions& options = {}) const = 0; - - virtual bool validate_call_format(const ToolCallItem& call) const = 0; - -protected: - std::string generate_call_id() const { - return "call_" + llm::generate_uuid(); - } - - std::string clean_json_string(const std::string& json_str) const { - std::string cleaned = json_str; - size_t start = cleaned.find_first_not_of(" \t\n\r"); - if (start == std::string::npos) return ""; - size_t end = cleaned.find_last_not_of(" \t\n\r"); - cleaned = cleaned.substr(start, end - start + 1); - return cleaned; - } - - bool is_valid_json(const std::string& json_str) const { - try { - int brace_count = 0; - bool in_string = false; - bool escaped = false; - - for (char c : json_str) { - if (escaped) { - escaped = false; - continue; - } - - if (c == '\\') { - escaped = true; - continue; - } - - if (c == '"') { - in_string = !in_string; - continue; - } - - if (!in_string) { - if (c == '{') brace_count++; - else if (c == '}') brace_count--; - } - } - - return brace_count == 0 && !in_string; - } catch (...) { - return false; - } - } -}; - -class DetectorFactory { -public: - static std::unique_ptr create_detector(ModelFormat format); - - static std::vector> create_all_detectors(); - - static ModelFormat infer_format_from_model_name(const std::string& model_name); - - static std::string get_format_name(ModelFormat format); - -private: - DetectorFactory() = default; -}; - -} // namespace function_call -} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/base_format_detector.cpp b/xllm/core/function_call/base_format_detector.cpp new file mode 100644 index 00000000..483b71be --- /dev/null +++ b/xllm/core/function_call/base_format_detector.cpp @@ -0,0 +1,66 @@ +#include "base_format_detector.h" +#include +#include +#include + +namespace llm { +namespace function_call { + +BaseFormatDetector::BaseFormatDetector() + : current_tool_id_(-1) + , current_tool_name_sent_(false) + , bot_token_("") + , eot_token_("") + , tool_call_separator_(", ") { +} + +std::unordered_map BaseFormatDetector::get_tool_indices(const std::vector& tools) { + std::unordered_map indices; + for (size_t i = 0; i < tools.size(); ++i) { + if (!tools[i].function().name().empty()) { + indices[tools[i].function().name()] = static_cast(i); + } + } + return indices; +} + +std::vector BaseFormatDetector::parse_base_json(const std::string& action_json, const std::vector& tools) { + auto tool_indices = get_tool_indices(tools); + std::vector results; + + // TODO: Replace with a more robust JSON library for better functionality and reliability + std::string trimmed = action_json; + trimmed.erase(0, trimmed.find_first_not_of(" \t\n\r")); + trimmed.erase(trimmed.find_last_not_of(" \t\n\r") + 1); + + if (trimmed.empty()) { + return results; + } + + std::regex name_regex("\"name\"\\s*:\\s*\"([^\"]+)\""); + std::regex args_regex("\"(?:parameters|arguments)\"\\s*:\\s*(\\{[^}]*\\})"); + + std::smatch name_match, args_match; + + if (std::regex_search(trimmed, name_match, name_regex)) { + std::string name = name_match[1].str(); + + if (tool_indices.find(name) != tool_indices.end()) { + std::string parameters = "{}"; + + if (std::regex_search(trimmed, args_match, args_regex)) { + parameters = args_match[1].str(); + } + + results.emplace_back(-1, name, parameters); + } else { + LOG(ERROR) << "Model attempted to call undefined function: " << name; + } + } + + return results; +} + + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/base_format_detector.h b/xllm/core/function_call/base_format_detector.h new file mode 100644 index 00000000..1f0ed28c --- /dev/null +++ b/xllm/core/function_call/base_format_detector.h @@ -0,0 +1,64 @@ +#pragma once + +#include "core_types.h" +#include "chat.pb.h" +#include +#include +#include +#include +#include + +namespace llm { +namespace function_call { + +class BaseFormatDetector { +public: + BaseFormatDetector(); + virtual ~BaseFormatDetector() = default; + + BaseFormatDetector(const BaseFormatDetector&) = delete; + BaseFormatDetector& operator=(const BaseFormatDetector&) = delete; + +protected: + // Streaming state management + // Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks + std::string buffer_; + + // Stores complete tool call info (name and arguments) for each tool being parsed. + // Used by serving layer for completion handling when streaming ends. + // Format: [{"name": str, "arguments": dict}, ...] + std::vector> prev_tool_call_arr_; + + // Index of currently streaming tool call. Starts at -1 (no active tool), + // increments as each tool completes. Tracks which tool's arguments are streaming. + int current_tool_id_; + + // Flag for whether current tool's name has been sent to client. + // Tool names sent first with empty parameters, then arguments stream incrementally. + bool current_tool_name_sent_; + + // Tracks raw JSON string content streamed to client for each tool's arguments. + // Critical for serving layer to calculate remaining content when streaming ends. + // Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72'] + std::vector streamed_args_for_tool_; + + // Token configuration (override in subclasses) + std::string bot_token_; + std::string eot_token_; + std::string tool_call_separator_; + + // Tool indices cache + std::unordered_map tool_indices_; + +public: + std::unordered_map get_tool_indices(const std::vector& tools); + + std::vector parse_base_json(const std::string& action_json, const std::vector& tools); + + virtual StreamingParseResult detect_and_parse(const std::string& text, const std::vector& tools) = 0; + + virtual bool has_tool_call(const std::string& text) = 0; +}; + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/core_types.h b/xllm/core/function_call/core_types.h new file mode 100644 index 00000000..165fd615 --- /dev/null +++ b/xllm/core/function_call/core_types.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include +#include +#include +#include "chat.pb.h" + +namespace llm { +namespace function_call { + +struct ToolCallItem { + int tool_index; + std::optional name; + std::string parameters; // JSON string + + ToolCallItem() : tool_index(-1), parameters("") {} + + ToolCallItem(int index, const std::optional& func_name, const std::string& params) + : tool_index(index), name(func_name), parameters(params) {} +}; + +struct StreamingParseResult { + std::string normal_text; + std::vector calls; + + StreamingParseResult() = default; + + StreamingParseResult(const std::string& text) : normal_text(text) {} + + StreamingParseResult(const std::vector& tool_calls) : calls(tool_calls) {} + + StreamingParseResult(const std::string& text, const std::vector& tool_calls) + : normal_text(text), calls(tool_calls) {} + + bool has_calls() const { + return !calls.empty(); + } + + void clear() { + normal_text.clear(); + calls.clear(); + } +}; + + +struct StructureInfo { + std::string begin; + std::string end; + std::string trigger; + + StructureInfo() = default; + + StructureInfo(const std::string& begin_str, const std::string& end_str, const std::string& trigger_str) + : begin(begin_str), end(end_str), trigger(trigger_str) {} +}; + + +using GetInfoFunc = std::function; + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/function_call.h b/xllm/core/function_call/function_call.h index d789cc0b..a4781ab3 100644 --- a/xllm/core/function_call/function_call.h +++ b/xllm/core/function_call/function_call.h @@ -1,69 +1,30 @@ #pragma once -#include "types.h" +#include "core_types.h" +#include "base_format_detector.h" +#include "qwen25_detector.h" #include "function_call_parser.h" namespace llm { namespace function_call { -class FunctionCallInterface { -public: - static FunctionCallInterface& getInstance() { - static FunctionCallInterface instance; - return instance; - } - - ParseResult parse(const std::string& text) { - return parser_.parse_auto(text); - } - - ParseResult parse(const std::string& text, const std::string& format) { - ModelFormat model_format = ModelFormat::UNKNOWN; - if (format == "qwen" || format == "qwen25") { - model_format = ModelFormat::QWEN25; - } - return parser_.parse_with_format(text, model_format); - } - - StreamingParseResult parseStreaming(const std::string& chunk) { - return parser_.parse_streaming_auto(chunk); - } - - bool hasFunction(const std::string& text) { - return utils::has_function_calls(text); - } - - std::string detectFormat(const std::string& text) { - return utils::detect_best_format(text); - } - - std::string generateConstraints(const std::vector& function_names, - const std::string& format = "auto") { - return utils::generate_ebnf_constraints(function_names, format); - } - - void setPreferredFormat(const std::string& model_name) { - parser_.set_preferred_format(model_name); - } - - void resetStreamingState() { - parser_.reset_all_streaming_states(); - } - -private: - FunctionCallInterface() = default; - FunctionCallParser parser_; -}; +using Parser = FunctionCallParser; +using Detector = BaseFormatDetector; +using QwenDetector = Qwen25Detector; -} // namespace function_call - -inline function_call::ParseResult parse_function_calls(const std::string& text) { - return function_call::FunctionCallInterface::getInstance().parse(text); +inline std::vector parse( + const std::string& text, + const std::vector& tools, + const std::string& format = "qwen25") { + return utils::parse_function_calls(text, tools, format); } -inline bool has_function_calls(const std::string& text) { - auto result = function_call::FunctionCallInterface::getInstance().parse(text); - return result.has_tool_calls(); +inline bool has_calls( + const std::string& text, + const std::string& format = "qwen25") { + return utils::has_function_calls(text, format); } + +} // namespace function_call } // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/function_call_parser.cpp b/xllm/core/function_call/function_call_parser.cpp index 1a419311..4e939fcb 100644 --- a/xllm/core/function_call/function_call_parser.cpp +++ b/xllm/core/function_call/function_call_parser.cpp @@ -1,264 +1,97 @@ #include "function_call_parser.h" -#include "base_detector.h" -#include -#include +#include +#include namespace llm { namespace function_call { -FunctionCallParser::FunctionCallParser() : preferred_format_(ModelFormat::UNKNOWN) { - initialize_detectors(); -} +const std::unordered_map FunctionCallParser::ToolCallParserEnum = { + {"qwen25", "qwen25"}, + {"qwen3", "qwen25"}, + // TODO + // {"llama3", "llama3"}, + // {"mistral", "mistral"}, + // {"deepseekv3", "deepseekv3"}, + // {"pythonic", "pythonic"}, + // {"kimi_k2", "kimi_k2"}, + // {"qwen3_coder", "qwen3_coder"}, + // {"glm45", "glm45"}, + // {"step3", "step3"}, +}; -void FunctionCallParser::initialize_detectors() { - detectors_ = DetectorFactory::create_all_detectors(); +FunctionCallParser::FunctionCallParser(const std::vector& tools, const std::string& tool_call_parser) + : tools_(tools) { - for (auto& detector : detectors_) { - if (detector) { - format_detectors_[detector->get_format()] = std::move(detector); - } + detector_ = create_detector(tool_call_parser); + if (!detector_) { + throw std::invalid_argument("Unsupported tool_call_parser: " + tool_call_parser); } - detectors_.clear(); -} - -void FunctionCallParser::set_preferred_format(ModelFormat format) { - preferred_format_ = format; } -void FunctionCallParser::set_preferred_format(const std::string& model_name) { - preferred_format_ = infer_format_from_model_name(model_name); +bool FunctionCallParser::has_tool_call(const std::string& text) const { + return detector_->has_tool_call(text); } -ParseResult FunctionCallParser::parse_auto(const std::string& text) { - if (preferred_format_ != ModelFormat::UNKNOWN) { - auto detector = get_detector(preferred_format_); - if (detector && detector->detect(text)) { - return detector->parse_calls(text); - } - } - - for (auto& [format, detector] : format_detectors_) { - if (detector && detector->detect(text)) { - return detector->parse_calls(text); - } - } +std::tuple> FunctionCallParser::parse_non_stream(const std::string& full_text) { + StreamingParseResult parsed_result = detector_->detect_and_parse(full_text, tools_); - return {}; -} - -ParseResult FunctionCallParser::parse_with_format(const std::string& text, ModelFormat format) { - auto detector = get_detector(format); - if (!detector) { - LOG(WARNING) << "Unsupported format: " << static_cast(format); - return {}; + if (!parsed_result.calls.empty()) { + return std::make_tuple(parsed_result.normal_text, parsed_result.calls); + } else { + return std::make_tuple(full_text, std::vector()); } - - return detector->parse_calls(text); } -StreamingParseResult FunctionCallParser::parse_streaming_auto(const std::string& chunk) { - if (preferred_format_ != ModelFormat::UNKNOWN) { - auto detector = get_detector(preferred_format_); - if (detector) { - return detector->parse_streaming(chunk); - } - } - - for (auto& [format, detector] : format_detectors_) { - if (detector) { - auto result = detector->parse_streaming(chunk); - if (result.has_completed_calls() || result.has_partial_call()) { - return result; - } - } - } - - StreamingParseResult empty_result; - return empty_result; -} - -StreamingParseResult FunctionCallParser::parse_streaming_with_format(const std::string& chunk, ModelFormat format) { - auto detector = get_detector(format); - if (!detector) { - StreamingParseResult result; - result.has_error = true; - result.error_message = "Unsupported format"; - return result; - } - - return detector->parse_streaming(chunk); -} -std::vector FunctionCallParser::detect_formats(const std::string& text) { - std::vector results; - - for (auto& [format, detector] : format_detectors_) { - if (detector) { - auto result = detector->detect_format(text); - results.push_back(result); - } - } - - std::sort(results.begin(), results.end(), - [](const FormatDetectionResult& a, const FormatDetectionResult& b) { - return a.confidence > b.confidence; - }); - - return results; -} - -FormatDetectionResult FunctionCallParser::get_best_format(const std::string& text) { - auto results = detect_formats(text); - if (!results.empty()) { - return results[0]; - } - - FormatDetectionResult empty_result; - return empty_result; -} - -bool FunctionCallParser::validate_calls(const std::vector& calls, ModelFormat format) { - auto detector = get_detector(format); - if (!detector) { - return false; - } - - for (const auto& call : calls) { - if (!detector->validate_call_format(call)) { - return false; - } - } - - return true; -} - -std::string FunctionCallParser::generate_constraints(const std::vector& function_names, - ModelFormat format, - const ConstraintOptions& options) { - ModelFormat target_format = format; - if (target_format == ModelFormat::UNKNOWN) { - target_format = preferred_format_; - } - if (target_format == ModelFormat::UNKNOWN) { - target_format = ModelFormat::QWEN25; // 默认格式 +std::unique_ptr FunctionCallParser::create_detector(const std::string& tool_call_parser) { + auto it = ToolCallParserEnum.find(tool_call_parser); + if (it == ToolCallParserEnum.end()) { + return nullptr; } - auto detector = get_detector(target_format); - if (!detector) { - return ""; + if (it->second == "qwen25") { + return std::make_unique(); } - ConstraintOptions modified_options = options; - modified_options.allowed_functions = function_names; + // if (tool_call_parser == "llama3") { + // return std::make_unique(); + // } + // if (tool_call_parser == "mistral") { + // return std::make_unique(); + // } - auto grammar = detector->generate_ebnf_grammar(modified_options); - return grammar.to_string(); -} - -void FunctionCallParser::reset_all_streaming_states() { - for (auto& [format, detector] : format_detectors_) { - if (detector) { - detector->reset_streaming_state(); - } - } -} - -void FunctionCallParser::reset_streaming_state(ModelFormat format) { - auto detector = get_detector(format); - if (detector) { - detector->reset_streaming_state(); - } -} - -std::vector FunctionCallParser::get_supported_formats() const { - std::vector formats; - for (const auto& [format, detector] : format_detectors_) { - if (detector) { - formats.push_back(format); - } - } - return formats; -} - -std::string FunctionCallParser::get_format_name(ModelFormat format) const { - return DetectorFactory::get_format_name(format); -} - -bool FunctionCallParser::is_format_supported(ModelFormat format) const { - return format_detectors_.find(format) != format_detectors_.end(); -} - -BaseFormatDetector* FunctionCallParser::get_detector(ModelFormat format) { - auto it = format_detectors_.find(format); - if (it != format_detectors_.end()) { - return it->second.get(); - } return nullptr; } -ModelFormat FunctionCallParser::infer_format_from_model_name(const std::string& model_name) { - return DetectorFactory::infer_format_from_model_name(model_name); -} - namespace utils { -std::vector parse_function_calls(const std::string& text) { - static FunctionCallParser parser; - return parser.parse_auto(text).tool_calls; -} - -std::vector parse_function_calls(const std::string& text, const std::string& format) { - static FunctionCallParser parser; - - ModelFormat model_format = ModelFormat::UNKNOWN; - if (format == "qwen25" || format == "qwen") { - model_format = ModelFormat::QWEN25; - } - - if (model_format == ModelFormat::UNKNOWN) { - return parser.parse_auto(text).tool_calls; +std::vector parse_function_calls( + const std::string& text, + const std::vector& tools, + const std::string& parser_type) { + + try { + FunctionCallParser parser(tools, parser_type); + auto [normal_text, calls] = parser.parse_non_stream(text); + return calls; + } catch (const std::exception& e) { + LOG(ERROR) << "Error parsing function calls: " << e.what(); + return {}; } - - return parser.parse_with_format(text, model_format).tool_calls; -} - -bool has_function_calls(const std::string& text) { - static FunctionCallParser parser; - auto calls = parser.parse_auto(text); - return calls.has_tool_calls(); } -std::string detect_best_format(const std::string& text) { - static FunctionCallParser parser; - auto result = parser.get_best_format(text); - return parser.get_format_name(result.format); -} -std::string generate_ebnf_constraints(const std::vector& function_names, - const std::string& format) { - static FunctionCallParser parser; - - ModelFormat model_format = ModelFormat::UNKNOWN; - if (format == "qwen25" || format == "qwen") { - model_format = ModelFormat::QWEN25; - } +bool has_function_calls( + const std::string& text, + const std::string& parser_type) { - return parser.generate_constraints(function_names, model_format); -} - -bool validate_function_call_format(const ToolCallItem& call, const std::string& format) { - static FunctionCallParser parser; - - ModelFormat model_format = ModelFormat::UNKNOWN; - if (format == "qwen25" || format == "qwen") { - model_format = ModelFormat::QWEN25; - } - - if (model_format == ModelFormat::UNKNOWN) { + try { + FunctionCallParser parser({}, parser_type); + return parser.has_tool_call(text); + } catch (const std::exception& e) { + LOG(ERROR) << "Error checking function calls: " << e.what(); return false; } - - return parser.validate_calls({call}, model_format); } } // namespace utils diff --git a/xllm/core/function_call/function_call_parser.h b/xllm/core/function_call/function_call_parser.h index 67f9d50c..0855b6a4 100644 --- a/xllm/core/function_call/function_call_parser.h +++ b/xllm/core/function_call/function_call_parser.h @@ -1,81 +1,61 @@ #pragma once -#include "types.h" -#include "base_detector.h" +#include "core_types.h" +#include "base_format_detector.h" +#include "qwen25_detector.h" #include #include #include #include +#include namespace llm { namespace function_call { class FunctionCallParser { +public: + static const std::unordered_map ToolCallParserEnum; + private: - std::vector> detectors_; - std::unordered_map> format_detectors_; - ModelFormat preferred_format_; - + std::unique_ptr detector_; + std::vector tools_; + public: - FunctionCallParser(); + + FunctionCallParser(const std::vector& tools, const std::string& tool_call_parser); + ~FunctionCallParser() = default; FunctionCallParser(const FunctionCallParser&) = delete; FunctionCallParser& operator=(const FunctionCallParser&) = delete; - - void set_preferred_format(ModelFormat format); - void set_preferred_format(const std::string& model_name); - - ParseResult parse_auto(const std::string& text); - - ParseResult parse_with_format(const std::string& text, ModelFormat format); - - StreamingParseResult parse_streaming_auto(const std::string& chunk); - - StreamingParseResult parse_streaming_with_format(const std::string& chunk, ModelFormat format); - - std::vector detect_formats(const std::string& text); - - FormatDetectionResult get_best_format(const std::string& text); - - bool validate_calls(const std::vector& calls, ModelFormat format); - - std::string generate_constraints(const std::vector& function_names, - ModelFormat format = ModelFormat::UNKNOWN, - const ConstraintOptions& options = {}); - - void reset_all_streaming_states(); - - void reset_streaming_state(ModelFormat format); - - std::vector get_supported_formats() const; - - std::string get_format_name(ModelFormat format) const; - - bool is_format_supported(ModelFormat format) const; - -private: - void initialize_detectors(); - - BaseFormatDetector* get_detector(ModelFormat format); - - ModelFormat infer_format_from_model_name(const std::string& model_name); -}; -namespace utils { + bool has_tool_call(const std::string& text) const; + + std::tuple> parse_non_stream(const std::string& full_text); + + // StructuralTagResponseFormat get_structure_tag(); -std::vector parse_function_calls(const std::string& text); + // std::tuple get_structure_constraint(const std::string& tool_choice); -std::vector parse_function_calls(const std::string& text, const std::string& format); + BaseFormatDetector* get_detector() const { return detector_.get(); } -bool has_function_calls(const std::string& text); +private: + std::unique_ptr create_detector(const std::string& tool_call_parser); +}; + +namespace utils { -std::string detect_best_format(const std::string& text); +std::vector parse_function_calls( + const std::string& text, + const std::vector& tools, + const std::string& parser_type = "qwen25" +); -std::string generate_ebnf_constraints(const std::vector& function_names, - const std::string& format = "auto"); +bool has_function_calls( + const std::string& text, + const std::string& parser_type = "qwen25" +); -bool validate_function_call_format(const ToolCallItem& call, const std::string& format); } // namespace utils diff --git a/xllm/core/function_call/qwen25_detector.cpp b/xllm/core/function_call/qwen25_detector.cpp new file mode 100644 index 00000000..aff593ad --- /dev/null +++ b/xllm/core/function_call/qwen25_detector.cpp @@ -0,0 +1,68 @@ +#include "qwen25_detector.h" +#include +#include + +namespace llm { +namespace function_call { + +Qwen25Detector::Qwen25Detector() : BaseFormatDetector() { + bot_token_ = "\n"; + eot_token_ = "\n"; + tool_call_separator_ = "\n"; +} + +bool Qwen25Detector::has_tool_call(const std::string& text) { + return text.find(bot_token_) != std::string::npos; +} + +StreamingParseResult Qwen25Detector::detect_and_parse(const std::string& text, const std::vector& tools) { + size_t idx = text.find(bot_token_); + std::string normal_text = (idx != std::string::npos) ? text.substr(0, idx) : text; + + while (!normal_text.empty() && std::isspace(normal_text.back())) { + normal_text.pop_back(); + } + + if (text.find(bot_token_) == std::string::npos) { + return StreamingParseResult(normal_text); + } + + std::string escaped_bot_token = bot_token_; + std::string escaped_eot_token = eot_token_; + + std::string pattern = escaped_bot_token + "(.*?)" + escaped_eot_token; + + size_t pos = 0; + while ((pos = pattern.find("\n", pos)) != std::string::npos) { + pattern.replace(pos, 1, "\\n"); + pos += 2; + } + + std::regex tool_call_regex(pattern, std::regex_constants::ECMAScript); + std::sregex_iterator iter(text.begin(), text.end(), tool_call_regex); + std::sregex_iterator end; + + std::vector calls; + + for (; iter != end; ++iter) { + std::smatch match = *iter; + std::string match_result = match[1].str(); + + try { + match_result.erase(0, match_result.find_first_not_of(" \t\n\r")); + match_result.erase(match_result.find_last_not_of(" \t\n\r") + 1); + + auto parsed_calls = parse_base_json(match_result, tools); + calls.insert(calls.end(), parsed_calls.begin(), parsed_calls.end()); + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to parse JSON part: " << match_result + << ", JSON parse error: " << e.what(); + continue; + } + } + + return StreamingParseResult(normal_text, calls); +} + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/qwen25_detector.h b/xllm/core/function_call/qwen25_detector.h index ac0c81ea..010d4b46 100644 --- a/xllm/core/function_call/qwen25_detector.h +++ b/xllm/core/function_call/qwen25_detector.h @@ -1,198 +1,25 @@ #pragma once -#include "base_detector.h" -#include -#include +#include "base_format_detector.h" +#include namespace llm { namespace function_call { class Qwen25Detector : public BaseFormatDetector { -private: - std::regex start_pattern_; - std::regex end_pattern_; - std::regex full_pattern_; - std::string buffer_; - public: - Qwen25Detector() - : start_pattern_(R"()"), - end_pattern_(R"()"), - full_pattern_(R"(\s*(\{.*?\})\s*)", - std::regex_constants::ECMAScript) {} - - ~Qwen25Detector() override = default; - - bool detect(const std::string& text) const override { - return std::regex_search(text, start_pattern_) && - std::regex_search(text, end_pattern_); - } - - FormatDetectionResult detect_format(const std::string& text) const override { - FormatDetectionResult result; - result.format = ModelFormat::QWEN25; - - bool has_start = std::regex_search(text, start_pattern_); - bool has_end = std::regex_search(text, end_pattern_); - - if (has_start && has_end) { - result.confidence = 0.95; - result.reason = "Found complete tags"; - } else if (has_start) { - result.confidence = 0.7; - result.reason = "Found opening tag"; - } else { - result.confidence = 0.0; - result.reason = "No Qwen2.5 format markers found"; - } - - result.structure_info = get_structure_info(); - return result; - } - - ParseResult parse_calls(const std::string& text) const override { - ParseResult result; - std::string normal_text = text; - - std::sregex_iterator iter(text.begin(), text.end(), full_pattern_); - std::sregex_iterator end; - - std::vector> match_positions; - - for (; iter != end; ++iter) { - const std::smatch& match = *iter; - std::string json_content = match[1].str(); - - match_positions.push_back({match.position(), match.length()}); - - ToolCallItem call; - call.id = generate_call_id(); - call.type = "function"; - - if (parse_json_content(json_content, call)) { - call.state = ParseState::COMPLETED; - result.tool_calls.push_back(call); - } else { - call.state = ParseState::ERROR; - call.error = "Failed to parse JSON content: " + json_content; - result.tool_calls.push_back(call); - } - } - - for (auto it = match_positions.rbegin(); it != match_positions.rend(); ++it) { - normal_text.erase(it->first, it->second); - } - - result.normal_text = normal_text; - return result; - } - - StreamingParseResult parse_streaming(const std::string& chunk) override { - buffer_ += chunk; - StreamingParseResult result; - - auto completed_result = parse_calls(buffer_); - result.completed_calls = completed_result.tool_calls; - result.normal_text = completed_result.normal_text; - - if (std::regex_search(buffer_, start_pattern_) && - !std::regex_search(buffer_, end_pattern_)) { - ToolCallItem partial; - partial.id = "partial_call"; - partial.type = "function"; - partial.state = ParseState::PARSING; - result.partial_call = partial; - } - - if (!completed_result.tool_calls.empty()) { - size_t last_end = buffer_.rfind(""); - if (last_end != std::string::npos) { - last_end += 12; // Length of "" - result.remaining_text = buffer_.substr(last_end); - buffer_ = result.remaining_text; - } - } - - return result; - } - - void reset_streaming_state() override { - buffer_.clear(); - } - - ModelFormat get_format() const override { - return ModelFormat::QWEN25; - } - - std::string get_format_name() const override { - return "qwen25"; - } - - StructureInfo get_structure_info() const override { - StructureInfo info("qwen25", "", ""); - info.patterns["function_call"] = R"(\s*\{.*?\}\s*)"; - info.patterns["json_content"] = R"(\{.*?\})"; - return info; - } - - EBNFGrammar generate_ebnf_grammar(const ConstraintOptions& options) const override { - EBNFGrammar grammar; - grammar.start_rule = "tool_calls"; - - if (options.allow_multiple_calls) { - grammar.add_rule(EBNFRule("tool_calls", "tool_call+")); - } else { - grammar.add_rule(EBNFRule("tool_calls", "tool_call")); - } - - grammar.add_rule(EBNFRule("tool_call", "\"\" ws json_object ws \"\"")); - - grammar.add_rule(EBNFRule("json_object", "\"{\" ws json_members ws \"}\"")); - grammar.add_rule(EBNFRule("json_members", "json_member (\",\" ws json_member)*")); - grammar.add_rule(EBNFRule("json_member", "json_string \":\" ws json_value")); - - grammar.add_rule(EBNFRule("json_value", "json_string | json_object | json_array")); - grammar.add_rule(EBNFRule("json_string", "\"\\\"\" [^\"\\\\]* \"\\\"\"")); - grammar.add_rule(EBNFRule("json_array", "\"[\" ws (json_value (\",\" ws json_value)*)? ws \"]\"")); - - grammar.add_rule(EBNFRule("ws", "[ \\t\\n\\r]*", true)); - - return grammar; - } - - bool validate_call_format(const ToolCallItem& call) const override { - if (call.function_name.empty()) return false; - if (call.function_arguments.empty()) return false; - return is_valid_json(call.function_arguments); - } + Qwen25Detector(); + virtual ~Qwen25Detector() = default; + private: - bool parse_json_content(const std::string& json_str, ToolCallItem& call) const { - try { - auto json_obj = nlohmann::json::parse(clean_json_string(json_str)); - - if (json_obj.contains("name") && json_obj["name"].is_string()) { - call.function_name = json_obj["name"].get(); - } else { - return false; - } - - if (json_obj.contains("arguments")) { - if (json_obj["arguments"].is_string()) { - call.function_arguments = json_obj["arguments"].get(); - } else { - call.function_arguments = json_obj["arguments"].dump(); - } - } else { - call.function_arguments = "{}"; - } - - return true; - } catch (const nlohmann::json::exception& e) { - call.error = "JSON parse error: " + std::string(e.what()); - return false; - } - } + std::string normal_text_buffer_; // Buffer for handling partial end tokens + +public: + + bool has_tool_call(const std::string& text) override; + + StreamingParseResult detect_and_parse(const std::string& text, const std::vector& tools) override; }; } // namespace function_call diff --git a/xllm/core/function_call/types.h b/xllm/core/function_call/types.h deleted file mode 100644 index 9b14857e..00000000 --- a/xllm/core/function_call/types.h +++ /dev/null @@ -1,160 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -namespace llm { -namespace function_call { - -enum class ParseState { - PENDING, - PARSING, - COMPLETED, - ERROR -}; - -struct ToolCallItem { - std::string id; - std::string type; - std::string function_name; - std::string function_arguments; - ParseState state = ParseState::PENDING; - std::optional error; - - ToolCallItem() = default; - ToolCallItem(const std::string& id, const std::string& type, - const std::string& name, const std::string& args) - : id(id), type(type), function_name(name), function_arguments(args) {} - - bool is_valid() const { - return state == ParseState::COMPLETED && - !function_name.empty() && - !function_arguments.empty() && - !error.has_value(); - } - - std::string to_string() const { - std::string result = "ToolCallItem{"; - result += "id='" + id + "', "; - result += "type='" + type + "', "; - result += "function_name='" + function_name + "', "; - result += "arguments='" + function_arguments + "', "; - result += "state=" + std::to_string(static_cast(state)); - if (error.has_value()) { - result += ", error='" + error.value() + "'"; - } - result += "}"; - return result; - } -}; - -struct StreamingParseResult { - std::string normal_text; - std::vector completed_calls; - std::optional partial_call; - std::string remaining_text; - bool has_error = false; - std::string error_message; - - bool has_completed_calls() const { - return !completed_calls.empty(); - } - - bool has_partial_call() const { - return partial_call.has_value(); - } - - void clear() { - completed_calls.clear(); - partial_call.reset(); - remaining_text.clear(); - has_error = false; - error_message.clear(); - } -}; -struct ParseResult { - std::string normal_text; - std::vector tool_calls; - bool has_error = false; - std::string error_message; - - bool has_tool_calls() const { - return !tool_calls.empty(); - } - - void clear() { - normal_text.clear(); - tool_calls.clear(); - has_error = false; - error_message.clear(); - } -}; - -struct StructureInfo { - std::string format_name; - std::string start_marker; - std::string end_marker; - std::unordered_map patterns; - - StructureInfo() = default; - StructureInfo(const std::string& name, const std::string& start, const std::string& end) - : format_name(name), start_marker(start), end_marker(end) {} -}; - -enum class ModelFormat { - QWEN25, - UNKNOWN -}; - -struct FormatDetectionResult { - ModelFormat format = ModelFormat::UNKNOWN; - double confidence = 0.0; - std::string reason; - StructureInfo structure_info; - - bool is_valid() const { - return format != ModelFormat::UNKNOWN && confidence > 0.5; - } -}; - -struct EBNFRule { - std::string name; - std::string definition; - bool is_terminal = false; - - EBNFRule() = default; - EBNFRule(const std::string& name, const std::string& def, bool terminal = false) - : name(name), definition(def), is_terminal(terminal) {} -}; - -struct EBNFGrammar { - std::vector rules; - std::string start_rule; - - void add_rule(const EBNFRule& rule) { - rules.push_back(rule); - } - - std::string to_string() const { - std::string result; - for (const auto& rule : rules) { - result += rule.name + " ::= " + rule.definition + "\n"; - } - return result; - } -}; - -struct ConstraintOptions { - bool allow_multiple_calls = true; - bool require_arguments = true; - bool strict_json = true; - std::vector allowed_functions; - - ConstraintOptions() = default; -}; - -} // namespace function_call -} // namespace llm \ No newline at end of file From 9a2bac26080cbf20a90ef43ade749c4f5c7ed190 Mon Sep 17 00:00:00 2001 From: dengyingxu1 Date: Tue, 12 Aug 2025 14:11:21 +0800 Subject: [PATCH 5/9] refactor: optimize parse_base_json and tool call flow. --- xllm/api_service/api_service.cpp | 5 +- xllm/api_service/chat_service_impl.cpp | 114 +++++++------- .../function_call/base_format_detector.cpp | 133 ++++++++++------ .../core/function_call/base_format_detector.h | 112 ++++++++------ .../function_call/function_call_parser.cpp | 142 +++++++++--------- .../core/function_call/function_call_parser.h | 66 ++++---- xllm/core/function_call/qwen25_detector.cpp | 142 ++++++++++++------ xllm/core/function_call/qwen25_detector.h | 11 +- xllm/core/util/uuid.cpp | 10 -- xllm/core/util/uuid.h | 4 - 10 files changed, 419 insertions(+), 320 deletions(-) diff --git a/xllm/api_service/api_service.cpp b/xllm/api_service/api_service.cpp index 6539e3af..0f087293 100644 --- a/xllm/api_service/api_service.cpp +++ b/xllm/api_service/api_service.cpp @@ -106,7 +106,10 @@ void ChatCompletionsImpl(std::unique_ptr& service, std::string attachment = std::move(ctrl->request_attachment().to_string()); std::string error; - auto json_status = google::protobuf::util::JsonStringToMessage(attachment, req_pb); + google::protobuf::util::JsonParseOptions options; + options.ignore_unknown_fields = true; + auto json_status = + google::protobuf::util::JsonStringToMessage(attachment, req_pb, options); if (!json_status.ok()) { ctrl->SetFailed(json_status.ToString()); LOG(ERROR) << "parse json to proto failed: " << json_status.ToString(); diff --git a/xllm/api_service/chat_service_impl.cpp b/xllm/api_service/chat_service_impl.cpp index b116f357..348c6c82 100644 --- a/xllm/api_service/chat_service_impl.cpp +++ b/xllm/api_service/chat_service_impl.cpp @@ -22,8 +22,8 @@ #include "chat_template/chat_template.h" #include "common/instance_name.h" #include "common/uuid.h" -#include "function_call/function_call.h" #include "function_call/core_types.h" +#include "function_call/function_call.h" #include "request/request_params.h" #include "util/utils.h" @@ -31,44 +31,42 @@ namespace xllm { namespace { -std::string generate_tool_call_id() { return "call_" + llm::generate_uuid(); } - struct ToolCallResult { std::optional> tool_calls; std::string text; std::string finish_reason; }; -ToolCallResult process_tool_calls( - const std::string& text, - const std::vector& tools, - const std::string& parser_format, - const std::string& finish_reason) { - +ToolCallResult process_tool_calls(std::string text, + const std::vector& tools, + const std::string& parser_format, + std::string finish_reason) { ToolCallResult result; - result.text = text; - result.finish_reason = finish_reason; function_call::FunctionCallParser parser(tools, parser_format); if (!parser.has_tool_call(text)) { + result.text = std::move(text); + result.finish_reason = std::move(finish_reason); return result; } - if (result.finish_reason == "stop") { + if (finish_reason == "stop") { result.finish_reason = "tool_calls"; + } else { + result.finish_reason = std::move(finish_reason); } try { auto [parsed_text, call_info_list] = parser.parse_non_stream(text); - result.text = parsed_text; + result.text = std::move(parsed_text); std::vector tool_calls; tool_calls.reserve(call_info_list.size()); for (const auto& call_info : call_info_list) { proto::ToolCall tool_call; - tool_call.set_id("call_" + generate_uuid()); + tool_call.set_id(function_call::utils::generate_tool_call_id()); tool_call.set_type("function"); auto* function = tool_call.mutable_function(); @@ -77,7 +75,7 @@ ToolCallResult process_tool_calls( } function->set_arguments(call_info.parameters); - tool_calls.push_back(tool_call); + tool_calls.emplace_back(tool_call); } result.tool_calls = std::move(tool_calls); @@ -88,7 +86,6 @@ ToolCallResult process_tool_calls( return result; } - void set_logprobs(proto::ChatChoice* choice, const std::optional>& logprobs) { if (!logprobs.has_value() || logprobs.value().empty()) { @@ -208,7 +205,6 @@ bool send_result_to_client_brpc(std::shared_ptr call, int64_t created_time, const std::string& model, const RequestOutput& req_output, - bool has_tools = false, const std::string& parser_format = "", const std::vector& tools = {}) { auto& response = call->response(); @@ -225,34 +221,30 @@ bool send_result_to_client_brpc(std::shared_ptr call, auto* message = choice->mutable_message(); message->set_role("assistant"); - auto setOutputAndFinishReason = [&]() { + auto set_output_and_finish_reason = [&]() { message->set_content(output.text); if (output.finish_reason.has_value()) { choice->set_finish_reason(output.finish_reason.value()); } }; - if (has_tools && !parser_format.empty()) { - std::string finish_reason; - if (output.finish_reason.has_value()) { - finish_reason = output.finish_reason.value(); - } + if (!tools.empty() && !parser_format.empty()) { + auto result = process_tool_calls( + output.text, tools, parser_format, output.finish_reason.value_or("")); - auto result = process_tool_calls(output.text, tools, parser_format, finish_reason); - message->set_content(result.text); - + if (result.tool_calls) { - for (const auto& tool_call : *result.tool_calls) { - *message->add_tool_calls() = tool_call; + for (auto& tool_call : *result.tool_calls) { + *message->add_tool_calls() = std::move(tool_call); } } - + if (!result.finish_reason.empty()) { choice->set_finish_reason(result.finish_reason); } } else { - setOutputAndFinishReason(); + set_output_and_finish_reason(); } } @@ -519,7 +511,6 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { first_message_sent = std::unordered_set(), request_id = request_params.request_id, created_time = absl::ToUnixSeconds(absl::Now()), - has_tools = request_params.has_tools(), proto_tools = request_params.proto_tools]( const RequestOutput& req_output) mutable -> bool { if (req_output.status.has_value()) { @@ -539,35 +530,48 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { master->get_rate_limiter()->decrease_one_request(); } - std::string parser_format = + const std::string parser_format = master->options().tool_call_parser().value_or(""); + const bool has_tool_support = + !proto_tools.empty() && !parser_format.empty(); + if (stream) { - if (has_tools && !parser_format.empty()) { + if (has_tool_support) { + // TODO: Support tool call streaming output LOG(ERROR) << "Tool call does not support streaming output"; + return send_delta_to_client_brpc(call, + include_usage, + &first_message_sent, + request_id, + created_time, + model, + req_output); + } else { + // Stream response without tool support + return send_delta_to_client_brpc(call, + include_usage, + &first_message_sent, + request_id, + created_time, + model, + req_output); + } + } else { + if (has_tool_support) { + // Non-stream response with tool support + return send_result_to_client_brpc(call, + request_id, + created_time, + model, + req_output, + parser_format, + proto_tools); + } else { + // Non-stream response without tool support + return send_result_to_client_brpc( + call, request_id, created_time, model, req_output); } - // send delta to client - return send_delta_to_client_brpc(call_data, - include_usage, - &first_message_sent, - request_id, - created_time, - model, - req_output); - } - - if (has_tools && !parser_format.empty()) { - return send_result_to_client_brpc(call_data, - request_id, - created_time, - model, - req_output, - has_tools, - parser_format, - proto_tools); } - - return send_result_to_client_brpc( - call, request_id, created_time, model, req_output); }); } diff --git a/xllm/core/function_call/base_format_detector.cpp b/xllm/core/function_call/base_format_detector.cpp index 483b71be..da69363a 100644 --- a/xllm/core/function_call/base_format_detector.cpp +++ b/xllm/core/function_call/base_format_detector.cpp @@ -1,66 +1,99 @@ #include "base_format_detector.h" -#include -#include + #include +#include +#include namespace llm { namespace function_call { -BaseFormatDetector::BaseFormatDetector() - : current_tool_id_(-1) - , current_tool_name_sent_(false) - , bot_token_("") - , eot_token_("") - , tool_call_separator_(", ") { -} +BaseFormatDetector::BaseFormatDetector() + : current_tool_id_(-1), + current_tool_name_sent_(false), + bot_token_(""), + eot_token_(""), + tool_call_separator_(", ") {} -std::unordered_map BaseFormatDetector::get_tool_indices(const std::vector& tools) { - std::unordered_map indices; - for (size_t i = 0; i < tools.size(); ++i) { - if (!tools[i].function().name().empty()) { - indices[tools[i].function().name()] = static_cast(i); - } +std::unordered_map BaseFormatDetector::get_tool_indices( + const std::vector& tools) { + std::unordered_map indices; + for (size_t i = 0; i < tools.size(); ++i) { + if (!tools[i].function().name().empty()) { + indices[tools[i].function().name()] = static_cast(i); + } else { + LOG(ERROR) << "Tool at index " << i + << " has empty function name, skipping"; } - return indices; + } + return indices; } -std::vector BaseFormatDetector::parse_base_json(const std::string& action_json, const std::vector& tools) { - auto tool_indices = get_tool_indices(tools); - std::vector results; - - // TODO: Replace with a more robust JSON library for better functionality and reliability - std::string trimmed = action_json; - trimmed.erase(0, trimmed.find_first_not_of(" \t\n\r")); - trimmed.erase(trimmed.find_last_not_of(" \t\n\r") + 1); - - if (trimmed.empty()) { - return results; +std::vector BaseFormatDetector::parse_base_json( + const nlohmann::json& json_obj, + const std::vector& tools) { + auto tool_indices = get_tool_indices(tools); + std::vector results; + + std::vector actions; + if (json_obj.is_array()) { + for (const auto& item : json_obj) { + actions.emplace_back(item); } - - std::regex name_regex("\"name\"\\s*:\\s*\"([^\"]+)\""); - std::regex args_regex("\"(?:parameters|arguments)\"\\s*:\\s*(\\{[^}]*\\})"); - - std::smatch name_match, args_match; - - if (std::regex_search(trimmed, name_match, name_regex)) { - std::string name = name_match[1].str(); - - if (tool_indices.find(name) != tool_indices.end()) { - std::string parameters = "{}"; - - if (std::regex_search(trimmed, args_match, args_regex)) { - parameters = args_match[1].str(); - } - - results.emplace_back(-1, name, parameters); - } else { - LOG(ERROR) << "Model attempted to call undefined function: " << name; - } + } else { + actions.emplace_back(json_obj); + } + + for (const auto& act : actions) { + if (!act.is_object()) { + LOG(ERROR) << "Invalid tool call item, expected object, got: " + << act.type_name(); + continue; + } + + std::string name; + if (act.contains("name") && act["name"].is_string()) { + name = act["name"].get(); + } else { + LOG(ERROR) << "Invalid tool call: missing 'name' field or invalid type"; + continue; } - - return results; -} + if (tool_indices.find(name) == tool_indices.end()) { + LOG(ERROR) << "Model attempted to call undefined function: " << name; + continue; + } + + nlohmann::json parameters = nlohmann::json::object(); + + if (act.contains("parameters")) { + parameters = act["parameters"]; + } else if (act.contains("arguments")) { + parameters = act["arguments"]; + } else { + LOG(ERROR) << "No parameters or arguments field found for tool: " << name; + } + + if (!parameters.is_object()) { + LOG(ERROR) << "Invalid arguments type for tool: " << name + << ", expected object, got: " << parameters.type_name(); + parameters = nlohmann::json::object(); + } + + std::string parameters_str; + try { + parameters_str = parameters.dump( + -1, ' ', false, nlohmann::json::error_handler_t::ignore); + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to serialize arguments for tool: " << name + << ", error: " << e.what(); + parameters_str = "{}"; + } + + results.emplace_back(-1, name, parameters_str); + } + + return results; +} } // namespace function_call } // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/base_format_detector.h b/xllm/core/function_call/base_format_detector.h index 1f0ed28c..38c68bd9 100644 --- a/xllm/core/function_call/base_format_detector.h +++ b/xllm/core/function_call/base_format_detector.h @@ -1,63 +1,75 @@ #pragma once -#include "core_types.h" -#include "chat.pb.h" #include + +#include +#include #include -#include #include -#include +#include + +#include "chat.pb.h" +#include "core_types.h" namespace llm { namespace function_call { class BaseFormatDetector { -public: - BaseFormatDetector(); - virtual ~BaseFormatDetector() = default; - - BaseFormatDetector(const BaseFormatDetector&) = delete; - BaseFormatDetector& operator=(const BaseFormatDetector&) = delete; - -protected: - // Streaming state management - // Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks - std::string buffer_; - - // Stores complete tool call info (name and arguments) for each tool being parsed. - // Used by serving layer for completion handling when streaming ends. - // Format: [{"name": str, "arguments": dict}, ...] - std::vector> prev_tool_call_arr_; - - // Index of currently streaming tool call. Starts at -1 (no active tool), - // increments as each tool completes. Tracks which tool's arguments are streaming. - int current_tool_id_; - - // Flag for whether current tool's name has been sent to client. - // Tool names sent first with empty parameters, then arguments stream incrementally. - bool current_tool_name_sent_; - - // Tracks raw JSON string content streamed to client for each tool's arguments. - // Critical for serving layer to calculate remaining content when streaming ends. - // Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72'] - std::vector streamed_args_for_tool_; - - // Token configuration (override in subclasses) - std::string bot_token_; - std::string eot_token_; - std::string tool_call_separator_; - - // Tool indices cache - std::unordered_map tool_indices_; - -public: - std::unordered_map get_tool_indices(const std::vector& tools); - - std::vector parse_base_json(const std::string& action_json, const std::vector& tools); - - virtual StreamingParseResult detect_and_parse(const std::string& text, const std::vector& tools) = 0; - - virtual bool has_tool_call(const std::string& text) = 0; + public: + BaseFormatDetector(); + virtual ~BaseFormatDetector() = default; + + BaseFormatDetector(const BaseFormatDetector&) = delete; + BaseFormatDetector& operator=(const BaseFormatDetector&) = delete; + + protected: + // Streaming state management + // Buffer for accumulating incomplete patterns that arrive across multiple + // streaming chunks + std::string buffer_; + + // Stores complete tool call info (name and arguments) for each tool being + // parsed. Used by serving layer for completion handling when streaming ends. + // Format: [{"name": str, "arguments": dict}, ...] + std::vector> prev_tool_call_arr_; + + // Index of currently streaming tool call. Starts at -1 (no active tool), + // increments as each tool completes. Tracks which tool's arguments are + // streaming. + int current_tool_id_; + + // Flag for whether current tool's name has been sent to client. + // Tool names sent first with empty parameters, then arguments stream + // incrementally. + bool current_tool_name_sent_; + + // Tracks raw JSON string content streamed to client for each tool's + // arguments. Critical for serving layer to calculate remaining content when + // streaming ends. Each index corresponds to a tool_id. Example: + // ['{"location": "San Francisco"', '{"temp": 72'] + std::vector streamed_args_for_tool_; + + // Token configuration (override in subclasses) + std::string bot_token_; + std::string eot_token_; + std::string tool_call_separator_; + + // Tool indices cache + std::unordered_map tool_indices_; + + public: + std::unordered_map get_tool_indices( + const std::vector& tools); + + std::vector parse_base_json( + const nlohmann::json& json_obj, + const std::vector& tools); + + virtual StreamingParseResult detect_and_parse( + const std::string& text, + const std::vector& tools) = 0; + + virtual bool has_tool_call(const std::string& text) = 0; }; } // namespace function_call diff --git a/xllm/core/function_call/function_call_parser.cpp b/xllm/core/function_call/function_call_parser.cpp index 4e939fcb..eeb43b15 100644 --- a/xllm/core/function_call/function_call_parser.cpp +++ b/xllm/core/function_call/function_call_parser.cpp @@ -1,99 +1,105 @@ #include "function_call_parser.h" -#include + #include +#include +#include "common/uuid.h" +#include "qwen25_detector.h" namespace llm { namespace function_call { -const std::unordered_map FunctionCallParser::ToolCallParserEnum = { - {"qwen25", "qwen25"}, - {"qwen3", "qwen25"}, - // TODO - // {"llama3", "llama3"}, - // {"mistral", "mistral"}, - // {"deepseekv3", "deepseekv3"}, - // {"pythonic", "pythonic"}, - // {"kimi_k2", "kimi_k2"}, - // {"qwen3_coder", "qwen3_coder"}, - // {"glm45", "glm45"}, - // {"step3", "step3"}, +const std::unordered_map + FunctionCallParser::ToolCallParserEnum = { + {"qwen25", "qwen25"}, + {"qwen3", "qwen25"}, + // TODO + // {"llama3", "llama3"}, + // {"mistral", "mistral"}, + // {"deepseekv3", "deepseekv3"}, + // {"pythonic", "pythonic"}, + // {"qwen3_coder", "qwen3_coder"}, + // {"glm45", "glm45"}, + // {"step3", "step3"}, }; -FunctionCallParser::FunctionCallParser(const std::vector& tools, const std::string& tool_call_parser) +FunctionCallParser::FunctionCallParser(const std::vector& tools, + const std::string& tool_call_parser) : tools_(tools) { - - detector_ = create_detector(tool_call_parser); - if (!detector_) { - throw std::invalid_argument("Unsupported tool_call_parser: " + tool_call_parser); - } + detector_ = create_detector(tool_call_parser); + if (!detector_) { + throw std::invalid_argument("Unsupported tool_call_parser: " + + tool_call_parser); + } } bool FunctionCallParser::has_tool_call(const std::string& text) const { - return detector_->has_tool_call(text); + return detector_->has_tool_call(text); } -std::tuple> FunctionCallParser::parse_non_stream(const std::string& full_text) { - StreamingParseResult parsed_result = detector_->detect_and_parse(full_text, tools_); - - if (!parsed_result.calls.empty()) { - return std::make_tuple(parsed_result.normal_text, parsed_result.calls); - } else { - return std::make_tuple(full_text, std::vector()); - } -} +std::tuple> +FunctionCallParser::parse_non_stream(const std::string& full_text) { + StreamingParseResult parsed_result = + detector_->detect_and_parse(full_text, tools_); + if (!parsed_result.calls.empty()) { + return std::make_tuple(parsed_result.normal_text, parsed_result.calls); + } else { + return std::make_tuple(full_text, std::vector()); + } +} -std::unique_ptr FunctionCallParser::create_detector(const std::string& tool_call_parser) { - auto it = ToolCallParserEnum.find(tool_call_parser); - if (it == ToolCallParserEnum.end()) { - return nullptr; - } - - if (it->second == "qwen25") { - return std::make_unique(); - } - - // if (tool_call_parser == "llama3") { - // return std::make_unique(); - // } - // if (tool_call_parser == "mistral") { - // return std::make_unique(); - // } - +std::unique_ptr FunctionCallParser::create_detector( + const std::string& tool_call_parser) { + auto it = ToolCallParserEnum.find(tool_call_parser); + if (it == ToolCallParserEnum.end()) { return nullptr; + } + + if (it->second == "qwen25") { + return std::make_unique(); + } + + // if (tool_call_parser == "llama3") { + // return std::make_unique(); + // } + // if (tool_call_parser == "mistral") { + // return std::make_unique(); + // } + + return nullptr; } namespace utils { std::vector parse_function_calls( - const std::string& text, + const std::string& text, const std::vector& tools, const std::string& parser_type) { - - try { - FunctionCallParser parser(tools, parser_type); - auto [normal_text, calls] = parser.parse_non_stream(text); - return calls; - } catch (const std::exception& e) { - LOG(ERROR) << "Error parsing function calls: " << e.what(); - return {}; - } + try { + FunctionCallParser parser(tools, parser_type); + auto [normal_text, calls] = parser.parse_non_stream(text); + return calls; + } catch (const std::exception& e) { + LOG(ERROR) << "Error parsing function calls: " << e.what(); + return {}; + } } - -bool has_function_calls( - const std::string& text, - const std::string& parser_type) { - - try { - FunctionCallParser parser({}, parser_type); - return parser.has_tool_call(text); - } catch (const std::exception& e) { - LOG(ERROR) << "Error checking function calls: " << e.what(); - return false; - } +bool has_function_calls(const std::string& text, + const std::string& parser_type) { + try { + FunctionCallParser parser({}, parser_type); + return parser.has_tool_call(text); + } catch (const std::exception& e) { + LOG(ERROR) << "Error checking function calls: " << e.what(); + return false; + } } +thread_local ShortUUID short_uuid; + +std::string generate_tool_call_id() { return "call_" + short_uuid.random(); } + } // namespace utils } // namespace function_call diff --git a/xllm/core/function_call/function_call_parser.h b/xllm/core/function_call/function_call_parser.h index 0855b6a4..8a59864c 100644 --- a/xllm/core/function_call/function_call_parser.h +++ b/xllm/core/function_call/function_call_parser.h @@ -1,62 +1,62 @@ #pragma once -#include "core_types.h" -#include "base_format_detector.h" -#include "qwen25_detector.h" -#include -#include #include -#include +#include #include +#include +#include + +#include "base_format_detector.h" +#include "core_types.h" namespace llm { namespace function_call { class FunctionCallParser { -public: - static const std::unordered_map ToolCallParserEnum; + public: + static const std::unordered_map ToolCallParserEnum; + + private: + std::unique_ptr detector_; + std::vector tools_; -private: - std::unique_ptr detector_; - std::vector tools_; + public: + FunctionCallParser(const std::vector& tools, + const std::string& tool_call_parser); -public: + ~FunctionCallParser() = default; - FunctionCallParser(const std::vector& tools, const std::string& tool_call_parser); - - ~FunctionCallParser() = default; - - FunctionCallParser(const FunctionCallParser&) = delete; - FunctionCallParser& operator=(const FunctionCallParser&) = delete; + FunctionCallParser(const FunctionCallParser&) = delete; + FunctionCallParser& operator=(const FunctionCallParser&) = delete; - bool has_tool_call(const std::string& text) const; + bool has_tool_call(const std::string& text) const; - std::tuple> parse_non_stream(const std::string& full_text); + std::tuple> parse_non_stream( + const std::string& full_text); - // StructuralTagResponseFormat get_structure_tag(); + // StructuralTagResponseFormat get_structure_tag(); - // std::tuple get_structure_constraint(const std::string& tool_choice); + // std::tuple get_structure_constraint(const + // std::string& tool_choice); - BaseFormatDetector* get_detector() const { return detector_.get(); } + BaseFormatDetector* get_detector() const { return detector_.get(); } -private: - std::unique_ptr create_detector(const std::string& tool_call_parser); + private: + std::unique_ptr create_detector( + const std::string& tool_call_parser); }; namespace utils { std::vector parse_function_calls( - const std::string& text, - const std::vector& tools, - const std::string& parser_type = "qwen25" -); - -bool has_function_calls( const std::string& text, - const std::string& parser_type = "qwen25" -); + const std::vector& tools, + const std::string& parser_type = "qwen25"); +bool has_function_calls(const std::string& text, + const std::string& parser_type = "qwen25"); +std::string generate_tool_call_id(); } // namespace utils } // namespace function_call diff --git a/xllm/core/function_call/qwen25_detector.cpp b/xllm/core/function_call/qwen25_detector.cpp index aff593ad..1dc0cb7c 100644 --- a/xllm/core/function_call/qwen25_detector.cpp +++ b/xllm/core/function_call/qwen25_detector.cpp @@ -1,67 +1,115 @@ #include "qwen25_detector.h" -#include + +#include #include +#include namespace llm { namespace function_call { Qwen25Detector::Qwen25Detector() : BaseFormatDetector() { - bot_token_ = "\n"; - eot_token_ = "\n"; - tool_call_separator_ = "\n"; + bot_token_ = "\n"; + eot_token_ = "\n"; + tool_call_separator_ = "\n"; + + std::string pattern = bot_token_ + "([\\s\\S]*?)" + eot_token_; + tool_call_regex_ = std::regex( + pattern, + std::regex_constants::ECMAScript | std::regex_constants::optimize); } bool Qwen25Detector::has_tool_call(const std::string& text) { - return text.find(bot_token_) != std::string::npos; + return text.find(bot_token_) != std::string::npos; } -StreamingParseResult Qwen25Detector::detect_and_parse(const std::string& text, const std::vector& tools) { - size_t idx = text.find(bot_token_); - std::string normal_text = (idx != std::string::npos) ? text.substr(0, idx) : text; - - while (!normal_text.empty() && std::isspace(normal_text.back())) { - normal_text.pop_back(); +std::string_view Qwen25Detector::trim_whitespace(std::string_view str) const { + const char* whitespace = " \t\n\r"; + + size_t start = str.find_first_not_of(whitespace); + if (start == std::string_view::npos) { + return std::string_view{}; + } + + size_t end = str.find_last_not_of(whitespace); + + return str.substr(start, end - start + 1); +} + +std::vector> Qwen25Detector::find_tool_call_ranges( + const std::string& text) const { + std::vector> ranges; + ranges.reserve(4); + + size_t search_pos = 0; + const size_t bot_token_len = bot_token_.length(); + const size_t eot_token_len = eot_token_.length(); + + while (search_pos < text.length()) { + size_t start_pos = text.find(bot_token_, search_pos); + if (start_pos == std::string::npos) { + break; } - - if (text.find(bot_token_) == std::string::npos) { - return StreamingParseResult(normal_text); + + size_t content_start = start_pos + bot_token_len; + size_t end_pos = text.find(eot_token_, content_start); + if (end_pos == std::string::npos) { + break; } - std::string escaped_bot_token = bot_token_; - std::string escaped_eot_token = eot_token_; - - std::string pattern = escaped_bot_token + "(.*?)" + escaped_eot_token; - - size_t pos = 0; - while ((pos = pattern.find("\n", pos)) != std::string::npos) { - pattern.replace(pos, 1, "\\n"); - pos += 2; + ranges.emplace_back(content_start, end_pos); + search_pos = end_pos + eot_token_len; + } + + return ranges; +} + +StreamingParseResult Qwen25Detector::detect_and_parse( + const std::string& text, + const std::vector& tools) { + size_t bot_token_pos = text.find(bot_token_); + + std::string normal_text; + if (bot_token_pos != std::string::npos) { + std::string_view normal_text_view(text.data(), bot_token_pos); + std::string_view trimmed = trim_whitespace(normal_text_view); + normal_text = std::string(trimmed); + } else { + std::string_view trimmed = trim_whitespace(text); + normal_text = std::string(trimmed); + return StreamingParseResult(normal_text); + } + + auto tool_call_ranges = find_tool_call_ranges(text); + + std::vector calls; + calls.reserve(tool_call_ranges.size()); + + for (const auto& range : tool_call_ranges) { + std::string_view content_view(text.data() + range.first, + range.second - range.first); + std::string_view trimmed_content = trim_whitespace(content_view); + + if (trimmed_content.empty()) { + continue; } - - std::regex tool_call_regex(pattern, std::regex_constants::ECMAScript); - std::sregex_iterator iter(text.begin(), text.end(), tool_call_regex); - std::sregex_iterator end; - - std::vector calls; - - for (; iter != end; ++iter) { - std::smatch match = *iter; - std::string match_result = match[1].str(); - - try { - match_result.erase(0, match_result.find_first_not_of(" \t\n\r")); - match_result.erase(match_result.find_last_not_of(" \t\n\r") + 1); - - auto parsed_calls = parse_base_json(match_result, tools); - calls.insert(calls.end(), parsed_calls.begin(), parsed_calls.end()); - } catch (const std::exception& e) { - LOG(ERROR) << "Failed to parse JSON part: " << match_result - << ", JSON parse error: " << e.what(); - continue; - } + + try { + std::string json_content(trimmed_content); + auto json_obj = nlohmann::json::parse(json_content); + auto parsed_calls = parse_base_json(json_obj, tools); + + calls.insert(calls.end(), + std::make_move_iterator(parsed_calls.begin()), + std::make_move_iterator(parsed_calls.end())); + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to parse JSON part: " + << std::string(trimmed_content) + << ", JSON parse error: " << e.what(); + continue; } - - return StreamingParseResult(normal_text, calls); + } + + return StreamingParseResult(std::move(normal_text), std::move(calls)); } } // namespace function_call diff --git a/xllm/core/function_call/qwen25_detector.h b/xllm/core/function_call/qwen25_detector.h index 010d4b46..ef61a6cd 100644 --- a/xllm/core/function_call/qwen25_detector.h +++ b/xllm/core/function_call/qwen25_detector.h @@ -2,6 +2,8 @@ #include "base_format_detector.h" #include +#include +#include namespace llm { namespace function_call { @@ -13,10 +15,15 @@ class Qwen25Detector : public BaseFormatDetector { virtual ~Qwen25Detector() = default; private: - std::string normal_text_buffer_; // Buffer for handling partial end tokens + std::string normal_text_buffer_; + + std::regex tool_call_regex_; + + std::string_view trim_whitespace(std::string_view str) const; + + std::vector> find_tool_call_ranges(const std::string& text) const; public: - bool has_tool_call(const std::string& text) override; StreamingParseResult detect_and_parse(const std::string& text, const std::vector& tools) override; diff --git a/xllm/core/util/uuid.cpp b/xllm/core/util/uuid.cpp index c86fa969..18c917fb 100644 --- a/xllm/core/util/uuid.cpp +++ b/xllm/core/util/uuid.cpp @@ -18,14 +18,4 @@ std::string ShortUUID::random(size_t len) { return uuid; } -std::string generate_uuid(size_t len) { - static thread_local ShortUUID uuid_generator; - return uuid_generator.random(len); -} - -std::string generate_uuid(size_t len) { - static thread_local ShortUUID uuid_generator; - return uuid_generator.random(len); -} - } // namespace xllm \ No newline at end of file diff --git a/xllm/core/util/uuid.h b/xllm/core/util/uuid.h index 42c647c5..0c2604be 100644 --- a/xllm/core/util/uuid.h +++ b/xllm/core/util/uuid.h @@ -18,8 +18,4 @@ class ShortUUID { absl::BitGen gen_; }; -std::string generate_uuid(size_t len = 22); - -std::string generate_uuid(size_t len = 22); - } // namespace xllm \ No newline at end of file From 18d03ddd191e589558b5b90fd8040e4cd41f536c Mon Sep 17 00:00:00 2001 From: dengyingxu1 Date: Wed, 13 Aug 2025 21:53:21 +0800 Subject: [PATCH 6/9] refactor: parse proto::Tool during RequestParams construction. --- xllm/api_service/chat_service_impl.cpp | 42 +++---- .../core/framework/request/request_params.cpp | 79 +++++++++++++ xllm/core/framework/request/request_params.h | 18 ++- .../function_call/base_format_detector.cpp | 8 +- .../core/function_call/base_format_detector.h | 9 +- xllm/core/function_call/core_types.h | 105 +++++++++++------- xllm/core/function_call/function_call.h | 21 ++-- .../function_call/function_call_parser.cpp | 4 +- .../core/function_call/function_call_parser.h | 6 +- xllm/core/function_call/qwen25_detector.cpp | 2 +- xllm/core/function_call/qwen25_detector.h | 44 ++++---- xllm/core/runtime/llm_master.cpp | 2 +- 12 files changed, 220 insertions(+), 120 deletions(-) diff --git a/xllm/api_service/chat_service_impl.cpp b/xllm/api_service/chat_service_impl.cpp index 348c6c82..68cbe85f 100644 --- a/xllm/api_service/chat_service_impl.cpp +++ b/xllm/api_service/chat_service_impl.cpp @@ -37,10 +37,11 @@ struct ToolCallResult { std::string finish_reason; }; -ToolCallResult process_tool_calls(std::string text, - const std::vector& tools, - const std::string& parser_format, - std::string finish_reason) { +ToolCallResult process_tool_calls( + std::string text, + const std::vector& tools, + const std::string& parser_format, + std::string finish_reason) { ToolCallResult result; function_call::FunctionCallParser parser(tools, parser_format); @@ -200,13 +201,14 @@ bool send_delta_to_client_brpc(std::shared_ptr call, } template -bool send_result_to_client_brpc(std::shared_ptr call, - const std::string& request_id, - int64_t created_time, - const std::string& model, - const RequestOutput& req_output, - const std::string& parser_format = "", - const std::vector& tools = {}) { +bool send_result_to_client_brpc( + std::shared_ptr call, + const std::string& request_id, + int64_t created_time, + const std::string& model, + const RequestOutput& req_output, + const std::string& parser_format = "", + const std::vector& tools = {}) { auto& response = call->response(); response.set_object("chat.completion"); response.set_id(request_id); @@ -486,18 +488,6 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { include_usage = rpc_request.stream_options().include_usage(); } - if (rpc_request.tools_size() > 0) { - request_params.proto_tools.assign(rpc_request.tools().begin(), - rpc_request.tools().end()); - - // TODO: Implement support for 'required' option in tool_choice. - if (rpc_request.has_tool_choice()) { - request_params.tool_choice = rpc_request.tool_choice(); - } else { - request_params.tool_choice = "auto"; - } - } - // schedule the request master_->handle_request( std::move(messages), @@ -511,7 +501,7 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { first_message_sent = std::unordered_set(), request_id = request_params.request_id, created_time = absl::ToUnixSeconds(absl::Now()), - proto_tools = request_params.proto_tools]( + json_tools = request_params.tools]( const RequestOutput& req_output) mutable -> bool { if (req_output.status.has_value()) { const auto& status = req_output.status.value(); @@ -533,7 +523,7 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { const std::string parser_format = master->options().tool_call_parser().value_or(""); const bool has_tool_support = - !proto_tools.empty() && !parser_format.empty(); + !json_tools.empty() && !parser_format.empty(); if (stream) { if (has_tool_support) { @@ -565,7 +555,7 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { model, req_output, parser_format, - proto_tools); + json_tools); } else { // Non-stream response without tool support return send_result_to_client_brpc( diff --git a/xllm/core/framework/request/request_params.cpp b/xllm/core/framework/request/request_params.cpp index 75211217..6941cd00 100644 --- a/xllm/core/framework/request/request_params.cpp +++ b/xllm/core/framework/request/request_params.cpp @@ -25,6 +25,74 @@ std::string generate_chat_request_id() { } // namespace +nlohmann::json RequestParams::proto_value_to_json( + const google::protobuf::Value& pb_value) { + switch (pb_value.kind_case()) { + case google::protobuf::Value::kNullValue: + return nlohmann::json(nullptr); + + case google::protobuf::Value::kNumberValue: + return nlohmann::json(pb_value.number_value()); + + case google::protobuf::Value::kStringValue: + return nlohmann::json(pb_value.string_value()); + + case google::protobuf::Value::kBoolValue: + return nlohmann::json(pb_value.bool_value()); + + case google::protobuf::Value::kStructValue: + return proto_struct_to_json(pb_value.struct_value()); + + case google::protobuf::Value::kListValue: { + nlohmann::json array = nlohmann::json::array(); + const auto& list = pb_value.list_value(); + for (const auto& item : list.values()) { + array.push_back(proto_value_to_json(item)); + } + return array; + } + + case google::protobuf::Value::KIND_NOT_SET: + default: + return nlohmann::json(nullptr); + } +} + +nlohmann::json RequestParams::proto_struct_to_json( + const google::protobuf::Struct& pb_struct) { + nlohmann::json result = nlohmann::json::object(); + + for (const auto& field : pb_struct.fields()) { + result[field.first] = proto_value_to_json(field.second); + } + + return result; +} + +void RequestParams::parse_tools_from_proto( + const google::protobuf::RepeatedPtrField& proto_tools) { + tools.clear(); + tools.reserve(proto_tools.size()); + + for (const auto& proto_tool : proto_tools) { + function_call::JsonTool json_tool; + json_tool.type = proto_tool.type(); + + const auto& proto_function = proto_tool.function(); + json_tool.function.name = proto_function.name(); + json_tool.function.description = proto_function.description(); + + if (proto_function.has_parameters()) { + json_tool.function.parameters = + proto_struct_to_json(proto_function.parameters()); + } else { + json_tool.function.parameters = nlohmann::json::object(); + } + + tools.emplace_back(std::move(json_tool)); + } +} + RequestParams::RequestParams(const proto::CompletionRequest& request, const std::string& x_rid, const std::string& x_rtime) { @@ -157,6 +225,17 @@ void InitFromChatRequest(RequestParams& params, const ChatRequest& request) { params.streaming = false; } } + + // Parse tools from proto request + if (request.tools_size() > 0) { + parse_tools_from_proto(request.tools()); + + if (request.has_tool_choice()) { + tool_choice = request.tool_choice(); + } else { + tool_choice = "auto"; + } + } } } // namespace diff --git a/xllm/core/framework/request/request_params.h b/xllm/core/framework/request/request_params.h index 4fac5def..e2cea828 100644 --- a/xllm/core/framework/request/request_params.h +++ b/xllm/core/framework/request/request_params.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include #include @@ -9,6 +10,7 @@ #include "completion.pb.h" #include "core/common/macros.h" #include "embedding.pb.h" +#include "function_call/core_types.h" #include "multimodal.pb.h" #include "request_output.h" @@ -102,12 +104,18 @@ struct RequestParams { // decode address. std::string decode_address; - // decode address. - std::string decode_address; - - std::vector proto_tools; + // JSON-based tools (replacing proto_tools) + std::vector tools; std::string tool_choice = "auto"; - bool has_tools() const { return !proto_tools.empty(); } + bool has_tools() const { return !tools.empty(); } + + private: + void parse_tools_from_proto( + const google::protobuf::RepeatedPtrField& proto_tools); + + nlohmann::json proto_struct_to_json( + const google::protobuf::Struct& pb_struct); + nlohmann::json proto_value_to_json(const google::protobuf::Value& pb_value); }; } // namespace xllm diff --git a/xllm/core/function_call/base_format_detector.cpp b/xllm/core/function_call/base_format_detector.cpp index da69363a..b4ccb3b3 100644 --- a/xllm/core/function_call/base_format_detector.cpp +++ b/xllm/core/function_call/base_format_detector.cpp @@ -15,11 +15,11 @@ BaseFormatDetector::BaseFormatDetector() tool_call_separator_(", ") {} std::unordered_map BaseFormatDetector::get_tool_indices( - const std::vector& tools) { + const std::vector& tools) { std::unordered_map indices; for (size_t i = 0; i < tools.size(); ++i) { - if (!tools[i].function().name().empty()) { - indices[tools[i].function().name()] = static_cast(i); + if (!tools[i].function.name.empty()) { + indices[tools[i].function.name] = static_cast(i); } else { LOG(ERROR) << "Tool at index " << i << " has empty function name, skipping"; @@ -30,7 +30,7 @@ std::unordered_map BaseFormatDetector::get_tool_indices( std::vector BaseFormatDetector::parse_base_json( const nlohmann::json& json_obj, - const std::vector& tools) { + const std::vector& tools) { auto tool_indices = get_tool_indices(tools); std::vector results; diff --git a/xllm/core/function_call/base_format_detector.h b/xllm/core/function_call/base_format_detector.h index 38c68bd9..7f677f7d 100644 --- a/xllm/core/function_call/base_format_detector.h +++ b/xllm/core/function_call/base_format_detector.h @@ -59,15 +59,14 @@ class BaseFormatDetector { public: std::unordered_map get_tool_indices( - const std::vector& tools); + const std::vector& tools); - std::vector parse_base_json( - const nlohmann::json& json_obj, - const std::vector& tools); + std::vector parse_base_json(const nlohmann::json& json_obj, + const std::vector& tools); virtual StreamingParseResult detect_and_parse( const std::string& text, - const std::vector& tools) = 0; + const std::vector& tools) = 0; virtual bool has_tool_call(const std::string& text) = 0; }; diff --git a/xllm/core/function_call/core_types.h b/xllm/core/function_call/core_types.h index 165fd615..d931fa42 100644 --- a/xllm/core/function_call/core_types.h +++ b/xllm/core/function_call/core_types.h @@ -1,60 +1,83 @@ #pragma once +#include +#include +#include #include #include -#include -#include -#include "chat.pb.h" namespace llm { namespace function_call { +struct JsonFunction { + std::string name; + std::string description; + nlohmann::json parameters; + + JsonFunction() = default; + JsonFunction(const std::string& func_name, + const std::string& desc, + const nlohmann::json& params) + : name(func_name), description(desc), parameters(params) {} +}; + +struct JsonTool { + std::string type; // "function" + JsonFunction function; + + JsonTool() : type("function") {} + JsonTool(const std::string& tool_type, const JsonFunction& func) + : type(tool_type), function(func) {} +}; + struct ToolCallItem { - int tool_index; - std::optional name; - std::string parameters; // JSON string - - ToolCallItem() : tool_index(-1), parameters("") {} - - ToolCallItem(int index, const std::optional& func_name, const std::string& params) - : tool_index(index), name(func_name), parameters(params) {} + int tool_index; + std::optional name; + std::string parameters; // JSON string + + ToolCallItem() : tool_index(-1), parameters("") {} + + ToolCallItem(int index, + const std::optional& func_name, + const std::string& params) + : tool_index(index), name(func_name), parameters(params) {} }; struct StreamingParseResult { - std::string normal_text; - std::vector calls; - - StreamingParseResult() = default; - - StreamingParseResult(const std::string& text) : normal_text(text) {} - - StreamingParseResult(const std::vector& tool_calls) : calls(tool_calls) {} - - StreamingParseResult(const std::string& text, const std::vector& tool_calls) - : normal_text(text), calls(tool_calls) {} - - bool has_calls() const { - return !calls.empty(); - } - - void clear() { - normal_text.clear(); - calls.clear(); - } -}; + std::string normal_text; + std::vector calls; + StreamingParseResult() = default; -struct StructureInfo { - std::string begin; - std::string end; - std::string trigger; - - StructureInfo() = default; - - StructureInfo(const std::string& begin_str, const std::string& end_str, const std::string& trigger_str) - : begin(begin_str), end(end_str), trigger(trigger_str) {} + StreamingParseResult(const std::string& text) : normal_text(text) {} + + StreamingParseResult(const std::vector& tool_calls) + : calls(tool_calls) {} + + StreamingParseResult(const std::string& text, + const std::vector& tool_calls) + : normal_text(text), calls(tool_calls) {} + + bool has_calls() const { return !calls.empty(); } + + void clear() { + normal_text.clear(); + calls.clear(); + } }; +struct StructureInfo { + std::string begin; + std::string end; + std::string trigger; + + StructureInfo() = default; + + StructureInfo(const std::string& begin_str, + const std::string& end_str, + const std::string& trigger_str) + : begin(begin_str), end(end_str), trigger(trigger_str) {} +}; using GetInfoFunc = std::function; diff --git a/xllm/core/function_call/function_call.h b/xllm/core/function_call/function_call.h index a4781ab3..70aac8bb 100644 --- a/xllm/core/function_call/function_call.h +++ b/xllm/core/function_call/function_call.h @@ -1,9 +1,9 @@ #pragma once -#include "core_types.h" #include "base_format_detector.h" -#include "qwen25_detector.h" +#include "core_types.h" #include "function_call_parser.h" +#include "qwen25_detector.h" namespace llm { namespace function_call { @@ -12,19 +12,16 @@ using Parser = FunctionCallParser; using Detector = BaseFormatDetector; using QwenDetector = Qwen25Detector; -inline std::vector parse( - const std::string& text, - const std::vector& tools, - const std::string& format = "qwen25") { - return utils::parse_function_calls(text, tools, format); +inline std::vector parse(const std::string& text, + const std::vector& tools, + const std::string& format = "qwen25") { + return utils::parse_function_calls(text, tools, format); } -inline bool has_calls( - const std::string& text, - const std::string& format = "qwen25") { - return utils::has_function_calls(text, format); +inline bool has_calls(const std::string& text, + const std::string& format = "qwen25") { + return utils::has_function_calls(text, format); } - } // namespace function_call } // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/function_call_parser.cpp b/xllm/core/function_call/function_call_parser.cpp index eeb43b15..87d8b4c3 100644 --- a/xllm/core/function_call/function_call_parser.cpp +++ b/xllm/core/function_call/function_call_parser.cpp @@ -22,7 +22,7 @@ const std::unordered_map // {"step3", "step3"}, }; -FunctionCallParser::FunctionCallParser(const std::vector& tools, +FunctionCallParser::FunctionCallParser(const std::vector& tools, const std::string& tool_call_parser) : tools_(tools) { detector_ = create_detector(tool_call_parser); @@ -73,7 +73,7 @@ namespace utils { std::vector parse_function_calls( const std::string& text, - const std::vector& tools, + const std::vector& tools, const std::string& parser_type) { try { FunctionCallParser parser(tools, parser_type); diff --git a/xllm/core/function_call/function_call_parser.h b/xllm/core/function_call/function_call_parser.h index 8a59864c..94a36933 100644 --- a/xllm/core/function_call/function_call_parser.h +++ b/xllm/core/function_call/function_call_parser.h @@ -18,10 +18,10 @@ class FunctionCallParser { private: std::unique_ptr detector_; - std::vector tools_; + std::vector tools_; public: - FunctionCallParser(const std::vector& tools, + FunctionCallParser(const std::vector& tools, const std::string& tool_call_parser); ~FunctionCallParser() = default; @@ -50,7 +50,7 @@ namespace utils { std::vector parse_function_calls( const std::string& text, - const std::vector& tools, + const std::vector& tools, const std::string& parser_type = "qwen25"); bool has_function_calls(const std::string& text, diff --git a/xllm/core/function_call/qwen25_detector.cpp b/xllm/core/function_call/qwen25_detector.cpp index 1dc0cb7c..7e008c6b 100644 --- a/xllm/core/function_call/qwen25_detector.cpp +++ b/xllm/core/function_call/qwen25_detector.cpp @@ -65,7 +65,7 @@ std::vector> Qwen25Detector::find_tool_call_ranges( StreamingParseResult Qwen25Detector::detect_and_parse( const std::string& text, - const std::vector& tools) { + const std::vector& tools) { size_t bot_token_pos = text.find(bot_token_); std::string normal_text; diff --git a/xllm/core/function_call/qwen25_detector.h b/xllm/core/function_call/qwen25_detector.h index ef61a6cd..ac78ec6e 100644 --- a/xllm/core/function_call/qwen25_detector.h +++ b/xllm/core/function_call/qwen25_detector.h @@ -1,32 +1,36 @@ #pragma once -#include "base_format_detector.h" +#include #include #include -#include + +#include "base_format_detector.h" namespace llm { namespace function_call { class Qwen25Detector : public BaseFormatDetector { -public: - Qwen25Detector(); - - virtual ~Qwen25Detector() = default; - -private: - std::string normal_text_buffer_; - - std::regex tool_call_regex_; - - std::string_view trim_whitespace(std::string_view str) const; - - std::vector> find_tool_call_ranges(const std::string& text) const; - -public: - bool has_tool_call(const std::string& text) override; - - StreamingParseResult detect_and_parse(const std::string& text, const std::vector& tools) override; + public: + Qwen25Detector(); + + virtual ~Qwen25Detector() = default; + + private: + std::string normal_text_buffer_; + + std::regex tool_call_regex_; + + std::string_view trim_whitespace(std::string_view str) const; + + std::vector> find_tool_call_ranges( + const std::string& text) const; + + public: + bool has_tool_call(const std::string& text) override; + + StreamingParseResult detect_and_parse( + const std::string& text, + const std::vector& tools) override; }; } // namespace function_call diff --git a/xllm/core/runtime/llm_master.cpp b/xllm/core/runtime/llm_master.cpp index b42f3953..183ea5c7 100644 --- a/xllm/core/runtime/llm_master.cpp +++ b/xllm/core/runtime/llm_master.cpp @@ -424,7 +424,7 @@ std::shared_ptr LLMMaster::generate_request( Timer timer; std::optional prompt; if (sp.has_tools()) { - prompt = chat_template_->apply(messages, sp.proto_tools); + prompt = chat_template_->apply(messages, sp.tools); } else { prompt = chat_template_->apply(messages); } From b07906cb995306cdbe3bcdb809dfa6422df2e590 Mon Sep 17 00:00:00 2001 From: dengyingxu1 Date: Wed, 13 Aug 2025 23:29:43 +0800 Subject: [PATCH 7/9] perf: optimize result handling with arena allocation and swap. --- xllm/api_service/chat_service_impl.cpp | 38 +++++++++++++++----------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/xllm/api_service/chat_service_impl.cpp b/xllm/api_service/chat_service_impl.cpp index 68cbe85f..7b764865 100644 --- a/xllm/api_service/chat_service_impl.cpp +++ b/xllm/api_service/chat_service_impl.cpp @@ -32,7 +32,7 @@ namespace { struct ToolCallResult { - std::optional> tool_calls; + std::optional> tool_calls; std::string text; std::string finish_reason; }; @@ -41,7 +41,8 @@ ToolCallResult process_tool_calls( std::string text, const std::vector& tools, const std::string& parser_format, - std::string finish_reason) { + std::string finish_reason, + google::protobuf::Arena* arena = nullptr) { ToolCallResult result; function_call::FunctionCallParser parser(tools, parser_format); @@ -62,21 +63,23 @@ ToolCallResult process_tool_calls( auto [parsed_text, call_info_list] = parser.parse_non_stream(text); result.text = std::move(parsed_text); - std::vector tool_calls; - tool_calls.reserve(call_info_list.size()); + google::protobuf::RepeatedPtrField tool_calls; for (const auto& call_info : call_info_list) { - proto::ToolCall tool_call; - tool_call.set_id(function_call::utils::generate_tool_call_id()); - tool_call.set_type("function"); + proto::ToolCall* tool_call = + arena ? google::protobuf::Arena::CreateMessage(arena) + : new proto::ToolCall(); - auto* function = tool_call.mutable_function(); + tool_call->set_id(function_call::utils::generate_tool_call_id()); + tool_call->set_type("function"); + + auto* function = tool_call->mutable_function(); if (call_info.name) { function->set_name(*call_info.name); } function->set_arguments(call_info.parameters); - tool_calls.emplace_back(tool_call); + tool_calls.AddAllocated(tool_call); } result.tool_calls = std::move(tool_calls); @@ -231,19 +234,22 @@ bool send_result_to_client_brpc( }; if (!tools.empty() && !parser_format.empty()) { - auto result = process_tool_calls( - output.text, tools, parser_format, output.finish_reason.value_or("")); + auto* arena = response.GetArena(); + auto result = process_tool_calls(output.text, + tools, + parser_format, + output.finish_reason.value_or(""), + arena); - message->set_content(result.text); + message->mutable_content()->swap(result.text); if (result.tool_calls) { - for (auto& tool_call : *result.tool_calls) { - *message->add_tool_calls() = std::move(tool_call); - } + auto& source_tool_calls = *result.tool_calls; + message->mutable_tool_calls()->Swap(&source_tool_calls); } if (!result.finish_reason.empty()) { - choice->set_finish_reason(result.finish_reason); + choice->mutable_finish_reason()->swap(result.finish_reason); } } else { set_output_and_finish_reason(); From daffeb841f4e2e7dad39036d914a9c0b3d8d9744 Mon Sep 17 00:00:00 2001 From: dengyingxu1 Date: Thu, 14 Aug 2025 22:26:51 +0800 Subject: [PATCH 8/9] feat: add non-streaming tool call support for kimi_k2 and deepseekv3. --- xllm/core/function_call/CMakeLists.txt | 22 + .../function_call/deepseekv3_detector.cpp | 175 +++++++ xllm/core/function_call/deepseekv3_detector.h | 31 ++ .../deepseekv3_detector_test.cpp | 418 ++++++++++++++++ xllm/core/function_call/function_call.h | 6 +- .../function_call/function_call_parser.cpp | 27 +- xllm/core/function_call/kimik2_detector.cpp | 152 ++++++ xllm/core/function_call/kimik2_detector.h | 54 +++ .../function_call/kimik2_detector_test.cpp | 455 ++++++++++++++++++ .../function_call/qwen25_detector_test.cpp | 328 +++++++++++++ 10 files changed, 1659 insertions(+), 9 deletions(-) create mode 100644 xllm/core/function_call/deepseekv3_detector.cpp create mode 100644 xllm/core/function_call/deepseekv3_detector.h create mode 100644 xllm/core/function_call/deepseekv3_detector_test.cpp create mode 100644 xllm/core/function_call/kimik2_detector.cpp create mode 100644 xllm/core/function_call/kimik2_detector.h create mode 100644 xllm/core/function_call/kimik2_detector_test.cpp create mode 100644 xllm/core/function_call/qwen25_detector_test.cpp diff --git a/xllm/core/function_call/CMakeLists.txt b/xllm/core/function_call/CMakeLists.txt index 7b9b5f6d..eb246a6a 100644 --- a/xllm/core/function_call/CMakeLists.txt +++ b/xllm/core/function_call/CMakeLists.txt @@ -8,14 +8,36 @@ cc_library ( core_types.h base_format_detector.h qwen25_detector.h + kimik2_detector.h + deepseekv3_detector.h function_call_parser.h function_call.h SRCS base_format_detector.cpp qwen25_detector.cpp + kimik2_detector.cpp + deepseekv3_detector.cpp function_call_parser.cpp DEPS nlohmann_json::nlohmann_json glog::glog proto::xllm_proto ) + +function(add_detector_test TEST_NAME) + cc_test( + NAME + ${TEST_NAME} + SRCS + ${TEST_NAME}.cpp + DEPS + :function_call + GTest::gtest + GTest::gtest_main + nlohmann_json::nlohmann_json + ) +endfunction() + +add_detector_test(qwen25_detector_test) +add_detector_test(kimik2_detector_test) +add_detector_test(deepseekv3_detector_test) diff --git a/xllm/core/function_call/deepseekv3_detector.cpp b/xllm/core/function_call/deepseekv3_detector.cpp new file mode 100644 index 00000000..e3e1cc90 --- /dev/null +++ b/xllm/core/function_call/deepseekv3_detector.cpp @@ -0,0 +1,175 @@ +#include "deepseekv3_detector.h" + +#include +#include +#include +#include + +namespace llm { +namespace function_call { + +DeepSeekV3Detector::DeepSeekV3Detector() : BaseFormatDetector() { + bot_token_ = "<|tool▁calls▁begin|>"; + eot_token_ = "<|tool▁calls▁end|>"; + tool_call_separator_ = ""; +} + +bool DeepSeekV3Detector::has_tool_call(const std::string& text) { + return text.find(bot_token_) != std::string::npos; +} + +std::string_view DeepSeekV3Detector::trim_whitespace( + std::string_view str) const { + const char* whitespace = " \t\n\r"; + + size_t start = str.find_first_not_of(whitespace); + if (start == std::string_view::npos) { + return std::string_view{}; + } + + size_t end = str.find_last_not_of(whitespace); + + return str.substr(start, end - start + 1); +} + +std::vector> +DeepSeekV3Detector::find_tool_call_ranges(const std::string& text) const { + std::vector> ranges; + ranges.reserve(4); + + const std::string call_begin = "<|tool▁call▁begin|>"; + const std::string call_end = "<|tool▁call▁end|>"; + + size_t search_pos = 0; + const size_t call_begin_len = call_begin.length(); + const size_t call_end_len = call_end.length(); + + while (search_pos < text.length()) { + size_t start_pos = text.find(call_begin, search_pos); + if (start_pos == std::string::npos) { + break; + } + + size_t content_start = start_pos + call_begin_len; + size_t end_pos = text.find(call_end, content_start); + if (end_pos == std::string::npos) { + break; + } + + ranges.emplace_back(content_start, end_pos); + search_pos = end_pos + call_end_len; + } + + return ranges; +} + +StreamingParseResult DeepSeekV3Detector::detect_and_parse( + const std::string& text, + const std::vector& tools) { + size_t bot_token_pos = text.find(bot_token_); + + std::string normal_text; + if (bot_token_pos != std::string::npos) { + std::string_view normal_text_view(text.data(), bot_token_pos); + std::string_view trimmed = trim_whitespace(normal_text_view); + normal_text = std::string(trimmed); + } else { + std::string_view trimmed = trim_whitespace(text); + normal_text = std::string(trimmed); + return StreamingParseResult(normal_text); + } + + auto tool_call_ranges = find_tool_call_ranges(text); + + std::vector calls; + calls.reserve(tool_call_ranges.size()); + + for (const auto& range : tool_call_ranges) { + std::string_view content_view(text.data() + range.first, + range.second - range.first); + std::string_view trimmed_content = trim_whitespace(content_view); + + if (trimmed_content.empty()) { + continue; + } + + try { + // Parse DeepSeek V3 format: function_name\n```json\n{args}\n``` + const std::string tool_sep = "<|tool▁sep|>"; + size_t sep_pos = trimmed_content.find(tool_sep); + if (sep_pos == std::string_view::npos) { + LOG(ERROR) << "Failed to find tool separator in: " + << std::string(trimmed_content); + continue; + } + + // Extract function name (between tool_sep and first newline) + size_t name_start = sep_pos + tool_sep.length(); + size_t name_end = trimmed_content.find('\n', name_start); + if (name_end == std::string_view::npos) { + LOG(ERROR) << "Failed to find function name end in: " + << std::string(trimmed_content); + continue; + } + + std::string_view func_name_view = + trimmed_content.substr(name_start, name_end - name_start); + std::string_view func_name_trimmed = trim_whitespace(func_name_view); + std::string func_name(func_name_trimmed); + + // Find JSON block (between ```json\n and \n```) + const std::string json_start = "```json\n"; + const std::string json_end = "\n```"; + + size_t json_start_pos = trimmed_content.find(json_start, name_end); + if (json_start_pos == std::string_view::npos) { + LOG(ERROR) << "Failed to find JSON start in: " + << std::string(trimmed_content); + continue; + } + + size_t json_content_start = json_start_pos + json_start.length(); + size_t json_end_pos = trimmed_content.find(json_end, json_content_start); + if (json_end_pos == std::string_view::npos) { + LOG(ERROR) << "Failed to find JSON end in: " + << std::string(trimmed_content); + continue; + } + + std::string_view json_view = trimmed_content.substr( + json_content_start, json_end_pos - json_content_start); + std::string_view json_trimmed = trim_whitespace(json_view); + + // Parse JSON arguments + nlohmann::json func_args; + try { + std::string json_content(json_trimmed); + func_args = nlohmann::json::parse(json_content); + } catch (const nlohmann::json::parse_error& e) { + LOG(ERROR) << "Failed to parse JSON arguments: " + << std::string(json_trimmed) + << ", JSON parse error: " << e.what(); + continue; + } + + // Create JSON object for parse_base_json + nlohmann::json match_json; + match_json["name"] = func_name; + match_json["parameters"] = func_args; + + auto parsed_calls = parse_base_json(match_json, tools); + calls.insert(calls.end(), + std::make_move_iterator(parsed_calls.begin()), + std::make_move_iterator(parsed_calls.end())); + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to parse tool call: " + << std::string(trimmed_content) << ", error: " << e.what(); + continue; + } + } + + return StreamingParseResult(std::move(normal_text), std::move(calls)); +} + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/deepseekv3_detector.h b/xllm/core/function_call/deepseekv3_detector.h new file mode 100644 index 00000000..55468cc7 --- /dev/null +++ b/xllm/core/function_call/deepseekv3_detector.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include "base_format_detector.h" + +namespace llm { +namespace function_call { + +class DeepSeekV3Detector : public BaseFormatDetector { + public: + DeepSeekV3Detector(); + + virtual ~DeepSeekV3Detector() = default; + + bool has_tool_call(const std::string& text) override; + + StreamingParseResult detect_and_parse( + const std::string& text, + const std::vector& tools) override; + + private: + std::string func_call_regex_; + std::string func_detail_regex_; + std::string_view trim_whitespace(std::string_view str) const; + std::vector> find_tool_call_ranges( + const std::string& text) const; +}; + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/deepseekv3_detector_test.cpp b/xllm/core/function_call/deepseekv3_detector_test.cpp new file mode 100644 index 00000000..5284f08f --- /dev/null +++ b/xllm/core/function_call/deepseekv3_detector_test.cpp @@ -0,0 +1,418 @@ +#include "deepseekv3_detector.h" + +#include + +#include +#include +#include + +namespace llm { +namespace function_call { + +class DeepSeekV3DetectorTest : public ::testing::Test { + protected: + void SetUp() override { + detector_ = std::make_unique(); + + // Setup test tools + nlohmann::json weather_params = { + {"type", "object"}, + {"properties", + {{"location", + {{"type", "string"}, + {"description", "The city and state, e.g. San Francisco, CA"}}}, + {"unit", {{"type", "string"}, {"enum", {"celsius", "fahrenheit"}}}}}}, + {"required", {"location"}}}; + + JsonFunction weather_func("get_current_weather", + "Get the current weather in a given location", + weather_params); + weather_tool_ = JsonTool("function", weather_func); + + nlohmann::json calculator_params = { + {"type", "object"}, + {"properties", + {{"expression", + {{"type", "string"}, + {"description", "Mathematical expression to evaluate"}}}}}, + {"required", {"expression"}}}; + + JsonFunction calculator_func( + "calculate", "Calculate mathematical expressions", calculator_params); + calculator_tool_ = JsonTool("function", calculator_func); + + tools_ = {weather_tool_, calculator_tool_}; + } + + std::unique_ptr detector_; + JsonTool weather_tool_; + JsonTool calculator_tool_; + std::vector tools_; +}; + +// Test constructor and basic properties +TEST_F(DeepSeekV3DetectorTest, ConstructorInitializesCorrectly) { + EXPECT_NE(detector_, nullptr); + + // Test basic token detection + std::string text_with_tool_call = + "Some text " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test\n`" + "``json\n{\"name\": " + "\"test\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + std::string text_without_tool_call = + "Just normal text without any tool calls"; + + EXPECT_TRUE(detector_->has_tool_call(text_with_tool_call)); + EXPECT_FALSE(detector_->has_tool_call(text_without_tool_call)); +} + +// Test has_tool_call method +TEST_F(DeepSeekV3DetectorTest, HasToolCallDetection) { + // Test text containing tool calls + EXPECT_TRUE(detector_->has_tool_call("<|tool▁calls▁begin|>")); + EXPECT_TRUE(detector_->has_tool_call( + "Previous text <|tool▁calls▁begin|>Following content")); + EXPECT_TRUE(detector_->has_tool_call( + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test\n`" + "``json\n{\"name\": " + "\"test\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>")); + + // Test text not containing tool calls + EXPECT_FALSE(detector_->has_tool_call("")); + EXPECT_FALSE(detector_->has_tool_call("Regular text")); + EXPECT_FALSE(detector_->has_tool_call("tool_calls without special tokens")); + EXPECT_FALSE(detector_->has_tool_call(" without unicode tokens")); +} + +// Test single tool call parsing +TEST_F(DeepSeekV3DetectorTest, SingleToolCallParsing) { + std::string text = + "Please help me check the weather " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": \"Beijing\", \"unit\": " + "\"celsius\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Please help me check the weather"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + EXPECT_TRUE(call.name.has_value()); + EXPECT_EQ(call.name.value(), "get_current_weather"); + + // Verify parameter JSON + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_EQ(params["unit"], "celsius"); +} + +// Test multiple tool calls parsing +TEST_F(DeepSeekV3DetectorTest, MultipleToolCallsParsing) { + std::string text = + "Please help me check the weather and calculate an expression " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": " + "\"Shanghai\"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<" + "|tool▁sep|>calculate\n```json\n{\"expression\": \"2 + 3 * " + "4\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, + "Please help me check the weather and calculate an expression"); + EXPECT_EQ(result.calls.size(), 2); + + // Verify first tool call + const auto& call1 = result.calls[0]; + EXPECT_EQ(call1.tool_index, -1); // Base class always returns -1 + EXPECT_TRUE(call1.name.has_value()); + EXPECT_EQ(call1.name.value(), "get_current_weather"); + + nlohmann::json params1 = nlohmann::json::parse(call1.parameters); + EXPECT_EQ(params1["location"], "Shanghai"); + + // Verify second tool call + const auto& call2 = result.calls[1]; + EXPECT_EQ(call2.tool_index, -1); // Base class always returns -1 + EXPECT_TRUE(call2.name.has_value()); + EXPECT_EQ(call2.name.value(), "calculate"); + + nlohmann::json params2 = nlohmann::json::parse(call2.parameters); + EXPECT_EQ(params2["expression"], "2 + 3 * 4"); +} + +// Test DeepSeekV3 specific format with exact tokens +TEST_F(DeepSeekV3DetectorTest, DeepSeekV3SpecificFormat) { + std::string text = + "I need weather info " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": " + "\"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "I need weather info"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_TRUE(call.name.has_value()); + EXPECT_EQ(call.name.value(), "get_current_weather"); + + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Tokyo"); +} + +// Test invalid JSON handling +TEST_F(DeepSeekV3DetectorTest, InvalidJsonHandling) { + std::string text = + "Test invalid JSON " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": \"Beijing\", " + "invalid_json}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Test invalid JSON"); + EXPECT_EQ(result.calls.size(), 0); // Invalid JSON should be ignored +} + +// Test empty tool call content +TEST_F(DeepSeekV3DetectorTest, EmptyToolCallContent) { + std::string text = + "Test empty content " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>\n```" + "json\n \t\n \n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Test empty content"); + EXPECT_EQ(result.calls.size(), 0); // Empty content should be ignored +} + +// Test incomplete tool call (only start tag) +TEST_F(DeepSeekV3DetectorTest, IncompleteToolCall) { + std::string text = + "Incomplete tool call " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": \"Beijing\"}"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Incomplete tool call"); + EXPECT_EQ(result.calls.size(), 0); // Incomplete calls should be ignored +} + +// Test unknown tool name handling +TEST_F(DeepSeekV3DetectorTest, UnknownToolName) { + std::string text = + "Unknown tool " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>" + "unknown_tool\n```json\n{\"param\": " + "\"value\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Unknown tool"); + // Base class will skip unknown tools, so should be 0 calls + EXPECT_EQ(result.calls.size(), 0); +} + +// Test case with only normal text +TEST_F(DeepSeekV3DetectorTest, OnlyNormalText) { + std::string text = "This is a regular text without any tool calls."; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, + "This is a regular text without any tool calls."); + EXPECT_EQ(result.calls.size(), 0); + EXPECT_FALSE(result.has_calls()); +} + +// Test empty string input +TEST_F(DeepSeekV3DetectorTest, EmptyStringInput) { + std::string text = ""; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, ""); + EXPECT_EQ(result.calls.size(), 0); + EXPECT_FALSE(result.has_calls()); +} + +// Test whitespace-only input +TEST_F(DeepSeekV3DetectorTest, WhitespaceOnlyInput) { + std::string text = " \t\n\r "; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, ""); + EXPECT_EQ(result.calls.size(), 0); +} + +// Test complex nested JSON parameters +TEST_F(DeepSeekV3DetectorTest, ComplexNestedJsonParameters) { + std::string text = + "Complex parameter test " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": \"Beijing\", \"options\": " + "{\"include_forecast\": true, \"days\": 7, \"details\": " + "[\"temperature\", \"humidity\", " + "\"wind\"]}}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Complex parameter test"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_TRUE(params["options"]["include_forecast"]); + EXPECT_EQ(params["options"]["days"], 7); + EXPECT_EQ(params["options"]["details"].size(), 3); +} + +// Test tool call in the middle of text +TEST_F(DeepSeekV3DetectorTest, ToolCallInMiddleOfText) { + std::string text = + "Previous text " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>" + "calculate\n```json\n{\"expression\": " + "\"1+1\"}\n```<|tool▁call▁end|><|tool▁calls▁end|> Following text"; + + auto result = detector_->detect_and_parse(text, tools_); + + // Note: According to implementation, only text before tool call is preserved + // as normal_text + EXPECT_EQ(result.normal_text, "Previous text"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + EXPECT_EQ(call.name.value(), "calculate"); +} + +// Test special characters handling +TEST_F(DeepSeekV3DetectorTest, SpecialCharactersHandling) { + std::string text = + "Special characters test " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": \"New York City\", \"note\": " + "\"Contains " + "symbols!@#$%^&*()_+=\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Special characters test"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "New York City"); + EXPECT_EQ(params["note"], "Contains symbols!@#$%^&*()_+="); +} + +// Test whitespace trimming +TEST_F(DeepSeekV3DetectorTest, WhitespaceTrimming) { + std::string text_with_whitespace = + " \t\nPrevious text\r\n " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": " + "\"Beijing\"}\n```<|tool▁call▁end|><|tool▁calls▁end|> \t\r\n"; + + auto result = detector_->detect_and_parse(text_with_whitespace, tools_); + + // Verify normal text is correctly trimmed + EXPECT_EQ(result.normal_text, "Previous text"); + + // Verify tool call is correctly parsed + EXPECT_EQ(result.calls.size(), 1); + EXPECT_EQ(result.calls[0].tool_index, -1); // Base class always returns -1 +} + +// Test regex pattern matching edge cases +TEST_F(DeepSeekV3DetectorTest, RegexPatternEdgeCases) { + // Test with newlines in function name (should fail) + std::string text1 = + "Test " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current\nweather\n```json\n{\"location\": " + "\"Beijing\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + auto result1 = detector_->detect_and_parse(text1, tools_); + EXPECT_EQ(result1.calls.size(), 0); // Should fail to match + + // Test with missing json markers + std::string text2 = + "Test " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n{\"location\": " + "\"Beijing\"}\n<|tool▁call▁end|><|tool▁calls▁end|>"; + auto result2 = detector_->detect_and_parse(text2, tools_); + EXPECT_EQ(result2.calls.size(), + 0); // Should fail to match without ```json``` markers +} + +// Performance test: multiple tool calls +TEST_F(DeepSeekV3DetectorTest, PerformanceWithMultipleToolCalls) { + std::string text = "Performance test"; + + // Build text containing multiple tool calls + for (int i = 0; i < 10000; ++i) { + text += + " <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>" + "calculate\n```json\n{\"expression\": \"" + + std::to_string(i) + " + " + std::to_string(i + 1) + + "\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + } + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Performance test"); + EXPECT_EQ(result.calls.size(), 10000); + + // Verify each tool call is correctly parsed + for (int i = 0; i < 10000; ++i) { + const auto& call = result.calls[i]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + EXPECT_EQ(call.name.value(), "calculate"); + + nlohmann::json params = nlohmann::json::parse(call.parameters); + std::string expected_expr = + std::to_string(i) + " + " + std::to_string(i + 1); + EXPECT_EQ(params["expression"], expected_expr); + } +} + +// Test error handling with malformed tokens +TEST_F(DeepSeekV3DetectorTest, MalformedTokensHandling) { + // Test with incomplete start token + std::string text1 = + "Test " + "<|tool▁calls▁begi><|tool▁call▁begin|>function<|tool▁sep|>test\n```" + "json\n{}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + auto result1 = detector_->detect_and_parse(text1, tools_); + EXPECT_EQ(result1.normal_text, + "Test " + "<|tool▁calls▁begi><|tool▁call▁begin|>function<|tool▁sep|>" + "test\n```json\n{}\n```<|tool▁call▁end|><|tool▁calls▁end|>"); + EXPECT_EQ(result1.calls.size(), 0); + + // Test with incomplete end token + std::string text2 = + "Test " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test\n`" + "``json\n{}\n```<|tool▁call▁end|><|tool▁calls▁en|>"; + auto result2 = detector_->detect_and_parse(text2, tools_); + EXPECT_EQ(result2.normal_text, "Test"); + EXPECT_EQ(result2.calls.size(), 0); // Should not match incomplete pattern +} + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/function_call.h b/xllm/core/function_call/function_call.h index 70aac8bb..87950ebc 100644 --- a/xllm/core/function_call/function_call.h +++ b/xllm/core/function_call/function_call.h @@ -4,14 +4,12 @@ #include "core_types.h" #include "function_call_parser.h" #include "qwen25_detector.h" +#include "kimik2_detector.h" +#include "deepseekv3_detector.h" namespace llm { namespace function_call { -using Parser = FunctionCallParser; -using Detector = BaseFormatDetector; -using QwenDetector = Qwen25Detector; - inline std::vector parse(const std::string& text, const std::vector& tools, const std::string& format = "qwen25") { diff --git a/xllm/core/function_call/function_call_parser.cpp b/xllm/core/function_call/function_call_parser.cpp index 87d8b4c3..2fe6759c 100644 --- a/xllm/core/function_call/function_call_parser.cpp +++ b/xllm/core/function_call/function_call_parser.cpp @@ -5,6 +5,8 @@ #include "common/uuid.h" #include "qwen25_detector.h" +#include "kimik2_detector.h" +#include "deepseekv3_detector.h" namespace llm { namespace function_call { @@ -12,10 +14,11 @@ const std::unordered_map FunctionCallParser::ToolCallParserEnum = { {"qwen25", "qwen25"}, {"qwen3", "qwen25"}, + {"kimi_k2", "kimi_k2"}, + {"deepseekv3", "deepseekv3"}, // TODO // {"llama3", "llama3"}, // {"mistral", "mistral"}, - // {"deepseekv3", "deepseekv3"}, // {"pythonic", "pythonic"}, // {"qwen3_coder", "qwen3_coder"}, // {"glm45", "glm45"}, @@ -26,10 +29,16 @@ FunctionCallParser::FunctionCallParser(const std::vector& tools, const std::string& tool_call_parser) : tools_(tools) { detector_ = create_detector(tool_call_parser); - if (!detector_) { - throw std::invalid_argument("Unsupported tool_call_parser: " + - tool_call_parser); - } + CHECK(detector_ != nullptr) + << "Unsupported tool_call_parser: " << tool_call_parser + << ". Supported parsers are: " << [this]() { + std::string supported; + for (const auto& [key, value] : ToolCallParserEnum) { + if (!supported.empty()) supported += ", "; + supported += key; + } + return supported; + }(); } bool FunctionCallParser::has_tool_call(const std::string& text) const { @@ -58,6 +67,14 @@ std::unique_ptr FunctionCallParser::create_detector( if (it->second == "qwen25") { return std::make_unique(); } + + if (it->second == "kimi_k2") { + return std::make_unique(); + } + + if (it->second == "deepseekv3") { + return std::make_unique(); + } // if (tool_call_parser == "llama3") { // return std::make_unique(); diff --git a/xllm/core/function_call/kimik2_detector.cpp b/xllm/core/function_call/kimik2_detector.cpp new file mode 100644 index 00000000..72e383e6 --- /dev/null +++ b/xllm/core/function_call/kimik2_detector.cpp @@ -0,0 +1,152 @@ +#include "kimik2_detector.h" + +#include +#include +#include + +namespace llm { +namespace function_call { + +KimiK2Detector::KimiK2Detector() : BaseFormatDetector() { + // Initialize KimiK2 specific tokens + bot_token_ = "<|tool_calls_section_begin|>"; + eot_token_ = "<|tool_calls_section_end|>"; + tool_call_start_token_ = "<|tool_call_begin|>"; + tool_call_end_token_ = "<|tool_call_end|>"; + tool_call_argument_begin_token_ = "<|tool_call_argument_begin|>"; + + // Regex pattern for parsing tool calls with the following format: + // <|tool_call_begin|>functions.{func_name}:{index} + // <|tool_call_argument_begin|>{json_args}<|tool_call_end|> + // Note: C++ regex doesn't support named groups, so we use numbered groups: + // Group 1: tool_call_id (functions.{func_name}:{index}) + // Group 2: function_arguments ({json_args}) + std::string pattern = + R"(<\|tool_call_begin\|>\s*([\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(\{.*?\})\s*<\|tool_call_end\|>)"; + + try { + tool_call_regex_ = std::regex(pattern, std::regex_constants::ECMAScript); + } catch (const std::regex_error& e) { + LOG(ERROR) << "Failed to compile KimiK2 regex pattern: " << e.what(); + throw; + } + + last_arguments_ = ""; +} + +bool KimiK2Detector::has_tool_call(const std::string& text) { + // Check if the text contains the KimiK2 tool call section begin token + return text.find(bot_token_) != std::string::npos; +} + +StreamingParseResult KimiK2Detector::detect_and_parse( + const std::string& text, + const std::vector& tools) { + size_t bot_pos = text.find(bot_token_); + + std::string normal_text = + (bot_pos != std::string::npos) ? text.substr(0, bot_pos) : text; + + if (bot_pos == std::string::npos) { + return StreamingParseResult(normal_text); + } + + std::vector calls; + + try { + std::sregex_iterator iter( + text.begin() + bot_pos, text.end(), tool_call_regex_); + std::sregex_iterator end; + + for (; iter != end; ++iter) { + std::smatch match = *iter; + + if (match.size() >= 3) { + std::string tool_call_id = match[1].str(); + std::string function_arguments = match[2].str(); + + std::string function_name = extract_function_name(tool_call_id); + int function_index = extract_function_index(tool_call_id); + + calls.emplace_back(function_index, // Use the call index in the + // response, not tool position + function_name, // Function name + function_arguments // JSON parameters + ); + } + } + + return StreamingParseResult(normal_text, calls); + + } catch (const std::exception& e) { + LOG(ERROR) << "Error in KimiK2 detect_and_parse: " << e.what(); + // Return the normal text if parsing fails + return StreamingParseResult(normal_text); + } +} + +std::string KimiK2Detector::extract_function_name( + const std::string& tool_call_id) const { + // tool_call_id format: functions.{func_name}:{index} + // Example: functions.get_weather:0 + + try { + // Find the position of "functions." + size_t functions_pos = tool_call_id.find("functions."); + if (functions_pos == std::string::npos) { + LOG(WARNING) + << "Invalid tool_call_id format, missing 'functions.' prefix: " + << tool_call_id; + return ""; + } + + // Skip "functions." (10 characters) + size_t start_pos = functions_pos + 10; + + // Find the position of the last colon + size_t colon_pos = tool_call_id.find_last_of(':'); + if (colon_pos == std::string::npos || colon_pos <= start_pos) { + LOG(WARNING) << "Invalid tool_call_id format, missing ':' separator: " + << tool_call_id; + return ""; + } + + // Extract function name between "functions." and ":" + return tool_call_id.substr(start_pos, colon_pos - start_pos); + + } catch (const std::exception& e) { + LOG(ERROR) << "Error extracting function name from tool_call_id: " + << tool_call_id << ", error: " << e.what(); + return ""; + } +} + +int KimiK2Detector::extract_function_index( + const std::string& tool_call_id) const { + // tool_call_id format: functions.{func_name}:{index} + // Example: functions.get_weather:0 + + try { + // Find the position of the last colon + size_t colon_pos = tool_call_id.find_last_of(':'); + if (colon_pos == std::string::npos) { + LOG(WARNING) << "Invalid tool_call_id format, missing ':' separator: " + << tool_call_id; + return 0; + } + + // Extract index string after the colon + std::string index_str = tool_call_id.substr(colon_pos + 1); + + // Convert to integer + return std::stoi(index_str); + + } catch (const std::exception& e) { + LOG(ERROR) << "Error extracting function index from tool_call_id: " + << tool_call_id << ", error: " << e.what(); + return 0; + } +} + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/kimik2_detector.h b/xllm/core/function_call/kimik2_detector.h new file mode 100644 index 00000000..577abd27 --- /dev/null +++ b/xllm/core/function_call/kimik2_detector.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +#include "base_format_detector.h" + +namespace llm { +namespace function_call { + +/** + * Detector for Kimi K2 model function call format. + * + * Format Structure: + * ``` + * <|tool_calls_section_begin|> + * <|tool_call_begin|>functions.{func_name}:{index} + * <|tool_call_argument_begin|>{json_args}<|tool_call_end|> + * <|tool_calls_section_end|> + * ``` + * + * Reference: + * https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md + */ +class KimiK2Detector : public BaseFormatDetector { + public: + KimiK2Detector(); + + virtual ~KimiK2Detector() = default; + + private: + std::string tool_call_start_token_; + std::string tool_call_end_token_; + std::string tool_call_argument_begin_token_; + + std::regex tool_call_regex_; + + std::string last_arguments_; + + public: + bool has_tool_call(const std::string& text) override; + + StreamingParseResult detect_and_parse( + const std::string& text, + const std::vector& tools) override; + + private: + std::string extract_function_name(const std::string& tool_call_id) const; + + int extract_function_index(const std::string& tool_call_id) const; +}; + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/kimik2_detector_test.cpp b/xllm/core/function_call/kimik2_detector_test.cpp new file mode 100644 index 00000000..45ee632f --- /dev/null +++ b/xllm/core/function_call/kimik2_detector_test.cpp @@ -0,0 +1,455 @@ +#include "kimik2_detector.h" + +#include + +#include +#include +#include + +namespace llm { +namespace function_call { + +class KimiK2DetectorTest : public ::testing::Test { + protected: + void SetUp() override { + detector_ = std::make_unique(); + + // Setup test tools + nlohmann::json weather_params = { + {"type", "object"}, + {"properties", + {{"location", + {{"type", "string"}, + {"description", "The city and state, e.g. San Francisco, CA"}}}, + {"unit", {{"type", "string"}, {"enum", {"celsius", "fahrenheit"}}}}}}, + {"required", {"location"}}}; + + JsonFunction weather_func("get_current_weather", + "Get the current weather in a given location", + weather_params); + weather_tool_ = JsonTool("function", weather_func); + + nlohmann::json calculator_params = { + {"type", "object"}, + {"properties", + {{"expression", + {{"type", "string"}, + {"description", "Mathematical expression to evaluate"}}}}}, + {"required", {"expression"}}}; + + JsonFunction calculator_func( + "calculate", "Calculate mathematical expressions", calculator_params); + calculator_tool_ = JsonTool("function", calculator_func); + + tools_ = {weather_tool_, calculator_tool_}; + } + + std::unique_ptr detector_; + JsonTool weather_tool_; + JsonTool calculator_tool_; + std::vector tools_; +}; + +// Test constructor and basic properties +TEST_F(KimiK2DetectorTest, ConstructorInitializesCorrectly) { + EXPECT_NE(detector_, nullptr); + + // Test basic token detection + std::string text_with_tool_call = + "Some text " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.test:0 " + "<|tool_call_argument_begin|>{\"param\": " + "\"value\"}<|tool_call_end|><|tool_calls_section_end|>"; + std::string text_without_tool_call = + "Just normal text without any tool calls"; + + EXPECT_TRUE(detector_->has_tool_call(text_with_tool_call)); + EXPECT_FALSE(detector_->has_tool_call(text_without_tool_call)); +} + +// Test has_tool_call method +TEST_F(KimiK2DetectorTest, HasToolCallDetection) { + // Test text containing tool calls + EXPECT_TRUE(detector_->has_tool_call("<|tool_calls_section_begin|>")); + EXPECT_TRUE(detector_->has_tool_call( + "Previous text <|tool_calls_section_begin|>Following content")); + EXPECT_TRUE(detector_->has_tool_call( + "<|tool_calls_section_begin|><|tool_call_begin|>functions.test:0 " + "<|tool_call_argument_begin|>{\"param\": " + "\"value\"}<|tool_call_end|><|tool_calls_section_end|>")); + + // Test text not containing tool calls + EXPECT_FALSE(detector_->has_tool_call("")); + EXPECT_FALSE(detector_->has_tool_call("Regular text")); + EXPECT_FALSE( + detector_->has_tool_call("tool_calls_section_begin without brackets")); + EXPECT_FALSE( + detector_->has_tool_call("<|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": " + "\"Beijing\"}<|tool_call_end|><|tool_calls_section_end|> \t\r\n"; + + auto result = detector_->detect_and_parse(text_with_whitespace, tools_); + + // Verify normal text is correctly extracted + EXPECT_EQ(result.normal_text, " \t\nPrevious text\r\n "); + + // Verify tool call is correctly parsed + EXPECT_EQ(result.calls.size(), 1); + EXPECT_EQ(result.calls[0].tool_index, 0); +} + +// Test single tool call parsing +TEST_F(KimiK2DetectorTest, SingleToolCallParsing) { + std::string text = + "Please help me check the weather " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": \"Beijing\", " + "\"unit\": \"celsius\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Please help me check the weather "); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, 0); + EXPECT_TRUE(call.name.has_value()); + EXPECT_EQ(call.name.value(), "get_current_weather"); + + // Verify parameter JSON + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_EQ(params["unit"], "celsius"); +} + +// Test multiple tool calls parsing +TEST_F(KimiK2DetectorTest, MultipleToolCallsParsing) { + std::string text = + "Please help me check the weather and calculate an expression " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": " + "\"Shanghai\"}<|tool_call_end|><|tool_call_begin|>functions.calculate:1 " + "<|tool_call_argument_begin|>{\"expression\": \"2 + 3 * " + "4\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, + "Please help me check the weather and calculate an expression "); + EXPECT_EQ(result.calls.size(), 2); + + // Verify first tool call + const auto& call1 = result.calls[0]; + EXPECT_EQ(call1.tool_index, 0); + EXPECT_TRUE(call1.name.has_value()); + EXPECT_EQ(call1.name.value(), "get_current_weather"); + + nlohmann::json params1 = nlohmann::json::parse(call1.parameters); + EXPECT_EQ(params1["location"], "Shanghai"); + + // Verify second tool call + const auto& call2 = result.calls[1]; + EXPECT_EQ(call2.tool_index, 1); + EXPECT_TRUE(call2.name.has_value()); + EXPECT_EQ(call2.name.value(), "calculate"); + + nlohmann::json params2 = nlohmann::json::parse(call2.parameters); + EXPECT_EQ(params2["expression"], "2 + 3 * 4"); +} + +// Test invalid JSON handling +TEST_F(KimiK2DetectorTest, InvalidJsonHandling) { + std::string text = + "Test invalid JSON " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": \"Beijing\", " + "invalid_json}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Test invalid JSON "); + // KimiK2 detector should still parse the call even with invalid JSON, leaving + // JSON validation to higher levels + EXPECT_EQ(result.calls.size(), 1); + EXPECT_EQ(result.calls[0].name.value(), "get_current_weather"); +} + +// Test empty tool call content +TEST_F(KimiK2DetectorTest, EmptyToolCallContent) { + std::string text = + "Test empty content " + "<|tool_calls_section_begin|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Test empty content "); + EXPECT_EQ(result.calls.size(), 0); // Empty content should be ignored +} + +// Test incomplete tool call (only start tag) +TEST_F(KimiK2DetectorTest, IncompleteToolCall) { + std::string text = + "Incomplete tool call " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Incomplete tool call "); + EXPECT_EQ(result.calls.size(), 0); // Incomplete calls should be ignored +} + +// Test unknown tool name handling +TEST_F(KimiK2DetectorTest, UnknownToolName) { + std::string text = + "Unknown tool " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.unknown_tool:0 " + "<|tool_call_argument_begin|>{\"param\": " + "\"value\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Unknown tool "); + // KimiK2 detector should parse the call regardless of whether the tool is + // known + EXPECT_EQ(result.calls.size(), 1); + EXPECT_EQ(result.calls[0].name.value(), "unknown_tool"); +} + +// Test case with only normal text +TEST_F(KimiK2DetectorTest, OnlyNormalText) { + std::string text = "This is a regular text without any tool calls."; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, + "This is a regular text without any tool calls."); + EXPECT_EQ(result.calls.size(), 0); + EXPECT_FALSE(result.has_calls()); +} + +// Test empty string input +TEST_F(KimiK2DetectorTest, EmptyStringInput) { + std::string text = ""; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, ""); + EXPECT_EQ(result.calls.size(), 0); + EXPECT_FALSE(result.has_calls()); +} + +// Test whitespace-only input +TEST_F(KimiK2DetectorTest, WhitespaceOnlyInput) { + std::string text = " \t\n\r "; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, " \t\n\r "); + EXPECT_EQ(result.calls.size(), 0); +} + +// Test complex nested JSON parameters +TEST_F(KimiK2DetectorTest, ComplexNestedJsonParameters) { + std::string text = + "Complex parameter test " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": \"Beijing\", " + "\"options\": {\"include_forecast\": true, \"days\": 7, \"details\": " + "[\"temperature\", \"humidity\", " + "\"wind\"]}}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Complex parameter test "); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, 0); + + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_TRUE(params["options"]["include_forecast"]); + EXPECT_EQ(params["options"]["days"], 7); + EXPECT_EQ(params["options"]["details"].size(), 3); +} + +// Test tool call in the middle of text +TEST_F(KimiK2DetectorTest, ToolCallInMiddleOfText) { + std::string text = + "Previous text " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.calculate:0 " + "<|tool_call_argument_begin|>{\"expression\": " + "\"1+1\"}<|tool_call_end|><|tool_calls_section_end|> Following text"; + + auto result = detector_->detect_and_parse(text, tools_); + + // Note: According to KimiK2 implementation, only text before tool call + // section is preserved as normal_text + EXPECT_EQ(result.normal_text, "Previous text "); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, 0); + EXPECT_EQ(call.name.value(), "calculate"); +} + +// Test special characters handling +TEST_F(KimiK2DetectorTest, SpecialCharactersHandling) { + std::string text = + "Special characters test " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": \"New York City\", " + "\"note\": \"Contains " + "symbols!@#$%^&*()_+=\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Special characters test "); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, 0); + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "New York City"); + EXPECT_EQ(params["note"], "Contains symbols!@#$%^&*()_+="); +} + +// Test function name extraction +TEST_F(KimiK2DetectorTest, FunctionNameExtraction) { + std::string text = + "Function name test " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.my_custom_" + "function:5 <|tool_call_argument_begin|>{\"param\": " + "\"value\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Function name test "); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, 5); // Should extract the index correctly + EXPECT_EQ(call.name.value(), "my_custom_function"); +} + +// Test malformed function ID handling +TEST_F(KimiK2DetectorTest, MalformedFunctionIdHandling) { + std::string text = + "Malformed ID test " + "<|tool_calls_section_begin|><|tool_call_begin|>invalid_format " + "<|tool_call_argument_begin|>{\"param\": " + "\"value\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Malformed ID test "); + // Malformed format that doesn't match regex should result in no calls + EXPECT_EQ(result.calls.size(), 0); +} + +// Test malformed function ID that matches regex but has invalid format +TEST_F(KimiK2DetectorTest, MalformedButMatchingFunctionId) { + std::string text = + "Malformed but matching test " + "<|tool_calls_section_begin|><|tool_call_begin|>invalid.format:0 " + "<|tool_call_argument_begin|>{\"param\": " + "\"value\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Malformed but matching test "); + // Should parse but with empty function name due to missing "functions." + // prefix + EXPECT_EQ(result.calls.size(), 1); + const auto& call = result.calls[0]; + EXPECT_TRUE(call.name.has_value()); + EXPECT_EQ(call.name.value(), ""); // Empty function name for malformed ID + EXPECT_EQ(call.tool_index, 0); +} + +// Test multiple sections (edge case) +TEST_F(KimiK2DetectorTest, MultipleSections) { + std::string text = + "First section " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": " + "\"Beijing\"}<|tool_call_end|><|tool_calls_section_end|> Middle text " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.calculate:1 " + "<|tool_call_argument_begin|>{\"expression\": " + "\"1+1\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + // Should extract text before first section + EXPECT_EQ(result.normal_text, "First section "); + // Should parse all tool calls from all sections + EXPECT_EQ(result.calls.size(), 2); + + EXPECT_EQ(result.calls[0].name.value(), "get_current_weather"); + EXPECT_EQ(result.calls[1].name.value(), "calculate"); +} + +// Performance test: many tool calls +TEST_F(KimiK2DetectorTest, PerformanceWithManyToolCalls) { + std::string text = "Performance test <|tool_calls_section_begin|>"; + + // Build text containing multiple tool calls + for (int i = 0; i < 10000; ++i) { + text += "<|tool_call_begin|>functions.calculate:" + std::to_string(i) + + " <|tool_call_argument_begin|>{\"expression\": \"" + + std::to_string(i) + " + " + std::to_string(i + 1) + + "\"}<|tool_call_end|>"; + } + text += "<|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Performance test "); + EXPECT_EQ(result.calls.size(), 10000); + + // Verify each tool call is correctly parsed + for (int i = 0; i < 10000; ++i) { + const auto& call = result.calls[i]; + EXPECT_EQ(call.tool_index, i); + EXPECT_EQ(call.name.value(), "calculate"); + + nlohmann::json params = nlohmann::json::parse(call.parameters); + std::string expected_expr = + std::to_string(i) + " + " + std::to_string(i + 1); + EXPECT_EQ(params["expression"], expected_expr); + } +} + +// Test edge case: nested braces in JSON +TEST_F(KimiK2DetectorTest, NestedBracesInJson) { + std::string text = + "Nested braces test " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": \"Beijing\", " + "\"config\": {\"nested\": {\"deep\": " + "\"value\"}}}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Nested braces test "); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, 0); + EXPECT_EQ(call.name.value(), "get_current_weather"); + + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_EQ(params["config"]["nested"]["deep"], "value"); +} + +} // namespace function_call +} // namespace llm \ No newline at end of file diff --git a/xllm/core/function_call/qwen25_detector_test.cpp b/xllm/core/function_call/qwen25_detector_test.cpp new file mode 100644 index 00000000..6a6f1976 --- /dev/null +++ b/xllm/core/function_call/qwen25_detector_test.cpp @@ -0,0 +1,328 @@ +#include "qwen25_detector.h" + +#include + +#include +#include +#include + +namespace llm { +namespace function_call { + +class Qwen25DetectorTest : public ::testing::Test { + protected: + void SetUp() override { + detector_ = std::make_unique(); + + // Setup test tools + nlohmann::json weather_params = { + {"type", "object"}, + {"properties", + {{"location", + {{"type", "string"}, + {"description", "The city and state, e.g. San Francisco, CA"}}}, + {"unit", {{"type", "string"}, {"enum", {"celsius", "fahrenheit"}}}}}}, + {"required", {"location"}}}; + + JsonFunction weather_func("get_current_weather", + "Get the current weather in a given location", + weather_params); + weather_tool_ = JsonTool("function", weather_func); + + nlohmann::json calculator_params = { + {"type", "object"}, + {"properties", + {{"expression", + {{"type", "string"}, + {"description", "Mathematical expression to evaluate"}}}}}, + {"required", {"expression"}}}; + + JsonFunction calculator_func( + "calculate", "Calculate mathematical expressions", calculator_params); + calculator_tool_ = JsonTool("function", calculator_func); + + tools_ = {weather_tool_, calculator_tool_}; + } + + std::unique_ptr detector_; + JsonTool weather_tool_; + JsonTool calculator_tool_; + std::vector tools_; +}; + +// Test constructor and basic properties +TEST_F(Qwen25DetectorTest, ConstructorInitializesCorrectly) { + EXPECT_NE(detector_, nullptr); + + // Test basic token detection + std::string text_with_tool_call = + "Some text \n{\"name\": \"test\"}\n"; + std::string text_without_tool_call = + "Just normal text without any tool calls"; + + EXPECT_TRUE(detector_->has_tool_call(text_with_tool_call)); + EXPECT_FALSE(detector_->has_tool_call(text_without_tool_call)); +} + +// Test has_tool_call method +TEST_F(Qwen25DetectorTest, HasToolCallDetection) { + // Test text containing tool calls + EXPECT_TRUE(detector_->has_tool_call("\n")); + EXPECT_TRUE( + detector_->has_tool_call("Previous text \nFollowing content")); + EXPECT_TRUE(detector_->has_tool_call( + "\n{\"name\": \"test\"}\n")); + + // Test text not containing tool calls + EXPECT_FALSE(detector_->has_tool_call("")); + EXPECT_FALSE(detector_->has_tool_call("Regular text")); + EXPECT_FALSE(detector_->has_tool_call("tool_call without brackets")); + EXPECT_FALSE(detector_->has_tool_call("\n {\"name\": " + "\"get_current_weather\", \"arguments\": {\"location\": \"Beijing\"}} " + "\n \t\r\n"; + + auto result = detector_->detect_and_parse(text_with_whitespace, tools_); + + // Verify normal text is correctly trimmed + EXPECT_EQ(result.normal_text, "Previous text"); + + // Verify tool call is correctly parsed + EXPECT_EQ(result.calls.size(), 1); + EXPECT_EQ(result.calls[0].tool_index, -1); // Base class always returns -1 +} + +// Test single tool call parsing +TEST_F(Qwen25DetectorTest, SingleToolCallParsing) { + std::string text = + "Please help me check the weather \n{\"name\": " + "\"get_current_weather\", \"arguments\": {\"location\": \"Beijing\", " + "\"unit\": \"celsius\"}}\n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Please help me check the weather"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + EXPECT_TRUE(call.name.has_value()); + EXPECT_EQ(call.name.value(), "get_current_weather"); + + // Verify parameter JSON + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_EQ(params["unit"], "celsius"); +} + +// Test multiple tool calls parsing +TEST_F(Qwen25DetectorTest, MultipleToolCallsParsing) { + std::string text = + "Please help me check the weather and calculate an expression " + "\n{\"name\": \"get_current_weather\", \"arguments\": " + "{\"location\": \"Shanghai\"}}\n\n\n{\"name\": " + "\"calculate\", \"arguments\": {\"expression\": \"2 + 3 * " + "4\"}}\n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, + "Please help me check the weather and calculate an expression"); + EXPECT_EQ(result.calls.size(), 2); + + // Verify first tool call + const auto& call1 = result.calls[0]; + EXPECT_EQ(call1.tool_index, -1); // Base class always returns -1 + EXPECT_TRUE(call1.name.has_value()); + EXPECT_EQ(call1.name.value(), "get_current_weather"); + + nlohmann::json params1 = nlohmann::json::parse(call1.parameters); + EXPECT_EQ(params1["location"], "Shanghai"); + + // Verify second tool call + const auto& call2 = result.calls[1]; + EXPECT_EQ(call2.tool_index, -1); // Base class always returns -1 + EXPECT_TRUE(call2.name.has_value()); + EXPECT_EQ(call2.name.value(), "calculate"); + + nlohmann::json params2 = nlohmann::json::parse(call2.parameters); + EXPECT_EQ(params2["expression"], "2 + 3 * 4"); +} + +// Test invalid JSON handling +TEST_F(Qwen25DetectorTest, InvalidJsonHandling) { + std::string text = + "Test invalid JSON \n{\"name\": \"get_current_weather\", " + "\"arguments\": {\"location\": \"Beijing\", invalid_json}}\n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Test invalid JSON"); + EXPECT_EQ(result.calls.size(), 0); // Invalid JSON should be ignored +} + +// Test empty tool call content +TEST_F(Qwen25DetectorTest, EmptyToolCallContent) { + std::string text = "Test empty content \n \t\n \n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Test empty content"); + EXPECT_EQ(result.calls.size(), 0); // Empty content should be ignored +} + +// Test incomplete tool call (only start tag) +TEST_F(Qwen25DetectorTest, IncompleteToolCall) { + std::string text = + "Incomplete tool call \n{\"name\": \"get_current_weather\""; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Incomplete tool call"); + EXPECT_EQ(result.calls.size(), 0); // Incomplete calls should be ignored +} + +// Test unknown tool name handling +TEST_F(Qwen25DetectorTest, UnknownToolName) { + std::string text = + "Unknown tool \n{\"name\": \"unknown_tool\", \"arguments\": " + "{\"param\": \"value\"}}\n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Unknown tool"); + // Base class will skip unknown tools, so should be 0 calls + EXPECT_EQ(result.calls.size(), 0); +} + +// Test case with only normal text +TEST_F(Qwen25DetectorTest, OnlyNormalText) { + std::string text = "This is a regular text without any tool calls."; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, + "This is a regular text without any tool calls."); + EXPECT_EQ(result.calls.size(), 0); + EXPECT_FALSE(result.has_calls()); +} + +// Test empty string input +TEST_F(Qwen25DetectorTest, EmptyStringInput) { + std::string text = ""; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, ""); + EXPECT_EQ(result.calls.size(), 0); + EXPECT_FALSE(result.has_calls()); +} + +// Test whitespace-only input +TEST_F(Qwen25DetectorTest, WhitespaceOnlyInput) { + std::string text = " \t\n\r "; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, ""); + EXPECT_EQ(result.calls.size(), 0); +} + +// Test complex nested JSON parameters +TEST_F(Qwen25DetectorTest, ComplexNestedJsonParameters) { + std::string text = + "Complex parameter test \n{\"name\": \"get_current_weather\", " + "\"arguments\": {\"location\": \"Beijing\", \"options\": " + "{\"include_forecast\": true, \"days\": 7, \"details\": " + "[\"temperature\", \"humidity\", \"wind\"]}}}\n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Complex parameter test"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_TRUE(params["options"]["include_forecast"]); + EXPECT_EQ(params["options"]["days"], 7); + EXPECT_EQ(params["options"]["details"].size(), 3); +} + +// Test tool call in the middle of text +TEST_F(Qwen25DetectorTest, ToolCallInMiddleOfText) { + std::string text = + "Previous text \n{\"name\": \"calculate\", \"arguments\": " + "{\"expression\": \"1+1\"}}\n Following text"; + + auto result = detector_->detect_and_parse(text, tools_); + + // Note: According to implementation, only text before tool call is preserved + // as normal_text + EXPECT_EQ(result.normal_text, "Previous text"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + EXPECT_EQ(call.name.value(), "calculate"); +} + +// Test special characters handling +TEST_F(Qwen25DetectorTest, SpecialCharactersHandling) { + std::string text = + "Special characters test \n{\"name\": " + "\"get_current_weather\", \"arguments\": {\"location\": \"New York " + "City\", \"note\": \"Contains symbols!@#$%^&*()_+=\"}}\n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Special characters test"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "New York City"); + EXPECT_EQ(params["note"], "Contains symbols!@#$%^&*()_+="); +} + +// Performance test: many tool calls +TEST_F(Qwen25DetectorTest, PerformanceWithManyToolCalls) { + std::string text = "Performance test"; + + // Build text containing multiple tool calls + for (int i = 0; i < 10000; ++i) { + text += + " \n{\"name\": \"calculate\", \"arguments\": " + "{\"expression\": \"" + + std::to_string(i) + " + " + std::to_string(i + 1) + + "\"}}\n"; + } + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Performance test"); + EXPECT_EQ(result.calls.size(), 10000); + + // Verify each tool call is correctly parsed + for (int i = 0; i < 10000; ++i) { + const auto& call = result.calls[i]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + EXPECT_EQ(call.name.value(), "calculate"); + + nlohmann::json params = nlohmann::json::parse(call.parameters); + std::string expected_expr = + std::to_string(i) + " + " + std::to_string(i + 1); + EXPECT_EQ(params["expression"], expected_expr); + } +} + +} // namespace function_call +} // namespace llm \ No newline at end of file From 31389d5ba96f0a9a07157327b78ee8c4c8489350 Mon Sep 17 00:00:00 2001 From: dengyingxu1 Date: Tue, 19 Aug 2025 16:39:29 +0800 Subject: [PATCH 9/9] feat: cherry-pick non-streaming tools call implementation. --- xllm/CMakeLists.txt | 1 + xllm/api_service/CMakeLists.txt | 1 + xllm/api_service/api_service.cpp | 1 - xllm/api_service/chat_service_impl.cpp | 230 +++++------------- xllm/core/CMakeLists.txt | 1 - xllm/core/common/options.h | 2 + xllm/core/common/types.h | 22 ++ .../chat_template/jinja_chat_template.cpp | 166 ++----------- .../chat_template/jinja_chat_template.h | 8 + .../core/framework/request/request_params.cpp | 145 +++++------ xllm/core/framework/request/request_params.h | 12 +- xllm/{core => }/function_call/CMakeLists.txt | 0 .../function_call/base_format_detector.cpp | 4 +- .../function_call/base_format_detector.h | 4 +- xllm/{core => }/function_call/core_types.h | 28 +-- .../function_call/deepseekv3_detector.cpp | 4 +- .../function_call/deepseekv3_detector.h | 4 +- .../deepseekv3_detector_test.cpp | 4 +- xllm/{core => }/function_call/function_call.h | 8 +- .../function_call/function_call_parser.cpp | 14 +- .../function_call/function_call_parser.h | 4 +- .../function_call/kimik2_detector.cpp | 4 +- .../function_call/kimik2_detector.h | 4 +- .../function_call/kimik2_detector_test.cpp | 4 +- .../function_call/qwen25_detector.cpp | 4 +- .../function_call/qwen25_detector.h | 4 +- .../function_call/qwen25_detector_test.cpp | 4 +- xllm/proto/chat.proto | 25 -- xllm/proto/common.proto | 25 ++ xllm/proto/multimodal.proto | 7 + 30 files changed, 271 insertions(+), 473 deletions(-) rename xllm/{core => }/function_call/CMakeLists.txt (100%) rename xllm/{core => }/function_call/base_format_detector.cpp (98%) rename xllm/{core => }/function_call/base_format_detector.h (98%) rename xllm/{core => }/function_call/core_types.h (72%) rename xllm/{core => }/function_call/deepseekv3_detector.cpp (99%) rename xllm/{core => }/function_call/deepseekv3_detector.h (94%) rename xllm/{core => }/function_call/deepseekv3_detector_test.cpp (99%) rename xllm/{core => }/function_call/function_call.h (94%) rename xllm/{core => }/function_call/function_call_parser.cpp (98%) rename xllm/{core => }/function_call/function_call_parser.h (97%) rename xllm/{core => }/function_call/kimik2_detector.cpp (99%) rename xllm/{core => }/function_call/kimik2_detector.h (96%) rename xllm/{core => }/function_call/kimik2_detector_test.cpp (99%) rename xllm/{core => }/function_call/qwen25_detector.cpp (98%) rename xllm/{core => }/function_call/qwen25_detector.h (95%) rename xllm/{core => }/function_call/qwen25_detector_test.cpp (99%) diff --git a/xllm/CMakeLists.txt b/xllm/CMakeLists.txt index 0f00463f..9a4f7770 100644 --- a/xllm/CMakeLists.txt +++ b/xllm/CMakeLists.txt @@ -4,6 +4,7 @@ include_directories(.) add_subdirectory(api_service) add_subdirectory(core) +add_subdirectory(function_call) add_subdirectory(models) add_subdirectory(proto) add_subdirectory(processors) diff --git a/xllm/api_service/CMakeLists.txt b/xllm/api_service/CMakeLists.txt index f9ff45a5..8c7dd5ec 100644 --- a/xllm/api_service/CMakeLists.txt +++ b/xllm/api_service/CMakeLists.txt @@ -29,5 +29,6 @@ cc_library( proto::xllm_proto absl::flat_hash_set absl::random_random + :function_call ) diff --git a/xllm/api_service/api_service.cpp b/xllm/api_service/api_service.cpp index 0f087293..d882505f 100644 --- a/xllm/api_service/api_service.cpp +++ b/xllm/api_service/api_service.cpp @@ -16,7 +16,6 @@ #include "models.pb.h" #include "service_impl_factory.h" #include "xllm_metrics.h" -#include namespace xllm { APIService::APIService(Master* master, diff --git a/xllm/api_service/chat_service_impl.cpp b/xllm/api_service/chat_service_impl.cpp index 7b764865..d4e979d4 100644 --- a/xllm/api_service/chat_service_impl.cpp +++ b/xllm/api_service/chat_service_impl.cpp @@ -13,36 +13,29 @@ #include #include "core/common/instance_name.h" +#include "core/common/types.h" #include "core/framework/request/mm_input_helper.h" #include "core/framework/request/request_params.h" #include "core/runtime/llm_master.h" #include "core/runtime/vlm_master.h" #include "core/util/utils.h" #include "core/util/uuid.h" -#include "chat_template/chat_template.h" -#include "common/instance_name.h" -#include "common/uuid.h" -#include "function_call/core_types.h" #include "function_call/function_call.h" -#include "request/request_params.h" -#include "util/utils.h" namespace xllm { namespace { - struct ToolCallResult { std::optional> tool_calls; std::string text; std::string finish_reason; }; -ToolCallResult process_tool_calls( - std::string text, - const std::vector& tools, - const std::string& parser_format, - std::string finish_reason, - google::protobuf::Arena* arena = nullptr) { +ToolCallResult process_tool_calls(std::string text, + const std::vector& tools, + const std::string& parser_format, + std::string finish_reason, + google::protobuf::Arena* arena = nullptr) { ToolCallResult result; function_call::FunctionCallParser parser(tools, parser_format); @@ -204,14 +197,13 @@ bool send_delta_to_client_brpc(std::shared_ptr call, } template -bool send_result_to_client_brpc( - std::shared_ptr call, - const std::string& request_id, - int64_t created_time, - const std::string& model, - const RequestOutput& req_output, - const std::string& parser_format = "", - const std::vector& tools = {}) { +bool send_result_to_client_brpc(std::shared_ptr call, + const std::string& request_id, + int64_t created_time, + const std::string& model, + const RequestOutput& req_output, + const std::string& parser_format = "", + const std::vector& tools = {}) { auto& response = call->response(); response.set_object("chat.completion"); response.set_id(request_id); @@ -327,7 +319,8 @@ void ChatServiceImpl::process_async_impl(std::shared_ptr call) { include_usage = include_usage, first_message_sent = std::unordered_set(), request_id = request_params.request_id, - created_time = absl::ToUnixSeconds(absl::Now())]( + created_time = absl::ToUnixSeconds(absl::Now()), + json_tools = request_params.tools]( const RequestOutput& req_output) mutable -> bool { if (req_output.status.has_value()) { const auto& status = req_output.status.value(); @@ -346,110 +339,48 @@ void ChatServiceImpl::process_async_impl(std::shared_ptr call) { master->get_rate_limiter()->decrease_one_request(); } - if (stream) { - // send delta to client - return send_delta_to_client_brpc(call, - include_usage, - &first_message_sent, - request_id, - created_time, - model, - req_output); - } - return send_result_to_client_brpc( - call, request_id, created_time, model, req_output); - }); -} - - -} // namespace - -ChatServiceImpl::ChatServiceImpl(LLMMaster* master, - const std::vector& models) - : APIServiceImpl(master, models) {} - -// chat_async for brpc -void ChatServiceImpl::process_async_impl(std::shared_ptr call) { - const auto& rpc_request = call->request(); - // check if model is supported - const auto& model = rpc_request.model(); - if (!models_.contains(model)) { - call->finish_with_error(StatusCode::UNKNOWN, "Model not supported"); - return; - } - - // Check if the request is being rate-limited. - if (master_->get_rate_limiter()->is_limited()) { - call->finish_with_error( - StatusCode::RESOURCE_EXHAUSTED, - "The number of concurrent requests has reached the limit."); - return; - } - - RequestParams request_params( - rpc_request, call->get_x_request_id(), call->get_x_request_time()); - std::vector messages; - messages.reserve(rpc_request.messages_size()); - for (const auto& message : rpc_request.messages()) { - messages.emplace_back(message.role(), message.content()); - } - - bool include_usage = false; - if (rpc_request.has_stream_options()) { - include_usage = rpc_request.stream_options().include_usage(); - } - std::optional> prompt_tokens = std::nullopt; - if (rpc_request.has_routing()) { - prompt_tokens = std::vector{}; - prompt_tokens->reserve(rpc_request.routing().token_ids_size()); - for (int i = 0; i < rpc_request.routing().token_ids_size(); i++) { - prompt_tokens->emplace_back(rpc_request.routing().token_ids(i)); - } - - request_params.decode_address = rpc_request.routing().decode_name(); - } - - master_->handle_request( - std::move(messages), - std::move(prompt_tokens), - std::move(request_params), - [call, - model, - master = master_, - stream = request_params.streaming, - include_usage = include_usage, - first_message_sent = std::unordered_set(), - request_id = request_params.request_id, - created_time = absl::ToUnixSeconds(absl::Now())]( - const RequestOutput& req_output) mutable -> bool { - if (req_output.status.has_value()) { - const auto& status = req_output.status.value(); - if (!status.ok()) { - // Reduce the number of concurrent requests when a - // request is finished with error. - master->get_rate_limiter()->decrease_one_request(); - - return call->finish_with_error(status.code(), status.message()); - } - } - - // Reduce the number of concurrent requests when a request - // is finished or canceled. - if (req_output.finished || req_output.cancelled) { - master->get_rate_limiter()->decrease_one_request(); - } + const std::string parser_format = + master->options().tool_call_parser().value_or(""); + const bool has_tool_support = + !json_tools.empty() && !parser_format.empty(); if (stream) { - return send_delta_to_client_brpc(call, - include_usage, - &first_message_sent, - request_id, - created_time, - model, - req_output); + if (has_tool_support) { + // TODO: Support tool call streaming output + LOG(ERROR) << "Tool call does not support streaming output"; + return send_delta_to_client_brpc(call, + include_usage, + &first_message_sent, + request_id, + created_time, + model, + req_output); + } else { + // Stream response without tool support + return send_delta_to_client_brpc(call, + include_usage, + &first_message_sent, + request_id, + created_time, + model, + req_output); + } + } else { + if (has_tool_support) { + // Non-stream response with tool support + return send_result_to_client_brpc(call, + request_id, + created_time, + model, + req_output, + parser_format, + json_tools); + } else { + // Non-stream response without tool support + return send_result_to_client_brpc( + call, request_id, created_time, model, req_output); + } } - return send_result_to_client_brpc( - call, request_id, created_time, model, req_output); }); } @@ -506,8 +437,7 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { include_usage = include_usage, first_message_sent = std::unordered_set(), request_id = request_params.request_id, - created_time = absl::ToUnixSeconds(absl::Now()), - json_tools = request_params.tools]( + created_time = absl::ToUnixSeconds(absl::Now())]( const RequestOutput& req_output) mutable -> bool { if (req_output.status.has_value()) { const auto& status = req_output.status.value(); @@ -526,48 +456,18 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { master->get_rate_limiter()->decrease_one_request(); } - const std::string parser_format = - master->options().tool_call_parser().value_or(""); - const bool has_tool_support = - !json_tools.empty() && !parser_format.empty(); - if (stream) { - if (has_tool_support) { - // TODO: Support tool call streaming output - LOG(ERROR) << "Tool call does not support streaming output"; - return send_delta_to_client_brpc(call, - include_usage, - &first_message_sent, - request_id, - created_time, - model, - req_output); - } else { - // Stream response without tool support - return send_delta_to_client_brpc(call, - include_usage, - &first_message_sent, - request_id, - created_time, - model, - req_output); - } - } else { - if (has_tool_support) { - // Non-stream response with tool support - return send_result_to_client_brpc(call, - request_id, - created_time, - model, - req_output, - parser_format, - json_tools); - } else { - // Non-stream response without tool support - return send_result_to_client_brpc( - call, request_id, created_time, model, req_output); - } + // send delta to client + return send_delta_to_client_brpc(call, + include_usage, + &first_message_sent, + request_id, + created_time, + model, + req_output); } + return send_result_to_client_brpc( + call, request_id, created_time, model, req_output); }); } diff --git a/xllm/core/CMakeLists.txt b/xllm/core/CMakeLists.txt index ff88a505..d0c8f682 100644 --- a/xllm/core/CMakeLists.txt +++ b/xllm/core/CMakeLists.txt @@ -1,5 +1,4 @@ add_subdirectory(common) -add_subdirectory(function_call) add_subdirectory(distributed_runtime) add_subdirectory(framework) add_subdirectory(kernels) diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index 60219d9a..4a2043f7 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -95,6 +95,8 @@ class Options { PROPERTY(std::optional, etcd_addr); PROPERTY(bool, enable_service_routing) = false; + + PROPERTY(std::optional, tool_call_parser); }; } // namespace xllm diff --git a/xllm/core/common/types.h b/xllm/core/common/types.h index 8d3e6e82..b7d9bd1a 100644 --- a/xllm/core/common/types.h +++ b/xllm/core/common/types.h @@ -214,4 +214,26 @@ struct DeviceStats { // TODO: add more device stats }; +// Function call related types +struct JsonFunction { + std::string name; + std::string description; + nlohmann::json parameters; + + JsonFunction() = default; + JsonFunction(const std::string& func_name, + const std::string& desc, + const nlohmann::json& params) + : name(func_name), description(desc), parameters(params) {} +}; + +struct JsonTool { + std::string type; // "function" + JsonFunction function; + + JsonTool() : type("function") {} + JsonTool(const std::string& tool_type, const JsonFunction& func) + : type(tool_type), function(func) {} +}; + } // namespace xllm diff --git a/xllm/core/framework/chat_template/jinja_chat_template.cpp b/xllm/core/framework/chat_template/jinja_chat_template.cpp index e2ec1c0a..5e114b1c 100644 --- a/xllm/core/framework/chat_template/jinja_chat_template.cpp +++ b/xllm/core/framework/chat_template/jinja_chat_template.cpp @@ -1,7 +1,6 @@ #include "jinja_chat_template.h" #include -#include #include #include @@ -24,12 +23,20 @@ JinjaChatTemplate::JinjaChatTemplate(const TokenizerArgs& args) : args_(args) { std::optional JinjaChatTemplate::apply( const ChatMessages& messages) const { - const std::vector empty_tools; + const std::vector empty_tools; return apply(messages, empty_tools); } std::optional JinjaChatTemplate::apply( - const ChatMessages& messages, const std::vector& tools) const { + nlohmann::ordered_json& messages) const { + // Call the overloaded method with empty tools + nlohmann::ordered_json empty_tools = nlohmann::json::array(); + return apply(messages, empty_tools); +} + +std::optional JinjaChatTemplate::apply( + const ChatMessages& messages, + const std::vector& json_tools) const { // convert the messages to json object nlohmann::ordered_json messages_json = nlohmann::json::array(); for (const auto& message : messages) { @@ -45,31 +52,24 @@ std::optional JinjaChatTemplate::apply( messages_json.push_back(message_json); } - - // convert tools to json object + nlohmann::ordered_json tools_json = nlohmann::json::array(); - if (!tools.empty()) { - try { - // Use ToolsConverter to convert tools to JSON string, then parse it - std::string tools_json_str = ToolsConverter::convert_tools_to_json(tools); - tools_json = nlohmann::json::parse(tools_json_str); - } catch (const std::exception& e) { - LOG(WARNING) << "Failed to convert tools to JSON: " << e.what(); - // Continue with empty tools array - } + for (const auto& json_tool : json_tools) { + nlohmann::ordered_json tool_json; + tool_json["type"] = json_tool.type; + + nlohmann::ordered_json function_json; + function_json["name"] = json_tool.function.name; + function_json["description"] = json_tool.function.description; + function_json["parameters"] = json_tool.function.parameters; + + tool_json["function"] = function_json; + tools_json.push_back(tool_json); } - - // apply the template with tools + // apply the template return apply(messages_json, tools_json); } -std::optional JinjaChatTemplate::apply( - nlohmann::ordered_json& messages) const { - // Call the overloaded method with empty tools - nlohmann::ordered_json empty_tools = nlohmann::json::array(); - return apply(messages, empty_tools); -} - std::optional JinjaChatTemplate::apply( nlohmann::ordered_json& messages, const nlohmann::ordered_json& tools) const { @@ -82,126 +82,6 @@ std::optional JinjaChatTemplate::apply( return template_->apply(input, options); } -std::optional JinjaChatTemplate::apply( - const ChatMessages& messages, - const std::vector& proto_tools) const { - // convert the messages to json object - nlohmann::ordered_json messages_json = nlohmann::json::array(); - for (const auto& message : messages) { - nlohmann::ordered_json message_json; - message_json["role"] = message.role; - message_json["content"] = message.content; - messages_json.push_back(message_json); - } - - // convert protobuf tools to json object - nlohmann::ordered_json tools_json = nlohmann::json::array(); - if (!proto_tools.empty()) { - try { - for (const auto& proto_tool : proto_tools) { - nlohmann::ordered_json tool_json; - tool_json["type"] = proto_tool.type(); - - nlohmann::ordered_json function_json; - function_json["name"] = proto_tool.function().name(); - function_json["description"] = proto_tool.function().description(); - - if (proto_tool.function().has_parameters()) { - std::string parameters_json_str; - google::protobuf::util::JsonPrintOptions options; - options.add_whitespace = false; - options.preserve_proto_field_names = true; - auto status = google::protobuf::util::MessageToJsonString( - proto_tool.function().parameters(), - ¶meters_json_str, - options); - if (status.ok()) { - function_json["parameters"] = - nlohmann::json::parse(parameters_json_str); - } else { - LOG(WARNING) << "Failed to convert parameters Struct to JSON: " - << status.message() - << ", tool: " << proto_tool.function().name(); - function_json["parameters"] = nlohmann::json::object(); - } - } else { - function_json["parameters"] = nlohmann::json::object(); - } - - tool_json["function"] = function_json; - tools_json.push_back(tool_json); - } - } catch (const std::exception& e) { - LOG(WARNING) << "Failed to convert protobuf tools to JSON: " << e.what(); - // Continue with empty tools array - tools_json = nlohmann::json::array(); - } - } - - // apply the template with tools - return apply(messages_json, tools_json); -} - -std::optional JinjaChatTemplate::apply( - const ChatMessages& messages, - const std::vector& proto_tools) const { - // convert the messages to json object - nlohmann::ordered_json messages_json = nlohmann::json::array(); - for (const auto& message : messages) { - nlohmann::ordered_json message_json; - message_json["role"] = message.role; - message_json["content"] = message.content; - messages_json.push_back(message_json); - } - - // convert protobuf tools to json object - nlohmann::ordered_json tools_json = nlohmann::json::array(); - if (!proto_tools.empty()) { - try { - for (const auto& proto_tool : proto_tools) { - nlohmann::ordered_json tool_json; - tool_json["type"] = proto_tool.type(); - - nlohmann::ordered_json function_json; - function_json["name"] = proto_tool.function().name(); - function_json["description"] = proto_tool.function().description(); - - if (proto_tool.function().has_parameters()) { - std::string parameters_json_str; - google::protobuf::util::JsonPrintOptions options; - options.add_whitespace = false; - options.preserve_proto_field_names = true; - auto status = google::protobuf::util::MessageToJsonString( - proto_tool.function().parameters(), - ¶meters_json_str, - options); - if (status.ok()) { - function_json["parameters"] = - nlohmann::json::parse(parameters_json_str); - } else { - LOG(WARNING) << "Failed to convert parameters Struct to JSON: " - << status.message() - << ", tool: " << proto_tool.function().name(); - function_json["parameters"] = nlohmann::json::object(); - } - } else { - function_json["parameters"] = nlohmann::json::object(); - } - - tool_json["function"] = function_json; - tools_json.push_back(tool_json); - } - } catch (const std::exception& e) { - LOG(WARNING) << "Failed to convert protobuf tools to JSON: " << e.what(); - // Continue with empty tools array - tools_json = nlohmann::json::array(); - } - } - - // apply the template with tools - return apply(messages_json, tools_json); -} - nlohmann::ordered_json JinjaChatTemplate::get_mm_content( const Message::MMContentVec& vec) const { nlohmann::ordered_json content_json = nlohmann::json::array(); diff --git a/xllm/core/framework/chat_template/jinja_chat_template.h b/xllm/core/framework/chat_template/jinja_chat_template.h index cd2d7e4b..fcb95b1f 100644 --- a/xllm/core/framework/chat_template/jinja_chat_template.h +++ b/xllm/core/framework/chat_template/jinja_chat_template.h @@ -7,6 +7,7 @@ #include #include +#include "core/common/types.h" #include "framework/tokenizer/tokenizer_args.h" namespace xllm { @@ -52,10 +53,17 @@ class JinjaChatTemplate { std::optional apply(const ChatMessages& messages) const; + std::optional apply( + const ChatMessages& messages, + const std::vector& json_tools) const; + // expose this function for testing // apply the template to the values in the json object std::optional apply(nlohmann::ordered_json& messages) const; + std::optional apply(nlohmann::ordered_json& messages, + const nlohmann::ordered_json& tools) const; + private: nlohmann::ordered_json get_mm_content(const Message::MMContentVec& vec) const; diff --git a/xllm/core/framework/request/request_params.cpp b/xllm/core/framework/request/request_params.cpp index 6941cd00..10fd3b9b 100644 --- a/xllm/core/framework/request/request_params.cpp +++ b/xllm/core/framework/request/request_params.cpp @@ -25,74 +25,6 @@ std::string generate_chat_request_id() { } // namespace -nlohmann::json RequestParams::proto_value_to_json( - const google::protobuf::Value& pb_value) { - switch (pb_value.kind_case()) { - case google::protobuf::Value::kNullValue: - return nlohmann::json(nullptr); - - case google::protobuf::Value::kNumberValue: - return nlohmann::json(pb_value.number_value()); - - case google::protobuf::Value::kStringValue: - return nlohmann::json(pb_value.string_value()); - - case google::protobuf::Value::kBoolValue: - return nlohmann::json(pb_value.bool_value()); - - case google::protobuf::Value::kStructValue: - return proto_struct_to_json(pb_value.struct_value()); - - case google::protobuf::Value::kListValue: { - nlohmann::json array = nlohmann::json::array(); - const auto& list = pb_value.list_value(); - for (const auto& item : list.values()) { - array.push_back(proto_value_to_json(item)); - } - return array; - } - - case google::protobuf::Value::KIND_NOT_SET: - default: - return nlohmann::json(nullptr); - } -} - -nlohmann::json RequestParams::proto_struct_to_json( - const google::protobuf::Struct& pb_struct) { - nlohmann::json result = nlohmann::json::object(); - - for (const auto& field : pb_struct.fields()) { - result[field.first] = proto_value_to_json(field.second); - } - - return result; -} - -void RequestParams::parse_tools_from_proto( - const google::protobuf::RepeatedPtrField& proto_tools) { - tools.clear(); - tools.reserve(proto_tools.size()); - - for (const auto& proto_tool : proto_tools) { - function_call::JsonTool json_tool; - json_tool.type = proto_tool.type(); - - const auto& proto_function = proto_tool.function(); - json_tool.function.name = proto_function.name(); - json_tool.function.description = proto_function.description(); - - if (proto_function.has_parameters()) { - json_tool.function.parameters = - proto_struct_to_json(proto_function.parameters()); - } else { - json_tool.function.parameters = nlohmann::json::object(); - } - - tools.emplace_back(std::move(json_tool)); - } -} - RequestParams::RequestParams(const proto::CompletionRequest& request, const std::string& x_rid, const std::string& x_rtime) { @@ -162,6 +94,77 @@ RequestParams::RequestParams(const proto::CompletionRequest& request, } namespace { + +nlohmann::json proto_value_to_json(const google::protobuf::Value& pb_value); + +nlohmann::json proto_struct_to_json(const google::protobuf::Struct& pb_struct) { + nlohmann::json result = nlohmann::json::object(); + + for (const auto& field : pb_struct.fields()) { + result[field.first] = proto_value_to_json(field.second); + } + + return result; +} + +nlohmann::json proto_value_to_json(const google::protobuf::Value& pb_value) { + switch (pb_value.kind_case()) { + case google::protobuf::Value::kNullValue: + return nlohmann::json(nullptr); + + case google::protobuf::Value::kNumberValue: + return nlohmann::json(pb_value.number_value()); + + case google::protobuf::Value::kStringValue: + return nlohmann::json(pb_value.string_value()); + + case google::protobuf::Value::kBoolValue: + return nlohmann::json(pb_value.bool_value()); + + case google::protobuf::Value::kStructValue: + return proto_struct_to_json(pb_value.struct_value()); + + case google::protobuf::Value::kListValue: { + nlohmann::json array = nlohmann::json::array(); + const auto& list = pb_value.list_value(); + for (const auto& item : list.values()) { + array.push_back(proto_value_to_json(item)); + } + return array; + } + + case google::protobuf::Value::KIND_NOT_SET: + default: + return nlohmann::json(nullptr); + } +} + +std::vector parse_tools_from_proto( + const google::protobuf::RepeatedPtrField& proto_tools) { + std::vector tools; + tools.clear(); + tools.reserve(proto_tools.size()); + + for (const auto& proto_tool : proto_tools) { + xllm::JsonTool json_tool; + json_tool.type = proto_tool.type(); + + const auto& proto_function = proto_tool.function(); + json_tool.function.name = proto_function.name(); + json_tool.function.description = proto_function.description(); + + if (proto_function.has_parameters()) { + json_tool.function.parameters = + proto_struct_to_json(proto_function.parameters()); + } else { + json_tool.function.parameters = nlohmann::json::object(); + } + + tools.emplace_back(std::move(json_tool)); + } + return tools; +} + template void InitFromChatRequest(RequestParams& params, const ChatRequest& request) { if (request.has_request_id()) { @@ -228,12 +231,12 @@ void InitFromChatRequest(RequestParams& params, const ChatRequest& request) { // Parse tools from proto request if (request.tools_size() > 0) { - parse_tools_from_proto(request.tools()); + params.tools = parse_tools_from_proto(request.tools()); if (request.has_tool_choice()) { - tool_choice = request.tool_choice(); + params.tool_choice = request.tool_choice(); } else { - tool_choice = "auto"; + params.tool_choice = "auto"; } } } diff --git a/xllm/core/framework/request/request_params.h b/xllm/core/framework/request/request_params.h index e2cea828..f10089bb 100644 --- a/xllm/core/framework/request/request_params.h +++ b/xllm/core/framework/request/request_params.h @@ -9,8 +9,8 @@ #include "common/macros.h" #include "completion.pb.h" #include "core/common/macros.h" +#include "core/common/types.h" #include "embedding.pb.h" -#include "function_call/core_types.h" #include "multimodal.pb.h" #include "request_output.h" @@ -105,17 +105,9 @@ struct RequestParams { std::string decode_address; // JSON-based tools (replacing proto_tools) - std::vector tools; + std::vector tools; std::string tool_choice = "auto"; bool has_tools() const { return !tools.empty(); } - - private: - void parse_tools_from_proto( - const google::protobuf::RepeatedPtrField& proto_tools); - - nlohmann::json proto_struct_to_json( - const google::protobuf::Struct& pb_struct); - nlohmann::json proto_value_to_json(const google::protobuf::Value& pb_value); }; } // namespace xllm diff --git a/xllm/core/function_call/CMakeLists.txt b/xllm/function_call/CMakeLists.txt similarity index 100% rename from xllm/core/function_call/CMakeLists.txt rename to xllm/function_call/CMakeLists.txt diff --git a/xllm/core/function_call/base_format_detector.cpp b/xllm/function_call/base_format_detector.cpp similarity index 98% rename from xllm/core/function_call/base_format_detector.cpp rename to xllm/function_call/base_format_detector.cpp index b4ccb3b3..7d031e11 100644 --- a/xllm/core/function_call/base_format_detector.cpp +++ b/xllm/function_call/base_format_detector.cpp @@ -4,7 +4,7 @@ #include #include -namespace llm { +namespace xllm { namespace function_call { BaseFormatDetector::BaseFormatDetector() @@ -96,4 +96,4 @@ std::vector BaseFormatDetector::parse_base_json( } } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/base_format_detector.h b/xllm/function_call/base_format_detector.h similarity index 98% rename from xllm/core/function_call/base_format_detector.h rename to xllm/function_call/base_format_detector.h index 7f677f7d..cf6c68f2 100644 --- a/xllm/core/function_call/base_format_detector.h +++ b/xllm/function_call/base_format_detector.h @@ -11,7 +11,7 @@ #include "chat.pb.h" #include "core_types.h" -namespace llm { +namespace xllm { namespace function_call { class BaseFormatDetector { @@ -72,4 +72,4 @@ class BaseFormatDetector { }; } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/core_types.h b/xllm/function_call/core_types.h similarity index 72% rename from xllm/core/function_call/core_types.h rename to xllm/function_call/core_types.h index d931fa42..2f7b17cc 100644 --- a/xllm/core/function_call/core_types.h +++ b/xllm/function_call/core_types.h @@ -6,29 +6,13 @@ #include #include -namespace llm { -namespace function_call { - -struct JsonFunction { - std::string name; - std::string description; - nlohmann::json parameters; - - JsonFunction() = default; - JsonFunction(const std::string& func_name, - const std::string& desc, - const nlohmann::json& params) - : name(func_name), description(desc), parameters(params) {} -}; +#include "core/common/types.h" -struct JsonTool { - std::string type; // "function" - JsonFunction function; +namespace xllm { +namespace function_call { - JsonTool() : type("function") {} - JsonTool(const std::string& tool_type, const JsonFunction& func) - : type(tool_type), function(func) {} -}; +using JsonFunction = xllm::JsonFunction; +using JsonTool = xllm::JsonTool; struct ToolCallItem { int tool_index; @@ -82,4 +66,4 @@ struct StructureInfo { using GetInfoFunc = std::function; } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/deepseekv3_detector.cpp b/xllm/function_call/deepseekv3_detector.cpp similarity index 99% rename from xllm/core/function_call/deepseekv3_detector.cpp rename to xllm/function_call/deepseekv3_detector.cpp index e3e1cc90..3c9ea919 100644 --- a/xllm/core/function_call/deepseekv3_detector.cpp +++ b/xllm/function_call/deepseekv3_detector.cpp @@ -5,7 +5,7 @@ #include #include -namespace llm { +namespace xllm { namespace function_call { DeepSeekV3Detector::DeepSeekV3Detector() : BaseFormatDetector() { @@ -172,4 +172,4 @@ StreamingParseResult DeepSeekV3Detector::detect_and_parse( } } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/deepseekv3_detector.h b/xllm/function_call/deepseekv3_detector.h similarity index 94% rename from xllm/core/function_call/deepseekv3_detector.h rename to xllm/function_call/deepseekv3_detector.h index 55468cc7..41e82c61 100644 --- a/xllm/core/function_call/deepseekv3_detector.h +++ b/xllm/function_call/deepseekv3_detector.h @@ -4,7 +4,7 @@ #include "base_format_detector.h" -namespace llm { +namespace xllm { namespace function_call { class DeepSeekV3Detector : public BaseFormatDetector { @@ -28,4 +28,4 @@ class DeepSeekV3Detector : public BaseFormatDetector { }; } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/deepseekv3_detector_test.cpp b/xllm/function_call/deepseekv3_detector_test.cpp similarity index 99% rename from xllm/core/function_call/deepseekv3_detector_test.cpp rename to xllm/function_call/deepseekv3_detector_test.cpp index 5284f08f..df014b65 100644 --- a/xllm/core/function_call/deepseekv3_detector_test.cpp +++ b/xllm/function_call/deepseekv3_detector_test.cpp @@ -6,7 +6,7 @@ #include #include -namespace llm { +namespace xllm { namespace function_call { class DeepSeekV3DetectorTest : public ::testing::Test { @@ -415,4 +415,4 @@ TEST_F(DeepSeekV3DetectorTest, MalformedTokensHandling) { } } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/function_call.h b/xllm/function_call/function_call.h similarity index 94% rename from xllm/core/function_call/function_call.h rename to xllm/function_call/function_call.h index 87950ebc..015f614a 100644 --- a/xllm/core/function_call/function_call.h +++ b/xllm/function_call/function_call.h @@ -2,12 +2,12 @@ #include "base_format_detector.h" #include "core_types.h" +#include "deepseekv3_detector.h" #include "function_call_parser.h" -#include "qwen25_detector.h" #include "kimik2_detector.h" -#include "deepseekv3_detector.h" +#include "qwen25_detector.h" -namespace llm { +namespace xllm { namespace function_call { inline std::vector parse(const std::string& text, @@ -22,4 +22,4 @@ inline bool has_calls(const std::string& text, } } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/function_call_parser.cpp b/xllm/function_call/function_call_parser.cpp similarity index 98% rename from xllm/core/function_call/function_call_parser.cpp rename to xllm/function_call/function_call_parser.cpp index 2fe6759c..b4643973 100644 --- a/xllm/core/function_call/function_call_parser.cpp +++ b/xllm/function_call/function_call_parser.cpp @@ -3,11 +3,11 @@ #include #include -#include "common/uuid.h" -#include "qwen25_detector.h" -#include "kimik2_detector.h" +#include "core/util/uuid.h" #include "deepseekv3_detector.h" -namespace llm { +#include "kimik2_detector.h" +#include "qwen25_detector.h" +namespace xllm { namespace function_call { const std::unordered_map @@ -67,11 +67,11 @@ std::unique_ptr FunctionCallParser::create_detector( if (it->second == "qwen25") { return std::make_unique(); } - + if (it->second == "kimi_k2") { return std::make_unique(); } - + if (it->second == "deepseekv3") { return std::make_unique(); } @@ -120,4 +120,4 @@ std::string generate_tool_call_id() { return "call_" + short_uuid.random(); } } // namespace utils } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/function_call_parser.h b/xllm/function_call/function_call_parser.h similarity index 97% rename from xllm/core/function_call/function_call_parser.h rename to xllm/function_call/function_call_parser.h index 94a36933..920b2a64 100644 --- a/xllm/core/function_call/function_call_parser.h +++ b/xllm/function_call/function_call_parser.h @@ -9,7 +9,7 @@ #include "base_format_detector.h" #include "core_types.h" -namespace llm { +namespace xllm { namespace function_call { class FunctionCallParser { @@ -60,4 +60,4 @@ std::string generate_tool_call_id(); } // namespace utils } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/kimik2_detector.cpp b/xllm/function_call/kimik2_detector.cpp similarity index 99% rename from xllm/core/function_call/kimik2_detector.cpp rename to xllm/function_call/kimik2_detector.cpp index 72e383e6..e79e15ba 100644 --- a/xllm/core/function_call/kimik2_detector.cpp +++ b/xllm/function_call/kimik2_detector.cpp @@ -4,7 +4,7 @@ #include #include -namespace llm { +namespace xllm { namespace function_call { KimiK2Detector::KimiK2Detector() : BaseFormatDetector() { @@ -149,4 +149,4 @@ int KimiK2Detector::extract_function_index( } } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/kimik2_detector.h b/xllm/function_call/kimik2_detector.h similarity index 96% rename from xllm/core/function_call/kimik2_detector.h rename to xllm/function_call/kimik2_detector.h index 577abd27..9cfa6627 100644 --- a/xllm/core/function_call/kimik2_detector.h +++ b/xllm/function_call/kimik2_detector.h @@ -5,7 +5,7 @@ #include "base_format_detector.h" -namespace llm { +namespace xllm { namespace function_call { /** @@ -51,4 +51,4 @@ class KimiK2Detector : public BaseFormatDetector { }; } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/kimik2_detector_test.cpp b/xllm/function_call/kimik2_detector_test.cpp similarity index 99% rename from xllm/core/function_call/kimik2_detector_test.cpp rename to xllm/function_call/kimik2_detector_test.cpp index 45ee632f..024183da 100644 --- a/xllm/core/function_call/kimik2_detector_test.cpp +++ b/xllm/function_call/kimik2_detector_test.cpp @@ -6,7 +6,7 @@ #include #include -namespace llm { +namespace xllm { namespace function_call { class KimiK2DetectorTest : public ::testing::Test { @@ -452,4 +452,4 @@ TEST_F(KimiK2DetectorTest, NestedBracesInJson) { } } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/qwen25_detector.cpp b/xllm/function_call/qwen25_detector.cpp similarity index 98% rename from xllm/core/function_call/qwen25_detector.cpp rename to xllm/function_call/qwen25_detector.cpp index 7e008c6b..224820bf 100644 --- a/xllm/core/function_call/qwen25_detector.cpp +++ b/xllm/function_call/qwen25_detector.cpp @@ -4,7 +4,7 @@ #include #include -namespace llm { +namespace xllm { namespace function_call { Qwen25Detector::Qwen25Detector() : BaseFormatDetector() { @@ -113,4 +113,4 @@ StreamingParseResult Qwen25Detector::detect_and_parse( } } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/qwen25_detector.h b/xllm/function_call/qwen25_detector.h similarity index 95% rename from xllm/core/function_call/qwen25_detector.h rename to xllm/function_call/qwen25_detector.h index ac78ec6e..9f40e959 100644 --- a/xllm/core/function_call/qwen25_detector.h +++ b/xllm/function_call/qwen25_detector.h @@ -6,7 +6,7 @@ #include "base_format_detector.h" -namespace llm { +namespace xllm { namespace function_call { class Qwen25Detector : public BaseFormatDetector { @@ -34,4 +34,4 @@ class Qwen25Detector : public BaseFormatDetector { }; } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/function_call/qwen25_detector_test.cpp b/xllm/function_call/qwen25_detector_test.cpp similarity index 99% rename from xllm/core/function_call/qwen25_detector_test.cpp rename to xllm/function_call/qwen25_detector_test.cpp index 6a6f1976..732bb69c 100644 --- a/xllm/core/function_call/qwen25_detector_test.cpp +++ b/xllm/function_call/qwen25_detector_test.cpp @@ -6,7 +6,7 @@ #include #include -namespace llm { +namespace xllm { namespace function_call { class Qwen25DetectorTest : public ::testing::Test { @@ -325,4 +325,4 @@ TEST_F(Qwen25DetectorTest, PerformanceWithManyToolCalls) { } } // namespace function_call -} // namespace llm \ No newline at end of file +} // namespace xllm \ No newline at end of file diff --git a/xllm/proto/chat.proto b/xllm/proto/chat.proto index cfc03b01..c41557fa 100644 --- a/xllm/proto/chat.proto +++ b/xllm/proto/chat.proto @@ -4,31 +4,6 @@ option go_package = "jd.com/jd-infer/xllm;xllm"; package xllm.proto; import "common.proto"; -import "google/protobuf/struct.proto"; - -message Function { - string name = 1; - string description = 2; - google.protobuf.Struct parameters = 3; - // string parameters = 3; -} - -message Tool { - string type = 1; // "function" - Function function = 2; -} - -message ToolCall { - optional uint32 index = 1; - optional string id = 2; - string type = 3; // "function" - FunctionCall function = 4; -} - -message FunctionCall { - string name = 1; - string arguments = 2; // JSON string -} message ChatMessage { // the role of the messages author. One of "system", "user", "assistant". diff --git a/xllm/proto/common.proto b/xllm/proto/common.proto index 6e621339..6897f6d0 100644 --- a/xllm/proto/common.proto +++ b/xllm/proto/common.proto @@ -3,6 +3,31 @@ syntax = "proto3"; option go_package = "jd.com/jd-infer/xllm;xllm"; package xllm.proto; +import "google/protobuf/struct.proto"; + +message Function { + string name = 1; + string description = 2; + google.protobuf.Struct parameters = 3; +} + +message Tool { + string type = 1; // "function" + Function function = 2; +} + +message ToolCall { + optional uint32 index = 1; + optional string id = 2; + string type = 3; // "function" + FunctionCall function = 4; +} + +message FunctionCall { + string name = 1; + string arguments = 2; // JSON string +} + message Usage { // the number of tokens in the prompt. optional int32 prompt_tokens = 1 [json_name="prompt_tokens"]; diff --git a/xllm/proto/multimodal.proto b/xllm/proto/multimodal.proto index 3e738154..41a3cfb2 100644 --- a/xllm/proto/multimodal.proto +++ b/xllm/proto/multimodal.proto @@ -31,6 +31,9 @@ message MMChatMessage { // the content of the message. null for assistant messages with function calls. repeated MMInputData content = 2; + + repeated ToolCall tool_calls = 3; + optional string tool_call_id = 4; } // Next Id: 27 @@ -120,5 +123,9 @@ message MMChatRequest { optional string service_request_id = 25; + Routing routing = 26; + + repeated Tool tools = 27; + optional string tool_choice = 28; }