diff --git a/xllm/CMakeLists.txt b/xllm/CMakeLists.txt index 0f00463f..9a4f7770 100644 --- a/xllm/CMakeLists.txt +++ b/xllm/CMakeLists.txt @@ -4,6 +4,7 @@ include_directories(.) add_subdirectory(api_service) add_subdirectory(core) +add_subdirectory(function_call) add_subdirectory(models) add_subdirectory(proto) add_subdirectory(processors) diff --git a/xllm/api_service/CMakeLists.txt b/xllm/api_service/CMakeLists.txt index f9ff45a5..8c7dd5ec 100644 --- a/xllm/api_service/CMakeLists.txt +++ b/xllm/api_service/CMakeLists.txt @@ -29,5 +29,6 @@ cc_library( proto::xllm_proto absl::flat_hash_set absl::random_random + :function_call ) diff --git a/xllm/api_service/api_service.cpp b/xllm/api_service/api_service.cpp index d632a5cf..d882505f 100644 --- a/xllm/api_service/api_service.cpp +++ b/xllm/api_service/api_service.cpp @@ -1,6 +1,7 @@ #include "api_service.h" #include +#include #include #include @@ -15,7 +16,6 @@ #include "models.pb.h" #include "service_impl_factory.h" #include "xllm_metrics.h" - namespace xllm { APIService::APIService(Master* master, @@ -71,7 +71,6 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller, auto ctrl = reinterpret_cast(controller); std::string attachment = std::move(ctrl->request_attachment().to_string()); - std::string error; auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); if (!st) { @@ -106,10 +105,13 @@ void ChatCompletionsImpl(std::unique_ptr& service, std::string attachment = std::move(ctrl->request_attachment().to_string()); std::string error; - auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); - if (!st) { - ctrl->SetFailed(error); - LOG(ERROR) << "parse json to proto failed: " << error; + google::protobuf::util::JsonParseOptions options; + options.ignore_unknown_fields = true; + auto json_status = + google::protobuf::util::JsonStringToMessage(attachment, req_pb, options); + if (!json_status.ok()) { + ctrl->SetFailed(json_status.ToString()); + LOG(ERROR) << "parse json to proto failed: " << json_status.ToString(); return; } @@ -175,7 +177,6 @@ void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller, auto ctrl = reinterpret_cast(controller); std::string attachment = std::move(ctrl->request_attachment().to_string()); - std::string error; auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); if (!st) { @@ -289,7 +290,6 @@ void APIService::LinkCluster(::google::protobuf::RpcController* controller, auto ctrl = reinterpret_cast(controller); std::string attachment = std::move(ctrl->request_attachment().to_string()); - std::string error; auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); if (!st) { @@ -344,7 +344,6 @@ void APIService::UnlinkCluster(::google::protobuf::RpcController* controller, auto ctrl = reinterpret_cast(controller); std::string attachment = std::move(ctrl->request_attachment().to_string()); - std::string error; auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); if (!st) { diff --git a/xllm/api_service/chat_service_impl.cpp b/xllm/api_service/chat_service_impl.cpp index 571a27e0..d4e979d4 100644 --- a/xllm/api_service/chat_service_impl.cpp +++ b/xllm/api_service/chat_service_impl.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -12,16 +13,76 @@ #include #include "core/common/instance_name.h" +#include "core/common/types.h" #include "core/framework/request/mm_input_helper.h" #include "core/framework/request/request_params.h" #include "core/runtime/llm_master.h" #include "core/runtime/vlm_master.h" #include "core/util/utils.h" #include "core/util/uuid.h" +#include "function_call/function_call.h" namespace xllm { namespace { +struct ToolCallResult { + std::optional> tool_calls; + std::string text; + std::string finish_reason; +}; + +ToolCallResult process_tool_calls(std::string text, + const std::vector& tools, + const std::string& parser_format, + std::string finish_reason, + google::protobuf::Arena* arena = nullptr) { + ToolCallResult result; + + function_call::FunctionCallParser parser(tools, parser_format); + + if (!parser.has_tool_call(text)) { + result.text = std::move(text); + result.finish_reason = std::move(finish_reason); + return result; + } + + if (finish_reason == "stop") { + result.finish_reason = "tool_calls"; + } else { + result.finish_reason = std::move(finish_reason); + } + + try { + auto [parsed_text, call_info_list] = parser.parse_non_stream(text); + result.text = std::move(parsed_text); + + google::protobuf::RepeatedPtrField tool_calls; + + for (const auto& call_info : call_info_list) { + proto::ToolCall* tool_call = + arena ? google::protobuf::Arena::CreateMessage(arena) + : new proto::ToolCall(); + + tool_call->set_id(function_call::utils::generate_tool_call_id()); + tool_call->set_type("function"); + + auto* function = tool_call->mutable_function(); + if (call_info.name) { + function->set_name(*call_info.name); + } + function->set_arguments(call_info.parameters); + + tool_calls.AddAllocated(tool_call); + } + + result.tool_calls = std::move(tool_calls); + } catch (const std::exception& e) { + LOG(ERROR) << "Tool call parsing error: " << e.what(); + } + + return result; +} + void set_logprobs(proto::ChatChoice* choice, const std::optional>& logprobs) { if (!logprobs.has_value() || logprobs.value().empty()) { @@ -140,7 +201,9 @@ bool send_result_to_client_brpc(std::shared_ptr call, const std::string& request_id, int64_t created_time, const std::string& model, - const RequestOutput& req_output) { + const RequestOutput& req_output, + const std::string& parser_format = "", + const std::vector& tools = {}) { auto& response = call->response(); response.set_object("chat.completion"); response.set_id(request_id); @@ -154,9 +217,34 @@ bool send_result_to_client_brpc(std::shared_ptr call, set_logprobs(choice, output.logprobs); auto* message = choice->mutable_message(); message->set_role("assistant"); - message->set_content(output.text); - if (output.finish_reason.has_value()) { - choice->set_finish_reason(output.finish_reason.value()); + + auto set_output_and_finish_reason = [&]() { + message->set_content(output.text); + if (output.finish_reason.has_value()) { + choice->set_finish_reason(output.finish_reason.value()); + } + }; + + if (!tools.empty() && !parser_format.empty()) { + auto* arena = response.GetArena(); + auto result = process_tool_calls(output.text, + tools, + parser_format, + output.finish_reason.value_or(""), + arena); + + message->mutable_content()->swap(result.text); + + if (result.tool_calls) { + auto& source_tool_calls = *result.tool_calls; + message->mutable_tool_calls()->Swap(&source_tool_calls); + } + + if (!result.finish_reason.empty()) { + choice->mutable_finish_reason()->swap(result.finish_reason); + } + } else { + set_output_and_finish_reason(); } } @@ -231,7 +319,8 @@ void ChatServiceImpl::process_async_impl(std::shared_ptr call) { include_usage = include_usage, first_message_sent = std::unordered_set(), request_id = request_params.request_id, - created_time = absl::ToUnixSeconds(absl::Now())]( + created_time = absl::ToUnixSeconds(absl::Now()), + json_tools = request_params.tools]( const RequestOutput& req_output) mutable -> bool { if (req_output.status.has_value()) { const auto& status = req_output.status.value(); @@ -250,17 +339,48 @@ void ChatServiceImpl::process_async_impl(std::shared_ptr call) { master->get_rate_limiter()->decrease_one_request(); } + const std::string parser_format = + master->options().tool_call_parser().value_or(""); + const bool has_tool_support = + !json_tools.empty() && !parser_format.empty(); + if (stream) { - return send_delta_to_client_brpc(call, - include_usage, - &first_message_sent, - request_id, - created_time, - model, - req_output); + if (has_tool_support) { + // TODO: Support tool call streaming output + LOG(ERROR) << "Tool call does not support streaming output"; + return send_delta_to_client_brpc(call, + include_usage, + &first_message_sent, + request_id, + created_time, + model, + req_output); + } else { + // Stream response without tool support + return send_delta_to_client_brpc(call, + include_usage, + &first_message_sent, + request_id, + created_time, + model, + req_output); + } + } else { + if (has_tool_support) { + // Non-stream response with tool support + return send_result_to_client_brpc(call, + request_id, + created_time, + model, + req_output, + parser_format, + json_tools); + } else { + // Non-stream response without tool support + return send_result_to_client_brpc( + call, request_id, created_time, model, req_output); + } } - return send_result_to_client_brpc( - call, request_id, created_time, model, req_output); }); } @@ -337,6 +457,7 @@ void MMChatServiceImpl::process_async(std::shared_ptr call) { } if (stream) { + // send delta to client return send_delta_to_client_brpc(call, include_usage, &first_message_sent, diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index e858a411..78a3fee7 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -174,6 +174,13 @@ DEFINE_bool(enable_atb_comm_multiprocess, false, "whether to use multiprocess mode."); +// --- function call config --- + +DEFINE_string(tool_call_parser, + "", + "Specify the parser for handling tool-call interactions. " + "Options include: 'qwen25'"); + DEFINE_bool(enable_atb_spec_kernel, false, "whether to use ATB speculative kernel."); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 67d903f2..d69b41cc 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -98,6 +98,8 @@ DECLARE_bool(disable_custom_kernels); DECLARE_bool(enable_atb_comm_multiprocess); +DECLARE_string(tool_call_parser); + DECLARE_bool(enable_atb_spec_kernel); DECLARE_string(etcd_addr); diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index 60219d9a..4a2043f7 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -95,6 +95,8 @@ class Options { PROPERTY(std::optional, etcd_addr); PROPERTY(bool, enable_service_routing) = false; + + PROPERTY(std::optional, tool_call_parser); }; } // namespace xllm diff --git a/xllm/core/common/types.h b/xllm/core/common/types.h index 8d3e6e82..b7d9bd1a 100644 --- a/xllm/core/common/types.h +++ b/xllm/core/common/types.h @@ -214,4 +214,26 @@ struct DeviceStats { // TODO: add more device stats }; +// Function call related types +struct JsonFunction { + std::string name; + std::string description; + nlohmann::json parameters; + + JsonFunction() = default; + JsonFunction(const std::string& func_name, + const std::string& desc, + const nlohmann::json& params) + : name(func_name), description(desc), parameters(params) {} +}; + +struct JsonTool { + std::string type; // "function" + JsonFunction function; + + JsonTool() : type("function") {} + JsonTool(const std::string& tool_type, const JsonFunction& func) + : type(tool_type), function(func) {} +}; + } // namespace xllm diff --git a/xllm/core/framework/chat_template/jinja_chat_template.cpp b/xllm/core/framework/chat_template/jinja_chat_template.cpp index 119b7a2f..5e114b1c 100644 --- a/xllm/core/framework/chat_template/jinja_chat_template.cpp +++ b/xllm/core/framework/chat_template/jinja_chat_template.cpp @@ -23,6 +23,20 @@ JinjaChatTemplate::JinjaChatTemplate(const TokenizerArgs& args) : args_(args) { std::optional JinjaChatTemplate::apply( const ChatMessages& messages) const { + const std::vector empty_tools; + return apply(messages, empty_tools); +} + +std::optional JinjaChatTemplate::apply( + nlohmann::ordered_json& messages) const { + // Call the overloaded method with empty tools + nlohmann::ordered_json empty_tools = nlohmann::json::array(); + return apply(messages, empty_tools); +} + +std::optional JinjaChatTemplate::apply( + const ChatMessages& messages, + const std::vector& json_tools) const { // convert the messages to json object nlohmann::ordered_json messages_json = nlohmann::json::array(); for (const auto& message : messages) { @@ -38,14 +52,30 @@ std::optional JinjaChatTemplate::apply( messages_json.push_back(message_json); } + + nlohmann::ordered_json tools_json = nlohmann::json::array(); + for (const auto& json_tool : json_tools) { + nlohmann::ordered_json tool_json; + tool_json["type"] = json_tool.type; + + nlohmann::ordered_json function_json; + function_json["name"] = json_tool.function.name; + function_json["description"] = json_tool.function.description; + function_json["parameters"] = json_tool.function.parameters; + + tool_json["function"] = function_json; + tools_json.push_back(tool_json); + } // apply the template - return apply(messages_json); + return apply(messages_json, tools_json); } std::optional JinjaChatTemplate::apply( - nlohmann::ordered_json& messages) const { + nlohmann::ordered_json& messages, + const nlohmann::ordered_json& tools) const { minja::chat_template_inputs input; input.messages = messages; + input.tools = tools; input.add_generation_prompt = true; minja::chat_template_options options; diff --git a/xllm/core/framework/chat_template/jinja_chat_template.h b/xllm/core/framework/chat_template/jinja_chat_template.h index cd2d7e4b..fcb95b1f 100644 --- a/xllm/core/framework/chat_template/jinja_chat_template.h +++ b/xllm/core/framework/chat_template/jinja_chat_template.h @@ -7,6 +7,7 @@ #include #include +#include "core/common/types.h" #include "framework/tokenizer/tokenizer_args.h" namespace xllm { @@ -52,10 +53,17 @@ class JinjaChatTemplate { std::optional apply(const ChatMessages& messages) const; + std::optional apply( + const ChatMessages& messages, + const std::vector& json_tools) const; + // expose this function for testing // apply the template to the values in the json object std::optional apply(nlohmann::ordered_json& messages) const; + std::optional apply(nlohmann::ordered_json& messages, + const nlohmann::ordered_json& tools) const; + private: nlohmann::ordered_json get_mm_content(const Message::MMContentVec& vec) const; diff --git a/xllm/core/framework/request/request_params.cpp b/xllm/core/framework/request/request_params.cpp index 75211217..10fd3b9b 100644 --- a/xllm/core/framework/request/request_params.cpp +++ b/xllm/core/framework/request/request_params.cpp @@ -94,6 +94,77 @@ RequestParams::RequestParams(const proto::CompletionRequest& request, } namespace { + +nlohmann::json proto_value_to_json(const google::protobuf::Value& pb_value); + +nlohmann::json proto_struct_to_json(const google::protobuf::Struct& pb_struct) { + nlohmann::json result = nlohmann::json::object(); + + for (const auto& field : pb_struct.fields()) { + result[field.first] = proto_value_to_json(field.second); + } + + return result; +} + +nlohmann::json proto_value_to_json(const google::protobuf::Value& pb_value) { + switch (pb_value.kind_case()) { + case google::protobuf::Value::kNullValue: + return nlohmann::json(nullptr); + + case google::protobuf::Value::kNumberValue: + return nlohmann::json(pb_value.number_value()); + + case google::protobuf::Value::kStringValue: + return nlohmann::json(pb_value.string_value()); + + case google::protobuf::Value::kBoolValue: + return nlohmann::json(pb_value.bool_value()); + + case google::protobuf::Value::kStructValue: + return proto_struct_to_json(pb_value.struct_value()); + + case google::protobuf::Value::kListValue: { + nlohmann::json array = nlohmann::json::array(); + const auto& list = pb_value.list_value(); + for (const auto& item : list.values()) { + array.push_back(proto_value_to_json(item)); + } + return array; + } + + case google::protobuf::Value::KIND_NOT_SET: + default: + return nlohmann::json(nullptr); + } +} + +std::vector parse_tools_from_proto( + const google::protobuf::RepeatedPtrField& proto_tools) { + std::vector tools; + tools.clear(); + tools.reserve(proto_tools.size()); + + for (const auto& proto_tool : proto_tools) { + xllm::JsonTool json_tool; + json_tool.type = proto_tool.type(); + + const auto& proto_function = proto_tool.function(); + json_tool.function.name = proto_function.name(); + json_tool.function.description = proto_function.description(); + + if (proto_function.has_parameters()) { + json_tool.function.parameters = + proto_struct_to_json(proto_function.parameters()); + } else { + json_tool.function.parameters = nlohmann::json::object(); + } + + tools.emplace_back(std::move(json_tool)); + } + return tools; +} + template void InitFromChatRequest(RequestParams& params, const ChatRequest& request) { if (request.has_request_id()) { @@ -157,6 +228,17 @@ void InitFromChatRequest(RequestParams& params, const ChatRequest& request) { params.streaming = false; } } + + // Parse tools from proto request + if (request.tools_size() > 0) { + params.tools = parse_tools_from_proto(request.tools()); + + if (request.has_tool_choice()) { + params.tool_choice = request.tool_choice(); + } else { + params.tool_choice = "auto"; + } + } } } // namespace diff --git a/xllm/core/framework/request/request_params.h b/xllm/core/framework/request/request_params.h index 225ece91..f10089bb 100644 --- a/xllm/core/framework/request/request_params.h +++ b/xllm/core/framework/request/request_params.h @@ -1,12 +1,15 @@ #pragma once #include +#include #include #include #include #include "chat.pb.h" +#include "common/macros.h" #include "completion.pb.h" #include "core/common/macros.h" +#include "core/common/types.h" #include "embedding.pb.h" #include "multimodal.pb.h" #include "request_output.h" @@ -100,6 +103,11 @@ struct RequestParams { // decode address. std::string decode_address; + + // JSON-based tools (replacing proto_tools) + std::vector tools; + std::string tool_choice = "auto"; + bool has_tools() const { return !tools.empty(); } }; } // namespace xllm diff --git a/xllm/core/runtime/llm_master.cpp b/xllm/core/runtime/llm_master.cpp index ec2af81b..183ea5c7 100644 --- a/xllm/core/runtime/llm_master.cpp +++ b/xllm/core/runtime/llm_master.cpp @@ -422,7 +422,13 @@ std::shared_ptr LLMMaster::generate_request( const RequestParams& sp, OutputCallback callback) { Timer timer; - auto prompt = chat_template_->apply(messages); + std::optional prompt; + if (sp.has_tools()) { + prompt = chat_template_->apply(messages, sp.tools); + } else { + prompt = chat_template_->apply(messages); + } + if (!prompt.has_value()) { CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "Failed to construct prompt from messages"); diff --git a/xllm/function_call/CMakeLists.txt b/xllm/function_call/CMakeLists.txt new file mode 100644 index 00000000..eb246a6a --- /dev/null +++ b/xllm/function_call/CMakeLists.txt @@ -0,0 +1,43 @@ +include(cc_library) +include(cc_test) + +cc_library ( + NAME + function_call + HDRS + core_types.h + base_format_detector.h + qwen25_detector.h + kimik2_detector.h + deepseekv3_detector.h + function_call_parser.h + function_call.h + SRCS + base_format_detector.cpp + qwen25_detector.cpp + kimik2_detector.cpp + deepseekv3_detector.cpp + function_call_parser.cpp + DEPS + nlohmann_json::nlohmann_json + glog::glog + proto::xllm_proto +) + +function(add_detector_test TEST_NAME) + cc_test( + NAME + ${TEST_NAME} + SRCS + ${TEST_NAME}.cpp + DEPS + :function_call + GTest::gtest + GTest::gtest_main + nlohmann_json::nlohmann_json + ) +endfunction() + +add_detector_test(qwen25_detector_test) +add_detector_test(kimik2_detector_test) +add_detector_test(deepseekv3_detector_test) diff --git a/xllm/function_call/base_format_detector.cpp b/xllm/function_call/base_format_detector.cpp new file mode 100644 index 00000000..7d031e11 --- /dev/null +++ b/xllm/function_call/base_format_detector.cpp @@ -0,0 +1,99 @@ +#include "base_format_detector.h" + +#include +#include +#include + +namespace xllm { +namespace function_call { + +BaseFormatDetector::BaseFormatDetector() + : current_tool_id_(-1), + current_tool_name_sent_(false), + bot_token_(""), + eot_token_(""), + tool_call_separator_(", ") {} + +std::unordered_map BaseFormatDetector::get_tool_indices( + const std::vector& tools) { + std::unordered_map indices; + for (size_t i = 0; i < tools.size(); ++i) { + if (!tools[i].function.name.empty()) { + indices[tools[i].function.name] = static_cast(i); + } else { + LOG(ERROR) << "Tool at index " << i + << " has empty function name, skipping"; + } + } + return indices; +} + +std::vector BaseFormatDetector::parse_base_json( + const nlohmann::json& json_obj, + const std::vector& tools) { + auto tool_indices = get_tool_indices(tools); + std::vector results; + + std::vector actions; + if (json_obj.is_array()) { + for (const auto& item : json_obj) { + actions.emplace_back(item); + } + } else { + actions.emplace_back(json_obj); + } + + for (const auto& act : actions) { + if (!act.is_object()) { + LOG(ERROR) << "Invalid tool call item, expected object, got: " + << act.type_name(); + continue; + } + + std::string name; + if (act.contains("name") && act["name"].is_string()) { + name = act["name"].get(); + } else { + LOG(ERROR) << "Invalid tool call: missing 'name' field or invalid type"; + continue; + } + + if (tool_indices.find(name) == tool_indices.end()) { + LOG(ERROR) << "Model attempted to call undefined function: " << name; + continue; + } + + nlohmann::json parameters = nlohmann::json::object(); + + if (act.contains("parameters")) { + parameters = act["parameters"]; + } else if (act.contains("arguments")) { + parameters = act["arguments"]; + } else { + LOG(ERROR) << "No parameters or arguments field found for tool: " << name; + } + + if (!parameters.is_object()) { + LOG(ERROR) << "Invalid arguments type for tool: " << name + << ", expected object, got: " << parameters.type_name(); + parameters = nlohmann::json::object(); + } + + std::string parameters_str; + try { + parameters_str = parameters.dump( + -1, ' ', false, nlohmann::json::error_handler_t::ignore); + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to serialize arguments for tool: " << name + << ", error: " << e.what(); + parameters_str = "{}"; + } + + results.emplace_back(-1, name, parameters_str); + } + + return results; +} + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/base_format_detector.h b/xllm/function_call/base_format_detector.h new file mode 100644 index 00000000..cf6c68f2 --- /dev/null +++ b/xllm/function_call/base_format_detector.h @@ -0,0 +1,75 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include "chat.pb.h" +#include "core_types.h" + +namespace xllm { +namespace function_call { + +class BaseFormatDetector { + public: + BaseFormatDetector(); + virtual ~BaseFormatDetector() = default; + + BaseFormatDetector(const BaseFormatDetector&) = delete; + BaseFormatDetector& operator=(const BaseFormatDetector&) = delete; + + protected: + // Streaming state management + // Buffer for accumulating incomplete patterns that arrive across multiple + // streaming chunks + std::string buffer_; + + // Stores complete tool call info (name and arguments) for each tool being + // parsed. Used by serving layer for completion handling when streaming ends. + // Format: [{"name": str, "arguments": dict}, ...] + std::vector> prev_tool_call_arr_; + + // Index of currently streaming tool call. Starts at -1 (no active tool), + // increments as each tool completes. Tracks which tool's arguments are + // streaming. + int current_tool_id_; + + // Flag for whether current tool's name has been sent to client. + // Tool names sent first with empty parameters, then arguments stream + // incrementally. + bool current_tool_name_sent_; + + // Tracks raw JSON string content streamed to client for each tool's + // arguments. Critical for serving layer to calculate remaining content when + // streaming ends. Each index corresponds to a tool_id. Example: + // ['{"location": "San Francisco"', '{"temp": 72'] + std::vector streamed_args_for_tool_; + + // Token configuration (override in subclasses) + std::string bot_token_; + std::string eot_token_; + std::string tool_call_separator_; + + // Tool indices cache + std::unordered_map tool_indices_; + + public: + std::unordered_map get_tool_indices( + const std::vector& tools); + + std::vector parse_base_json(const nlohmann::json& json_obj, + const std::vector& tools); + + virtual StreamingParseResult detect_and_parse( + const std::string& text, + const std::vector& tools) = 0; + + virtual bool has_tool_call(const std::string& text) = 0; +}; + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/core_types.h b/xllm/function_call/core_types.h new file mode 100644 index 00000000..2f7b17cc --- /dev/null +++ b/xllm/function_call/core_types.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "core/common/types.h" + +namespace xllm { +namespace function_call { + +using JsonFunction = xllm::JsonFunction; +using JsonTool = xllm::JsonTool; + +struct ToolCallItem { + int tool_index; + std::optional name; + std::string parameters; // JSON string + + ToolCallItem() : tool_index(-1), parameters("") {} + + ToolCallItem(int index, + const std::optional& func_name, + const std::string& params) + : tool_index(index), name(func_name), parameters(params) {} +}; + +struct StreamingParseResult { + std::string normal_text; + std::vector calls; + + StreamingParseResult() = default; + + StreamingParseResult(const std::string& text) : normal_text(text) {} + + StreamingParseResult(const std::vector& tool_calls) + : calls(tool_calls) {} + + StreamingParseResult(const std::string& text, + const std::vector& tool_calls) + : normal_text(text), calls(tool_calls) {} + + bool has_calls() const { return !calls.empty(); } + + void clear() { + normal_text.clear(); + calls.clear(); + } +}; + +struct StructureInfo { + std::string begin; + std::string end; + std::string trigger; + + StructureInfo() = default; + + StructureInfo(const std::string& begin_str, + const std::string& end_str, + const std::string& trigger_str) + : begin(begin_str), end(end_str), trigger(trigger_str) {} +}; + +using GetInfoFunc = std::function; + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/deepseekv3_detector.cpp b/xllm/function_call/deepseekv3_detector.cpp new file mode 100644 index 00000000..3c9ea919 --- /dev/null +++ b/xllm/function_call/deepseekv3_detector.cpp @@ -0,0 +1,175 @@ +#include "deepseekv3_detector.h" + +#include +#include +#include +#include + +namespace xllm { +namespace function_call { + +DeepSeekV3Detector::DeepSeekV3Detector() : BaseFormatDetector() { + bot_token_ = "<|tool▁calls▁begin|>"; + eot_token_ = "<|tool▁calls▁end|>"; + tool_call_separator_ = ""; +} + +bool DeepSeekV3Detector::has_tool_call(const std::string& text) { + return text.find(bot_token_) != std::string::npos; +} + +std::string_view DeepSeekV3Detector::trim_whitespace( + std::string_view str) const { + const char* whitespace = " \t\n\r"; + + size_t start = str.find_first_not_of(whitespace); + if (start == std::string_view::npos) { + return std::string_view{}; + } + + size_t end = str.find_last_not_of(whitespace); + + return str.substr(start, end - start + 1); +} + +std::vector> +DeepSeekV3Detector::find_tool_call_ranges(const std::string& text) const { + std::vector> ranges; + ranges.reserve(4); + + const std::string call_begin = "<|tool▁call▁begin|>"; + const std::string call_end = "<|tool▁call▁end|>"; + + size_t search_pos = 0; + const size_t call_begin_len = call_begin.length(); + const size_t call_end_len = call_end.length(); + + while (search_pos < text.length()) { + size_t start_pos = text.find(call_begin, search_pos); + if (start_pos == std::string::npos) { + break; + } + + size_t content_start = start_pos + call_begin_len; + size_t end_pos = text.find(call_end, content_start); + if (end_pos == std::string::npos) { + break; + } + + ranges.emplace_back(content_start, end_pos); + search_pos = end_pos + call_end_len; + } + + return ranges; +} + +StreamingParseResult DeepSeekV3Detector::detect_and_parse( + const std::string& text, + const std::vector& tools) { + size_t bot_token_pos = text.find(bot_token_); + + std::string normal_text; + if (bot_token_pos != std::string::npos) { + std::string_view normal_text_view(text.data(), bot_token_pos); + std::string_view trimmed = trim_whitespace(normal_text_view); + normal_text = std::string(trimmed); + } else { + std::string_view trimmed = trim_whitespace(text); + normal_text = std::string(trimmed); + return StreamingParseResult(normal_text); + } + + auto tool_call_ranges = find_tool_call_ranges(text); + + std::vector calls; + calls.reserve(tool_call_ranges.size()); + + for (const auto& range : tool_call_ranges) { + std::string_view content_view(text.data() + range.first, + range.second - range.first); + std::string_view trimmed_content = trim_whitespace(content_view); + + if (trimmed_content.empty()) { + continue; + } + + try { + // Parse DeepSeek V3 format: function_name\n```json\n{args}\n``` + const std::string tool_sep = "<|tool▁sep|>"; + size_t sep_pos = trimmed_content.find(tool_sep); + if (sep_pos == std::string_view::npos) { + LOG(ERROR) << "Failed to find tool separator in: " + << std::string(trimmed_content); + continue; + } + + // Extract function name (between tool_sep and first newline) + size_t name_start = sep_pos + tool_sep.length(); + size_t name_end = trimmed_content.find('\n', name_start); + if (name_end == std::string_view::npos) { + LOG(ERROR) << "Failed to find function name end in: " + << std::string(trimmed_content); + continue; + } + + std::string_view func_name_view = + trimmed_content.substr(name_start, name_end - name_start); + std::string_view func_name_trimmed = trim_whitespace(func_name_view); + std::string func_name(func_name_trimmed); + + // Find JSON block (between ```json\n and \n```) + const std::string json_start = "```json\n"; + const std::string json_end = "\n```"; + + size_t json_start_pos = trimmed_content.find(json_start, name_end); + if (json_start_pos == std::string_view::npos) { + LOG(ERROR) << "Failed to find JSON start in: " + << std::string(trimmed_content); + continue; + } + + size_t json_content_start = json_start_pos + json_start.length(); + size_t json_end_pos = trimmed_content.find(json_end, json_content_start); + if (json_end_pos == std::string_view::npos) { + LOG(ERROR) << "Failed to find JSON end in: " + << std::string(trimmed_content); + continue; + } + + std::string_view json_view = trimmed_content.substr( + json_content_start, json_end_pos - json_content_start); + std::string_view json_trimmed = trim_whitespace(json_view); + + // Parse JSON arguments + nlohmann::json func_args; + try { + std::string json_content(json_trimmed); + func_args = nlohmann::json::parse(json_content); + } catch (const nlohmann::json::parse_error& e) { + LOG(ERROR) << "Failed to parse JSON arguments: " + << std::string(json_trimmed) + << ", JSON parse error: " << e.what(); + continue; + } + + // Create JSON object for parse_base_json + nlohmann::json match_json; + match_json["name"] = func_name; + match_json["parameters"] = func_args; + + auto parsed_calls = parse_base_json(match_json, tools); + calls.insert(calls.end(), + std::make_move_iterator(parsed_calls.begin()), + std::make_move_iterator(parsed_calls.end())); + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to parse tool call: " + << std::string(trimmed_content) << ", error: " << e.what(); + continue; + } + } + + return StreamingParseResult(std::move(normal_text), std::move(calls)); +} + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/deepseekv3_detector.h b/xllm/function_call/deepseekv3_detector.h new file mode 100644 index 00000000..41e82c61 --- /dev/null +++ b/xllm/function_call/deepseekv3_detector.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include "base_format_detector.h" + +namespace xllm { +namespace function_call { + +class DeepSeekV3Detector : public BaseFormatDetector { + public: + DeepSeekV3Detector(); + + virtual ~DeepSeekV3Detector() = default; + + bool has_tool_call(const std::string& text) override; + + StreamingParseResult detect_and_parse( + const std::string& text, + const std::vector& tools) override; + + private: + std::string func_call_regex_; + std::string func_detail_regex_; + std::string_view trim_whitespace(std::string_view str) const; + std::vector> find_tool_call_ranges( + const std::string& text) const; +}; + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/deepseekv3_detector_test.cpp b/xllm/function_call/deepseekv3_detector_test.cpp new file mode 100644 index 00000000..df014b65 --- /dev/null +++ b/xllm/function_call/deepseekv3_detector_test.cpp @@ -0,0 +1,418 @@ +#include "deepseekv3_detector.h" + +#include + +#include +#include +#include + +namespace xllm { +namespace function_call { + +class DeepSeekV3DetectorTest : public ::testing::Test { + protected: + void SetUp() override { + detector_ = std::make_unique(); + + // Setup test tools + nlohmann::json weather_params = { + {"type", "object"}, + {"properties", + {{"location", + {{"type", "string"}, + {"description", "The city and state, e.g. San Francisco, CA"}}}, + {"unit", {{"type", "string"}, {"enum", {"celsius", "fahrenheit"}}}}}}, + {"required", {"location"}}}; + + JsonFunction weather_func("get_current_weather", + "Get the current weather in a given location", + weather_params); + weather_tool_ = JsonTool("function", weather_func); + + nlohmann::json calculator_params = { + {"type", "object"}, + {"properties", + {{"expression", + {{"type", "string"}, + {"description", "Mathematical expression to evaluate"}}}}}, + {"required", {"expression"}}}; + + JsonFunction calculator_func( + "calculate", "Calculate mathematical expressions", calculator_params); + calculator_tool_ = JsonTool("function", calculator_func); + + tools_ = {weather_tool_, calculator_tool_}; + } + + std::unique_ptr detector_; + JsonTool weather_tool_; + JsonTool calculator_tool_; + std::vector tools_; +}; + +// Test constructor and basic properties +TEST_F(DeepSeekV3DetectorTest, ConstructorInitializesCorrectly) { + EXPECT_NE(detector_, nullptr); + + // Test basic token detection + std::string text_with_tool_call = + "Some text " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test\n`" + "``json\n{\"name\": " + "\"test\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + std::string text_without_tool_call = + "Just normal text without any tool calls"; + + EXPECT_TRUE(detector_->has_tool_call(text_with_tool_call)); + EXPECT_FALSE(detector_->has_tool_call(text_without_tool_call)); +} + +// Test has_tool_call method +TEST_F(DeepSeekV3DetectorTest, HasToolCallDetection) { + // Test text containing tool calls + EXPECT_TRUE(detector_->has_tool_call("<|tool▁calls▁begin|>")); + EXPECT_TRUE(detector_->has_tool_call( + "Previous text <|tool▁calls▁begin|>Following content")); + EXPECT_TRUE(detector_->has_tool_call( + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test\n`" + "``json\n{\"name\": " + "\"test\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>")); + + // Test text not containing tool calls + EXPECT_FALSE(detector_->has_tool_call("")); + EXPECT_FALSE(detector_->has_tool_call("Regular text")); + EXPECT_FALSE(detector_->has_tool_call("tool_calls without special tokens")); + EXPECT_FALSE(detector_->has_tool_call(" without unicode tokens")); +} + +// Test single tool call parsing +TEST_F(DeepSeekV3DetectorTest, SingleToolCallParsing) { + std::string text = + "Please help me check the weather " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": \"Beijing\", \"unit\": " + "\"celsius\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Please help me check the weather"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + EXPECT_TRUE(call.name.has_value()); + EXPECT_EQ(call.name.value(), "get_current_weather"); + + // Verify parameter JSON + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_EQ(params["unit"], "celsius"); +} + +// Test multiple tool calls parsing +TEST_F(DeepSeekV3DetectorTest, MultipleToolCallsParsing) { + std::string text = + "Please help me check the weather and calculate an expression " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": " + "\"Shanghai\"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<" + "|tool▁sep|>calculate\n```json\n{\"expression\": \"2 + 3 * " + "4\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, + "Please help me check the weather and calculate an expression"); + EXPECT_EQ(result.calls.size(), 2); + + // Verify first tool call + const auto& call1 = result.calls[0]; + EXPECT_EQ(call1.tool_index, -1); // Base class always returns -1 + EXPECT_TRUE(call1.name.has_value()); + EXPECT_EQ(call1.name.value(), "get_current_weather"); + + nlohmann::json params1 = nlohmann::json::parse(call1.parameters); + EXPECT_EQ(params1["location"], "Shanghai"); + + // Verify second tool call + const auto& call2 = result.calls[1]; + EXPECT_EQ(call2.tool_index, -1); // Base class always returns -1 + EXPECT_TRUE(call2.name.has_value()); + EXPECT_EQ(call2.name.value(), "calculate"); + + nlohmann::json params2 = nlohmann::json::parse(call2.parameters); + EXPECT_EQ(params2["expression"], "2 + 3 * 4"); +} + +// Test DeepSeekV3 specific format with exact tokens +TEST_F(DeepSeekV3DetectorTest, DeepSeekV3SpecificFormat) { + std::string text = + "I need weather info " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": " + "\"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "I need weather info"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_TRUE(call.name.has_value()); + EXPECT_EQ(call.name.value(), "get_current_weather"); + + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Tokyo"); +} + +// Test invalid JSON handling +TEST_F(DeepSeekV3DetectorTest, InvalidJsonHandling) { + std::string text = + "Test invalid JSON " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": \"Beijing\", " + "invalid_json}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Test invalid JSON"); + EXPECT_EQ(result.calls.size(), 0); // Invalid JSON should be ignored +} + +// Test empty tool call content +TEST_F(DeepSeekV3DetectorTest, EmptyToolCallContent) { + std::string text = + "Test empty content " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>\n```" + "json\n \t\n \n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Test empty content"); + EXPECT_EQ(result.calls.size(), 0); // Empty content should be ignored +} + +// Test incomplete tool call (only start tag) +TEST_F(DeepSeekV3DetectorTest, IncompleteToolCall) { + std::string text = + "Incomplete tool call " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": \"Beijing\"}"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Incomplete tool call"); + EXPECT_EQ(result.calls.size(), 0); // Incomplete calls should be ignored +} + +// Test unknown tool name handling +TEST_F(DeepSeekV3DetectorTest, UnknownToolName) { + std::string text = + "Unknown tool " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>" + "unknown_tool\n```json\n{\"param\": " + "\"value\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Unknown tool"); + // Base class will skip unknown tools, so should be 0 calls + EXPECT_EQ(result.calls.size(), 0); +} + +// Test case with only normal text +TEST_F(DeepSeekV3DetectorTest, OnlyNormalText) { + std::string text = "This is a regular text without any tool calls."; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, + "This is a regular text without any tool calls."); + EXPECT_EQ(result.calls.size(), 0); + EXPECT_FALSE(result.has_calls()); +} + +// Test empty string input +TEST_F(DeepSeekV3DetectorTest, EmptyStringInput) { + std::string text = ""; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, ""); + EXPECT_EQ(result.calls.size(), 0); + EXPECT_FALSE(result.has_calls()); +} + +// Test whitespace-only input +TEST_F(DeepSeekV3DetectorTest, WhitespaceOnlyInput) { + std::string text = " \t\n\r "; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, ""); + EXPECT_EQ(result.calls.size(), 0); +} + +// Test complex nested JSON parameters +TEST_F(DeepSeekV3DetectorTest, ComplexNestedJsonParameters) { + std::string text = + "Complex parameter test " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": \"Beijing\", \"options\": " + "{\"include_forecast\": true, \"days\": 7, \"details\": " + "[\"temperature\", \"humidity\", " + "\"wind\"]}}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Complex parameter test"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_TRUE(params["options"]["include_forecast"]); + EXPECT_EQ(params["options"]["days"], 7); + EXPECT_EQ(params["options"]["details"].size(), 3); +} + +// Test tool call in the middle of text +TEST_F(DeepSeekV3DetectorTest, ToolCallInMiddleOfText) { + std::string text = + "Previous text " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>" + "calculate\n```json\n{\"expression\": " + "\"1+1\"}\n```<|tool▁call▁end|><|tool▁calls▁end|> Following text"; + + auto result = detector_->detect_and_parse(text, tools_); + + // Note: According to implementation, only text before tool call is preserved + // as normal_text + EXPECT_EQ(result.normal_text, "Previous text"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + EXPECT_EQ(call.name.value(), "calculate"); +} + +// Test special characters handling +TEST_F(DeepSeekV3DetectorTest, SpecialCharactersHandling) { + std::string text = + "Special characters test " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": \"New York City\", \"note\": " + "\"Contains " + "symbols!@#$%^&*()_+=\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Special characters test"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "New York City"); + EXPECT_EQ(params["note"], "Contains symbols!@#$%^&*()_+="); +} + +// Test whitespace trimming +TEST_F(DeepSeekV3DetectorTest, WhitespaceTrimming) { + std::string text_with_whitespace = + " \t\nPrevious text\r\n " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n```json\n{\"location\": " + "\"Beijing\"}\n```<|tool▁call▁end|><|tool▁calls▁end|> \t\r\n"; + + auto result = detector_->detect_and_parse(text_with_whitespace, tools_); + + // Verify normal text is correctly trimmed + EXPECT_EQ(result.normal_text, "Previous text"); + + // Verify tool call is correctly parsed + EXPECT_EQ(result.calls.size(), 1); + EXPECT_EQ(result.calls[0].tool_index, -1); // Base class always returns -1 +} + +// Test regex pattern matching edge cases +TEST_F(DeepSeekV3DetectorTest, RegexPatternEdgeCases) { + // Test with newlines in function name (should fail) + std::string text1 = + "Test " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current\nweather\n```json\n{\"location\": " + "\"Beijing\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + auto result1 = detector_->detect_and_parse(text1, tools_); + EXPECT_EQ(result1.calls.size(), 0); // Should fail to match + + // Test with missing json markers + std::string text2 = + "Test " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_" + "current_weather\n{\"location\": " + "\"Beijing\"}\n<|tool▁call▁end|><|tool▁calls▁end|>"; + auto result2 = detector_->detect_and_parse(text2, tools_); + EXPECT_EQ(result2.calls.size(), + 0); // Should fail to match without ```json``` markers +} + +// Performance test: multiple tool calls +TEST_F(DeepSeekV3DetectorTest, PerformanceWithMultipleToolCalls) { + std::string text = "Performance test"; + + // Build text containing multiple tool calls + for (int i = 0; i < 10000; ++i) { + text += + " <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>" + "calculate\n```json\n{\"expression\": \"" + + std::to_string(i) + " + " + std::to_string(i + 1) + + "\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + } + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Performance test"); + EXPECT_EQ(result.calls.size(), 10000); + + // Verify each tool call is correctly parsed + for (int i = 0; i < 10000; ++i) { + const auto& call = result.calls[i]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + EXPECT_EQ(call.name.value(), "calculate"); + + nlohmann::json params = nlohmann::json::parse(call.parameters); + std::string expected_expr = + std::to_string(i) + " + " + std::to_string(i + 1); + EXPECT_EQ(params["expression"], expected_expr); + } +} + +// Test error handling with malformed tokens +TEST_F(DeepSeekV3DetectorTest, MalformedTokensHandling) { + // Test with incomplete start token + std::string text1 = + "Test " + "<|tool▁calls▁begi><|tool▁call▁begin|>function<|tool▁sep|>test\n```" + "json\n{}\n```<|tool▁call▁end|><|tool▁calls▁end|>"; + auto result1 = detector_->detect_and_parse(text1, tools_); + EXPECT_EQ(result1.normal_text, + "Test " + "<|tool▁calls▁begi><|tool▁call▁begin|>function<|tool▁sep|>" + "test\n```json\n{}\n```<|tool▁call▁end|><|tool▁calls▁end|>"); + EXPECT_EQ(result1.calls.size(), 0); + + // Test with incomplete end token + std::string text2 = + "Test " + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test\n`" + "``json\n{}\n```<|tool▁call▁end|><|tool▁calls▁en|>"; + auto result2 = detector_->detect_and_parse(text2, tools_); + EXPECT_EQ(result2.normal_text, "Test"); + EXPECT_EQ(result2.calls.size(), 0); // Should not match incomplete pattern +} + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/function_call.h b/xllm/function_call/function_call.h new file mode 100644 index 00000000..015f614a --- /dev/null +++ b/xllm/function_call/function_call.h @@ -0,0 +1,25 @@ +#pragma once + +#include "base_format_detector.h" +#include "core_types.h" +#include "deepseekv3_detector.h" +#include "function_call_parser.h" +#include "kimik2_detector.h" +#include "qwen25_detector.h" + +namespace xllm { +namespace function_call { + +inline std::vector parse(const std::string& text, + const std::vector& tools, + const std::string& format = "qwen25") { + return utils::parse_function_calls(text, tools, format); +} + +inline bool has_calls(const std::string& text, + const std::string& format = "qwen25") { + return utils::has_function_calls(text, format); +} + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/function_call_parser.cpp b/xllm/function_call/function_call_parser.cpp new file mode 100644 index 00000000..b4643973 --- /dev/null +++ b/xllm/function_call/function_call_parser.cpp @@ -0,0 +1,123 @@ +#include "function_call_parser.h" + +#include +#include + +#include "core/util/uuid.h" +#include "deepseekv3_detector.h" +#include "kimik2_detector.h" +#include "qwen25_detector.h" +namespace xllm { +namespace function_call { + +const std::unordered_map + FunctionCallParser::ToolCallParserEnum = { + {"qwen25", "qwen25"}, + {"qwen3", "qwen25"}, + {"kimi_k2", "kimi_k2"}, + {"deepseekv3", "deepseekv3"}, + // TODO + // {"llama3", "llama3"}, + // {"mistral", "mistral"}, + // {"pythonic", "pythonic"}, + // {"qwen3_coder", "qwen3_coder"}, + // {"glm45", "glm45"}, + // {"step3", "step3"}, +}; + +FunctionCallParser::FunctionCallParser(const std::vector& tools, + const std::string& tool_call_parser) + : tools_(tools) { + detector_ = create_detector(tool_call_parser); + CHECK(detector_ != nullptr) + << "Unsupported tool_call_parser: " << tool_call_parser + << ". Supported parsers are: " << [this]() { + std::string supported; + for (const auto& [key, value] : ToolCallParserEnum) { + if (!supported.empty()) supported += ", "; + supported += key; + } + return supported; + }(); +} + +bool FunctionCallParser::has_tool_call(const std::string& text) const { + return detector_->has_tool_call(text); +} + +std::tuple> +FunctionCallParser::parse_non_stream(const std::string& full_text) { + StreamingParseResult parsed_result = + detector_->detect_and_parse(full_text, tools_); + + if (!parsed_result.calls.empty()) { + return std::make_tuple(parsed_result.normal_text, parsed_result.calls); + } else { + return std::make_tuple(full_text, std::vector()); + } +} + +std::unique_ptr FunctionCallParser::create_detector( + const std::string& tool_call_parser) { + auto it = ToolCallParserEnum.find(tool_call_parser); + if (it == ToolCallParserEnum.end()) { + return nullptr; + } + + if (it->second == "qwen25") { + return std::make_unique(); + } + + if (it->second == "kimi_k2") { + return std::make_unique(); + } + + if (it->second == "deepseekv3") { + return std::make_unique(); + } + + // if (tool_call_parser == "llama3") { + // return std::make_unique(); + // } + // if (tool_call_parser == "mistral") { + // return std::make_unique(); + // } + + return nullptr; +} + +namespace utils { + +std::vector parse_function_calls( + const std::string& text, + const std::vector& tools, + const std::string& parser_type) { + try { + FunctionCallParser parser(tools, parser_type); + auto [normal_text, calls] = parser.parse_non_stream(text); + return calls; + } catch (const std::exception& e) { + LOG(ERROR) << "Error parsing function calls: " << e.what(); + return {}; + } +} + +bool has_function_calls(const std::string& text, + const std::string& parser_type) { + try { + FunctionCallParser parser({}, parser_type); + return parser.has_tool_call(text); + } catch (const std::exception& e) { + LOG(ERROR) << "Error checking function calls: " << e.what(); + return false; + } +} + +thread_local ShortUUID short_uuid; + +std::string generate_tool_call_id() { return "call_" + short_uuid.random(); } + +} // namespace utils + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/function_call_parser.h b/xllm/function_call/function_call_parser.h new file mode 100644 index 00000000..920b2a64 --- /dev/null +++ b/xllm/function_call/function_call_parser.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "base_format_detector.h" +#include "core_types.h" + +namespace xllm { +namespace function_call { + +class FunctionCallParser { + public: + static const std::unordered_map ToolCallParserEnum; + + private: + std::unique_ptr detector_; + std::vector tools_; + + public: + FunctionCallParser(const std::vector& tools, + const std::string& tool_call_parser); + + ~FunctionCallParser() = default; + + FunctionCallParser(const FunctionCallParser&) = delete; + FunctionCallParser& operator=(const FunctionCallParser&) = delete; + + bool has_tool_call(const std::string& text) const; + + std::tuple> parse_non_stream( + const std::string& full_text); + + // StructuralTagResponseFormat get_structure_tag(); + + // std::tuple get_structure_constraint(const + // std::string& tool_choice); + + BaseFormatDetector* get_detector() const { return detector_.get(); } + + private: + std::unique_ptr create_detector( + const std::string& tool_call_parser); +}; + +namespace utils { + +std::vector parse_function_calls( + const std::string& text, + const std::vector& tools, + const std::string& parser_type = "qwen25"); + +bool has_function_calls(const std::string& text, + const std::string& parser_type = "qwen25"); + +std::string generate_tool_call_id(); +} // namespace utils + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/kimik2_detector.cpp b/xllm/function_call/kimik2_detector.cpp new file mode 100644 index 00000000..e79e15ba --- /dev/null +++ b/xllm/function_call/kimik2_detector.cpp @@ -0,0 +1,152 @@ +#include "kimik2_detector.h" + +#include +#include +#include + +namespace xllm { +namespace function_call { + +KimiK2Detector::KimiK2Detector() : BaseFormatDetector() { + // Initialize KimiK2 specific tokens + bot_token_ = "<|tool_calls_section_begin|>"; + eot_token_ = "<|tool_calls_section_end|>"; + tool_call_start_token_ = "<|tool_call_begin|>"; + tool_call_end_token_ = "<|tool_call_end|>"; + tool_call_argument_begin_token_ = "<|tool_call_argument_begin|>"; + + // Regex pattern for parsing tool calls with the following format: + // <|tool_call_begin|>functions.{func_name}:{index} + // <|tool_call_argument_begin|>{json_args}<|tool_call_end|> + // Note: C++ regex doesn't support named groups, so we use numbered groups: + // Group 1: tool_call_id (functions.{func_name}:{index}) + // Group 2: function_arguments ({json_args}) + std::string pattern = + R"(<\|tool_call_begin\|>\s*([\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(\{.*?\})\s*<\|tool_call_end\|>)"; + + try { + tool_call_regex_ = std::regex(pattern, std::regex_constants::ECMAScript); + } catch (const std::regex_error& e) { + LOG(ERROR) << "Failed to compile KimiK2 regex pattern: " << e.what(); + throw; + } + + last_arguments_ = ""; +} + +bool KimiK2Detector::has_tool_call(const std::string& text) { + // Check if the text contains the KimiK2 tool call section begin token + return text.find(bot_token_) != std::string::npos; +} + +StreamingParseResult KimiK2Detector::detect_and_parse( + const std::string& text, + const std::vector& tools) { + size_t bot_pos = text.find(bot_token_); + + std::string normal_text = + (bot_pos != std::string::npos) ? text.substr(0, bot_pos) : text; + + if (bot_pos == std::string::npos) { + return StreamingParseResult(normal_text); + } + + std::vector calls; + + try { + std::sregex_iterator iter( + text.begin() + bot_pos, text.end(), tool_call_regex_); + std::sregex_iterator end; + + for (; iter != end; ++iter) { + std::smatch match = *iter; + + if (match.size() >= 3) { + std::string tool_call_id = match[1].str(); + std::string function_arguments = match[2].str(); + + std::string function_name = extract_function_name(tool_call_id); + int function_index = extract_function_index(tool_call_id); + + calls.emplace_back(function_index, // Use the call index in the + // response, not tool position + function_name, // Function name + function_arguments // JSON parameters + ); + } + } + + return StreamingParseResult(normal_text, calls); + + } catch (const std::exception& e) { + LOG(ERROR) << "Error in KimiK2 detect_and_parse: " << e.what(); + // Return the normal text if parsing fails + return StreamingParseResult(normal_text); + } +} + +std::string KimiK2Detector::extract_function_name( + const std::string& tool_call_id) const { + // tool_call_id format: functions.{func_name}:{index} + // Example: functions.get_weather:0 + + try { + // Find the position of "functions." + size_t functions_pos = tool_call_id.find("functions."); + if (functions_pos == std::string::npos) { + LOG(WARNING) + << "Invalid tool_call_id format, missing 'functions.' prefix: " + << tool_call_id; + return ""; + } + + // Skip "functions." (10 characters) + size_t start_pos = functions_pos + 10; + + // Find the position of the last colon + size_t colon_pos = tool_call_id.find_last_of(':'); + if (colon_pos == std::string::npos || colon_pos <= start_pos) { + LOG(WARNING) << "Invalid tool_call_id format, missing ':' separator: " + << tool_call_id; + return ""; + } + + // Extract function name between "functions." and ":" + return tool_call_id.substr(start_pos, colon_pos - start_pos); + + } catch (const std::exception& e) { + LOG(ERROR) << "Error extracting function name from tool_call_id: " + << tool_call_id << ", error: " << e.what(); + return ""; + } +} + +int KimiK2Detector::extract_function_index( + const std::string& tool_call_id) const { + // tool_call_id format: functions.{func_name}:{index} + // Example: functions.get_weather:0 + + try { + // Find the position of the last colon + size_t colon_pos = tool_call_id.find_last_of(':'); + if (colon_pos == std::string::npos) { + LOG(WARNING) << "Invalid tool_call_id format, missing ':' separator: " + << tool_call_id; + return 0; + } + + // Extract index string after the colon + std::string index_str = tool_call_id.substr(colon_pos + 1); + + // Convert to integer + return std::stoi(index_str); + + } catch (const std::exception& e) { + LOG(ERROR) << "Error extracting function index from tool_call_id: " + << tool_call_id << ", error: " << e.what(); + return 0; + } +} + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/kimik2_detector.h b/xllm/function_call/kimik2_detector.h new file mode 100644 index 00000000..9cfa6627 --- /dev/null +++ b/xllm/function_call/kimik2_detector.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +#include "base_format_detector.h" + +namespace xllm { +namespace function_call { + +/** + * Detector for Kimi K2 model function call format. + * + * Format Structure: + * ``` + * <|tool_calls_section_begin|> + * <|tool_call_begin|>functions.{func_name}:{index} + * <|tool_call_argument_begin|>{json_args}<|tool_call_end|> + * <|tool_calls_section_end|> + * ``` + * + * Reference: + * https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md + */ +class KimiK2Detector : public BaseFormatDetector { + public: + KimiK2Detector(); + + virtual ~KimiK2Detector() = default; + + private: + std::string tool_call_start_token_; + std::string tool_call_end_token_; + std::string tool_call_argument_begin_token_; + + std::regex tool_call_regex_; + + std::string last_arguments_; + + public: + bool has_tool_call(const std::string& text) override; + + StreamingParseResult detect_and_parse( + const std::string& text, + const std::vector& tools) override; + + private: + std::string extract_function_name(const std::string& tool_call_id) const; + + int extract_function_index(const std::string& tool_call_id) const; +}; + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/kimik2_detector_test.cpp b/xllm/function_call/kimik2_detector_test.cpp new file mode 100644 index 00000000..024183da --- /dev/null +++ b/xllm/function_call/kimik2_detector_test.cpp @@ -0,0 +1,455 @@ +#include "kimik2_detector.h" + +#include + +#include +#include +#include + +namespace xllm { +namespace function_call { + +class KimiK2DetectorTest : public ::testing::Test { + protected: + void SetUp() override { + detector_ = std::make_unique(); + + // Setup test tools + nlohmann::json weather_params = { + {"type", "object"}, + {"properties", + {{"location", + {{"type", "string"}, + {"description", "The city and state, e.g. San Francisco, CA"}}}, + {"unit", {{"type", "string"}, {"enum", {"celsius", "fahrenheit"}}}}}}, + {"required", {"location"}}}; + + JsonFunction weather_func("get_current_weather", + "Get the current weather in a given location", + weather_params); + weather_tool_ = JsonTool("function", weather_func); + + nlohmann::json calculator_params = { + {"type", "object"}, + {"properties", + {{"expression", + {{"type", "string"}, + {"description", "Mathematical expression to evaluate"}}}}}, + {"required", {"expression"}}}; + + JsonFunction calculator_func( + "calculate", "Calculate mathematical expressions", calculator_params); + calculator_tool_ = JsonTool("function", calculator_func); + + tools_ = {weather_tool_, calculator_tool_}; + } + + std::unique_ptr detector_; + JsonTool weather_tool_; + JsonTool calculator_tool_; + std::vector tools_; +}; + +// Test constructor and basic properties +TEST_F(KimiK2DetectorTest, ConstructorInitializesCorrectly) { + EXPECT_NE(detector_, nullptr); + + // Test basic token detection + std::string text_with_tool_call = + "Some text " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.test:0 " + "<|tool_call_argument_begin|>{\"param\": " + "\"value\"}<|tool_call_end|><|tool_calls_section_end|>"; + std::string text_without_tool_call = + "Just normal text without any tool calls"; + + EXPECT_TRUE(detector_->has_tool_call(text_with_tool_call)); + EXPECT_FALSE(detector_->has_tool_call(text_without_tool_call)); +} + +// Test has_tool_call method +TEST_F(KimiK2DetectorTest, HasToolCallDetection) { + // Test text containing tool calls + EXPECT_TRUE(detector_->has_tool_call("<|tool_calls_section_begin|>")); + EXPECT_TRUE(detector_->has_tool_call( + "Previous text <|tool_calls_section_begin|>Following content")); + EXPECT_TRUE(detector_->has_tool_call( + "<|tool_calls_section_begin|><|tool_call_begin|>functions.test:0 " + "<|tool_call_argument_begin|>{\"param\": " + "\"value\"}<|tool_call_end|><|tool_calls_section_end|>")); + + // Test text not containing tool calls + EXPECT_FALSE(detector_->has_tool_call("")); + EXPECT_FALSE(detector_->has_tool_call("Regular text")); + EXPECT_FALSE( + detector_->has_tool_call("tool_calls_section_begin without brackets")); + EXPECT_FALSE( + detector_->has_tool_call("<|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": " + "\"Beijing\"}<|tool_call_end|><|tool_calls_section_end|> \t\r\n"; + + auto result = detector_->detect_and_parse(text_with_whitespace, tools_); + + // Verify normal text is correctly extracted + EXPECT_EQ(result.normal_text, " \t\nPrevious text\r\n "); + + // Verify tool call is correctly parsed + EXPECT_EQ(result.calls.size(), 1); + EXPECT_EQ(result.calls[0].tool_index, 0); +} + +// Test single tool call parsing +TEST_F(KimiK2DetectorTest, SingleToolCallParsing) { + std::string text = + "Please help me check the weather " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": \"Beijing\", " + "\"unit\": \"celsius\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Please help me check the weather "); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, 0); + EXPECT_TRUE(call.name.has_value()); + EXPECT_EQ(call.name.value(), "get_current_weather"); + + // Verify parameter JSON + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_EQ(params["unit"], "celsius"); +} + +// Test multiple tool calls parsing +TEST_F(KimiK2DetectorTest, MultipleToolCallsParsing) { + std::string text = + "Please help me check the weather and calculate an expression " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": " + "\"Shanghai\"}<|tool_call_end|><|tool_call_begin|>functions.calculate:1 " + "<|tool_call_argument_begin|>{\"expression\": \"2 + 3 * " + "4\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, + "Please help me check the weather and calculate an expression "); + EXPECT_EQ(result.calls.size(), 2); + + // Verify first tool call + const auto& call1 = result.calls[0]; + EXPECT_EQ(call1.tool_index, 0); + EXPECT_TRUE(call1.name.has_value()); + EXPECT_EQ(call1.name.value(), "get_current_weather"); + + nlohmann::json params1 = nlohmann::json::parse(call1.parameters); + EXPECT_EQ(params1["location"], "Shanghai"); + + // Verify second tool call + const auto& call2 = result.calls[1]; + EXPECT_EQ(call2.tool_index, 1); + EXPECT_TRUE(call2.name.has_value()); + EXPECT_EQ(call2.name.value(), "calculate"); + + nlohmann::json params2 = nlohmann::json::parse(call2.parameters); + EXPECT_EQ(params2["expression"], "2 + 3 * 4"); +} + +// Test invalid JSON handling +TEST_F(KimiK2DetectorTest, InvalidJsonHandling) { + std::string text = + "Test invalid JSON " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": \"Beijing\", " + "invalid_json}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Test invalid JSON "); + // KimiK2 detector should still parse the call even with invalid JSON, leaving + // JSON validation to higher levels + EXPECT_EQ(result.calls.size(), 1); + EXPECT_EQ(result.calls[0].name.value(), "get_current_weather"); +} + +// Test empty tool call content +TEST_F(KimiK2DetectorTest, EmptyToolCallContent) { + std::string text = + "Test empty content " + "<|tool_calls_section_begin|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Test empty content "); + EXPECT_EQ(result.calls.size(), 0); // Empty content should be ignored +} + +// Test incomplete tool call (only start tag) +TEST_F(KimiK2DetectorTest, IncompleteToolCall) { + std::string text = + "Incomplete tool call " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Incomplete tool call "); + EXPECT_EQ(result.calls.size(), 0); // Incomplete calls should be ignored +} + +// Test unknown tool name handling +TEST_F(KimiK2DetectorTest, UnknownToolName) { + std::string text = + "Unknown tool " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.unknown_tool:0 " + "<|tool_call_argument_begin|>{\"param\": " + "\"value\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Unknown tool "); + // KimiK2 detector should parse the call regardless of whether the tool is + // known + EXPECT_EQ(result.calls.size(), 1); + EXPECT_EQ(result.calls[0].name.value(), "unknown_tool"); +} + +// Test case with only normal text +TEST_F(KimiK2DetectorTest, OnlyNormalText) { + std::string text = "This is a regular text without any tool calls."; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, + "This is a regular text without any tool calls."); + EXPECT_EQ(result.calls.size(), 0); + EXPECT_FALSE(result.has_calls()); +} + +// Test empty string input +TEST_F(KimiK2DetectorTest, EmptyStringInput) { + std::string text = ""; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, ""); + EXPECT_EQ(result.calls.size(), 0); + EXPECT_FALSE(result.has_calls()); +} + +// Test whitespace-only input +TEST_F(KimiK2DetectorTest, WhitespaceOnlyInput) { + std::string text = " \t\n\r "; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, " \t\n\r "); + EXPECT_EQ(result.calls.size(), 0); +} + +// Test complex nested JSON parameters +TEST_F(KimiK2DetectorTest, ComplexNestedJsonParameters) { + std::string text = + "Complex parameter test " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": \"Beijing\", " + "\"options\": {\"include_forecast\": true, \"days\": 7, \"details\": " + "[\"temperature\", \"humidity\", " + "\"wind\"]}}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Complex parameter test "); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, 0); + + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_TRUE(params["options"]["include_forecast"]); + EXPECT_EQ(params["options"]["days"], 7); + EXPECT_EQ(params["options"]["details"].size(), 3); +} + +// Test tool call in the middle of text +TEST_F(KimiK2DetectorTest, ToolCallInMiddleOfText) { + std::string text = + "Previous text " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.calculate:0 " + "<|tool_call_argument_begin|>{\"expression\": " + "\"1+1\"}<|tool_call_end|><|tool_calls_section_end|> Following text"; + + auto result = detector_->detect_and_parse(text, tools_); + + // Note: According to KimiK2 implementation, only text before tool call + // section is preserved as normal_text + EXPECT_EQ(result.normal_text, "Previous text "); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, 0); + EXPECT_EQ(call.name.value(), "calculate"); +} + +// Test special characters handling +TEST_F(KimiK2DetectorTest, SpecialCharactersHandling) { + std::string text = + "Special characters test " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": \"New York City\", " + "\"note\": \"Contains " + "symbols!@#$%^&*()_+=\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Special characters test "); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, 0); + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "New York City"); + EXPECT_EQ(params["note"], "Contains symbols!@#$%^&*()_+="); +} + +// Test function name extraction +TEST_F(KimiK2DetectorTest, FunctionNameExtraction) { + std::string text = + "Function name test " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.my_custom_" + "function:5 <|tool_call_argument_begin|>{\"param\": " + "\"value\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Function name test "); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, 5); // Should extract the index correctly + EXPECT_EQ(call.name.value(), "my_custom_function"); +} + +// Test malformed function ID handling +TEST_F(KimiK2DetectorTest, MalformedFunctionIdHandling) { + std::string text = + "Malformed ID test " + "<|tool_calls_section_begin|><|tool_call_begin|>invalid_format " + "<|tool_call_argument_begin|>{\"param\": " + "\"value\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Malformed ID test "); + // Malformed format that doesn't match regex should result in no calls + EXPECT_EQ(result.calls.size(), 0); +} + +// Test malformed function ID that matches regex but has invalid format +TEST_F(KimiK2DetectorTest, MalformedButMatchingFunctionId) { + std::string text = + "Malformed but matching test " + "<|tool_calls_section_begin|><|tool_call_begin|>invalid.format:0 " + "<|tool_call_argument_begin|>{\"param\": " + "\"value\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Malformed but matching test "); + // Should parse but with empty function name due to missing "functions." + // prefix + EXPECT_EQ(result.calls.size(), 1); + const auto& call = result.calls[0]; + EXPECT_TRUE(call.name.has_value()); + EXPECT_EQ(call.name.value(), ""); // Empty function name for malformed ID + EXPECT_EQ(call.tool_index, 0); +} + +// Test multiple sections (edge case) +TEST_F(KimiK2DetectorTest, MultipleSections) { + std::string text = + "First section " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": " + "\"Beijing\"}<|tool_call_end|><|tool_calls_section_end|> Middle text " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.calculate:1 " + "<|tool_call_argument_begin|>{\"expression\": " + "\"1+1\"}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + // Should extract text before first section + EXPECT_EQ(result.normal_text, "First section "); + // Should parse all tool calls from all sections + EXPECT_EQ(result.calls.size(), 2); + + EXPECT_EQ(result.calls[0].name.value(), "get_current_weather"); + EXPECT_EQ(result.calls[1].name.value(), "calculate"); +} + +// Performance test: many tool calls +TEST_F(KimiK2DetectorTest, PerformanceWithManyToolCalls) { + std::string text = "Performance test <|tool_calls_section_begin|>"; + + // Build text containing multiple tool calls + for (int i = 0; i < 10000; ++i) { + text += "<|tool_call_begin|>functions.calculate:" + std::to_string(i) + + " <|tool_call_argument_begin|>{\"expression\": \"" + + std::to_string(i) + " + " + std::to_string(i + 1) + + "\"}<|tool_call_end|>"; + } + text += "<|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Performance test "); + EXPECT_EQ(result.calls.size(), 10000); + + // Verify each tool call is correctly parsed + for (int i = 0; i < 10000; ++i) { + const auto& call = result.calls[i]; + EXPECT_EQ(call.tool_index, i); + EXPECT_EQ(call.name.value(), "calculate"); + + nlohmann::json params = nlohmann::json::parse(call.parameters); + std::string expected_expr = + std::to_string(i) + " + " + std::to_string(i + 1); + EXPECT_EQ(params["expression"], expected_expr); + } +} + +// Test edge case: nested braces in JSON +TEST_F(KimiK2DetectorTest, NestedBracesInJson) { + std::string text = + "Nested braces test " + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_" + "weather:0 <|tool_call_argument_begin|>{\"location\": \"Beijing\", " + "\"config\": {\"nested\": {\"deep\": " + "\"value\"}}}<|tool_call_end|><|tool_calls_section_end|>"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Nested braces test "); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, 0); + EXPECT_EQ(call.name.value(), "get_current_weather"); + + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_EQ(params["config"]["nested"]["deep"], "value"); +} + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/qwen25_detector.cpp b/xllm/function_call/qwen25_detector.cpp new file mode 100644 index 00000000..224820bf --- /dev/null +++ b/xllm/function_call/qwen25_detector.cpp @@ -0,0 +1,116 @@ +#include "qwen25_detector.h" + +#include +#include +#include + +namespace xllm { +namespace function_call { + +Qwen25Detector::Qwen25Detector() : BaseFormatDetector() { + bot_token_ = "\n"; + eot_token_ = "\n"; + tool_call_separator_ = "\n"; + + std::string pattern = bot_token_ + "([\\s\\S]*?)" + eot_token_; + tool_call_regex_ = std::regex( + pattern, + std::regex_constants::ECMAScript | std::regex_constants::optimize); +} + +bool Qwen25Detector::has_tool_call(const std::string& text) { + return text.find(bot_token_) != std::string::npos; +} + +std::string_view Qwen25Detector::trim_whitespace(std::string_view str) const { + const char* whitespace = " \t\n\r"; + + size_t start = str.find_first_not_of(whitespace); + if (start == std::string_view::npos) { + return std::string_view{}; + } + + size_t end = str.find_last_not_of(whitespace); + + return str.substr(start, end - start + 1); +} + +std::vector> Qwen25Detector::find_tool_call_ranges( + const std::string& text) const { + std::vector> ranges; + ranges.reserve(4); + + size_t search_pos = 0; + const size_t bot_token_len = bot_token_.length(); + const size_t eot_token_len = eot_token_.length(); + + while (search_pos < text.length()) { + size_t start_pos = text.find(bot_token_, search_pos); + if (start_pos == std::string::npos) { + break; + } + + size_t content_start = start_pos + bot_token_len; + size_t end_pos = text.find(eot_token_, content_start); + if (end_pos == std::string::npos) { + break; + } + + ranges.emplace_back(content_start, end_pos); + search_pos = end_pos + eot_token_len; + } + + return ranges; +} + +StreamingParseResult Qwen25Detector::detect_and_parse( + const std::string& text, + const std::vector& tools) { + size_t bot_token_pos = text.find(bot_token_); + + std::string normal_text; + if (bot_token_pos != std::string::npos) { + std::string_view normal_text_view(text.data(), bot_token_pos); + std::string_view trimmed = trim_whitespace(normal_text_view); + normal_text = std::string(trimmed); + } else { + std::string_view trimmed = trim_whitespace(text); + normal_text = std::string(trimmed); + return StreamingParseResult(normal_text); + } + + auto tool_call_ranges = find_tool_call_ranges(text); + + std::vector calls; + calls.reserve(tool_call_ranges.size()); + + for (const auto& range : tool_call_ranges) { + std::string_view content_view(text.data() + range.first, + range.second - range.first); + std::string_view trimmed_content = trim_whitespace(content_view); + + if (trimmed_content.empty()) { + continue; + } + + try { + std::string json_content(trimmed_content); + auto json_obj = nlohmann::json::parse(json_content); + auto parsed_calls = parse_base_json(json_obj, tools); + + calls.insert(calls.end(), + std::make_move_iterator(parsed_calls.begin()), + std::make_move_iterator(parsed_calls.end())); + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to parse JSON part: " + << std::string(trimmed_content) + << ", JSON parse error: " << e.what(); + continue; + } + } + + return StreamingParseResult(std::move(normal_text), std::move(calls)); +} + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/qwen25_detector.h b/xllm/function_call/qwen25_detector.h new file mode 100644 index 00000000..9f40e959 --- /dev/null +++ b/xllm/function_call/qwen25_detector.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +#include "base_format_detector.h" + +namespace xllm { +namespace function_call { + +class Qwen25Detector : public BaseFormatDetector { + public: + Qwen25Detector(); + + virtual ~Qwen25Detector() = default; + + private: + std::string normal_text_buffer_; + + std::regex tool_call_regex_; + + std::string_view trim_whitespace(std::string_view str) const; + + std::vector> find_tool_call_ranges( + const std::string& text) const; + + public: + bool has_tool_call(const std::string& text) override; + + StreamingParseResult detect_and_parse( + const std::string& text, + const std::vector& tools) override; +}; + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/function_call/qwen25_detector_test.cpp b/xllm/function_call/qwen25_detector_test.cpp new file mode 100644 index 00000000..732bb69c --- /dev/null +++ b/xllm/function_call/qwen25_detector_test.cpp @@ -0,0 +1,328 @@ +#include "qwen25_detector.h" + +#include + +#include +#include +#include + +namespace xllm { +namespace function_call { + +class Qwen25DetectorTest : public ::testing::Test { + protected: + void SetUp() override { + detector_ = std::make_unique(); + + // Setup test tools + nlohmann::json weather_params = { + {"type", "object"}, + {"properties", + {{"location", + {{"type", "string"}, + {"description", "The city and state, e.g. San Francisco, CA"}}}, + {"unit", {{"type", "string"}, {"enum", {"celsius", "fahrenheit"}}}}}}, + {"required", {"location"}}}; + + JsonFunction weather_func("get_current_weather", + "Get the current weather in a given location", + weather_params); + weather_tool_ = JsonTool("function", weather_func); + + nlohmann::json calculator_params = { + {"type", "object"}, + {"properties", + {{"expression", + {{"type", "string"}, + {"description", "Mathematical expression to evaluate"}}}}}, + {"required", {"expression"}}}; + + JsonFunction calculator_func( + "calculate", "Calculate mathematical expressions", calculator_params); + calculator_tool_ = JsonTool("function", calculator_func); + + tools_ = {weather_tool_, calculator_tool_}; + } + + std::unique_ptr detector_; + JsonTool weather_tool_; + JsonTool calculator_tool_; + std::vector tools_; +}; + +// Test constructor and basic properties +TEST_F(Qwen25DetectorTest, ConstructorInitializesCorrectly) { + EXPECT_NE(detector_, nullptr); + + // Test basic token detection + std::string text_with_tool_call = + "Some text \n{\"name\": \"test\"}\n"; + std::string text_without_tool_call = + "Just normal text without any tool calls"; + + EXPECT_TRUE(detector_->has_tool_call(text_with_tool_call)); + EXPECT_FALSE(detector_->has_tool_call(text_without_tool_call)); +} + +// Test has_tool_call method +TEST_F(Qwen25DetectorTest, HasToolCallDetection) { + // Test text containing tool calls + EXPECT_TRUE(detector_->has_tool_call("\n")); + EXPECT_TRUE( + detector_->has_tool_call("Previous text \nFollowing content")); + EXPECT_TRUE(detector_->has_tool_call( + "\n{\"name\": \"test\"}\n")); + + // Test text not containing tool calls + EXPECT_FALSE(detector_->has_tool_call("")); + EXPECT_FALSE(detector_->has_tool_call("Regular text")); + EXPECT_FALSE(detector_->has_tool_call("tool_call without brackets")); + EXPECT_FALSE(detector_->has_tool_call("\n {\"name\": " + "\"get_current_weather\", \"arguments\": {\"location\": \"Beijing\"}} " + "\n \t\r\n"; + + auto result = detector_->detect_and_parse(text_with_whitespace, tools_); + + // Verify normal text is correctly trimmed + EXPECT_EQ(result.normal_text, "Previous text"); + + // Verify tool call is correctly parsed + EXPECT_EQ(result.calls.size(), 1); + EXPECT_EQ(result.calls[0].tool_index, -1); // Base class always returns -1 +} + +// Test single tool call parsing +TEST_F(Qwen25DetectorTest, SingleToolCallParsing) { + std::string text = + "Please help me check the weather \n{\"name\": " + "\"get_current_weather\", \"arguments\": {\"location\": \"Beijing\", " + "\"unit\": \"celsius\"}}\n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Please help me check the weather"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + EXPECT_TRUE(call.name.has_value()); + EXPECT_EQ(call.name.value(), "get_current_weather"); + + // Verify parameter JSON + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_EQ(params["unit"], "celsius"); +} + +// Test multiple tool calls parsing +TEST_F(Qwen25DetectorTest, MultipleToolCallsParsing) { + std::string text = + "Please help me check the weather and calculate an expression " + "\n{\"name\": \"get_current_weather\", \"arguments\": " + "{\"location\": \"Shanghai\"}}\n\n\n{\"name\": " + "\"calculate\", \"arguments\": {\"expression\": \"2 + 3 * " + "4\"}}\n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, + "Please help me check the weather and calculate an expression"); + EXPECT_EQ(result.calls.size(), 2); + + // Verify first tool call + const auto& call1 = result.calls[0]; + EXPECT_EQ(call1.tool_index, -1); // Base class always returns -1 + EXPECT_TRUE(call1.name.has_value()); + EXPECT_EQ(call1.name.value(), "get_current_weather"); + + nlohmann::json params1 = nlohmann::json::parse(call1.parameters); + EXPECT_EQ(params1["location"], "Shanghai"); + + // Verify second tool call + const auto& call2 = result.calls[1]; + EXPECT_EQ(call2.tool_index, -1); // Base class always returns -1 + EXPECT_TRUE(call2.name.has_value()); + EXPECT_EQ(call2.name.value(), "calculate"); + + nlohmann::json params2 = nlohmann::json::parse(call2.parameters); + EXPECT_EQ(params2["expression"], "2 + 3 * 4"); +} + +// Test invalid JSON handling +TEST_F(Qwen25DetectorTest, InvalidJsonHandling) { + std::string text = + "Test invalid JSON \n{\"name\": \"get_current_weather\", " + "\"arguments\": {\"location\": \"Beijing\", invalid_json}}\n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Test invalid JSON"); + EXPECT_EQ(result.calls.size(), 0); // Invalid JSON should be ignored +} + +// Test empty tool call content +TEST_F(Qwen25DetectorTest, EmptyToolCallContent) { + std::string text = "Test empty content \n \t\n \n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Test empty content"); + EXPECT_EQ(result.calls.size(), 0); // Empty content should be ignored +} + +// Test incomplete tool call (only start tag) +TEST_F(Qwen25DetectorTest, IncompleteToolCall) { + std::string text = + "Incomplete tool call \n{\"name\": \"get_current_weather\""; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Incomplete tool call"); + EXPECT_EQ(result.calls.size(), 0); // Incomplete calls should be ignored +} + +// Test unknown tool name handling +TEST_F(Qwen25DetectorTest, UnknownToolName) { + std::string text = + "Unknown tool \n{\"name\": \"unknown_tool\", \"arguments\": " + "{\"param\": \"value\"}}\n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Unknown tool"); + // Base class will skip unknown tools, so should be 0 calls + EXPECT_EQ(result.calls.size(), 0); +} + +// Test case with only normal text +TEST_F(Qwen25DetectorTest, OnlyNormalText) { + std::string text = "This is a regular text without any tool calls."; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, + "This is a regular text without any tool calls."); + EXPECT_EQ(result.calls.size(), 0); + EXPECT_FALSE(result.has_calls()); +} + +// Test empty string input +TEST_F(Qwen25DetectorTest, EmptyStringInput) { + std::string text = ""; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, ""); + EXPECT_EQ(result.calls.size(), 0); + EXPECT_FALSE(result.has_calls()); +} + +// Test whitespace-only input +TEST_F(Qwen25DetectorTest, WhitespaceOnlyInput) { + std::string text = " \t\n\r "; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, ""); + EXPECT_EQ(result.calls.size(), 0); +} + +// Test complex nested JSON parameters +TEST_F(Qwen25DetectorTest, ComplexNestedJsonParameters) { + std::string text = + "Complex parameter test \n{\"name\": \"get_current_weather\", " + "\"arguments\": {\"location\": \"Beijing\", \"options\": " + "{\"include_forecast\": true, \"days\": 7, \"details\": " + "[\"temperature\", \"humidity\", \"wind\"]}}}\n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Complex parameter test"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "Beijing"); + EXPECT_TRUE(params["options"]["include_forecast"]); + EXPECT_EQ(params["options"]["days"], 7); + EXPECT_EQ(params["options"]["details"].size(), 3); +} + +// Test tool call in the middle of text +TEST_F(Qwen25DetectorTest, ToolCallInMiddleOfText) { + std::string text = + "Previous text \n{\"name\": \"calculate\", \"arguments\": " + "{\"expression\": \"1+1\"}}\n Following text"; + + auto result = detector_->detect_and_parse(text, tools_); + + // Note: According to implementation, only text before tool call is preserved + // as normal_text + EXPECT_EQ(result.normal_text, "Previous text"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + EXPECT_EQ(call.name.value(), "calculate"); +} + +// Test special characters handling +TEST_F(Qwen25DetectorTest, SpecialCharactersHandling) { + std::string text = + "Special characters test \n{\"name\": " + "\"get_current_weather\", \"arguments\": {\"location\": \"New York " + "City\", \"note\": \"Contains symbols!@#$%^&*()_+=\"}}\n"; + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Special characters test"); + EXPECT_EQ(result.calls.size(), 1); + + const auto& call = result.calls[0]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + nlohmann::json params = nlohmann::json::parse(call.parameters); + EXPECT_EQ(params["location"], "New York City"); + EXPECT_EQ(params["note"], "Contains symbols!@#$%^&*()_+="); +} + +// Performance test: many tool calls +TEST_F(Qwen25DetectorTest, PerformanceWithManyToolCalls) { + std::string text = "Performance test"; + + // Build text containing multiple tool calls + for (int i = 0; i < 10000; ++i) { + text += + " \n{\"name\": \"calculate\", \"arguments\": " + "{\"expression\": \"" + + std::to_string(i) + " + " + std::to_string(i + 1) + + "\"}}\n"; + } + + auto result = detector_->detect_and_parse(text, tools_); + + EXPECT_EQ(result.normal_text, "Performance test"); + EXPECT_EQ(result.calls.size(), 10000); + + // Verify each tool call is correctly parsed + for (int i = 0; i < 10000; ++i) { + const auto& call = result.calls[i]; + EXPECT_EQ(call.tool_index, -1); // Base class always returns -1 + EXPECT_EQ(call.name.value(), "calculate"); + + nlohmann::json params = nlohmann::json::parse(call.parameters); + std::string expected_expr = + std::to_string(i) + " + " + std::to_string(i + 1); + EXPECT_EQ(params["expression"], expected_expr); + } +} + +} // namespace function_call +} // namespace xllm \ No newline at end of file diff --git a/xllm/proto/chat.proto b/xllm/proto/chat.proto index e59ab705..c41557fa 100644 --- a/xllm/proto/chat.proto +++ b/xllm/proto/chat.proto @@ -17,6 +17,8 @@ message ChatMessage { // TODO: add function call support // FunctionCall function_call = 4; + repeated ToolCall tool_calls = 3; + optional string tool_call_id = 4; } @@ -107,6 +109,9 @@ message ChatRequest { optional string service_request_id = 25; Routing routing = 26; + + repeated Tool tools = 27; + optional string tool_choice = 28; } message ChatLogProbData { diff --git a/xllm/proto/common.proto b/xllm/proto/common.proto index 6e621339..6897f6d0 100644 --- a/xllm/proto/common.proto +++ b/xllm/proto/common.proto @@ -3,6 +3,31 @@ syntax = "proto3"; option go_package = "jd.com/jd-infer/xllm;xllm"; package xllm.proto; +import "google/protobuf/struct.proto"; + +message Function { + string name = 1; + string description = 2; + google.protobuf.Struct parameters = 3; +} + +message Tool { + string type = 1; // "function" + Function function = 2; +} + +message ToolCall { + optional uint32 index = 1; + optional string id = 2; + string type = 3; // "function" + FunctionCall function = 4; +} + +message FunctionCall { + string name = 1; + string arguments = 2; // JSON string +} + message Usage { // the number of tokens in the prompt. optional int32 prompt_tokens = 1 [json_name="prompt_tokens"]; diff --git a/xllm/proto/multimodal.proto b/xllm/proto/multimodal.proto index 3e738154..41a3cfb2 100644 --- a/xllm/proto/multimodal.proto +++ b/xllm/proto/multimodal.proto @@ -31,6 +31,9 @@ message MMChatMessage { // the content of the message. null for assistant messages with function calls. repeated MMInputData content = 2; + + repeated ToolCall tool_calls = 3; + optional string tool_call_id = 4; } // Next Id: 27 @@ -120,5 +123,9 @@ message MMChatRequest { optional string service_request_id = 25; + Routing routing = 26; + + repeated Tool tools = 27; + optional string tool_choice = 28; } diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index 018d15eb..1868d9bc 100644 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -110,7 +110,8 @@ int run() { .enable_schedule_overlap(FLAGS_enable_schedule_overlap) .kv_cache_transfer_mode(FLAGS_kv_cache_transfer_mode) .etcd_addr(FLAGS_etcd_addr) - .enable_service_routing(FLAGS_enable_service_routing); + .enable_service_routing(FLAGS_enable_service_routing) + .tool_call_parser(FLAGS_tool_call_parser); InstanceName::name()->set_name(options.instance_name().value_or(""));