Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xllm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include_directories(.)

add_subdirectory(api_service)
add_subdirectory(core)
add_subdirectory(function_call)
add_subdirectory(models)
add_subdirectory(proto)
add_subdirectory(processors)
Expand Down
1 change: 1 addition & 0 deletions xllm/api_service/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ cc_library(
proto::xllm_proto
absl::flat_hash_set
absl::random_random
:function_call
)

17 changes: 8 additions & 9 deletions xllm/api_service/api_service.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "api_service.h"

#include <glog/logging.h>
#include <google/protobuf/util/json_util.h>
#include <json2pb/json_to_pb.h>
#include <json2pb/pb_to_json.h>

Expand All @@ -15,7 +16,6 @@
#include "models.pb.h"
#include "service_impl_factory.h"
#include "xllm_metrics.h"

namespace xllm {

APIService::APIService(Master* master,
Expand Down Expand Up @@ -71,7 +71,6 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller,

auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
std::string attachment = std::move(ctrl->request_attachment().to_string());

std::string error;
auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error);
if (!st) {
Expand Down Expand Up @@ -106,10 +105,13 @@ void ChatCompletionsImpl(std::unique_ptr<Service>& service,
std::string attachment = std::move(ctrl->request_attachment().to_string());
std::string error;

auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error);
if (!st) {
ctrl->SetFailed(error);
LOG(ERROR) << "parse json to proto failed: " << error;
google::protobuf::util::JsonParseOptions options;
options.ignore_unknown_fields = true;
auto json_status =
google::protobuf::util::JsonStringToMessage(attachment, req_pb, options);
if (!json_status.ok()) {
ctrl->SetFailed(json_status.ToString());
LOG(ERROR) << "parse json to proto failed: " << json_status.ToString();
return;
}

Expand Down Expand Up @@ -175,7 +177,6 @@ void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller,

auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
std::string attachment = std::move(ctrl->request_attachment().to_string());

std::string error;
auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error);
if (!st) {
Expand Down Expand Up @@ -289,7 +290,6 @@ void APIService::LinkCluster(::google::protobuf::RpcController* controller,

auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
std::string attachment = std::move(ctrl->request_attachment().to_string());

std::string error;
auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error);
if (!st) {
Expand Down Expand Up @@ -344,7 +344,6 @@ void APIService::UnlinkCluster(::google::protobuf::RpcController* controller,

auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
std::string attachment = std::move(ctrl->request_attachment().to_string());

std::string error;
auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error);
if (!st) {
Expand Down
149 changes: 135 additions & 14 deletions xllm/api_service/chat_service_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <absl/time/clock.h>
#include <absl/time/time.h>
#include <glog/logging.h>
#include <google/protobuf/util/json_util.h>
#include <torch/torch.h>

#include <boost/algorithm/string.hpp>
Expand All @@ -12,16 +13,76 @@
#include <unordered_set>

#include "core/common/instance_name.h"
#include "core/common/types.h"
#include "core/framework/request/mm_input_helper.h"
#include "core/framework/request/request_params.h"
#include "core/runtime/llm_master.h"
#include "core/runtime/vlm_master.h"
#include "core/util/utils.h"
#include "core/util/uuid.h"
#include "function_call/function_call.h"

namespace xllm {
namespace {

struct ToolCallResult {
std::optional<google::protobuf::RepeatedPtrField<proto::ToolCall>> tool_calls;
std::string text;
std::string finish_reason;
};

ToolCallResult process_tool_calls(std::string text,
const std::vector<xllm::JsonTool>& tools,
const std::string& parser_format,
std::string finish_reason,
google::protobuf::Arena* arena = nullptr) {
ToolCallResult result;

function_call::FunctionCallParser parser(tools, parser_format);

if (!parser.has_tool_call(text)) {
result.text = std::move(text);
result.finish_reason = std::move(finish_reason);
return result;
}

if (finish_reason == "stop") {
result.finish_reason = "tool_calls";
} else {
result.finish_reason = std::move(finish_reason);
}

try {
auto [parsed_text, call_info_list] = parser.parse_non_stream(text);
result.text = std::move(parsed_text);

google::protobuf::RepeatedPtrField<proto::ToolCall> tool_calls;

for (const auto& call_info : call_info_list) {
proto::ToolCall* tool_call =
arena ? google::protobuf::Arena::CreateMessage<proto::ToolCall>(arena)
: new proto::ToolCall();

tool_call->set_id(function_call::utils::generate_tool_call_id());
tool_call->set_type("function");

auto* function = tool_call->mutable_function();
if (call_info.name) {
function->set_name(*call_info.name);
}
function->set_arguments(call_info.parameters);

tool_calls.AddAllocated(tool_call);
}

result.tool_calls = std::move(tool_calls);
} catch (const std::exception& e) {
LOG(ERROR) << "Tool call parsing error: " << e.what();
}

return result;
}

void set_logprobs(proto::ChatChoice* choice,
const std::optional<std::vector<LogProb>>& logprobs) {
if (!logprobs.has_value() || logprobs.value().empty()) {
Expand Down Expand Up @@ -140,7 +201,9 @@ bool send_result_to_client_brpc(std::shared_ptr<ChatCall> call,
const std::string& request_id,
int64_t created_time,
const std::string& model,
const RequestOutput& req_output) {
const RequestOutput& req_output,
const std::string& parser_format = "",
const std::vector<xllm::JsonTool>& tools = {}) {
auto& response = call->response();
response.set_object("chat.completion");
response.set_id(request_id);
Expand All @@ -154,9 +217,34 @@ bool send_result_to_client_brpc(std::shared_ptr<ChatCall> call,
set_logprobs(choice, output.logprobs);
auto* message = choice->mutable_message();
message->set_role("assistant");
message->set_content(output.text);
if (output.finish_reason.has_value()) {
choice->set_finish_reason(output.finish_reason.value());

auto set_output_and_finish_reason = [&]() {
message->set_content(output.text);
if (output.finish_reason.has_value()) {
choice->set_finish_reason(output.finish_reason.value());
}
};

if (!tools.empty() && !parser_format.empty()) {
auto* arena = response.GetArena();
auto result = process_tool_calls(output.text,
tools,
parser_format,
output.finish_reason.value_or(""),
arena);

message->mutable_content()->swap(result.text);

if (result.tool_calls) {
auto& source_tool_calls = *result.tool_calls;
message->mutable_tool_calls()->Swap(&source_tool_calls);
}

if (!result.finish_reason.empty()) {
choice->mutable_finish_reason()->swap(result.finish_reason);
}
} else {
set_output_and_finish_reason();
}
}

Expand Down Expand Up @@ -231,7 +319,8 @@ void ChatServiceImpl::process_async_impl(std::shared_ptr<ChatCall> call) {
include_usage = include_usage,
first_message_sent = std::unordered_set<size_t>(),
request_id = request_params.request_id,
created_time = absl::ToUnixSeconds(absl::Now())](
created_time = absl::ToUnixSeconds(absl::Now()),
json_tools = request_params.tools](
const RequestOutput& req_output) mutable -> bool {
if (req_output.status.has_value()) {
const auto& status = req_output.status.value();
Expand All @@ -250,17 +339,48 @@ void ChatServiceImpl::process_async_impl(std::shared_ptr<ChatCall> call) {
master->get_rate_limiter()->decrease_one_request();
}

const std::string parser_format =
master->options().tool_call_parser().value_or("");
const bool has_tool_support =
!json_tools.empty() && !parser_format.empty();

if (stream) {
return send_delta_to_client_brpc(call,
include_usage,
&first_message_sent,
request_id,
created_time,
model,
req_output);
if (has_tool_support) {
// TODO: Support tool call streaming output
LOG(ERROR) << "Tool call does not support streaming output";
return send_delta_to_client_brpc(call,
include_usage,
&first_message_sent,
request_id,
created_time,
model,
req_output);
} else {
// Stream response without tool support
return send_delta_to_client_brpc(call,
include_usage,
&first_message_sent,
request_id,
created_time,
model,
req_output);
}
} else {
if (has_tool_support) {
// Non-stream response with tool support
return send_result_to_client_brpc(call,
request_id,
created_time,
model,
req_output,
parser_format,
json_tools);
} else {
// Non-stream response without tool support
return send_result_to_client_brpc(
call, request_id, created_time, model, req_output);
}
}
return send_result_to_client_brpc(
call, request_id, created_time, model, req_output);
});
}

Expand Down Expand Up @@ -337,6 +457,7 @@ void MMChatServiceImpl::process_async(std::shared_ptr<MMChatCall> call) {
}

if (stream) {
// send delta to client
return send_delta_to_client_brpc(call,
include_usage,
&first_message_sent,
Expand Down
7 changes: 7 additions & 0 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ DEFINE_bool(enable_atb_comm_multiprocess,
false,
"whether to use multiprocess mode.");

// --- function call config ---

DEFINE_string(tool_call_parser,
"",
"Specify the parser for handling tool-call interactions. "
"Options include: 'qwen25'");

DEFINE_bool(enable_atb_spec_kernel,
false,
"whether to use ATB speculative kernel.");
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ DECLARE_bool(disable_custom_kernels);

DECLARE_bool(enable_atb_comm_multiprocess);

DECLARE_string(tool_call_parser);

DECLARE_bool(enable_atb_spec_kernel);

DECLARE_string(etcd_addr);
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/common/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class Options {
PROPERTY(std::optional<std::string>, etcd_addr);

PROPERTY(bool, enable_service_routing) = false;

PROPERTY(std::optional<std::string>, tool_call_parser);
};

} // namespace xllm
22 changes: 22 additions & 0 deletions xllm/core/common/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,26 @@ struct DeviceStats {
// TODO: add more device stats
};

// Function call related types
struct JsonFunction {
std::string name;
std::string description;
nlohmann::json parameters;

JsonFunction() = default;
JsonFunction(const std::string& func_name,
const std::string& desc,
const nlohmann::json& params)
: name(func_name), description(desc), parameters(params) {}
};

struct JsonTool {
std::string type; // "function"
JsonFunction function;

JsonTool() : type("function") {}
JsonTool(const std::string& tool_type, const JsonFunction& func)
: type(tool_type), function(func) {}
};

} // namespace xllm
Loading