diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 8e2f2f22c..5e9c1a9b4 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -1,4 +1,7 @@ #include "llamaCPP.h" + +#include + #include "llama.h" #include "log.h" #include "utils/nitro_utils.h" @@ -6,6 +9,21 @@ using namespace inferences; using json = nlohmann::json; +/** + * Queue to handle the inference task, this is to ensure that the inference + * task is handled in a sequential manner + */ +static trantor::SerialTaskQueue queue("worker"); + +/** + * The state of the inference task + */ +enum InferenceStatus { + PENDING, + RUNNING, + FINISHED +}; + /** * There is a need to save state of current ongoing inference status of a * handler, this struct is to solve that issue @@ -15,8 +33,8 @@ using json = nlohmann::json; */ struct inferenceState { bool is_stopped = false; - bool is_streaming = false; int task_id; + InferenceStatus inferenceStatus = PENDING; llamaCPP *instance; inferenceState(llamaCPP *inst) : instance(inst) {} @@ -35,7 +53,7 @@ std::shared_ptr create_inference_state(llamaCPP *instance) { * Check if model already loaded if not return message to user * @param callback the function to return message to user */ -void llamaCPP::checkModelLoaded( +bool llamaCPP::checkModelLoaded( std::function &callback) { if (!llama.model_loaded_external) { Json::Value jsonResp; @@ -44,8 +62,9 @@ void llamaCPP::checkModelLoaded( auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); resp->setStatusCode(drogon::k409Conflict); callback(resp); - return; + return false; } + return true; } Json::Value create_embedding_payload(const std::vector &embedding, @@ -70,7 +89,6 @@ std::string create_full_return_json(const std::string &id, const std::string &system_fingerprint, int prompt_tokens, int completion_tokens, Json::Value finish_reason = Json::Value()) { - Json::Value root; root["id"] = id; @@ -163,9 +181,11 @@ void llamaCPP::inference( const auto &jsonBody = req->getJsonObject(); // Check if model is loaded - checkModelLoaded(callback); - - inferenceImpl(jsonBody, callback); + if(checkModelLoaded(callback)) { + // Model is loaded + // Do Inference + inferenceImpl(jsonBody, callback); + } } void llamaCPP::inferenceImpl( @@ -318,28 +338,24 @@ void llamaCPP::inferenceImpl( auto state = create_inference_state(this); auto chunked_content_provider = [state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t { - if (!state->is_streaming) { - state->task_id = - state->instance->llama.request_completion(data, false, false, -1); - state->instance->single_queue_is_busy = true; + + if(state->inferenceStatus == PENDING) { + state->inferenceStatus = RUNNING; } + if (!pBuffer) { LOG_INFO << "Connection closed or buffer is null. Reset context"; state->instance->llama.request_cancel(state->task_id); - state->is_streaming = false; state->instance->single_queue_is_busy = false; return 0; } if (state->is_stopped) { - state->is_streaming = false; state->instance->single_queue_is_busy = false; return 0; } task_result result = state->instance->llama.next_result(state->task_id); if (!result.error) { - // Update streaming state to being streamed - state->is_streaming = true; const std::string to_send = result.result_json["content"]; const std::string str = "data: " + @@ -363,35 +379,48 @@ void llamaCPP::inferenceImpl( LOG_INFO << "reached result stop"; state->is_stopped = true; state->instance->llama.request_cancel(state->task_id); - state->is_streaming = false; state->instance->single_queue_is_busy = false; - - return nRead; } - return nRead; - } else { - if (state->instance->llama.params.n_parallel == 1) { - while (state->instance->single_queue_is_busy) { - LOG_INFO << "Waiting for task to be released status:" - << state->instance->single_queue_is_busy; - std::this_thread::sleep_for(std::chrono::milliseconds( - 500)); // Waiting in 500 miliseconds step - } + + // Make sure nBufferSize is not zero + // Otherwise it stop streaming + if(!nRead) { + state->instance->single_queue_is_busy = false; } - std::string str = "\n\n"; - std::size_t nRead = str.size(); - memcpy(pBuffer, str.data(), nRead); - LOG_INFO << "Failing retrying now"; + return nRead; } - state->is_streaming = false; state->instance->single_queue_is_busy = false; return 0; }; - auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, - "chat_completions.txt"); - callback(resp); + + // Run task in serial queue + queue.runTaskInQueue([callback, state, data, + chunked_content_provider]() { + state->task_id = + state->instance->llama.request_completion(data, false, false, -1); + + state->instance->single_queue_is_busy = true; + + // Start streaming response + auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, + "chat_completions.txt"); + callback(resp); + int retries = 0; + + // Since this is an async task, we will wait for the task to be completed + while (state->instance->single_queue_is_busy && retries < 10) { + // Should wait chunked_content_provider lambda to be called within 3s + if(state->inferenceStatus == PENDING) { + retries += 1; + } + LOG_INFO << "Wait for task to be released:" << state->task_id; + std::this_thread::sleep_for(std::chrono::milliseconds(300)); + } + + state->inferenceStatus = FINISHED; + }); return; } else { Json::Value respData; @@ -423,11 +452,14 @@ void llamaCPP::inferenceImpl( void llamaCPP::embedding( const HttpRequestPtr &req, std::function &&callback) { - checkModelLoaded(callback); - const auto &jsonBody = req->getJsonObject(); - - embeddingImpl(jsonBody, callback); - return; + // Check if model is loaded + if(checkModelLoaded(callback)) { + // Model is loaded + const auto &jsonBody = req->getJsonObject(); + // Run embedding + embeddingImpl(jsonBody, callback); + return; + } } void llamaCPP::embeddingImpl( diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index 82704a613..5f2be54b3 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -2571,7 +2571,7 @@ class llamaCPP : public drogon::HttpController, public ChatProvider { std::function &callback); void embeddingImpl(std::shared_ptr jsonBody, std::function &callback); - void checkModelLoaded(std::function &callback); + bool checkModelLoaded(std::function &callback); void warmupModel(); void backgroundTask(); void stopBackgroundTask();