From fc899b95be3f89e1e63a5850379b9762e0d3cd78 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 16 Apr 2024 08:41:12 +0700 Subject: [PATCH 1/2] fix: reject completion request for embedding model --- controllers/llamaCPP.cc | 18 ++++++++++++++++-- controllers/llamaCPP.h | 12 ++++++++---- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 3e3015c2a..fdf482f57 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -187,6 +187,16 @@ void llamaCPP::ChatCompletion( void llamaCPP::InferenceImpl( inferences::ChatCompletionRequest&& completion, std::function&& callback) { + if (model_type_ == ModelType::EMBEDDING) { + LOG_WARN << "Not support completion for embedding model"; + Json::Value jsonResp; + jsonResp["message"] = + "Not support completion for embedding model"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } std::string formatted_output = pre_prompt; int request_id = ++no_of_requests; LOG_INFO_REQUEST(request_id) << "Generating reponse for inference request"; @@ -653,6 +663,11 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr jsonBody) { params.n_ctx = jsonBody->get("ctx_len", 2048).asInt(); params.embedding = jsonBody->get("embedding", true).asBool(); model_type = jsonBody->get("model_type", "llm").asString(); + if (model_type == "llm") { + model_type_ = ModelType::LLM; + } else { + model_type_ = ModelType::EMBEDDING; + } // Check if n_parallel exists in jsonBody, if not, set to drogon_thread params.n_batch = jsonBody->get("n_batch", 512).asInt(); params.n_parallel = jsonBody->get("n_parallel", 1).asInt(); @@ -712,8 +727,7 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr jsonBody) { // For model like nomic-embed-text-v1.5.f16.gguf, etc, we don't need to warm up model. // So we use this variable to differentiate with other models - // TODO: in case embedded model only, we should reject completion request from user? - if (model_type == "llm") { + if (model_type_ == ModelType::LLM) { WarmupModel(); } return true; diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index 900786c79..caf11dce3 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -59,13 +59,14 @@ class llamaCPP : public drogon::HttpController, // PATH_ADD("/llama/chat_completion", Post); METHOD_LIST_END void ChatCompletion( - inferences::ChatCompletionRequest &&completion, + inferences::ChatCompletionRequest&& completion, std::function&& callback) override; void Embedding( const HttpRequestPtr& req, std::function&& callback) override; - void LoadModel(const HttpRequestPtr& req, - std::function&& callback) override; + void LoadModel( + const HttpRequestPtr& req, + std::function&& callback) override; void UnloadModel( const HttpRequestPtr& req, std::function&& callback) override; @@ -90,6 +91,8 @@ class llamaCPP : public drogon::HttpController, int clean_cache_threshold; std::string grammar_file_content; + enum class ModelType { LLM = 0, EMBEDDING } model_type_; + /** * Queue to handle the inference tasks */ @@ -100,7 +103,8 @@ class llamaCPP : public drogon::HttpController, std::function&& callback); void EmbeddingImpl(std::shared_ptr jsonBody, std::function&& callback); - bool CheckModelLoaded(const std::function& callback); + bool CheckModelLoaded( + const std::function& callback); void WarmupModel(); void BackgroundTask(); void StopBackgroundTask(); From a9d24db83deaddec65bcd6843715c7ee5840923f Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 16 Apr 2024 10:40:07 +0700 Subject: [PATCH 2/2] fix: move model_type to llama_server_context --- context/llama_server_context.h | 3 +++ controllers/llamaCPP.cc | 8 ++++---- controllers/llamaCPP.h | 2 -- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/context/llama_server_context.h b/context/llama_server_context.h index 3839e4b3f..21792f11b 100644 --- a/context/llama_server_context.h +++ b/context/llama_server_context.h @@ -211,6 +211,8 @@ enum stop_type { STOP_PARTIAL, }; +enum class ModelType { LLM = 0, EMBEDDING }; + static bool ends_with(const std::string& str, const std::string& suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); @@ -502,6 +504,7 @@ struct llama_server_context { std::condition_variable condition_tasks; std::mutex mutex_results; std::condition_variable condition_results; + ModelType model_type = ModelType::LLM; ~llama_server_context() { if (ctx) { diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index fdf482f57..e4f33a27e 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -187,7 +187,7 @@ void llamaCPP::ChatCompletion( void llamaCPP::InferenceImpl( inferences::ChatCompletionRequest&& completion, std::function&& callback) { - if (model_type_ == ModelType::EMBEDDING) { + if (llama.model_type == ModelType::EMBEDDING) { LOG_WARN << "Not support completion for embedding model"; Json::Value jsonResp; jsonResp["message"] = @@ -664,9 +664,9 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr jsonBody) { params.embedding = jsonBody->get("embedding", true).asBool(); model_type = jsonBody->get("model_type", "llm").asString(); if (model_type == "llm") { - model_type_ = ModelType::LLM; + llama.model_type = ModelType::LLM; } else { - model_type_ = ModelType::EMBEDDING; + llama.model_type = ModelType::EMBEDDING; } // Check if n_parallel exists in jsonBody, if not, set to drogon_thread params.n_batch = jsonBody->get("n_batch", 512).asInt(); @@ -727,7 +727,7 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr jsonBody) { // For model like nomic-embed-text-v1.5.f16.gguf, etc, we don't need to warm up model. // So we use this variable to differentiate with other models - if (model_type_ == ModelType::LLM) { + if (llama.model_type == ModelType::LLM) { WarmupModel(); } return true; diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index caf11dce3..531c18b20 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -91,8 +91,6 @@ class llamaCPP : public drogon::HttpController, int clean_cache_threshold; std::string grammar_file_content; - enum class ModelType { LLM = 0, EMBEDDING } model_type_; - /** * Queue to handle the inference tasks */