From e60b302201b3066f2d7c20284a62788fe2605be7 Mon Sep 17 00:00:00 2001 From: liangzhiwei20 Date: Thu, 23 Oct 2025 17:07:14 +0800 Subject: [PATCH] feat: support v1/rerank interface for embedding model. --- xllm/api_service/CMakeLists.txt | 4 + xllm/api_service/api_service.cpp | 44 +++++ xllm/api_service/api_service.h | 13 ++ xllm/api_service/rerank_service_impl.cpp | 170 ++++++++++++++++++ xllm/api_service/rerank_service_impl.h | 41 +++++ .../core/framework/request/request_params.cpp | 19 ++ xllm/core/framework/request/request_params.h | 4 + xllm/proto/CMakeLists.txt | 1 + xllm/proto/rerank.proto | 41 +++++ xllm/proto/xllm_service.proto | 4 + xllm/server/xllm_server.cpp | 1 + 11 files changed, 342 insertions(+) create mode 100644 xllm/api_service/rerank_service_impl.cpp create mode 100644 xllm/api_service/rerank_service_impl.h create mode 100644 xllm/proto/rerank.proto diff --git a/xllm/api_service/CMakeLists.txt b/xllm/api_service/CMakeLists.txt index e1083ca46..ca69f113b 100644 --- a/xllm/api_service/CMakeLists.txt +++ b/xllm/api_service/CMakeLists.txt @@ -11,6 +11,7 @@ cc_library( chat_service_impl.h embedding_service_impl.h image_generation_service_impl.h + rerank_service_impl.h non_stream_call.h service_impl_factory.h stream_call.h @@ -23,6 +24,7 @@ cc_library( embedding_service_impl.cpp image_generation_service_impl.cpp models_service_impl.cpp + rerank_service_impl.cpp DEPS :master :chat_template @@ -32,5 +34,7 @@ cc_library( absl::flat_hash_set absl::random_random :function_call + torch + $<$:torch_npu> ) diff --git a/xllm/api_service/api_service.cpp b/xllm/api_service/api_service.cpp index cb0bbb068..9a5aed245 100644 --- a/xllm/api_service/api_service.cpp +++ b/xllm/api_service/api_service.cpp @@ -51,6 +51,9 @@ APIService::APIService(Master* master, embedding_service_impl_ = ServiceImplFactory::create_service_impl( llm_master, model_names); + rerank_service_impl_ = + ServiceImplFactory::create_service_impl(llm_master, + model_names); } else if (FLAGS_backend == "vlm") { auto vlm_master = dynamic_cast(master); mm_chat_service_impl_ = @@ -260,6 +263,47 @@ void APIService::ImageGenerationHttp( image_generation_service_impl_->process_async(call); } +void APIService::Rerank(::google::protobuf::RpcController* controller, + const proto::RerankRequest* request, + proto::RerankResponse* response, + ::google::protobuf::Closure* done) { + // TODO with xllm-service +} + +void APIService::RerankHttp(::google::protobuf::RpcController* controller, + const proto::HttpRequest* request, + proto::HttpResponse* response, + ::google::protobuf::Closure* done) { + xllm::ClosureGuard done_guard( + done, + std::bind(request_in_metric, nullptr), + std::bind(request_out_metric, (void*)controller)); + if (!request || !response || !controller) { + LOG(ERROR) << "brpc request | respose | controller is null"; + return; + } + + auto arena = response->GetArena(); + auto req_pb = + google::protobuf::Arena::CreateMessage(arena); + auto resp_pb = + google::protobuf::Arena::CreateMessage(arena); + + auto ctrl = reinterpret_cast(controller); + std::string attachment = std::move(ctrl->request_attachment().to_string()); + std::string error; + auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error); + if (!st) { + ctrl->SetFailed(error); + LOG(ERROR) << "parse json to proto failed: " << error; + return; + } + + std::shared_ptr call = + std::make_shared(ctrl, done_guard.release(), req_pb, resp_pb); + rerank_service_impl_->process_async(call); +} + void APIService::Models(::google::protobuf::RpcController* controller, const proto::ModelListRequest* request, proto::ModelListResponse* response, diff --git a/xllm/api_service/api_service.h b/xllm/api_service/api_service.h index 656611ed3..0911fb058 100644 --- a/xllm/api_service/api_service.h +++ b/xllm/api_service/api_service.h @@ -20,6 +20,7 @@ limitations under the License. #include "embedding_service_impl.h" #include "image_generation_service_impl.h" #include "models_service_impl.h" +#include "rerank_service_impl.h" #include "xllm_service.pb.h" namespace xllm { @@ -70,6 +71,17 @@ class APIService : public proto::XllmAPIService { const proto::HttpRequest* request, proto::HttpResponse* response, ::google::protobuf::Closure* done) override; + + void Rerank(::google::protobuf::RpcController* controller, + const proto::RerankRequest* request, + proto::RerankResponse* response, + ::google::protobuf::Closure* done) override; + + void RerankHttp(::google::protobuf::RpcController* controller, + const proto::HttpRequest* request, + proto::HttpResponse* response, + ::google::protobuf::Closure* done) override; + void Models(::google::protobuf::RpcController* controller, const proto::ModelListRequest* request, proto::ModelListResponse* response, @@ -109,6 +121,7 @@ class APIService : public proto::XllmAPIService { std::unique_ptr embedding_service_impl_; std::unique_ptr models_service_impl_; std::unique_ptr image_generation_service_impl_; + std::unique_ptr rerank_service_impl_; }; } // namespace xllm diff --git a/xllm/api_service/rerank_service_impl.cpp b/xllm/api_service/rerank_service_impl.cpp new file mode 100644 index 000000000..4f3400787 --- /dev/null +++ b/xllm/api_service/rerank_service_impl.cpp @@ -0,0 +1,170 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rerank_service_impl.h" + +#include +#include + +#include + +#include "common/instance_name.h" +#include "framework/request/request_params.h" +#include "runtime/llm_master.h" +#include "util/blocking_counter.h" +#include "util/utils.h" +#include "util/uuid.h" + +namespace xllm { +namespace { + +struct RerankRequestOutput { + int32_t index = 0; + std::string document = ""; + float score = 0.0f; + + RerankRequestOutput(int32_t index, std::string document, float score) + : index(index), document(std::move(document)), score(score) {} +}; + +bool send_result_to_client_brpc(std::shared_ptr call, + const std::string& request_id, + int64_t created_time, + const std::string& model, + const std::vector& documents, + int32_t top_n, + const std::vector& req_outputs) { + auto& response = call->response(); + response.set_id(request_id); + response.set_model(model); + + // calculate cosine similarity + size_t doc_size = documents.size() - 1; + std::string query = documents[doc_size]; + auto query_embed = req_outputs[doc_size].outputs[0].embeddings.value(); + auto query_tensor = torch::from_blob( + query_embed.data(), {query_embed.size()}, torch::kFloat32); + + std::vector rerank_outputs; + rerank_outputs.reserve(doc_size); + for (size_t i = 0; i < doc_size; ++i) { + if (req_outputs[i].outputs[0].embeddings.has_value()) { + auto doc_embed = req_outputs[i].outputs[0].embeddings.value(); + auto doc_tensor = torch::from_blob( + doc_embed.data(), {doc_embed.size()}, torch::kFloat32); + auto score = + torch::cosine_similarity(query_tensor, doc_tensor, 0).item(); + rerank_outputs.emplace_back(i, documents[i], score); + } + } + + std::sort(rerank_outputs.begin(), + rerank_outputs.end(), + [](const RerankRequestOutput& a, const RerankRequestOutput& b) { + return a.score > b.score; + }); + + // add top_n results + int32_t valid_top_n = std::min(top_n, static_cast(doc_size)); + response.mutable_results()->Reserve(valid_top_n); + for (int32_t i = 0; i < valid_top_n; ++i) { + auto* result = response.add_results(); + result->set_index(rerank_outputs[i].index); + auto* document = result->mutable_document(); + document->set_text(rerank_outputs[i].document); + result->set_relevance_score(rerank_outputs[i].score); + } + + // add usage statistics + int32_t num_prompt_tokens = 0; + int32_t num_generated_tokens = 0; + int32_t num_total_tokens = 0; + for (auto req_output : req_outputs) { + if (req_output.usage.has_value()) { + const auto& usage = req_output.usage.value(); + num_prompt_tokens += usage.num_prompt_tokens; + num_generated_tokens += usage.num_generated_tokens; + num_total_tokens += usage.num_total_tokens; + } + } + if (num_total_tokens > 0) { + auto* proto_usage = response.mutable_usage(); + proto_usage->set_prompt_tokens(num_prompt_tokens); + proto_usage->set_completion_tokens(num_generated_tokens); + proto_usage->set_total_tokens(num_total_tokens); + } + + return call->write_and_finish(response); +} + +} // namespace + +RerankServiceImpl::RerankServiceImpl(LLMMaster* master, + const std::vector& models) + : APIServiceImpl(models), master_(master) { + CHECK(master_ != nullptr); +} + +// rerank_async for brpc +void RerankServiceImpl::process_async_impl(std::shared_ptr call) { + const auto& rpc_request = call->request(); + // check if model is supported + const auto& model = rpc_request.model(); + if (!models_.contains(model)) { + call->finish_with_error(StatusCode::UNKNOWN, "Model not supported"); + return; + } + + std::vector documents; + if (rpc_request.documents_size() > 0) { + documents = std::vector(rpc_request.documents().begin(), + rpc_request.documents().end()); + } + documents.emplace_back(rpc_request.query()); + + // create RequestParams for rerank request + RequestParams request_params( + rpc_request, call->get_x_request_id(), call->get_x_request_time()); + std::vector sps(documents.size(), request_params); + auto request_id = request_params.request_id; + auto created_time = absl::ToUnixSeconds(absl::Now()); + + // schedule the request + std::vector req_outputs; + req_outputs.resize(documents.size()); + BlockingCounter counter(documents.size()); + + auto batch_callback = [&req_outputs, &counter](size_t index, + RequestOutput output) -> bool { + req_outputs[index] = std::move(output); + counter.decrement_count(); + return true; + }; + + master_->handle_batch_request(documents, sps, batch_callback); + + // Wait for all tasks to complete + counter.wait(); + + int32_t top_n = documents.size() - 1; + if (rpc_request.has_top_n()) { + top_n = rpc_request.top_n(); + } + + send_result_to_client_brpc( + call, request_id, created_time, model, documents, top_n, req_outputs); +} + +} // namespace xllm diff --git a/xllm/api_service/rerank_service_impl.h b/xllm/api_service/rerank_service_impl.h new file mode 100644 index 000000000..948301e0d --- /dev/null +++ b/xllm/api_service/rerank_service_impl.h @@ -0,0 +1,41 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include + +#include "api_service/api_service_impl.h" +#include "api_service/call.h" +#include "api_service/non_stream_call.h" +#include "rerank.pb.h" + +namespace xllm { + +using RerankCall = NonStreamCall; + +// a class to handle completion requests +class RerankServiceImpl final : public APIServiceImpl { + public: + RerankServiceImpl(LLMMaster* master, const std::vector& models); + + // brpc call_data needs to use shared_ptr + void process_async_impl(std::shared_ptr call); + + private: + DISALLOW_COPY_AND_ASSIGN(RerankServiceImpl); + LLMMaster* master_ = nullptr; +}; + +} // namespace xllm diff --git a/xllm/core/framework/request/request_params.cpp b/xllm/core/framework/request/request_params.cpp index d92365dca..8f2378d73 100644 --- a/xllm/core/framework/request/request_params.cpp +++ b/xllm/core/framework/request/request_params.cpp @@ -39,6 +39,11 @@ std::string generate_chat_request_id() { short_uuid.random(); } +std::string generate_rerank_request_id() { + return "rerankcmpl-" + InstanceName::name()->get_name_hash() + "-" + + short_uuid.random(); +} + } // namespace RequestParams::RequestParams(const proto::CompletionRequest& request, @@ -332,6 +337,20 @@ RequestParams::RequestParams(const proto::EmbeddingRequest& request, streaming = false; } +RequestParams::RequestParams(const proto::RerankRequest& request, + const std::string& x_rid, + const std::string& x_rtime) { + request_id = generate_rerank_request_id(); + if (request.has_service_request_id()) { + service_request_id = request.service_request_id(); + } + x_request_id = x_rid; + x_request_time = x_rtime; + is_embeddings = true; + max_tokens = 1; + streaming = false; +} + bool RequestParams::verify_params(OutputCallback callback) const { if (n == 0) { CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, diff --git a/xllm/core/framework/request/request_params.h b/xllm/core/framework/request/request_params.h index b6df6dcef..88ebe8652 100644 --- a/xllm/core/framework/request/request_params.h +++ b/xllm/core/framework/request/request_params.h @@ -31,6 +31,7 @@ limitations under the License. #include "multimodal.pb.h" #include "request.h" #include "request_output.h" +#include "rerank.pb.h" namespace xllm { @@ -48,6 +49,9 @@ struct RequestParams { RequestParams(const proto::EmbeddingRequest& request, const std::string& x_rid, const std::string& x_rtime); + RequestParams(const proto::RerankRequest& request, + const std::string& x_rid, + const std::string& x_rtime); bool verify_params(OutputCallback callback) const; diff --git a/xllm/proto/CMakeLists.txt b/xllm/proto/CMakeLists.txt index e1f20ba04..38be2b75d 100644 --- a/xllm/proto/CMakeLists.txt +++ b/xllm/proto/CMakeLists.txt @@ -10,6 +10,7 @@ proto_library( chat.proto multimodal.proto embedding.proto + rerank.proto models.proto worker.proto disagg_pd.proto diff --git a/xllm/proto/rerank.proto b/xllm/proto/rerank.proto new file mode 100644 index 000000000..2426b5dab --- /dev/null +++ b/xllm/proto/rerank.proto @@ -0,0 +1,41 @@ +syntax = "proto3"; + +option go_package = "jd.com/jd-infer/xllm;xllm"; +package xllm.proto; + +import "common.proto"; + +message RerankRequest { + string model = 1; + string query = 2; + repeated string documents = 3; + optional int32 top_n = 4; + optional int32 truncate_prompt_tokens = 5; + + optional string user = 6; + + optional string service_request_id = 7; +} + +message RerankDocument { + string text = 1; +} + +message RerankResult { + int32 index = 1; + + RerankDocument document = 2; + + float relevance_score = 3; +} + +message RerankResponse { + string id = 1; + + string model = 2; + + Usage usage = 3; + + repeated RerankResult results = 4; +} + diff --git a/xllm/proto/xllm_service.proto b/xllm/proto/xllm_service.proto index 435e996ff..edd6a4a52 100644 --- a/xllm/proto/xllm_service.proto +++ b/xllm/proto/xllm_service.proto @@ -9,6 +9,7 @@ import "completion.proto"; import "chat.proto"; import "embedding.proto"; import "image_generation.proto"; +import "rerank.proto"; import "models.proto"; message HttpRequest { @@ -52,6 +53,9 @@ service XllmAPIService { rpc ImageGeneration(ImageGenerationRequest) returns (ImageGenerationResponse); rpc ImageGenerationHttp(HttpRequest) returns (HttpResponse); + + rpc Rerank (RerankRequest) returns (RerankResponse); + rpc RerankHttp (HttpRequest) returns (HttpResponse); rpc Models (ModelListRequest) returns (ModelListResponse); rpc ModelsHttp (HttpRequest) returns (HttpResponse); diff --git a/xllm/server/xllm_server.cpp b/xllm/server/xllm_server.cpp index de092e446..6b8496a6d 100644 --- a/xllm/server/xllm_server.cpp +++ b/xllm/server/xllm_server.cpp @@ -40,6 +40,7 @@ bool XllmServer::start(std::unique_ptr service) { "v1/embeddings => EmbeddingsHttp," "v1/models => ModelsHttp," "v1/image/generation => ImageGenerationHttp," + "v1/rerank => RerankHttp," "get_cache_info => GetCacheInfo," "link_cluster => LinkCluster," "unlink_cluster => UnlinkCluster,"