From 701271d89ad44cd70697b275f686b2661ee895d0 Mon Sep 17 00:00:00 2001 From: tikikun Date: Thu, 18 Jan 2024 13:53:40 +0700 Subject: [PATCH 1/3] remove redundant temporary impl --- controllers/llamaCPP.cc | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 3f7eb9dd0..c0f564e12 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -6,16 +6,17 @@ using namespace inferences; using json = nlohmann::json; -struct State { +struct inferenceState { bool isStopped = false; int task_id; llamaCPP *instance; - State(int tid, llamaCPP *inst) : task_id(tid), instance(inst) {} + inferenceState(int tid, llamaCPP *inst) : task_id(tid), instance(inst) {} }; -std::shared_ptr createState(int task_id, llamaCPP *instance) { - return std::make_shared(task_id, instance); +std::shared_ptr create_inference_state(int task_id, + llamaCPP *instance) { + return std::make_shared(task_id, instance); } // -------------------------------------------- @@ -295,36 +296,21 @@ void llamaCPP::chatCompletion( #endif int task_id; - if (llama.params.n_parallel == 1) { - while (true) { - if (!single_queue_is_busy) { - task_id = llama.request_completion(data, false, false, -1); - single_queue_is_busy = true; - break; - } else { - std::this_thread::sleep_for( - std::chrono::milliseconds(500)); // Sleep for 500 milliseconds - } - } - } else { task_id = llama.request_completion(data, false, false, -1); - } LOG_INFO << "Resolved request for task_id:" << task_id; if (is_streamed) { - auto state = createState(task_id, this); + auto state = create_inference_state(task_id, this); auto chunked_content_provider = [this, state](char *pBuffer, std::size_t nBuffSize) -> std::size_t { if (!pBuffer) { LOG_INFO << "Connection closed or buffer is null. Reset context"; state->instance->llama.request_cancel(state->task_id); - single_queue_is_busy = false; return 0; } if (state->isStopped) { - single_queue_is_busy = false; return 0; } @@ -357,10 +343,8 @@ void llamaCPP::chatCompletion( } return nRead; } else { - single_queue_is_busy = false; return 0; } - single_queue_is_busy = false; return 0; }; auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, From 46ee74dd3a9d92e008aa2f582ed1bbd1e16c2187 Mon Sep 17 00:00:00 2001 From: tikikun Date: Thu, 18 Jan 2024 17:12:38 +0700 Subject: [PATCH 2/3] feat: upgrade waiting logic in the case of single queue --- controllers/llamaCPP.cc | 54 +++++++++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index c0f564e12..ee852669f 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -2,21 +2,24 @@ #include "llama.h" #include "log.h" #include "utils/nitro_utils.h" +#include +#include +#include using namespace inferences; using json = nlohmann::json; struct inferenceState { - bool isStopped = false; + bool is_stopped = false; + bool is_streaming = false; int task_id; llamaCPP *instance; - inferenceState(int tid, llamaCPP *inst) : task_id(tid), instance(inst) {} + inferenceState(llamaCPP *inst) : instance(inst) {} }; -std::shared_ptr create_inference_state(int task_id, - llamaCPP *instance) { - return std::make_shared(task_id, instance); +std::shared_ptr create_inference_state(llamaCPP *instance) { + return std::make_shared(instance); } // -------------------------------------------- @@ -296,26 +299,35 @@ void llamaCPP::chatCompletion( #endif int task_id; - task_id = llama.request_completion(data, false, false, -1); - LOG_INFO << "Resolved request for task_id:" << task_id; if (is_streamed) { - auto state = create_inference_state(task_id, this); - + auto state = create_inference_state(this); + state->task_id = task_id; auto chunked_content_provider = - [this, state](char *pBuffer, std::size_t nBuffSize) -> std::size_t { + [state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t { + if (!state->is_streaming) { + state->task_id = + state->instance->llama.request_completion(data, false, false, -1); + state->instance->single_queue_is_busy = true; + } if (!pBuffer) { LOG_INFO << "Connection closed or buffer is null. Reset context"; state->instance->llama.request_cancel(state->task_id); + state->is_streaming = false; + state->instance->single_queue_is_busy = false; return 0; } - if (state->isStopped) { + if (state->is_stopped) { + state->is_streaming = false; + state->instance->single_queue_is_busy = false; return 0; } task_result result = state->instance->llama.next_result(state->task_id); if (!result.error) { + // Update streaming state to being streamed + state->is_streaming = true; const std::string to_send = result.result_json["content"]; const std::string str = "data: " + @@ -337,14 +349,30 @@ void llamaCPP::chatCompletion( std::size_t nRead = std::min(str.size(), nBuffSize); memcpy(pBuffer, str.data(), nRead); LOG_INFO << "reached result stop"; - state->isStopped = true; + state->is_stopped = true; state->instance->llama.request_cancel(state->task_id); + state->is_streaming = false; + state->instance->single_queue_is_busy = false; + return nRead; } return nRead; } else { - return 0; + if (state->instance->llama.params.n_parallel == 1) { + while (state->instance->single_queue_is_busy) { + LOG_INFO << "Waiting for task to be released status:" + << state->instance->single_queue_is_busy; + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // Waiting in 500 miliseconds step + } + } + std::string str = "\n\n"; + std::size_t nRead = str.size(); + memcpy(pBuffer, str.data(), nRead); + LOG_INFO << "Failing retrying now"; + return nRead; } + state->is_streaming = false; + state->instance->single_queue_is_busy = false; return 0; }; auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, From ebdf4206c775597659544f8672759164ca8cbb55 Mon Sep 17 00:00:00 2001 From: tikikun Date: Thu, 18 Jan 2024 17:37:48 +0700 Subject: [PATCH 3/3] remove redundant temporary impl --- controllers/llamaCPP.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index ee852669f..da1cc6554 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -2,9 +2,6 @@ #include "llama.h" #include "log.h" #include "utils/nitro_utils.h" -#include -#include -#include using namespace inferences; using json = nlohmann::json;