diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 7e616f9b2..008520342 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -49,7 +49,7 @@ std::shared_ptr create_inference_state(llamaCPP* instance) { * @param callback the function to return message to user */ bool llamaCPP::CheckModelLoaded( - std::function& callback) { + const std::function& callback) { if (!llama.model_loaded_external) { LOG_ERROR << "Model has not been loaded"; Json::Value jsonResp; @@ -180,13 +180,13 @@ void llamaCPP::ChatCompletion( if (CheckModelLoaded(callback)) { // Model is loaded // Do Inference - InferenceImpl(std::move(completion), callback); + InferenceImpl(std::move(completion), std::move(callback)); } } void llamaCPP::InferenceImpl( inferences::ChatCompletionRequest&& completion, - std::function& callback) { + std::function&& callback) { std::string formatted_output = pre_prompt; int request_id = ++no_of_requests; LOG_INFO_REQUEST(request_id) << "Generating reponse for inference request"; @@ -404,14 +404,14 @@ void llamaCPP::InferenceImpl( }; // Queued task state->instance->queue->runTaskInQueue( - [callback, state, data, chunked_content_provider, request_id]() { + [cb = std::move(callback), state, data, chunked_content_provider, request_id]() { state->task_id = state->instance->llama.request_completion(data, false, false, -1); // Start streaming response auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, "chat_completions.txt"); - callback(resp); + cb(resp); int retries = 0; @@ -434,28 +434,31 @@ void llamaCPP::InferenceImpl( LOG_INFO_REQUEST(request_id) << "Inference completed"; }); } else { - Json::Value respData; - int task_id = llama.request_completion(data, false, false, -1); - LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone"; - if (!json_value(data, "stream", false)) { - std::string completion_text; - task_result result = llama.next_result(task_id); - if (!result.error && result.stop) { - int prompt_tokens = result.result_json["tokens_evaluated"]; - int predicted_tokens = result.result_json["tokens_predicted"]; - std::string to_send = result.result_json["content"]; - nitro_utils::ltrim(to_send); - respData = create_full_return_json( - nitro_utils::generate_random_string(20), "_", to_send, "_", - prompt_tokens, predicted_tokens); - } else { - respData["message"] = "Internal error during inference"; - LOG_ERROR_REQUEST(request_id) << "Error during inference"; - } - auto resp = nitro_utils::nitroHttpJsonResponse(respData); - callback(resp); - LOG_INFO_REQUEST(request_id) << "Inference completed"; - } + queue->runTaskInQueue( + [this, request_id, cb = std::move(callback), d = std::move(data)]() { + Json::Value respData; + int task_id = llama.request_completion(d, false, false, -1); + LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone"; + if (!json_value(d, "stream", false)) { + std::string completion_text; + task_result result = llama.next_result(task_id); + if (!result.error && result.stop) { + int prompt_tokens = result.result_json["tokens_evaluated"]; + int predicted_tokens = result.result_json["tokens_predicted"]; + std::string to_send = result.result_json["content"]; + nitro_utils::ltrim(to_send); + respData = create_full_return_json( + nitro_utils::generate_random_string(20), "_", to_send, "_", + prompt_tokens, predicted_tokens); + } else { + respData["message"] = "Internal error during inference"; + LOG_ERROR_REQUEST(request_id) << "Error during inference"; + } + auto resp = nitro_utils::nitroHttpJsonResponse(respData); + cb(resp); + LOG_INFO_REQUEST(request_id) << "Inference completed"; + } + }); } } @@ -467,14 +470,14 @@ void llamaCPP::Embedding( // Model is loaded const auto& jsonBody = req->getJsonObject(); // Run embedding - EmbeddingImpl(jsonBody, callback); + EmbeddingImpl(jsonBody, std::move(callback)); return; } } void llamaCPP::EmbeddingImpl( std::shared_ptr jsonBody, - std::function& callback) { + std::function&& callback) { int request_id = ++no_of_requests; LOG_INFO_REQUEST(request_id) << "Generating reponse for embedding request"; // Queue embedding task diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index a00d25538..900786c79 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -97,10 +97,10 @@ class llamaCPP : public drogon::HttpController, bool LoadModelImpl(std::shared_ptr jsonBody); void InferenceImpl(inferences::ChatCompletionRequest&& completion, - std::function& callback); + std::function&& callback); void EmbeddingImpl(std::shared_ptr jsonBody, - std::function& callback); - bool CheckModelLoaded(std::function& callback); + std::function&& callback); + bool CheckModelLoaded(const std::function& callback); void WarmupModel(); void BackgroundTask(); void StopBackgroundTask();