diff --git a/backend/cpp/llama-cpp/Makefile b/backend/cpp/llama-cpp/Makefile index 091c3386dd6a..c93bd61f19cb 100644 --- a/backend/cpp/llama-cpp/Makefile +++ b/backend/cpp/llama-cpp/Makefile @@ -1,5 +1,5 @@ -LLAMA_VERSION?=7d019cff744b73084b15ca81ba9916f3efab1223 +LLAMA_VERSION?=c4abcb2457217198efdd67d02675f5fddb7071c2 LLAMA_REPO?=https://github.com/ggerganov/llama.cpp CMAKE_ARGS?= diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index a71f43aec6c8..72ed1b09b568 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -579,7 +579,8 @@ class BackendServiceImpl final : public backend::Backend::Service { auto completion_id = gen_chatcmplid(); - std::unordered_set task_ids; + // need to store the reader as a pointer, so that it won't be destroyed when the handle returns + const auto rd = std::make_shared(ctx_server); try { std::vector tasks; @@ -871,18 +872,77 @@ class BackendServiceImpl final : public backend::Backend::Service { tasks.push_back(std::move(task)); } - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); + rd->post_tasks(std::move(tasks)); } catch (const std::exception & e) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what()); } - ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { + // Get first result for error checking (following server.cpp pattern) + server_task_result_ptr first_result = rd->next([&context]() { return context->IsCancelled(); }); + if (first_result == nullptr) { + // connection is closed + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } else if (first_result->is_error()) { + json error_json = first_result->to_json(); + backend::Reply reply; + reply.set_message(error_json.value("message", "")); + writer->Write(reply); + return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred")); + } + + // Process first result + json first_res_json = first_result->to_json(); + if (first_res_json.is_array()) { + for (const auto & res : first_res_json) { + std::string completion_text = res.value("content", ""); + + backend::Reply reply; + reply.set_message(completion_text); + int32_t tokens_predicted = res.value("tokens_predicted", 0); + reply.set_tokens(tokens_predicted); + int32_t tokens_evaluated = res.value("tokens_evaluated", 0); + reply.set_prompt_tokens(tokens_evaluated); + + if (res.contains("timings")) { + double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0); + reply.set_timing_prompt_processing(timing_prompt_processing); + double timing_token_generation = res.at("timings").value("predicted_ms", 0.0); + reply.set_timing_token_generation(timing_token_generation); + } + + writer->Write(reply); + } + } else { + std::string completion_text = first_res_json.value("content", ""); + + backend::Reply reply; + reply.set_message(completion_text); + int32_t tokens_predicted = first_res_json.value("tokens_predicted", 0); + reply.set_tokens(tokens_predicted); + int32_t tokens_evaluated = first_res_json.value("tokens_evaluated", 0); + reply.set_prompt_tokens(tokens_evaluated); + + if (first_res_json.contains("timings")) { + double timing_prompt_processing = first_res_json.at("timings").value("prompt_ms", 0.0); + reply.set_timing_prompt_processing(timing_prompt_processing); + double timing_token_generation = first_res_json.at("timings").value("predicted_ms", 0.0); + reply.set_timing_token_generation(timing_token_generation); + } + + writer->Write(reply); + } + + // Process subsequent results + while (rd->has_next()) { // Check if context is cancelled before processing result if (context->IsCancelled()) { - ctx_server.cancel_tasks(task_ids); - return false; + break; + } + + auto result = rd->next([&context]() { return context->IsCancelled(); }); + if (result == nullptr) { + // connection is closed + break; } json res_json = result->to_json(); @@ -904,9 +964,6 @@ class BackendServiceImpl final : public backend::Backend::Service { reply.set_timing_token_generation(timing_token_generation); } - // Log Request Correlation Id - - // Send the reply writer->Write(reply); } } else { @@ -926,24 +983,9 @@ class BackendServiceImpl final : public backend::Backend::Service { reply.set_timing_token_generation(timing_token_generation); } - - - // Send the reply - writer->Write(reply); - + writer->Write(reply); } - return true; - }, [&](const json & error_data) { - backend::Reply reply; - reply.set_message(error_data.value("content", "")); - writer->Write(reply); - return true; - }, [&context]() { - // Check if the gRPC context is cancelled - return context->IsCancelled(); - }); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } // Check if context was cancelled during processing if (context->IsCancelled()) { @@ -963,7 +1005,7 @@ class BackendServiceImpl final : public backend::Backend::Service { } std::cout << "[PREDICT] Received result: " << data.dump(2) << std::endl; auto completion_id = gen_chatcmplid(); - std::unordered_set task_ids; + const auto rd = std::make_shared(ctx_server); try { std::vector tasks; @@ -1261,9 +1303,7 @@ class BackendServiceImpl final : public backend::Backend::Service { tasks.push_back(std::move(task)); } - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); + rd->post_tasks(std::move(tasks)); } catch (const std::exception & e) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what()); } @@ -1271,51 +1311,45 @@ class BackendServiceImpl final : public backend::Backend::Service { std::cout << "[DEBUG] Waiting for results..." << std::endl; - // Check cancellation before waiting for results - if (context->IsCancelled()) { - ctx_server.cancel_tasks(task_ids); - ctx_server.queue_results.remove_waiting_task_ids(task_ids); + // Wait for all results + auto all_results = rd->wait_for_all([&context]() { return context->IsCancelled(); }); + + if (all_results.is_terminated) { return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); - } - - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - std::cout << "[DEBUG] Received " << results.size() << " results" << std::endl; - if (results.size() == 1) { + } else if (all_results.error) { + std::cout << "[DEBUG] Error in results: " << all_results.error->to_json().value("message", "") << std::endl; + reply->set_message(all_results.error->to_json().value("message", "")); + return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error occurred")); + } else { + std::cout << "[DEBUG] Received " << all_results.results.size() << " results" << std::endl; + if (all_results.results.size() == 1) { // single result - reply->set_message(results[0]->to_json().value("content", "")); + GGML_ASSERT(dynamic_cast(all_results.results[0].get()) != nullptr); + reply->set_message(all_results.results[0]->to_json().value("content", "")); - int32_t tokens_predicted = results[0]->to_json().value("tokens_predicted", 0); + int32_t tokens_predicted = all_results.results[0]->to_json().value("tokens_predicted", 0); reply->set_tokens(tokens_predicted); - int32_t tokens_evaluated = results[0]->to_json().value("tokens_evaluated", 0); + int32_t tokens_evaluated = all_results.results[0]->to_json().value("tokens_evaluated", 0); reply->set_prompt_tokens(tokens_evaluated); - if (results[0]->to_json().contains("timings")) { - double timing_prompt_processing = results[0]->to_json().at("timings").value("prompt_ms", 0.0); + if (all_results.results[0]->to_json().contains("timings")) { + double timing_prompt_processing = all_results.results[0]->to_json().at("timings").value("prompt_ms", 0.0); reply->set_timing_prompt_processing(timing_prompt_processing); - double timing_token_generation = results[0]->to_json().at("timings").value("predicted_ms", 0.0); + double timing_token_generation = all_results.results[0]->to_json().at("timings").value("predicted_ms", 0.0); reply->set_timing_token_generation(timing_token_generation); } } else { // multiple results (multitask) json arr = json::array(); - for (auto & res : results) { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); arr.push_back(res->to_json().value("content", "")); } reply->set_message(arr); } - - - }, [&](const json & error_data) { - std::cout << "[DEBUG] Error in results: " << error_data.value("content", "") << std::endl; - reply->set_message(error_data.value("content", "")); - }, [&context]() { - // Check if the gRPC context is cancelled - // This is checked every HTTP_POLLING_SECONDS (1 second) during receive_multi_results - return context->IsCancelled(); - }); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } + std::cout << "[DEBUG] Predict request completed successfully" << std::endl; // Check if context was cancelled during processing @@ -1352,9 +1386,7 @@ class BackendServiceImpl final : public backend::Backend::Service { int embd_normalize = 2; // default to Euclidean/L2 norm // create and queue the task - json responses = json::array(); - bool error = false; - std::unordered_set task_ids; + const auto rd = std::make_shared(ctx_server); { std::vector tasks; for (size_t i = 0; i < tokenized_prompts.size(); i++) { @@ -1369,40 +1401,23 @@ class BackendServiceImpl final : public backend::Backend::Service { tasks.push_back(std::move(task)); } - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); + rd->post_tasks(std::move(tasks)); } - // Check cancellation before waiting for results - if (context->IsCancelled()) { - ctx_server.cancel_tasks(task_ids); - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); - } - - // get the result - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - for (auto & res : results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - }, [&](const json & error_data) { - error = true; - }, [&context]() { - // Check if the gRPC context is cancelled - return context->IsCancelled(); - }); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - - // Check if context was cancelled during processing - if (context->IsCancelled()) { + // Wait for all results + auto all_results = rd->wait_for_all([&context]() { return context->IsCancelled(); }); + + if (all_results.is_terminated) { return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } else if (all_results.error) { + return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error in receiving results")); } - if (error) { - return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); + // Collect responses + json responses = json::array(); + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); } std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl; @@ -1453,9 +1468,7 @@ class BackendServiceImpl final : public backend::Backend::Service { } // Create and queue the task - json responses = json::array(); - bool error = false; - std::unordered_set task_ids; + const auto rd = std::make_shared(ctx_server); { std::vector tasks; std::vector documents; @@ -1473,40 +1486,23 @@ class BackendServiceImpl final : public backend::Backend::Service { tasks.push_back(std::move(task)); } - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); - } - - // Check cancellation before waiting for results - if (context->IsCancelled()) { - ctx_server.cancel_tasks(task_ids); - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + rd->post_tasks(std::move(tasks)); } - // Get the results - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - for (auto & res : results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - }, [&](const json & error_data) { - error = true; - }, [&context]() { - // Check if the gRPC context is cancelled - return context->IsCancelled(); - }); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - - // Check if context was cancelled during processing - if (context->IsCancelled()) { + // Wait for all results + auto all_results = rd->wait_for_all([&context]() { return context->IsCancelled(); }); + + if (all_results.is_terminated) { return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } else if (all_results.error) { + return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error in receiving results")); } - if (error) { - return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); + // Collect responses + json responses = json::array(); + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); } // Sort responses by score in descending order std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) { diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 975aac3e872a..6385997bf4fb 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -591,7 +591,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator // NOTE: this is a workaround as fasthttp // context cancellation does not fire in non-streaming requests - handleConnectionCancellation(c, input.Cancel, input.Context) + // handleConnectionCancellation(c, input.Cancel, input.Context) result, tokenUsage, err := ComputeChoices( input, diff --git a/core/http/endpoints/openai/mcp.go b/core/http/endpoints/openai/mcp.go index fe018bbbd09c..efb3c6d29096 100644 --- a/core/http/endpoints/openai/mcp.go +++ b/core/http/endpoints/openai/mcp.go @@ -80,7 +80,7 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ctxWithCancellation, cancel := context.WithCancel(ctx) defer cancel() - handleConnectionCancellation(c, cancel, ctxWithCancellation) + //handleConnectionCancellation(c, cancel, ctxWithCancellation) // TODO: instead of connecting to the API, we should just wire this internally // and act like completion.go. // We can do this as cogito expects an interface and we can create one that