From 712987401d77c0b03259365b63dcfec3fec0b485 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Sat, 13 Apr 2024 15:55:49 +0700 Subject: [PATCH 1/3] fix: race condition in inference between stream and non-stream --- controllers/llamaCPP.cc | 44 +++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 7e616f9b2..4361c5c3a 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -434,28 +434,30 @@ void llamaCPP::InferenceImpl( LOG_INFO_REQUEST(request_id) << "Inference completed"; }); } else { - Json::Value respData; - int task_id = llama.request_completion(data, false, false, -1); - LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone"; - if (!json_value(data, "stream", false)) { - std::string completion_text; - task_result result = llama.next_result(task_id); - if (!result.error && result.stop) { - int prompt_tokens = result.result_json["tokens_evaluated"]; - int predicted_tokens = result.result_json["tokens_predicted"]; - std::string to_send = result.result_json["content"]; - nitro_utils::ltrim(to_send); - respData = create_full_return_json( - nitro_utils::generate_random_string(20), "_", to_send, "_", - prompt_tokens, predicted_tokens); - } else { - respData["message"] = "Internal error during inference"; - LOG_ERROR_REQUEST(request_id) << "Error during inference"; + queue->runTaskInQueue([this, request_id, data, callback]() { + Json::Value respData; + int task_id = llama.request_completion(data, false, false, -1); + LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone"; + if (!json_value(data, "stream", false)) { + std::string completion_text; + task_result result = llama.next_result(task_id); + if (!result.error && result.stop) { + int prompt_tokens = result.result_json["tokens_evaluated"]; + int predicted_tokens = result.result_json["tokens_predicted"]; + std::string to_send = result.result_json["content"]; + nitro_utils::ltrim(to_send); + respData = create_full_return_json( + nitro_utils::generate_random_string(20), "_", to_send, "_", + prompt_tokens, predicted_tokens); + } else { + respData["message"] = "Internal error during inference"; + LOG_ERROR_REQUEST(request_id) << "Error during inference"; + } + auto resp = nitro_utils::nitroHttpJsonResponse(respData); + callback(resp); + LOG_INFO_REQUEST(request_id) << "Inference completed"; } - auto resp = nitro_utils::nitroHttpJsonResponse(respData); - callback(resp); - LOG_INFO_REQUEST(request_id) << "Inference completed"; - } + }); } } From 44ab290586a1145317ea7086a0490c4b35196065 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Sat, 13 Apr 2024 16:23:00 +0700 Subject: [PATCH 2/3] refactor: small improvement --- controllers/llamaCPP.cc | 49 +++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 4361c5c3a..a5fdd943b 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -434,30 +434,31 @@ void llamaCPP::InferenceImpl( LOG_INFO_REQUEST(request_id) << "Inference completed"; }); } else { - queue->runTaskInQueue([this, request_id, data, callback]() { - Json::Value respData; - int task_id = llama.request_completion(data, false, false, -1); - LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone"; - if (!json_value(data, "stream", false)) { - std::string completion_text; - task_result result = llama.next_result(task_id); - if (!result.error && result.stop) { - int prompt_tokens = result.result_json["tokens_evaluated"]; - int predicted_tokens = result.result_json["tokens_predicted"]; - std::string to_send = result.result_json["content"]; - nitro_utils::ltrim(to_send); - respData = create_full_return_json( - nitro_utils::generate_random_string(20), "_", to_send, "_", - prompt_tokens, predicted_tokens); - } else { - respData["message"] = "Internal error during inference"; - LOG_ERROR_REQUEST(request_id) << "Error during inference"; - } - auto resp = nitro_utils::nitroHttpJsonResponse(respData); - callback(resp); - LOG_INFO_REQUEST(request_id) << "Inference completed"; - } - }); + queue->runTaskInQueue( + [this, request_id, cb = std::move(callback), d = std::move(data)]() { + Json::Value respData; + int task_id = llama.request_completion(d, false, false, -1); + LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone"; + if (!json_value(d, "stream", false)) { + std::string completion_text; + task_result result = llama.next_result(task_id); + if (!result.error && result.stop) { + int prompt_tokens = result.result_json["tokens_evaluated"]; + int predicted_tokens = result.result_json["tokens_predicted"]; + std::string to_send = result.result_json["content"]; + nitro_utils::ltrim(to_send); + respData = create_full_return_json( + nitro_utils::generate_random_string(20), "_", to_send, "_", + prompt_tokens, predicted_tokens); + } else { + respData["message"] = "Internal error during inference"; + LOG_ERROR_REQUEST(request_id) << "Error during inference"; + } + auto resp = nitro_utils::nitroHttpJsonResponse(respData); + cb(resp); + LOG_INFO_REQUEST(request_id) << "Inference completed"; + } + }); } } From ede0e06c83eb2e9f8f464a33fc6cca08f70fc1f4 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Sat, 13 Apr 2024 16:37:37 +0700 Subject: [PATCH 3/3] refactor: small things continue --- controllers/llamaCPP.cc | 14 +++++++------- controllers/llamaCPP.h | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index a5fdd943b..008520342 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -49,7 +49,7 @@ std::shared_ptr create_inference_state(llamaCPP* instance) { * @param callback the function to return message to user */ bool llamaCPP::CheckModelLoaded( - std::function& callback) { + const std::function& callback) { if (!llama.model_loaded_external) { LOG_ERROR << "Model has not been loaded"; Json::Value jsonResp; @@ -180,13 +180,13 @@ void llamaCPP::ChatCompletion( if (CheckModelLoaded(callback)) { // Model is loaded // Do Inference - InferenceImpl(std::move(completion), callback); + InferenceImpl(std::move(completion), std::move(callback)); } } void llamaCPP::InferenceImpl( inferences::ChatCompletionRequest&& completion, - std::function& callback) { + std::function&& callback) { std::string formatted_output = pre_prompt; int request_id = ++no_of_requests; LOG_INFO_REQUEST(request_id) << "Generating reponse for inference request"; @@ -404,14 +404,14 @@ void llamaCPP::InferenceImpl( }; // Queued task state->instance->queue->runTaskInQueue( - [callback, state, data, chunked_content_provider, request_id]() { + [cb = std::move(callback), state, data, chunked_content_provider, request_id]() { state->task_id = state->instance->llama.request_completion(data, false, false, -1); // Start streaming response auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, "chat_completions.txt"); - callback(resp); + cb(resp); int retries = 0; @@ -470,14 +470,14 @@ void llamaCPP::Embedding( // Model is loaded const auto& jsonBody = req->getJsonObject(); // Run embedding - EmbeddingImpl(jsonBody, callback); + EmbeddingImpl(jsonBody, std::move(callback)); return; } } void llamaCPP::EmbeddingImpl( std::shared_ptr jsonBody, - std::function& callback) { + std::function&& callback) { int request_id = ++no_of_requests; LOG_INFO_REQUEST(request_id) << "Generating reponse for embedding request"; // Queue embedding task diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index a00d25538..900786c79 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -97,10 +97,10 @@ class llamaCPP : public drogon::HttpController, bool LoadModelImpl(std::shared_ptr jsonBody); void InferenceImpl(inferences::ChatCompletionRequest&& completion, - std::function& callback); + std::function&& callback); void EmbeddingImpl(std::shared_ptr jsonBody, - std::function& callback); - bool CheckModelLoaded(std::function& callback); + std::function&& callback); + bool CheckModelLoaded(const std::function& callback); void WarmupModel(); void BackgroundTask(); void StopBackgroundTask();