Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions xllm/api_service/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ cc_library(
chat_service_impl.h
embedding_service_impl.h
image_generation_service_impl.h
rerank_service_impl.h
non_stream_call.h
service_impl_factory.h
stream_call.h
Expand All @@ -23,6 +24,7 @@ cc_library(
embedding_service_impl.cpp
image_generation_service_impl.cpp
models_service_impl.cpp
rerank_service_impl.cpp
DEPS
:master
:chat_template
Expand All @@ -32,5 +34,7 @@ cc_library(
absl::flat_hash_set
absl::random_random
:function_call
torch
$<$<BOOL:${USE_NPU}>:torch_npu>
)

44 changes: 44 additions & 0 deletions xllm/api_service/api_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ APIService::APIService(Master* master,
embedding_service_impl_ =
ServiceImplFactory<EmbeddingServiceImpl>::create_service_impl(
llm_master, model_names);
rerank_service_impl_ =
ServiceImplFactory<RerankServiceImpl>::create_service_impl(llm_master,
model_names);
} else if (FLAGS_backend == "vlm") {
auto vlm_master = dynamic_cast<VLMMaster*>(master);
mm_chat_service_impl_ =
Expand Down Expand Up @@ -260,6 +263,47 @@ void APIService::ImageGenerationHttp(
image_generation_service_impl_->process_async(call);
}

void APIService::Rerank(::google::protobuf::RpcController* controller,
const proto::RerankRequest* request,
proto::RerankResponse* response,
::google::protobuf::Closure* done) {
// TODO with xllm-service
}

void APIService::RerankHttp(::google::protobuf::RpcController* controller,
const proto::HttpRequest* request,
proto::HttpResponse* response,
::google::protobuf::Closure* done) {
xllm::ClosureGuard done_guard(
done,
std::bind(request_in_metric, nullptr),
std::bind(request_out_metric, (void*)controller));
if (!request || !response || !controller) {
LOG(ERROR) << "brpc request | respose | controller is null";
return;
}

auto arena = response->GetArena();
auto req_pb =
google::protobuf::Arena::CreateMessage<proto::RerankRequest>(arena);
auto resp_pb =
google::protobuf::Arena::CreateMessage<proto::RerankResponse>(arena);

auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
std::string attachment = std::move(ctrl->request_attachment().to_string());
std::string error;
auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error);
if (!st) {
ctrl->SetFailed(error);
LOG(ERROR) << "parse json to proto failed: " << error;
return;
}

std::shared_ptr<Call> call =
std::make_shared<RerankCall>(ctrl, done_guard.release(), req_pb, resp_pb);
rerank_service_impl_->process_async(call);
}

void APIService::Models(::google::protobuf::RpcController* controller,
const proto::ModelListRequest* request,
proto::ModelListResponse* response,
Expand Down
13 changes: 13 additions & 0 deletions xllm/api_service/api_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "embedding_service_impl.h"
#include "image_generation_service_impl.h"
#include "models_service_impl.h"
#include "rerank_service_impl.h"
#include "xllm_service.pb.h"

namespace xllm {
Expand Down Expand Up @@ -70,6 +71,17 @@ class APIService : public proto::XllmAPIService {
const proto::HttpRequest* request,
proto::HttpResponse* response,
::google::protobuf::Closure* done) override;

void Rerank(::google::protobuf::RpcController* controller,
const proto::RerankRequest* request,
proto::RerankResponse* response,
::google::protobuf::Closure* done) override;

void RerankHttp(::google::protobuf::RpcController* controller,
const proto::HttpRequest* request,
proto::HttpResponse* response,
::google::protobuf::Closure* done) override;

void Models(::google::protobuf::RpcController* controller,
const proto::ModelListRequest* request,
proto::ModelListResponse* response,
Expand Down Expand Up @@ -109,6 +121,7 @@ class APIService : public proto::XllmAPIService {
std::unique_ptr<EmbeddingServiceImpl> embedding_service_impl_;
std::unique_ptr<ModelsServiceImpl> models_service_impl_;
std::unique_ptr<ImageGenerationServiceImpl> image_generation_service_impl_;
std::unique_ptr<RerankServiceImpl> rerank_service_impl_;
};

} // namespace xllm
170 changes: 170 additions & 0 deletions xllm/api_service/rerank_service_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "rerank_service_impl.h"

#include <glog/logging.h>
#include <torch/torch.h>

#include <string>

#include "common/instance_name.h"
#include "framework/request/request_params.h"
#include "runtime/llm_master.h"
#include "util/blocking_counter.h"
#include "util/utils.h"
#include "util/uuid.h"

namespace xllm {
namespace {

struct RerankRequestOutput {
int32_t index = 0;
std::string document = "";
float score = 0.0f;

RerankRequestOutput(int32_t index, std::string document, float score)
: index(index), document(std::move(document)), score(score) {}
};

bool send_result_to_client_brpc(std::shared_ptr<RerankCall> call,
const std::string& request_id,
int64_t created_time,
const std::string& model,
const std::vector<std::string>& documents,
int32_t top_n,
const std::vector<RequestOutput>& req_outputs) {
auto& response = call->response();
response.set_id(request_id);
response.set_model(model);

// calculate cosine similarity
size_t doc_size = documents.size() - 1;
std::string query = documents[doc_size];
auto query_embed = req_outputs[doc_size].outputs[0].embeddings.value();
auto query_tensor = torch::from_blob(
query_embed.data(), {query_embed.size()}, torch::kFloat32);

std::vector<RerankRequestOutput> rerank_outputs;
rerank_outputs.reserve(doc_size);
for (size_t i = 0; i < doc_size; ++i) {
if (req_outputs[i].outputs[0].embeddings.has_value()) {
auto doc_embed = req_outputs[i].outputs[0].embeddings.value();
auto doc_tensor = torch::from_blob(
doc_embed.data(), {doc_embed.size()}, torch::kFloat32);
auto score =
torch::cosine_similarity(query_tensor, doc_tensor, 0).item<float>();
rerank_outputs.emplace_back(i, documents[i], score);
}
}

std::sort(rerank_outputs.begin(),
rerank_outputs.end(),
[](const RerankRequestOutput& a, const RerankRequestOutput& b) {
return a.score > b.score;
});

// add top_n results
int32_t valid_top_n = std::min(top_n, static_cast<int32_t>(doc_size));
response.mutable_results()->Reserve(valid_top_n);
for (int32_t i = 0; i < valid_top_n; ++i) {
auto* result = response.add_results();
result->set_index(rerank_outputs[i].index);
auto* document = result->mutable_document();
document->set_text(rerank_outputs[i].document);
result->set_relevance_score(rerank_outputs[i].score);
}

// add usage statistics
int32_t num_prompt_tokens = 0;
int32_t num_generated_tokens = 0;
int32_t num_total_tokens = 0;
for (auto req_output : req_outputs) {
if (req_output.usage.has_value()) {
const auto& usage = req_output.usage.value();
num_prompt_tokens += usage.num_prompt_tokens;
num_generated_tokens += usage.num_generated_tokens;
num_total_tokens += usage.num_total_tokens;
}
}
if (num_total_tokens > 0) {
auto* proto_usage = response.mutable_usage();
proto_usage->set_prompt_tokens(num_prompt_tokens);
proto_usage->set_completion_tokens(num_generated_tokens);
proto_usage->set_total_tokens(num_total_tokens);
}

return call->write_and_finish(response);
}

} // namespace

RerankServiceImpl::RerankServiceImpl(LLMMaster* master,
const std::vector<std::string>& models)
: APIServiceImpl(models), master_(master) {
CHECK(master_ != nullptr);
}

// rerank_async for brpc
void RerankServiceImpl::process_async_impl(std::shared_ptr<RerankCall> call) {
const auto& rpc_request = call->request();
// check if model is supported
const auto& model = rpc_request.model();
if (!models_.contains(model)) {
call->finish_with_error(StatusCode::UNKNOWN, "Model not supported");
return;
}

std::vector<std::string> documents;
if (rpc_request.documents_size() > 0) {
documents = std::vector<std::string>(rpc_request.documents().begin(),
rpc_request.documents().end());
}
documents.emplace_back(rpc_request.query());

// create RequestParams for rerank request
RequestParams request_params(
rpc_request, call->get_x_request_id(), call->get_x_request_time());
std::vector<RequestParams> sps(documents.size(), request_params);
auto request_id = request_params.request_id;
auto created_time = absl::ToUnixSeconds(absl::Now());

// schedule the request
std::vector<RequestOutput> req_outputs;
req_outputs.resize(documents.size());
BlockingCounter counter(documents.size());

auto batch_callback = [&req_outputs, &counter](size_t index,
RequestOutput output) -> bool {
req_outputs[index] = std::move(output);
counter.decrement_count();
return true;
};

master_->handle_batch_request(documents, sps, batch_callback);

// Wait for all tasks to complete
counter.wait();

int32_t top_n = documents.size() - 1;
if (rpc_request.has_top_n()) {
top_n = rpc_request.top_n();
}

send_result_to_client_brpc(
call, request_id, created_time, model, documents, top_n, req_outputs);
}

} // namespace xllm
41 changes: 41 additions & 0 deletions xllm/api_service/rerank_service_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once
#include <absl/container/flat_hash_set.h>

#include "api_service/api_service_impl.h"
#include "api_service/call.h"
#include "api_service/non_stream_call.h"
#include "rerank.pb.h"

namespace xllm {

using RerankCall = NonStreamCall<proto::RerankRequest, proto::RerankResponse>;

// a class to handle completion requests
class RerankServiceImpl final : public APIServiceImpl<RerankCall> {
public:
RerankServiceImpl(LLMMaster* master, const std::vector<std::string>& models);

// brpc call_data needs to use shared_ptr
void process_async_impl(std::shared_ptr<RerankCall> call);

private:
DISALLOW_COPY_AND_ASSIGN(RerankServiceImpl);
LLMMaster* master_ = nullptr;
};

} // namespace xllm
19 changes: 19 additions & 0 deletions xllm/core/framework/request/request_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ std::string generate_chat_request_id() {
short_uuid.random();
}

std::string generate_rerank_request_id() {
return "rerankcmpl-" + InstanceName::name()->get_name_hash() + "-" +
short_uuid.random();
}

} // namespace

RequestParams::RequestParams(const proto::CompletionRequest& request,
Expand Down Expand Up @@ -332,6 +337,20 @@ RequestParams::RequestParams(const proto::EmbeddingRequest& request,
streaming = false;
}

RequestParams::RequestParams(const proto::RerankRequest& request,
const std::string& x_rid,
const std::string& x_rtime) {
request_id = generate_rerank_request_id();
if (request.has_service_request_id()) {
service_request_id = request.service_request_id();
}
x_request_id = x_rid;
x_request_time = x_rtime;
is_embeddings = true;
max_tokens = 1;
streaming = false;
}

bool RequestParams::verify_params(OutputCallback callback) const {
if (n == 0) {
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,
Expand Down
4 changes: 4 additions & 0 deletions xllm/core/framework/request/request_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "multimodal.pb.h"
#include "request.h"
#include "request_output.h"
#include "rerank.pb.h"

namespace xllm {

Expand All @@ -48,6 +49,9 @@ struct RequestParams {
RequestParams(const proto::EmbeddingRequest& request,
const std::string& x_rid,
const std::string& x_rtime);
RequestParams(const proto::RerankRequest& request,
const std::string& x_rid,
const std::string& x_rtime);

bool verify_params(OutputCallback callback) const;

Expand Down
1 change: 1 addition & 0 deletions xllm/proto/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ proto_library(
chat.proto
multimodal.proto
embedding.proto
rerank.proto
models.proto
worker.proto
disagg_pd.proto
Expand Down
Loading