diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 3f7eb9dd0..da1cc6554 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -6,16 +6,17 @@ using namespace inferences; using json = nlohmann::json; -struct State { - bool isStopped = false; +struct inferenceState { + bool is_stopped = false; + bool is_streaming = false; int task_id; llamaCPP *instance; - State(int tid, llamaCPP *inst) : task_id(tid), instance(inst) {} + inferenceState(llamaCPP *inst) : instance(inst) {} }; -std::shared_ptr createState(int task_id, llamaCPP *instance) { - return std::make_shared(task_id, instance); +std::shared_ptr create_inference_state(llamaCPP *instance) { + return std::make_shared(instance); } // -------------------------------------------- @@ -295,41 +296,35 @@ void llamaCPP::chatCompletion( #endif int task_id; - if (llama.params.n_parallel == 1) { - while (true) { - if (!single_queue_is_busy) { - task_id = llama.request_completion(data, false, false, -1); - single_queue_is_busy = true; - break; - } else { - std::this_thread::sleep_for( - std::chrono::milliseconds(500)); // Sleep for 500 milliseconds - } - } - } else { - task_id = llama.request_completion(data, false, false, -1); - } - LOG_INFO << "Resolved request for task_id:" << task_id; if (is_streamed) { - auto state = createState(task_id, this); - + auto state = create_inference_state(this); + state->task_id = task_id; auto chunked_content_provider = - [this, state](char *pBuffer, std::size_t nBuffSize) -> std::size_t { + [state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t { + if (!state->is_streaming) { + state->task_id = + state->instance->llama.request_completion(data, false, false, -1); + state->instance->single_queue_is_busy = true; + } if (!pBuffer) { LOG_INFO << "Connection closed or buffer is null. Reset context"; state->instance->llama.request_cancel(state->task_id); - single_queue_is_busy = false; + state->is_streaming = false; + state->instance->single_queue_is_busy = false; return 0; } - if (state->isStopped) { - single_queue_is_busy = false; + if (state->is_stopped) { + state->is_streaming = false; + state->instance->single_queue_is_busy = false; return 0; } task_result result = state->instance->llama.next_result(state->task_id); if (!result.error) { + // Update streaming state to being streamed + state->is_streaming = true; const std::string to_send = result.result_json["content"]; const std::string str = "data: " + @@ -351,16 +346,30 @@ void llamaCPP::chatCompletion( std::size_t nRead = std::min(str.size(), nBuffSize); memcpy(pBuffer, str.data(), nRead); LOG_INFO << "reached result stop"; - state->isStopped = true; + state->is_stopped = true; state->instance->llama.request_cancel(state->task_id); + state->is_streaming = false; + state->instance->single_queue_is_busy = false; + return nRead; } return nRead; } else { - single_queue_is_busy = false; - return 0; + if (state->instance->llama.params.n_parallel == 1) { + while (state->instance->single_queue_is_busy) { + LOG_INFO << "Waiting for task to be released status:" + << state->instance->single_queue_is_busy; + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // Waiting in 500 miliseconds step + } + } + std::string str = "\n\n"; + std::size_t nRead = str.size(); + memcpy(pBuffer, str.data(), nRead); + LOG_INFO << "Failing retrying now"; + return nRead; } - single_queue_is_busy = false; + state->is_streaming = false; + state->instance->single_queue_is_busy = false; return 0; }; auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,