From a8f17353b0ff5cdc98ac6f001523a4e5cefbb735 Mon Sep 17 00:00:00 2001 From: tikikun Date: Fri, 2 Feb 2024 09:02:03 +0700 Subject: [PATCH] feat: add queue for embedding --- controllers/llamaCPP.cc | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 474e675c6..416d22cd7 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -407,17 +407,31 @@ void llamaCPP::embedding( std::function &&callback) { check_model_loaded(llama, req, callback); + auto state = create_inference_state(this); + const auto &jsonBody = req->getJsonObject(); Json::Value responseData(Json::arrayValue); if (jsonBody->isMember("input")) { + // If single queue is busy, we will wait if not we will just go ahead and + // process and make it busy, and yet i'm aware not DRY, i have the same + // stuff on chatcompletion as well + if (state->instance->llama.params.n_parallel == 1) { + while (state->instance->single_queue_is_busy) { + LOG_INFO << "Waiting for task to be released status:" + << state->instance->single_queue_is_busy; + std::this_thread::sleep_for( + std::chrono::milliseconds(500)); // Waiting in 500 miliseconds step + } + } const Json::Value &input = (*jsonBody)["input"]; if (input.isString()) { // Process the single string input - const int task_id = llama.request_completion( + state->task_id = llama.request_completion( {{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1); - task_result result = llama.next_result(task_id); + state->instance->single_queue_is_busy = true; + task_result result = llama.next_result(state->task_id); std::vector embedding_result = result.result_json["embedding"]; responseData.append(create_embedding_payload(embedding_result, 0)); } else if (input.isArray()) { @@ -434,6 +448,9 @@ void llamaCPP::embedding( } } + // We already got result of the embedding so no longer busy + state->instance->single_queue_is_busy = false; + auto resp = nitro_utils::nitroHttpResponse(); Json::Value root; root["data"] = responseData;