diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 021c79faa..c2872ad77 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -542,33 +542,42 @@ void llamaCPP::ModelStatus( } void llamaCPP::LoadModel( - const HttpRequestPtr& req, - std::function&& callback) { - if (llama.model_loaded_external) { - LOG_INFO << "model loaded"; - Json::Value jsonResp; - jsonResp["message"] = "Model already loaded"; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(drogon::k409Conflict); - callback(resp); + const HttpRequestPtr &req, + std::function &&callback) { + + if(llama.model_loaded_external.load(std::memory_order_acquire)){ + ModelLoadedResponse(callback); return; } - + + bool modelLoadedSuccess; const auto& jsonBody = req->getJsonObject(); - if (!LoadModelImpl(jsonBody)) { - // Error occurred during model loading - Json::Value jsonResp; - jsonResp["message"] = "Failed to load model"; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - resp->setStatusCode(drogon::k500InternalServerError); - callback(resp); - } else { - // Model loaded successfully - Json::Value jsonResp; - jsonResp["message"] = "Model loaded successfully"; - auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); - callback(resp); + + { + std::lock_guard lck{load_model_mutex}; + if (llama.model_loaded_external.load(std::memory_order_relaxed)) { + ModelLoadedResponse(callback); + return; + } + modelLoadedSuccess = LoadModelImpl(jsonBody); + llama.model_loaded_external.store(modelLoadedSuccess, std::memory_order_release); } + if (modelLoadedSuccess) { + // Model loaded successfully + Json::Value jsonResp; + jsonResp["message"] = "Model loaded successfully"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + callback(resp); + return; + } + + // Error occurred during model loading + Json::Value jsonResp; + jsonResp["message"] = "Failed to load model"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(drogon::k500InternalServerError); + callback(resp); + } bool llamaCPP::LoadModelImpl(std::shared_ptr jsonBody) { @@ -662,7 +671,6 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr jsonBody) { queue = new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP"); - llama.model_loaded_external = true; LOG_INFO << "Started background task here!"; backgroundThread = std::thread(&llamaCPP::BackgroundTask, this); @@ -691,3 +699,12 @@ void llamaCPP::StopBackgroundTask() { } } } +void llamaCPP::ModelLoadedResponse( + std::function callback) { + LOG_INFO << "model loaded"; + Json::Value jsonResp; + jsonResp["message"] = "Model already loaded"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(drogon::k409Conflict); + callback(resp); +} diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index 75e597658..8f5e7a8dd 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -88,10 +88,12 @@ class llamaCPP : public drogon::HttpController, std::atomic no_of_chats = 0; int clean_cache_threshold; std::string grammar_file_content; + std::mutex load_model_mutex; /** * Queue to handle the inference tasks */ + trantor::ConcurrentTaskQueue* queue; bool LoadModelImpl(std::shared_ptr jsonBody); @@ -103,5 +105,7 @@ class llamaCPP : public drogon::HttpController, void WarmupModel(); void BackgroundTask(); void StopBackgroundTask(); + void ModelLoadedResponse( + std::function function); }; }; // namespace inferences