Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 2c33000

Browse files
authored
fix: reject completion request for embedding model (#510)
1 parent 7ae9928 commit 2c33000

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

context/llama_server_context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ enum stop_type {
211211
STOP_PARTIAL,
212212
};
213213

214+
enum class ModelType { LLM = 0, EMBEDDING };
215+
214216
static bool ends_with(const std::string& str, const std::string& suffix) {
215217
return str.size() >= suffix.size() &&
216218
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
@@ -502,6 +504,7 @@ struct llama_server_context {
502504
std::condition_variable condition_tasks;
503505
std::mutex mutex_results;
504506
std::condition_variable condition_results;
507+
ModelType model_type = ModelType::LLM;
505508

506509
~llama_server_context() {
507510
if (ctx) {

controllers/llamaCPP.cc

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,16 @@ void llamaCPP::ChatCompletion(
187187
void llamaCPP::InferenceImpl(
188188
inferences::ChatCompletionRequest&& completion,
189189
std::function<void(const HttpResponsePtr&)>&& callback) {
190+
if (llama.model_type == ModelType::EMBEDDING) {
191+
LOG_WARN << "Not support completion for embedding model";
192+
Json::Value jsonResp;
193+
jsonResp["message"] =
194+
"Not support completion for embedding model";
195+
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
196+
resp->setStatusCode(drogon::k400BadRequest);
197+
callback(resp);
198+
return;
199+
}
190200
std::string formatted_output = pre_prompt;
191201
int request_id = ++no_of_requests;
192202
LOG_INFO_REQUEST(request_id) << "Generating reponse for inference request";
@@ -653,6 +663,11 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
653663
params.n_ctx = jsonBody->get("ctx_len", 2048).asInt();
654664
params.embedding = jsonBody->get("embedding", true).asBool();
655665
model_type = jsonBody->get("model_type", "llm").asString();
666+
if (model_type == "llm") {
667+
llama.model_type = ModelType::LLM;
668+
} else {
669+
llama.model_type = ModelType::EMBEDDING;
670+
}
656671
// Check if n_parallel exists in jsonBody, if not, set to drogon_thread
657672
params.n_batch = jsonBody->get("n_batch", 512).asInt();
658673
params.n_parallel = jsonBody->get("n_parallel", 1).asInt();
@@ -712,8 +727,7 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
712727

713728
// For model like nomic-embed-text-v1.5.f16.gguf, etc, we don't need to warm up model.
714729
// So we use this variable to differentiate with other models
715-
// TODO: in case embedded model only, we should reject completion request from user?
716-
if (model_type == "llm") {
730+
if (llama.model_type == ModelType::LLM) {
717731
WarmupModel();
718732
}
719733
return true;

controllers/llamaCPP.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,14 @@ class llamaCPP : public drogon::HttpController<llamaCPP>,
5959
// PATH_ADD("/llama/chat_completion", Post);
6060
METHOD_LIST_END
6161
void ChatCompletion(
62-
inferences::ChatCompletionRequest &&completion,
62+
inferences::ChatCompletionRequest&& completion,
6363
std::function<void(const HttpResponsePtr&)>&& callback) override;
6464
void Embedding(
6565
const HttpRequestPtr& req,
6666
std::function<void(const HttpResponsePtr&)>&& callback) override;
67-
void LoadModel(const HttpRequestPtr& req,
68-
std::function<void(const HttpResponsePtr&)>&& callback) override;
67+
void LoadModel(
68+
const HttpRequestPtr& req,
69+
std::function<void(const HttpResponsePtr&)>&& callback) override;
6970
void UnloadModel(
7071
const HttpRequestPtr& req,
7172
std::function<void(const HttpResponsePtr&)>&& callback) override;
@@ -100,7 +101,8 @@ class llamaCPP : public drogon::HttpController<llamaCPP>,
100101
std::function<void(const HttpResponsePtr&)>&& callback);
101102
void EmbeddingImpl(std::shared_ptr<Json::Value> jsonBody,
102103
std::function<void(const HttpResponsePtr&)>&& callback);
103-
bool CheckModelLoaded(const std::function<void(const HttpResponsePtr&)>& callback);
104+
bool CheckModelLoaded(
105+
const std::function<void(const HttpResponsePtr&)>& callback);
104106
void WarmupModel();
105107
void BackgroundTask();
106108
void StopBackgroundTask();

0 commit comments

Comments
 (0)