diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index c801e84c3d415..1fccfdd17f138 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -14,6 +14,8 @@ endif() set(TARGET_SRCS server.cpp utils.hpp + server-http.cpp + server-http.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp new file mode 100644 index 0000000000000..196ced443261a --- /dev/null +++ b/tools/server/server-http.cpp @@ -0,0 +1,385 @@ +#include "utils.hpp" +#include "common.h" +#include "server-http.h" + +#include + +#include +#include +#include + +// auto generated files (see README.md for details) +#include "index.html.gz.hpp" +#include "loading.html.hpp" + +// +// HTTP implementation using cpp-httplib +// + +class server_http_context::Impl { +public: + std::unique_ptr srv; +}; + +server_http_context::server_http_context() + : pimpl(std::make_unique()) +{} + +server_http_context::~server_http_context() = default; + +static void log_server_request(const httplib::Request & req, const httplib::Response & res) { + // skip GH copilot requests when using default port + if (req.path == "/v1/health") { + return; + } + + // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch + + SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); + + SRV_DBG("request: %s\n", req.body.c_str()); + SRV_DBG("response: %s\n", res.body.c_str()); +} + +bool server_http_context::init(const common_params & params) { + path_prefix = params.api_prefix; + port = params.port; + hostname = params.hostname; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); + svr.reset( + new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) + ); + } else { + LOG_INF("Running without SSL\n"); + svr.reset(new httplib::Server()); + } +#else + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_ERR("Server is built without SSL support\n"); + return false; + } + pimpl->srv.reset(new httplib::Server()); +#endif + + auto & srv = pimpl->srv; + srv->set_default_headers({{"Server", "llama.cpp"}}); + srv->set_logger(log_server_request); + srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { + // this is fail-safe; exceptions should already handled by `ex_wrapper` + + std::string message; + try { + std::rethrow_exception(ep); + } catch (const std::exception & e) { + message = e.what(); + } catch (...) { + message = "Unknown Exception"; + } + + res.status = 500; + res.set_content(message, "text/plain"); + LOG_ERR("got exception: %s\n", message.c_str()); + }); + + srv->set_error_handler([](const httplib::Request &, httplib::Response & res) { + if (res.status == 404) { + res.set_content( + safe_json_to_str(json { + {"error", { + {"message", "File Not Found"}, + {"type", "not_found_error"}, + {"code", 404} + }} + }), + "application/json; charset=utf-8" + ); + } + // for other error codes, we skip processing here because it's already done by res->error() + }); + + // set timeouts and change hostname and port + srv->set_read_timeout (params.timeout_read); + srv->set_write_timeout(params.timeout_write); + + if (params.api_keys.size() == 1) { + auto key = params.api_keys[0]; + std::string substr = key.substr(std::max((int)(key.length() - 4), 0)); + LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str()); + } else if (params.api_keys.size() > 1) { + LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size()); + } + + // + // Middlewares + // + + auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) { + static const std::unordered_set public_endpoints = { + "/health", + "/v1/health", + "/models", + "/v1/models", + "/api/tags" + }; + + // If API key is not set, skip validation + if (api_keys.empty()) { + return true; + } + + // If path is public or is static file, skip validation + if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { + return true; + } + + // Check for API key in the header + auto auth_header = req.get_header_value("Authorization"); + + std::string prefix = "Bearer "; + if (auth_header.substr(0, prefix.size()) == prefix) { + std::string received_api_key = auth_header.substr(prefix.size()); + if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) { + return true; // API key is valid + } + } + + // API key is invalid or not provided + res.status = 401; + res.set_content( + safe_json_to_str(json { + {"error", { + {"message", "Invalid API Key"}, + {"type", "authentication_error"}, + {"code", 401} + }} + }), + "application/json; charset=utf-8" + ); + + LOG_WRN("Unauthorized: Invalid API Key\n"); + + return false; + }; + + auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) { + bool ready = is_ready.load(); + if (!ready) { + auto tmp = string_split(req.path, '.'); + if (req.path == "/" || tmp.back() == "html") { + res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); + res.status = 503; + } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { + // allow the models endpoint to be accessed during loading + return true; + } else { + res.status = 503; + res.set_content( + safe_json_to_str(json { + {"error", { + {"message", "Loading model"}, + {"type", "unavailable_error"}, + {"code", 503} + }} + }), + "application/json; charset=utf-8" + ); + } + return false; + } + return true; + }; + + // register server middlewares + srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + // If this is OPTIONS request, skip validation because browsers don't include Authorization header + if (req.method == "OPTIONS") { + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "GET, POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + res.set_content("", "text/html"); // blank response, no data + return httplib::Server::HandlerResponse::Handled; // skip further processing + } + if (!middleware_server_state(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + if (!middleware_validate_api_key(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + return httplib::Server::HandlerResponse::Unhandled; + }); + + int n_threads_http = params.n_threads_http; + if (n_threads_http < 1) { + // +2 threads for monitoring endpoints + n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); + } + LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http); + srv->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); }; + + // + // Web UI setup + // + + if (!params.webui) { + LOG_INF("Web UI is disabled\n"); + } else { + // register static assets routes + if (!params.public_path.empty()) { + // Set the base directory for serving static files + bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path); + if (!is_found) { + LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); + return 1; + } + } else { + // using embedded static index.html + srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) { + if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { + res.set_content("Error: gzip is not supported by this browser", "text/plain"); + } else { + res.set_header("Content-Encoding", "gzip"); + // COEP and COOP headers, required by pyodide (python interpreter) + res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); + res.set_header("Cross-Origin-Opener-Policy", "same-origin"); + res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); + } + return false; + }); + } + } + return true; +} + +bool server_http_context::start() { + // Bind and listen + + auto & srv = pimpl->srv; + bool was_bound = false; + bool is_sock = false; + if (string_ends_with(std::string(hostname), ".sock")) { + is_sock = true; + LOG_INF("%s: setting address family to AF_UNIX\n", __func__); + srv->set_address_family(AF_UNIX); + // bind_to_port requires a second arg, any value other than 0 should + // simply get ignored + was_bound = srv->bind_to_port(hostname, 8080); + } else { + LOG_INF("%s: binding port with default address family\n", __func__); + // bind HTTP listen port + if (port == 0) { + int bound_port = srv->bind_to_any_port(hostname); + if ((was_bound = (bound_port >= 0))) { + port = bound_port; + } + } else { + was_bound = srv->bind_to_port(hostname, port); + } + } + + if (!was_bound) { + LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port); + return false; + } + + // run the HTTP server in a thread + thread = std::thread([this]() { pimpl->srv->listen_after_bind(); }); + srv->wait_until_ready(); + + listening_address = is_sock ? string_format("unix://%s", hostname.c_str()) + : string_format("http://%s:%d", hostname.c_str(), port); + return true; +} + +void server_http_context::stop() { + if (pimpl->srv) { + pimpl->srv->stop(); + } +} + +static void set_headers(httplib::Response & res, const std::map & headers) { + for (const auto & [key, value] : headers) { + res.set_header(key, value); + } +} + +static std::map get_params(const httplib::Request & req) { + std::map params; + for (const auto & [key, value] : req.params) { + params[key] = value; + } + for (const auto & [key, value] : req.path_params) { + params[key] = value; + } + return params; +} + +static std::map get_headers(const httplib::Request & req) { + std::map headers; + for (const auto & [key, value] : req.headers) { + headers[key] = value; + } + return headers; +} + +static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) { + if (response->is_stream()) { + res.status = response->status; + set_headers(res, response->headers); + std::string content_type = response->content_type; + // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it + std::shared_ptr r_ptr = std::move(response); + const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool { + std::string chunk; + bool has_next = response->next(chunk); + if (!chunk.empty()) { + // TODO: maybe handle sink.write unsuccessful? for now, we rely on is_connection_closed() + sink.write(chunk.data(), chunk.size()); + SRV_DBG("http: streamed chunk: %s\n", chunk.c_str()); + } + if (!has_next) { + sink.done(); + SRV_DBG("%s", "http: stream ended\n"); + } + return has_next; + }; + const auto on_complete = [response = r_ptr](bool) mutable { + response.reset(); // trigger the destruction of the response object + }; + res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete); + } else { + res.status = response->status; + set_headers(res, response->headers); + res.set_content(response->data, response->content_type); + } +} + +void server_http_context::get(const std::string & path, server_http_context::handler_t handler) { + pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { + server_http_res_ptr response = handler(server_http_req{ + get_params(req), + get_headers(req), + req.path, + req.body, + req.is_connection_closed + }); + process_handler_response(response, res); + }); +} + +void server_http_context::post(const std::string & path, server_http_context::handler_t handler) { + pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { + server_http_res_ptr response = handler(server_http_req{ + get_params(req), + get_headers(req), + req.path, + req.body, + req.is_connection_closed + }); + process_handler_response(response, res); + }); +} + diff --git a/tools/server/server-http.h b/tools/server/server-http.h new file mode 100644 index 0000000000000..dc6ca92fd8751 --- /dev/null +++ b/tools/server/server-http.h @@ -0,0 +1,77 @@ +#pragma once + +#include "utils.hpp" +#include "common.h" + +#include +#include +#include +#include + +// generator-like API for HTTP response generation +// this object response with one of the 2 modes: +// 1) normal response: `data` contains the full response body +// 2) streaming response: each call to next(output) generates the next chunk +// when next(output) returns false, no more data after the current chunk +// note: some chunks can be empty, in which case no data is sent for that chunk +struct server_http_res { + std::string content_type = "application/json; charset=utf-8"; + int status = 200; + std::string data; + std::map headers; + + // TODO: move this to a virtual function once we have proper polymorphism support + std::function next = nullptr; + bool is_stream() const { + return next != nullptr; + } + + virtual ~server_http_res() = default; +}; + +// unique pointer, used by set_chunked_content_provider +// httplib requires the stream provider to be stored in heap +using server_http_res_ptr = std::unique_ptr; + +struct server_http_req { + std::map params; // path_params + query_params + std::map headers; // reserved for future use + std::string path; // reserved for future use + std::string body; + const std::function & should_stop; + + std::string get_param(const std::string & key, const std::string & def = "") const { + auto it = params.find(key); + if (it != params.end()) { + return it->second; + } + return def; + } +}; + +struct server_http_context { + class Impl; + std::unique_ptr pimpl; + + std::thread thread; // server thread + std::atomic is_ready = false; + + std::string path_prefix; + std::string hostname; + int port; + + server_http_context(); + ~server_http_context(); + + bool init(const common_params & params); + bool start(); + void stop(); + + // note: the handler should never throw exceptions + using handler_t = std::function; + void get(const std::string &, handler_t); + void post(const std::string &, handler_t); + + // for debugging + std::string listening_address; +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 535d2c450e21e..1c9e9a58d7daf 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,5 +1,15 @@ #include "chat.h" #include "utils.hpp" +#include "server-http.h" + +// fix problem with std::min and std::max +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif #include "arg.h" #include "common.h" @@ -10,13 +20,6 @@ #include "speculative.h" #include "mtmd.h" -// mime type for sending response -#define MIMETYPE_JSON "application/json; charset=utf-8" - -// auto generated files (see README.md for details) -#include "index.html.gz.hpp" -#include "loading.html.hpp" - #include #include #include @@ -25,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -1671,7 +1675,7 @@ struct server_slot { server_prompt prompt; void prompt_save(server_prompt_cache & prompt_cache) const { - assert(prompt.data.size() == 0); + GGML_ASSERT(prompt.data.size() == 0); const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); @@ -2382,6 +2386,7 @@ struct server_context { llama_batch_free(batch); } + // load the model and initialize llama_context bool load_model(const common_params & params) { SRV_INF("loading model '%s'\n", params.model.path.c_str()); @@ -2500,6 +2505,7 @@ struct server_context { return true; } + // initialize slots and server-related data void init() { SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); @@ -2599,6 +2605,11 @@ struct server_context { /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, /* enable_thinking */ enable_thinking, }; + + // print sample chat example to make it clear which template is used + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + common_chat_templates_source(chat_templates.get()), + common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); } server_slot * get_slot_by_id(int id) { @@ -4323,6 +4334,7 @@ struct server_context { } }; + // generator-like API for server responses, support pooling connection state and aggregating results struct server_response_reader { std::unordered_set id_tasks; @@ -4421,281 +4433,46 @@ struct server_response_reader { } }; -static void log_server_request(const httplib::Request & req, const httplib::Response & res) { - // skip GH copilot requests when using default port - if (req.path == "/v1/health") { - return; - } - - // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch - - SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); - - SRV_DBG("request: %s\n", req.body.c_str()); - SRV_DBG("response: %s\n", res.body.c_str()); -} - -static void res_err(httplib::Response & res, const json & error_data) { - json final_response {{"error", error_data}}; - res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); - res.status = json_value(error_data, "code", 500); -} - -static void res_ok(httplib::Response & res, const json & data) { - res.set_content(safe_json_to_str(data), MIMETYPE_JSON); - res.status = 200; -} - -std::function shutdown_handler; -std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; - -inline void signal_handler(int signal) { - if (is_terminating.test_and_set()) { - // in case it hangs, we can force terminate the server by hitting Ctrl+C twice - // this is for better developer experience, we can remove when the server is stable enough - fprintf(stderr, "Received second interrupt, terminating immediately.\n"); - exit(1); - } - - shutdown_handler(signal); -} - -int main(int argc, char ** argv) { - // own arguments required by this example - common_params params; - - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { - return 1; - } - - // TODO: should we have a separate n_parallel parameter for the server? - // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 - // TODO: this is a common configuration that is suitable for most local use cases - // however, overriding the parameters is a bit confusing - figure out something more intuitive - if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { - LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); - - params.n_parallel = 4; - params.kv_unified = true; - } - - common_init(); - - // struct that contains llama context and inference - server_context ctx_server; - - llama_backend_init(); - llama_numa_init(params.numa); - - LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); - LOG_INF("\n"); - LOG_INF("%s\n", common_params_get_system_info(params).c_str()); - LOG_INF("\n"); - - std::unique_ptr svr; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); - svr.reset( - new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) - ); - } else { - LOG_INF("Running without SSL\n"); - svr.reset(new httplib::Server()); +// generator-like API for HTTP response generation +struct server_res_generator : server_http_res { + server_response_reader rd; + server_res_generator(server_context & ctx_server_) : rd(ctx_server_) {} + void ok(const json & response_data) { + status = 200; + data = safe_json_to_str(response_data); } -#else - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_ERR("Server is built without SSL support\n"); - return 1; - } - svr.reset(new httplib::Server()); -#endif - - std::atomic state{SERVER_STATE_LOADING_MODEL}; - - svr->set_default_headers({{"Server", "llama.cpp"}}); - svr->set_logger(log_server_request); - svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { - std::string message; - try { - std::rethrow_exception(ep); - } catch (const std::exception & e) { - message = e.what(); - } catch (...) { - message = "Unknown Exception"; - } - - try { - json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); - LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); - res_err(res, formatted_error); - } catch (const std::exception & e) { - LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); - } - }); - - svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { - if (res.status == 404) { - res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); - } - // for other error codes, we skip processing here because it's already done by res_err() - }); - - // set timeouts and change hostname and port - svr->set_read_timeout (params.timeout_read); - svr->set_write_timeout(params.timeout_write); - - std::unordered_map log_data; - - log_data["hostname"] = params.hostname; - log_data["port"] = std::to_string(params.port); - - if (params.api_keys.size() == 1) { - auto key = params.api_keys[0]; - log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0)); - } else if (params.api_keys.size() > 1) { - log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; + void error(const json & error_data) { + status = json_value(error_data, "code", 500); + data = safe_json_to_str({{ "error", error_data }}); } +}; - // Necessary similarity of prompt for slot selection - ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - - // - // Middlewares - // - - auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { - static const std::unordered_set public_endpoints = { - "/health", - "/v1/health", - "/models", - "/v1/models", - "/api/tags" - }; - - // If API key is not set, skip validation - if (params.api_keys.empty()) { - return true; - } - - // If path is public or is static file, skip validation - if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { - return true; - } - - // Check for API key in the header - auto auth_header = req.get_header_value("Authorization"); - - std::string prefix = "Bearer "; - if (auth_header.substr(0, prefix.size()) == prefix) { - std::string received_api_key = auth_header.substr(prefix.size()); - if (std::find(params.api_keys.begin(), params.api_keys.end(), received_api_key) != params.api_keys.end()) { - return true; // API key is valid - } - } - - // API key is invalid or not provided - res_err(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); - - LOG_WRN("Unauthorized: Invalid API Key\n"); - - return false; - }; - - auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) { - server_state current_state = state.load(); - if (current_state == SERVER_STATE_LOADING_MODEL) { - auto tmp = string_split(req.path, '.'); - if (req.path == "/" || tmp.back() == "html") { - res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); - res.status = 503; - } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { - // allow the models endpoint to be accessed during loading - return true; - } else { - res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); - } - return false; - } - return true; - }; - - // register server middlewares - svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - // If this is OPTIONS request, skip validation because browsers don't include Authorization header - if (req.method == "OPTIONS") { - res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "GET, POST"); - res.set_header("Access-Control-Allow-Headers", "*"); - res.set_content("", "text/html"); // blank response, no data - return httplib::Server::HandlerResponse::Handled; // skip further processing - } - if (!middleware_server_state(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - if (!middleware_validate_api_key(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - return httplib::Server::HandlerResponse::Unhandled; - }); +struct server_routes { + const common_params & params; + server_context & ctx_server; + server_http_context & ctx_http; // for reading is_ready + server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) + : params(params), ctx_server(ctx_server), ctx_http(ctx_http) {} - // - // Route handlers (or controllers) - // +public: + // handlers using lambda function, so that they can capture `this` without `std::bind` - const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_health = [this](const server_http_req &) { // error and loading states are handled by middleware - json health = {{"status", "ok"}}; - res_ok(res, health); - }; - - const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { - if (!params.endpoint_slots) { - res_err(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - // request slots data using task queue - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res_err(res, result->to_json()); - return; - } - - // TODO: get rid of this dynamic_cast - auto res_task = dynamic_cast(result.get()); - GGML_ASSERT(res_task != nullptr); - - // optionally return "fail_on_no_slot" error - if (req.has_param("fail_on_no_slot")) { - if (res_task->n_idle_slots == 0) { - res_err(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); - return; - } - } - - res_ok(res, res_task->slots_data); + auto res = std::make_unique(ctx_server); + res->ok({{"status", "ok"}}); + return res; }; - const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_metrics = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); if (!params.endpoint_metrics) { - res_err(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); - return; + res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } // request slots data using task queue + // TODO: use server_response_reader int task_id = ctx_server.queue_tasks.get_new_id(); { server_task task(SERVER_TASK_TYPE_METRICS); @@ -4709,8 +4486,8 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(task_id); if (result->is_error()) { - res_err(res, result->to_json()); - return; + res->error(result->to_json()); + return res; } // TODO: get rid of this dynamic_cast @@ -4784,130 +4561,86 @@ int main(int argc, char ** argv) { } } - res.set_header("Process-Start-Time-Unix", std::to_string(res_task->t_start)); - - res.set_content(prometheus.str(), "text/plain; version=0.0.4"); - res.status = 200; // HTTP OK - }; - - const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { - json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_SAVE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res_err(res, result->to_json()); - return; - } - - res_ok(res, result->to_json()); + res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); + res->content_type = "text/plain; version=0.0.4"; + res->ok(prometheus.str()); + return res; }; - const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { - json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return; + server_http_context::handler_t get_slots = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_slots) { + res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } - std::string filepath = params.slot_save_path + filename; + // request slots data using task queue int task_id = ctx_server.queue_tasks.get_new_id(); { - server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); + server_task task(SERVER_TASK_TYPE_METRICS); task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_tasks.post(std::move(task), true); // high-priority task } + // get the result server_task_result_ptr result = ctx_server.queue_results.recv(task_id); ctx_server.queue_results.remove_waiting_task_id(task_id); if (result->is_error()) { - res_err(res, result->to_json()); - return; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res_ok(res, result->to_json()); - }; - - const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_ERASE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + res->error(result->to_json()); + return res; } - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + // TODO: get rid of this dynamic_cast + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); - if (result->is_error()) { - res_err(res, result->to_json()); - return; + // optionally return "fail_on_no_slot" error + if (!req.get_param("fail_on_no_slot").empty()) { + if (res_task->n_idle_slots == 0) { + res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + return res; + } } - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res_ok(res, result->to_json()); + res->ok(res_task->slots_data); + return res; }; - const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + server_http_context::handler_t post_slots = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); if (params.slot_save_path.empty()) { - res_err(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); - return; + res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } - std::string id_slot_str = req.path_params.at("id_slot"); + std::string id_slot_str = req.get_param("id_slot"); int id_slot; try { id_slot = std::stoi(id_slot_str); } catch (const std::exception &) { - res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + return res; } - std::string action = req.get_param_value("action"); + std::string action = req.get_param("action"); if (action == "save") { - handle_slots_save(req, res, id_slot); + return handle_slots_save(req, id_slot); } else if (action == "restore") { - handle_slots_restore(req, res, id_slot); + return handle_slots_restore(req, id_slot); } else if (action == "erase") { - handle_slots_erase(req, res, id_slot); + return handle_slots_erase(req, id_slot); } else { - res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + return res; } }; - const auto handle_props = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_props = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); json default_generation_settings_for_props; { @@ -4946,23 +4679,24 @@ int main(int argc, char ** argv) { } } - res_ok(res, data); + res->ok(data); + return res; }; - const auto handle_props_change = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - if (!ctx_server.params_base.endpoint_props) { - res_err(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); - return; + server_http_context::handler_t post_props = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_props) { + res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } - - json data = json::parse(req.body); - // update any props here - res_ok(res, {{ "success", true }}); + res->ok({{ "success", true }}); + return res; }; - const auto handle_api_show = [&ctx_server](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_api_show = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); bool has_mtmd = ctx_server.mctx != nullptr; json data = { { @@ -4988,193 +4722,12 @@ int main(int argc, char ** argv) { {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} }; - res_ok(res, data); + res->ok(data); + return res; }; - // handle completion-like requests (completion, chat, infill) - // we can optionally provide a custom format for partial results and final results - const auto handle_completions_impl = [&ctx_server]( - server_task_type type, - json & data, - const std::vector & files, - const std::function & is_connection_closed, - httplib::Response & res, - oaicompat_type oaicompat) -> void { - GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - - auto completion_id = gen_chatcmplid(); - // need to store the reader as a pointer, so that it won't be destroyed when the handle returns - // use shared_ptr as it's shared between the chunked_content_provider() and on_complete() - const auto rd = std::make_shared(ctx_server); - - try { - std::vector tasks; - - const auto & prompt = data.at("prompt"); - // TODO: this log can become very long, put it behind a flag or think about a more compact format - //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - - // process prompt - std::vector inputs; - - if (oaicompat && ctx_server.mctx != nullptr) { - // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); - } else { - // Everything else, including multimodal completions. - inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - } - tasks.reserve(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) { - server_task task = server_task(type); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - - task.tokens = std::move(inputs[i]); - task.params = server_task::params_from_json_cmpl( - ctx_server.ctx, - ctx_server.params_base, - data); - task.id_slot = json_value(data, "id_slot", -1); - - // OAI-compat - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; - // oaicompat_model is already populated by params_from_json_cmpl - - tasks.push_back(std::move(task)); - } - - rd->post_tasks(std::move(tasks)); - } catch (const std::exception & e) { - res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return; - } - - bool stream = json_value(data, "stream", false); - - if (!stream) { - // non-stream, wait for the results - auto all_results = rd->wait_for_all(is_connection_closed); - if (all_results.is_terminated) { - return; // connection is closed - } else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; - } else { - json arr = json::array(); - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - arr.push_back(res->to_json()); - } - // if single request, return single object instead of array - res_ok(res, arr.size() == 1 ? arr[0] : arr); - } - - } else { - // in streaming mode, the first error must be treated as non-stream response - // this is to match the OAI API behavior - // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd->next(is_connection_closed); - if (first_result == nullptr) { - return; // connection is closed - } else if (first_result->is_error()) { - res_err(res, first_result->to_json()); - return; - } else { - GGML_ASSERT( - dynamic_cast(first_result.get()) != nullptr - || dynamic_cast(first_result.get()) != nullptr - ); - } - - // next responses are streamed - json first_result_json = first_result->to_json(); - const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink & sink) mutable -> bool { - // flush the first result as it's not an error - if (!first_result_json.empty()) { - if (!server_sent_event(sink, first_result_json)) { - sink.done(); - return false; // sending failed, go to on_complete() - } - first_result_json.clear(); // mark as sent - } - - // receive subsequent results - auto result = rd->next([&sink]{ return !sink.is_writable(); }); - if (result == nullptr) { - sink.done(); - return false; // connection is closed, go to on_complete() - } - - // send the results - json res_json = result->to_json(); - bool ok = false; - if (result->is_error()) { - ok = server_sent_event(sink, json {{ "error", result->to_json() }}); - sink.done(); - return false; // go to on_complete() - } else { - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - ok = server_sent_event(sink, res_json); - } - - if (!ok) { - sink.done(); - return false; // sending failed, go to on_complete() - } - - // check if there is more data - if (!rd->has_next()) { - if (oaicompat != OAICOMPAT_TYPE_NONE) { - static const std::string ev_done = "data: [DONE]\n\n"; - sink.write(ev_done.data(), ev_done.size()); - } - sink.done(); - return false; // no more data, go to on_complete() - } - - // has next data, continue - return true; - }; - - auto on_complete = [rd](bool) { - rd->stop(); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }; - - const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - json data = json::parse(req.body); - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_NONE); - }; - - const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - json data = oaicompat_completion_params_parse(json::parse(req.body)); - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_COMPLETION); - }; - - const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + server_http_context::handler_t post_infill = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); // check model compatibility std::string err; if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { @@ -5187,43 +4740,42 @@ int main(int argc, char ** argv) { err += "middle token is missing. "; } if (!err.empty()) { - res_err(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); - return; + res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); + return res; } - json data = json::parse(req.body); - // validate input + json data = json::parse(req.body); if (data.contains("prompt") && !data.at("prompt").is_string()) { // prompt is optional - res_err(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); } if (!data.contains("input_prefix")) { - res_err(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); + res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); } if (!data.contains("input_suffix")) { - res_err(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); + res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); } if (data.contains("input_extra") && !data.at("input_extra").is_array()) { // input_extra is optional - res_err(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); + return res; } json input_extra = json_value(data, "input_extra", json::array()); for (const auto & chunk : input_extra) { // { "text": string, "filename": string } if (!chunk.contains("text") || !chunk.at("text").is_string()) { - res_err(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); + return res; } // filename is optional if (chunk.contains("filename") && !chunk.at("filename").is_string()) { - res_err(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; } } data["input_extra"] = input_extra; // default to empty array if it's not exist @@ -5244,49 +4796,69 @@ int main(int argc, char ** argv) { ); std::vector files; // dummy - handle_completions_impl( + return handle_completions_impl( SERVER_TASK_TYPE_INFILL, data, files, - req.is_connection_closed, - res, + req.should_stop, OAICOMPAT_TYPE_NONE); // infill is not OAI compatible }; - const auto handle_chat_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - LOG_DBG("request: %s\n", req.body.c_str()); + server_http_context::handler_t post_completions = [this](const server_http_req & req) { + std::vector files; // dummy + const json body = json::parse(req.body); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + body, + files, + req.should_stop, + OAICOMPAT_TYPE_NONE); + }; + + server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) { + std::vector files; // dummy + const json body = json::parse(req.body); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + body, + files, + req.should_stop, + OAICOMPAT_TYPE_COMPLETION); + }; - auto body = json::parse(req.body); + server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) { std::vector files; - json data = oaicompat_chat_params_parse( + json body = json::parse(req.body); + json body_parsed = oaicompat_chat_params_parse( body, ctx_server.oai_parser_opt, files); - - handle_completions_impl( + return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, - data, + body_parsed, files, - req.is_connection_closed, - res, + req.should_stop, OAICOMPAT_TYPE_CHAT); }; // same with handle_chat_completions, but without inference part - const auto handle_apply_template = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - auto body = json::parse(req.body); + server_http_context::handler_t post_apply_template = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); std::vector files; // dummy, unused + json body = json::parse(req.body); json data = oaicompat_chat_params_parse( body, ctx_server.oai_parser_opt, files); - res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); + res->ok({{ "prompt", std::move(data.at("prompt")) }}); + return res; }; - const auto handle_models = [¶ms, &ctx_server, &state](const httplib::Request &, httplib::Response & res) { - server_state current_state = state.load(); + server_http_context::handler_t get_models = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + bool is_model_ready = ctx_http.is_ready.load(); json model_meta = nullptr; - if (current_state == SERVER_STATE_READY) { + if (is_model_ready) { model_meta = ctx_server.model_meta(); } bool has_mtmd = ctx_server.mctx != nullptr; @@ -5325,12 +4897,13 @@ int main(int argc, char ** argv) { }} }; - res_ok(res, models); + res->ok(models); + return res; }; - const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + server_http_context::handler_t post_tokenize = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); const json body = json::parse(req.body); - json tokens_response = json::array(); if (body.count("content") != 0) { const bool add_special = json_value(body, "add_special", false); @@ -5366,10 +4939,12 @@ int main(int argc, char ** argv) { } const json data = format_tokenizer_response(tokens_response); - res_ok(res, data); + res->ok(data); + return res; }; - const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + server_http_context::handler_t post_detokenize = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); const json body = json::parse(req.body); std::string content; @@ -5379,118 +4954,23 @@ int main(int argc, char ** argv) { } const json data = format_detokenized_response(content); - res_ok(res, data); - }; - - const auto handle_embeddings_impl = [&ctx_server](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { - if (!ctx_server.params_base.embedding) { - res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); - return; - } - - const json body = json::parse(req.body); - - // for the shape of input/content, see tokenize_input_prompts() - json prompt; - if (body.count("input") != 0) { - prompt = body.at("input"); - } else if (body.contains("content")) { - oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible - prompt = body.at("content"); - } else { - res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return; - } - - bool use_base64 = false; - if (body.count("encoding_format") != 0) { - const std::string& format = body.at("encoding_format"); - if (format == "base64") { - use_base64 = true; - } else if (format != "float") { - res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } - - auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - for (const auto & tokens : tokenized_prompts) { - // this check is necessary for models that do not add BOS token to the input - if (tokens.empty()) { - res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } - - int embd_normalize = 2; // default to Euclidean/L2 norm - if (body.count("embd_normalize") != 0) { - embd_normalize = body.at("embd_normalize"); - if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); - } - } - - // create and queue the task - json responses = json::array(); - server_response_reader rd(ctx_server); - { - std::vector tasks; - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tokenized_prompts[i]); - - // OAI-compat - task.params.oaicompat = oaicompat; - task.params.embd_normalize = embd_normalize; - - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } - - // wait for the results - auto all_results = rd.wait_for_all(req.is_connection_closed); - - // collect results - if (all_results.is_terminated) { - return; // connection is closed - } else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; - } else { - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING - ? format_embeddings_response_oaicompat(body, responses, use_base64) - : json(responses); - res_ok(res, root); + res->ok(data); + return res; }; - const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); + server_http_context::handler_t post_embeddings = [this](const server_http_req & req) { + return handle_embeddings_impl(req, OAICOMPAT_TYPE_NONE); }; - const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); + server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) { + return handle_embeddings_impl(req, OAICOMPAT_TYPE_EMBEDDING); }; - const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + server_http_context::handler_t post_rerank = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { - res_err(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); - return; + res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } const json body = json::parse(req.body); @@ -5504,19 +4984,19 @@ int main(int argc, char ** argv) { if (body.count("query") == 1) { query = body.at("query"); if (!query.is_string()) { - res_err(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; } } else { - res_err(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; } std::vector documents = json_value(body, "documents", json_value(body, "texts", std::vector())); if (documents.empty()) { - res_err(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return res; } int top_n = json_value(body, "top_n", (int)documents.size()); @@ -5539,14 +5019,14 @@ int main(int argc, char ** argv) { } // wait for the results - auto all_results = rd.wait_for_all(req.is_connection_closed); + auto all_results = rd.wait_for_all(req.should_stop); // collect results if (all_results.is_terminated) { - return; // connection is closed + return res; // connection is closed } else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; + res->error(all_results.error->to_json()); + return res; } else { for (auto & res : all_results.results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); @@ -5562,10 +5042,12 @@ int main(int argc, char ** argv) { documents, top_n); - res_ok(res, root); + res->ok(root); + return res; }; - const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_lora_adapters = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); json result = json::array(); const auto & loras = ctx_server.params_base.lora_adapters; for (size_t i = 0; i < loras.size(); ++i) { @@ -5591,15 +5073,16 @@ int main(int argc, char ** argv) { } result.push_back(std::move(entry)); } - res_ok(res, result); - res.status = 200; // HTTP OK + res->ok(result); + return res; }; - const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { + server_http_context::handler_t post_lora_adapters = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); const json body = json::parse(req.body); if (!body.is_array()) { - res_err(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); + return res; } int task_id = ctx_server.queue_tasks.get_new_id(); @@ -5616,152 +5099,525 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(task_id); if (result->is_error()) { - res_err(res, result->to_json()); - return; + res->error(result->to_json()); + return res; } GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res_ok(res, result->to_json()); + res->ok(result->to_json()); + return res; }; - // - // Router - // +private: + std::unique_ptr handle_completions_impl( + server_task_type type, + const json & data, + const std::vector & files, + const std::function & should_stop, + oaicompat_type oaicompat) { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + + auto res = std::make_unique(ctx_server); + auto completion_id = gen_chatcmplid(); + auto & rd = res->rd; - if (!params.webui) { - LOG_INF("Web UI is disabled\n"); - } else { - // register static assets routes - if (!params.public_path.empty()) { - // Set the base directory for serving static files - bool is_found = svr->set_mount_point(params.api_prefix + "/", params.public_path); - if (!is_found) { - LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); - return 1; + try { + std::vector tasks; + + const auto & prompt = data.at("prompt"); + // TODO: this log can become very long, put it behind a flag or think about a more compact format + //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + + // process prompt + std::vector inputs; + + if (oaicompat && ctx_server.mctx != nullptr) { + // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); + } else { + // Everything else, including multimodal completions. + inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server.ctx, + ctx_server.params_base, + data); + task.id_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl + + tasks.push_back(std::move(task)); + } + + rd.post_tasks(std::move(tasks)); + } catch (const std::exception & e) { + res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + bool stream = json_value(data, "stream", false); + + if (!stream) { + // non-stream, wait for the results + auto all_results = rd.wait_for_all(should_stop); + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + json arr = json::array(); + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + arr.push_back(res->to_json()); + } + // if single request, return single object instead of array + res->ok(arr.size() == 1 ? arr[0] : arr); + } + } else { - // using embedded static index.html - svr->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) { - if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { - res.set_content("Error: gzip is not supported by this browser", "text/plain"); + // in streaming mode, the first error must be treated as non-stream response + // this is to match the OAI API behavior + // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 + server_task_result_ptr first_result = rd.next(should_stop); + if (first_result == nullptr) { + return res; // connection is closed + } else if (first_result->is_error()) { + res->error(first_result->to_json()); + return res; + } else { + GGML_ASSERT( + dynamic_cast(first_result.get()) != nullptr + || dynamic_cast(first_result.get()) != nullptr + ); + } + + // next responses are streamed + res->data = format_sse(first_result->to_json()); // to be sent immediately + res->status = 200; + res->content_type = "text/event-stream"; + res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool { + if (should_stop()) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + if (!res_this->data.empty()) { + // flush the first chunk + output = std::move(res_this->data); + res_this->data.clear(); + return true; + } + + server_response_reader & rd = res_this->rd; + + // check if there is more data + if (!rd.has_next()) { + if (oaicompat != OAICOMPAT_TYPE_NONE) { + output = "data: [DONE]\n\n"; + } else { + output = ""; + } + SRV_DBG("%s", "all results received, terminating stream\n"); + return false; // no more data, terminate + } + + // receive subsequent results + auto result = rd.next(should_stop); + if (result == nullptr) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + // send the results + json res_json = result->to_json(); + if (result->is_error()) { + output = format_sse(json {{ "error", res_json }}); + SRV_DBG("%s", "error received during streaming, terminating stream\n"); + return false; // terminate on error } else { - res.set_header("Content-Encoding", "gzip"); - // COEP and COOP headers, required by pyodide (python interpreter) - res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); - res.set_header("Cross-Origin-Opener-Policy", "same-origin"); - res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + output = format_sse(res_json); } - return false; - }); + + // has next data, continue + return true; + }; + } + + return res; + } + + std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + res->ok(result->to_json()); + return res; + } + + std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; + } + + std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot) { + auto res = std::make_unique(ctx_server); + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_ERASE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; + } + + std::unique_ptr handle_embeddings_impl(const server_http_req & req, oaicompat_type oaicompat) { + auto res = std::make_unique(ctx_server); + if (!ctx_server.params_base.embedding) { + res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + const json body = json::parse(req.body); + + // for the shape of input/content, see tokenize_input_prompts() + json prompt; + if (body.count("input") != 0) { + prompt = body.at("input"); + } else if (body.contains("content")) { + oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible + prompt = body.at("content"); + } else { + res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + bool use_base64 = false; + if (body.count("encoding_format") != 0) { + const std::string& format = body.at("encoding_format"); + if (format == "base64") { + use_base64 = true; + } else if (format != "float") { + res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); + for (const auto & tokens : tokenized_prompts) { + // this check is necessary for models that do not add BOS token to the input + if (tokens.empty()) { + res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + + int embd_normalize = 2; // default to Euclidean/L2 norm + if (body.count("embd_normalize") != 0) { + embd_normalize = body.at("embd_normalize"); + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); + } + } + + // create and queue the task + json responses = json::array(); + server_response_reader rd(ctx_server); + { + std::vector tasks; + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tokenized_prompts[i]); + + // OAI-compat + task.params.oaicompat = oaicompat; + task.params.embd_normalize = embd_normalize; + + tasks.push_back(std::move(task)); + } + rd.post_tasks(std::move(tasks)); + } + + // wait for the results + auto all_results = rd.wait_for_all(req.should_stop); + + // collect results + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + + // write JSON response + json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? format_embeddings_response_oaicompat(body, responses, use_base64) + : json(responses); + res->ok(root); + return res; + } +}; + +std::function shutdown_handler; +std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; + +inline void signal_handler(int signal) { + if (is_terminating.test_and_set()) { + // in case it hangs, we can force terminate the server by hitting Ctrl+C twice + // this is for better developer experience, we can remove when the server is stable enough + fprintf(stderr, "Received second interrupt, terminating immediately.\n"); + exit(1); + } + + shutdown_handler(signal); +} + +// wrapper function that handles exceptions and logs errors +// this is to make sure handler_t never throws exceptions; instead, it returns an error response +static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) { + return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr { + std::string message; + try { + return func(req); + } catch (const std::exception & e) { + message = e.what(); + } catch (...) { + message = "unknown error"; } + + auto res = std::make_unique(); + res->status = 500; + try { + json error_data = format_error_response(message, ERROR_TYPE_SERVER); + res->status = json_value(error_data, "code", 500); + res->data = safe_json_to_str({{ "error", error_data }}); + LOG_WRN("got exception: %s\n", res->data.c_str()); + } catch (const std::exception & e) { + LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); + res->data = "Internal Server Error"; + } + return res; + }; +} + +int main(int argc, char ** argv) { + // own arguments required by this example + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { + return 1; } + // TODO: should we have a separate n_parallel parameter for the server? + // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 + // TODO: this is a common configuration that is suitable for most local use cases + // however, overriding the parameters is a bit confusing - figure out something more intuitive + if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { + LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); + + params.n_parallel = 4; + params.kv_unified = true; + } + + common_init(); + + // struct that contains llama context and inference + server_context ctx_server; + + // Necessary similarity of prompt for slot selection + ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; + + llama_backend_init(); + llama_numa_init(params.numa); + + LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); + + server_http_context ctx_http; + if (!ctx_http.init(params)) { + LOG_ERR("%s: failed to initialize HTTP server\n", __func__); + return 1; + } + + // + // Router + // + // register API routes - svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check) - svr->Get (params.api_prefix + "/v1/health", handle_health); // public endpoint (no API key check) - svr->Get (params.api_prefix + "/metrics", handle_metrics); - svr->Get (params.api_prefix + "/props", handle_props); - svr->Post(params.api_prefix + "/props", handle_props_change); - svr->Post(params.api_prefix + "/api/show", handle_api_show); - svr->Get (params.api_prefix + "/models", handle_models); // public endpoint (no API key check) - svr->Get (params.api_prefix + "/v1/models", handle_models); // public endpoint (no API key check) - svr->Get (params.api_prefix + "/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check) - svr->Post(params.api_prefix + "/completion", handle_completions); // legacy - svr->Post(params.api_prefix + "/completions", handle_completions); - svr->Post(params.api_prefix + "/v1/completions", handle_completions_oai); - svr->Post(params.api_prefix + "/chat/completions", handle_chat_completions); - svr->Post(params.api_prefix + "/v1/chat/completions", handle_chat_completions); - svr->Post(params.api_prefix + "/api/chat", handle_chat_completions); // ollama specific endpoint - svr->Post(params.api_prefix + "/infill", handle_infill); - svr->Post(params.api_prefix + "/embedding", handle_embeddings); // legacy - svr->Post(params.api_prefix + "/embeddings", handle_embeddings); - svr->Post(params.api_prefix + "/v1/embeddings", handle_embeddings_oai); - svr->Post(params.api_prefix + "/rerank", handle_rerank); - svr->Post(params.api_prefix + "/reranking", handle_rerank); - svr->Post(params.api_prefix + "/v1/rerank", handle_rerank); - svr->Post(params.api_prefix + "/v1/reranking", handle_rerank); - svr->Post(params.api_prefix + "/tokenize", handle_tokenize); - svr->Post(params.api_prefix + "/detokenize", handle_detokenize); - svr->Post(params.api_prefix + "/apply-template", handle_apply_template); + server_routes routes(params, ctx_server, ctx_http); + + ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) + ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) + ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); + ctx_http.get ("/props", ex_wrapper(routes.get_props)); + ctx_http.post("/props", ex_wrapper(routes.post_props)); + ctx_http.post("/api/show", ex_wrapper(routes.get_api_show)); + ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) + ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) + ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check) + ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy + ctx_http.post("/completions", ex_wrapper(routes.post_completions)); + ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai)); + ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); + ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); + ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint + ctx_http.post("/infill", ex_wrapper(routes.post_infill)); + ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy + ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings)); + ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai)); + ctx_http.post("/rerank", ex_wrapper(routes.post_rerank)); + ctx_http.post("/reranking", ex_wrapper(routes.post_rerank)); + ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank)); + ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank)); + ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize)); + ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize)); + ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template)); // LoRA adapters hotswap - svr->Get (params.api_prefix + "/lora-adapters", handle_lora_adapters_list); - svr->Post(params.api_prefix + "/lora-adapters", handle_lora_adapters_apply); + ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters)); + ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters)); // Save & load slots - svr->Get (params.api_prefix + "/slots", handle_slots); - svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action); + ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); + ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); // // Start the server // - if (params.n_threads_http < 1) { - // +2 threads for monitoring endpoints - params.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); - } - log_data["n_threads_http"] = std::to_string(params.n_threads_http); - svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; - // clean up function, to be called before exit - auto clean_up = [&svr, &ctx_server]() { + // setup clean up function, to be called before exit + auto clean_up = [&ctx_http, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); - svr->stop(); + ctx_http.stop(); ctx_server.queue_results.terminate(); llama_backend_free(); }; - bool was_bound = false; - bool is_sock = false; - if (string_ends_with(std::string(params.hostname), ".sock")) { - is_sock = true; - LOG_INF("%s: setting address family to AF_UNIX\n", __func__); - svr->set_address_family(AF_UNIX); - // bind_to_port requires a second arg, any value other than 0 should - // simply get ignored - was_bound = svr->bind_to_port(params.hostname, 8080); - } else { - LOG_INF("%s: binding port with default address family\n", __func__); - // bind HTTP listen port - if (params.port == 0) { - int bound_port = svr->bind_to_any_port(params.hostname); - if ((was_bound = (bound_port >= 0))) { - params.port = bound_port; - } - } else { - was_bound = svr->bind_to_port(params.hostname, params.port); - } - } - - if (!was_bound) { - LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port); + // start the HTTP server before loading the model to be able to serve /health requests + if (!ctx_http.start()) { clean_up(); + LOG_ERR("%s: exiting due to HTTP server error\n", __func__); return 1; } - // run the HTTP server in a thread - std::thread t([&]() { svr->listen_after_bind(); }); - svr->wait_until_ready(); - - LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http); - // load the model LOG_INF("%s: loading model\n", __func__); if (!ctx_server.load_model(params)) { clean_up(); - t.join(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } LOG_ERR("%s: exiting due to model loading error\n", __func__); return 1; } ctx_server.init(); - state.store(SERVER_STATE_READY); + ctx_http.is_ready.store(true); LOG_INF("%s: model loaded\n", __func__); - // print sample chat example to make it clear which template is used - LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - common_chat_templates_source(ctx_server.chat_templates.get()), - common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja, ctx_server.params_base.default_template_kwargs).c_str()); - ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { ctx_server.process_single_task(std::move(task)); }); @@ -5789,15 +5645,15 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__, - is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : - string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str()); - + LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); + LOG_INF("%s: starting the main loop...\n", __func__); // this call blocks the main thread until queue_tasks.terminate() is called ctx_server.queue_tasks.start_loop(); clean_up(); - t.join(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } llama_memory_breakdown_print(ctx_server.ctx); return 0; diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index b1ecc5af5ed0a..bf21726051e55 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -9,8 +9,6 @@ #include "mtmd-helper.h" #include "chat.h" -#include - #define JSON_ASSERT GGML_ASSERT #include @@ -426,6 +424,10 @@ static std::string gen_tool_call_id() { // other common utils // +static std::string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); +} + // TODO: reuse llama_detokenize template static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { @@ -453,29 +455,25 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx, return out; } +// format server-sent event (SSE), return the formatted string to send // note: if data is a json array, it will be sent as multiple events, one per item -static bool server_sent_event(httplib::DataSink & sink, const json & data) { - static auto send_single = [](httplib::DataSink & sink, const json & data) -> bool { - const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + +static std::string format_sse(const json & data) { + std::ostringstream ss; + auto send_single = [&ss](const json & data) { + ss << "data: " << + safe_json_to_str(data) << "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). - - LOG_DBG("data stream, to_send: %s", str.c_str()); - return sink.write(str.c_str(), str.size()); }; if (data.is_array()) { for (const auto & item : data) { - if (!send_single(sink, item)) { - return false; - } + send_single(item); } } else { - return send_single(sink, data); + send_single(data); } - return true; + return ss.str(); } // @@ -954,10 +952,6 @@ static json format_logit_bias(const std::vector & logit_bias) return data; } -static std::string safe_json_to_str(const json & data) { - return data.dump(-1, ' ', false, json::error_handler_t::replace); -} - static std::vector get_token_probabilities(llama_context * ctx, int idx) { std::vector cur; const auto * logits = llama_get_logits_ith(ctx, idx);