diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 36d8a284b..a26eaa94d 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -21,6 +21,20 @@ std::shared_ptr create_inference_state(llamaCPP *instance) { // -------------------------------------------- +// Function to check if the model is loaded +void check_model_loaded(llama_server_context &llama, const HttpRequestPtr &req, + std::function &callback) { + if (!llama.model_loaded_external) { + Json::Value jsonResp; + jsonResp["message"] = + "Model has not been loaded, please load model into nitro"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(drogon::k409Conflict); + callback(resp); + return; + } +} + Json::Value create_embedding_payload(const std::vector &embedding, int prompt_tokens) { Json::Value dataItem; @@ -136,15 +150,8 @@ void llamaCPP::chatCompletion( const HttpRequestPtr &req, std::function &&callback) { - if (!llama.model_loaded_external) { - Json::Value jsonResp; - jsonResp["message"] = - "Model has not been loaded, please load model into nitro"; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(drogon::k409Conflict); - callback(resp); - return; - } + // Check if model is loaded + check_model_loaded(llama, req, callback); const auto &jsonBody = req->getJsonObject(); std::string formatted_output = pre_prompt; @@ -402,15 +409,7 @@ void llamaCPP::chatCompletion( void llamaCPP::embedding( const HttpRequestPtr &req, std::function &&callback) { - if (!llama.model_loaded_external) { - Json::Value jsonResp; - jsonResp["message"] = - "Model has not been loaded, please load model into nitro"; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(drogon::k409Conflict); - callback(resp); - return; - } + check_model_loaded(llama, req, callback); const auto &jsonBody = req->getJsonObject(); @@ -623,4 +622,4 @@ void llamaCPP::stopBackgroundTask() { backgroundThread.join(); } } -} \ No newline at end of file +}