diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 961798d2c..83eaddb4e 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -129,6 +129,9 @@ void server::FineTuning( void server::Inference(const HttpRequestPtr& req, std::function&& callback) { + + auto json_body = req->getJsonObject(); + LOG_TRACE << "Start inference"; auto q = std::make_shared(); auto ir = inference_svc_->HandleInference(q, req->getJsonObject()); @@ -141,20 +144,34 @@ void server::Inference(const HttpRequestPtr& req, callback(resp); return; } + + bool is_stream = + (*json_body).get("stream", false).asBool() || + (*json_body).get("body", Json::Value()).get("stream", false).asBool(); + LOG_TRACE << "Wait to inference"; - auto [status, res] = q->wait_and_pop(); - LOG_DEBUG << "response: " << res.toStyledString(); - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode( - static_cast(status["status_code"].asInt())); - callback(resp); - LOG_TRACE << "Done inference"; + if (is_stream) { + auto model_id = (*json_body).get("model", "invalid_model").asString(); + auto engine_type = [this, &json_body]() -> std::string { + if (!inference_svc_->HasFieldInReq(json_body, "engine")) { + return kLlamaRepo; + } else { + return (*(json_body)).get("engine", kLlamaRepo).asString(); + } + }(); + ProcessStreamRes(callback, q, engine_type, model_id); + } else { + ProcessNonStreamRes(callback, *q); + LOG_TRACE << "Done inference"; + } } void server::RouteRequest( const HttpRequestPtr& req, std::function&& callback) { + auto json_body = req->getJsonObject(); + LOG_TRACE << "Start route request"; auto q = std::make_shared(); auto ir = inference_svc_->HandleRouteRequest(q, req->getJsonObject()); @@ -167,14 +184,26 @@ void server::RouteRequest( callback(resp); return; } + auto is_stream = + (*json_body).get("stream", false).asBool() || + (*json_body).get("body", Json::Value()).get("stream", false).asBool(); LOG_TRACE << "Wait to route request"; - auto [status, res] = q->wait_and_pop(); - LOG_DEBUG << "response: " << res.toStyledString(); - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode( - static_cast(status["status_code"].asInt())); - callback(resp); - LOG_TRACE << "Done route request"; + if (is_stream) { + + auto model_id = (*json_body).get("model", "invalid_model").asString(); + auto engine_type = [this, &json_body]() -> std::string { + if (!inference_svc_->HasFieldInReq(json_body, "engine")) { + return kLlamaRepo; + } else { + return (*(json_body)).get("engine", kLlamaRepo).asString(); + } + }(); + ProcessStreamRes(callback, q, engine_type, model_id); + } else { + ProcessNonStreamRes(callback, *q); + LOG_TRACE << "Done route request"; + } + } void server::LoadModel(const HttpRequestPtr& req, diff --git a/engine/extensions/python-engine/python_engine.cc b/engine/extensions/python-engine/python_engine.cc index ddf6784e8..9be369bcf 100644 --- a/engine/extensions/python-engine/python_engine.cc +++ b/engine/extensions/python-engine/python_engine.cc @@ -16,7 +16,8 @@ static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, return size * nmemb; } -PythonEngine::PythonEngine() {} +PythonEngine::PythonEngine() : q_(4 /*n_parallel*/, "python_engine") {} + PythonEngine::~PythonEngine() { curl_global_cleanup(); @@ -169,7 +170,7 @@ bool PythonEngine::TerminateModelProcess(const std::string& model) { } CurlResponse PythonEngine::MakeGetRequest(const std::string& model, const std::string& path) { - auto config = models_[model]; + auto const& config = models_[model]; std::string full_url = "http://localhost:" + config.port + path; CurlResponse response; @@ -184,7 +185,7 @@ CurlResponse PythonEngine::MakeGetRequest(const std::string& model, } CurlResponse PythonEngine::MakeDeleteRequest(const std::string& model, const std::string& path) { - auto config = models_[model]; + auto const& config = models_[model]; std::string full_url = "http://localhost:" + config.port + path; CurlResponse response; @@ -203,7 +204,7 @@ CurlResponse PythonEngine::MakeDeleteRequest(const std::string& model, CurlResponse PythonEngine::MakePostRequest(const std::string& model, const std::string& path, const std::string& body) { - auto config = models_[model]; + auto const& config = models_[model]; std::string full_url = "http://localhost:" + config.port + path; CurlResponse response; @@ -450,6 +451,63 @@ void PythonEngine::HandleChatCompletion( std::shared_ptr json_body, std::function&& callback) {} +CurlResponse PythonEngine::MakeStreamPostRequest( + const std::string& model, const std::string& path, const std::string& body, + const std::function& callback) { + auto const& config = models_[model]; + CURL* curl = curl_easy_init(); + CurlResponse response; + + if (!curl) { + response.error = true; + response.error_message = "Failed to initialize CURL"; + return response; + } + + std::string full_url = "http://localhost:" + config.port + path; + + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + headers = curl_slist_append(headers, "Accept: text/event-stream"); + headers = curl_slist_append(headers, "Cache-Control: no-cache"); + headers = curl_slist_append(headers, "Connection: keep-alive"); + + StreamContext context{ + std::make_shared>( + callback), + ""}; + + curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, StreamWriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &context); + curl_easy_setopt(curl, CURLOPT_TRANSFER_ENCODING, 1L); + + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + response.error = true; + response.error_message = curl_easy_strerror(res); + + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = true; + status["status_code"] = 500; + + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + return response; +} + + void PythonEngine::HandleInference( std::shared_ptr json_body, std::function&& callback) { @@ -485,7 +543,8 @@ void PythonEngine::HandleInference( // Render with error handling try { - transformed_request = renderer_.Render(transform_request, *json_body); + transformed_request = renderer_.Render(transform_request, body); + } catch (const std::exception& e) { throw std::runtime_error("Template rendering error: " + std::string(e.what())); @@ -504,7 +563,17 @@ void PythonEngine::HandleInference( CurlResponse response; if (method == "post") { - response = MakePostRequest(model, path, transformed_request); + if (body.isMember("stream") && body["stream"].asBool()) { + q_.runTaskInQueue( + [this, model, path, transformed_request, cb = std::move(callback)] { + MakeStreamPostRequest(model, path, transformed_request, cb); + }); + + return; + } else { + response = MakePostRequest(model, path, transformed_request); + } + } else if (method == "get") { response = MakeGetRequest(model, path); } else if (method == "delete") { diff --git a/engine/extensions/python-engine/python_engine.h b/engine/extensions/python-engine/python_engine.h index 7b112f435..979ba1fd8 100644 --- a/engine/extensions/python-engine/python_engine.h +++ b/engine/extensions/python-engine/python_engine.h @@ -8,6 +8,8 @@ #include #include #include "config/model_config.h" +#include "trantor/utils/ConcurrentTaskQueue.h" + #include "cortex-common/EngineI.h" #include "extensions/template_renderer.h" #include "utils/file_logger.h" @@ -44,19 +46,12 @@ static size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, while ((pos = context->buffer.find('\n')) != std::string::npos) { std::string line = context->buffer.substr(0, pos); context->buffer = context->buffer.substr(pos + 1); + LOG_DEBUG << "line: "< async_file_logger_; std::unordered_map processMap; + trantor::ConcurrentTaskQueue q_; + // Helper functions CurlResponse MakePostRequest(const std::string& model, @@ -108,6 +105,10 @@ class PythonEngine : public EngineI { const std::string& path); CurlResponse MakeDeleteRequest(const std::string& model, const std::string& path); + CurlResponse MakeStreamPostRequest( + const std::string& model, const std::string& path, + const std::string& body, + const std::function& callback); // Process manager functions pid_t SpawnProcess(const std::string& model,