diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 826675868..6a34ac41c 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -10,7 +10,7 @@ using json = nlohmann::json; /** * The state of the inference task */ -enum InferenceStatus { PENDING, RUNNING, FINISHED }; +enum InferenceStatus { PENDING, RUNNING, EOS, FINISHED }; /** * There is a need to save state of current ongoing inference status of a @@ -21,7 +21,7 @@ enum InferenceStatus { PENDING, RUNNING, FINISHED }; */ struct inferenceState { int task_id; - InferenceStatus inferenceStatus = PENDING; + InferenceStatus inference_status = PENDING; llamaCPP *instance; inferenceState(llamaCPP *inst) : instance(inst) {} @@ -111,7 +111,6 @@ std::string create_full_return_json(const std::string &id, std::string create_return_json(const std::string &id, const std::string &model, const std::string &content, Json::Value finish_reason = Json::Value()) { - Json::Value root; root["id"] = id; @@ -167,7 +166,6 @@ void llamaCPP::warmupModel() { void llamaCPP::inference( const HttpRequestPtr &req, std::function &&callback) { - const auto &jsonBody = req->getJsonObject(); // Check if model is loaded if (checkModelLoaded(callback)) { @@ -180,7 +178,6 @@ void llamaCPP::inference( void llamaCPP::inferenceImpl( std::shared_ptr jsonBody, std::function &callback) { - std::string formatted_output = pre_prompt; json data; @@ -218,7 +215,6 @@ void llamaCPP::inferenceImpl( }; if (!llama.multimodal) { - for (const auto &message : messages) { std::string input_role = message["role"].asString(); std::string role; @@ -243,7 +239,6 @@ void llamaCPP::inferenceImpl( } formatted_output += ai_prompt; } else { - data["image_data"] = json::array(); for (const auto &message : messages) { std::string input_role = message["role"].asString(); @@ -327,18 +322,33 @@ void llamaCPP::inferenceImpl( auto state = create_inference_state(this); auto chunked_content_provider = [state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t { - if (state->inferenceStatus == PENDING) { - state->inferenceStatus = RUNNING; - } else if (state->inferenceStatus == FINISHED) { + if (state->inference_status == PENDING) { + state->inference_status = RUNNING; + } else if (state->inference_status == FINISHED) { return 0; } if (!pBuffer) { LOG_INFO << "Connection closed or buffer is null. Reset context"; - state->inferenceStatus = FINISHED; + state->inference_status = FINISHED; return 0; } + if (state->inference_status == EOS) { + LOG_INFO << "End of result"; + const std::string str = + "data: " + + create_return_json(nitro_utils::generate_random_string(20), "_", "", + "stop") + + "\n\n" + "data: [DONE]" + "\n\n"; + + LOG_VERBOSE("data stream", {{"to_send", str}}); + std::size_t nRead = std::min(str.size(), nBuffSize); + memcpy(pBuffer, str.data(), nRead); + state->inference_status = FINISHED; + return nRead; + } + task_result result = state->instance->llama.next_result(state->task_id); if (!result.error) { const std::string to_send = result.result_json["content"]; @@ -352,28 +362,22 @@ void llamaCPP::inferenceImpl( memcpy(pBuffer, str.data(), nRead); if (result.stop) { - const std::string str = - "data: " + - create_return_json(nitro_utils::generate_random_string(20), "_", - "", "stop") + - "\n\n" + "data: [DONE]" + "\n\n"; - - LOG_VERBOSE("data stream", {{"to_send", str}}); - std::size_t nRead = std::min(str.size(), nBuffSize); - memcpy(pBuffer, str.data(), nRead); LOG_INFO << "reached result stop"; - state->inferenceStatus = FINISHED; + state->inference_status = EOS; + return nRead; } // Make sure nBufferSize is not zero // Otherwise it stop streaming if (!nRead) { - state->inferenceStatus = FINISHED; + state->inference_status = FINISHED; } return nRead; + } else { + LOG_INFO << "Error during inference"; } - state->inferenceStatus = FINISHED; + state->inference_status = FINISHED; return 0; }; // Queued task @@ -391,16 +395,17 @@ void llamaCPP::inferenceImpl( // Since this is an async task, we will wait for the task to be // completed - while (state->inferenceStatus != FINISHED && retries < 10) { + while (state->inference_status != FINISHED && retries < 10) { // Should wait chunked_content_provider lambda to be called within // 3s - if (state->inferenceStatus == PENDING) { + if (state->inference_status == PENDING) { retries += 1; } - if (state->inferenceStatus != RUNNING) + if (state->inference_status != RUNNING) LOG_INFO << "Wait for task to be released:" << state->task_id; std::this_thread::sleep_for(std::chrono::milliseconds(100)); } + LOG_INFO << "Task completed, release it"; // Request completed, release it state->instance->llama.request_cancel(state->task_id); }); @@ -445,7 +450,6 @@ void llamaCPP::embedding( void llamaCPP::embeddingImpl( std::shared_ptr jsonBody, std::function &callback) { - // Queue embedding task auto state = create_inference_state(this); @@ -532,7 +536,6 @@ void llamaCPP::modelStatus( void llamaCPP::loadModel( const HttpRequestPtr &req, std::function &&callback) { - if (llama.model_loaded_external) { LOG_INFO << "model loaded"; Json::Value jsonResp; @@ -561,7 +564,6 @@ void llamaCPP::loadModel( } bool llamaCPP::loadModelImpl(std::shared_ptr jsonBody) { - gpt_params params; // By default will setting based on number of handlers if (jsonBody) { @@ -570,11 +572,9 @@ bool llamaCPP::loadModelImpl(std::shared_ptr jsonBody) { params.mmproj = jsonBody->operator[]("mmproj").asString(); } if (!jsonBody->operator[]("grp_attn_n").isNull()) { - params.grp_attn_n = jsonBody->operator[]("grp_attn_n").asInt(); } if (!jsonBody->operator[]("grp_attn_w").isNull()) { - params.grp_attn_w = jsonBody->operator[]("grp_attn_w").asInt(); } if (!jsonBody->operator[]("mlock").isNull()) {