diff --git a/context/llama_server_context.h b/context/llama_server_context.h index 3839e4b3f..21792f11b 100644 --- a/context/llama_server_context.h +++ b/context/llama_server_context.h @@ -211,6 +211,8 @@ enum stop_type { STOP_PARTIAL, }; +enum class ModelType { LLM = 0, EMBEDDING }; + static bool ends_with(const std::string& str, const std::string& suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); @@ -502,6 +504,7 @@ struct llama_server_context { std::condition_variable condition_tasks; std::mutex mutex_results; std::condition_variable condition_results; + ModelType model_type = ModelType::LLM; ~llama_server_context() { if (ctx) { diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 3e3015c2a..e4f33a27e 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -187,6 +187,16 @@ void llamaCPP::ChatCompletion( void llamaCPP::InferenceImpl( inferences::ChatCompletionRequest&& completion, std::function&& callback) { + if (llama.model_type == ModelType::EMBEDDING) { + LOG_WARN << "Not support completion for embedding model"; + Json::Value jsonResp; + jsonResp["message"] = + "Not support completion for embedding model"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } std::string formatted_output = pre_prompt; int request_id = ++no_of_requests; LOG_INFO_REQUEST(request_id) << "Generating reponse for inference request"; @@ -653,6 +663,11 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr jsonBody) { params.n_ctx = jsonBody->get("ctx_len", 2048).asInt(); params.embedding = jsonBody->get("embedding", true).asBool(); model_type = jsonBody->get("model_type", "llm").asString(); + if (model_type == "llm") { + llama.model_type = ModelType::LLM; + } else { + llama.model_type = ModelType::EMBEDDING; + } // Check if n_parallel exists in jsonBody, if not, set to drogon_thread params.n_batch = jsonBody->get("n_batch", 512).asInt(); params.n_parallel = jsonBody->get("n_parallel", 1).asInt(); @@ -712,8 +727,7 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr jsonBody) { // For model like nomic-embed-text-v1.5.f16.gguf, etc, we don't need to warm up model. // So we use this variable to differentiate with other models - // TODO: in case embedded model only, we should reject completion request from user? - if (model_type == "llm") { + if (llama.model_type == ModelType::LLM) { WarmupModel(); } return true; diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index 900786c79..531c18b20 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -59,13 +59,14 @@ class llamaCPP : public drogon::HttpController, // PATH_ADD("/llama/chat_completion", Post); METHOD_LIST_END void ChatCompletion( - inferences::ChatCompletionRequest &&completion, + inferences::ChatCompletionRequest&& completion, std::function&& callback) override; void Embedding( const HttpRequestPtr& req, std::function&& callback) override; - void LoadModel(const HttpRequestPtr& req, - std::function&& callback) override; + void LoadModel( + const HttpRequestPtr& req, + std::function&& callback) override; void UnloadModel( const HttpRequestPtr& req, std::function&& callback) override; @@ -100,7 +101,8 @@ class llamaCPP : public drogon::HttpController, std::function&& callback); void EmbeddingImpl(std::shared_ptr jsonBody, std::function&& callback); - bool CheckModelLoaded(const std::function& callback); + bool CheckModelLoaded( + const std::function& callback); void WarmupModel(); void BackgroundTask(); void StopBackgroundTask();