From c00764c26b7432a30ff598177ce14641f159da5e Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 9 Dec 2024 17:13:21 +0700 Subject: [PATCH 1/7] fix: improve streaming message for remote engine --- engine/config/model_config.h | 41 ++++++++++++ .../extensions/remote-engine/remote_engine.cc | 67 ++++++++++++++++--- .../extensions/remote-engine/remote_engine.h | 4 +- 3 files changed, 101 insertions(+), 11 deletions(-) diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 84e175d54..d319145ad 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -52,6 +52,37 @@ const std::string kAnthropicTransformRespTemplate = R"({ }, "system_fingerprint": "fp_6b68a8204b" })"; + +const std::string kAnthropicTransformStreamRespTemplate = R"( +{ + "object": "chat.completion.chunk", + {% if type == "message_start" %} + "model": "{{ message.model }}", + {% endif %} + "choices": [ + { + "index": 0, + "delta": { + {% if type == "message_start" %} + "role": "assistant", + "content": "" + {% else if type == "content_block_delta" %} + "role": "assistant", + "content": "{{ delta.text }}" + {% else if type == "content_block_stop" %} + "role": "assistant", + "content": "" + {% endif %} + }, + {% if type == "content_block_stop" %} + "finish_reason": "stop" + {% else %} + "finish_reason": null + {% endif %} + } + ] +} +)"; } // namespace struct RemoteModelConfig { @@ -108,6 +139,16 @@ struct RemoteModelConfig { kOpenAITransformRespTemplate; } } + + if (TransformResp["chat_completions"]["stream_template"].isNull()) { + if (is_anthropic(model)) { + TransformResp["chat_completions"]["stream_template"] = + kAnthropicTransformStreamRespTemplate; + } else { + TransformResp["chat_completions"]["stream_template"] = + kOpenAITransformRespTemplate; + } + } metadata = json.get("metadata", metadata); } diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 04effb457..8b9bc8200 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -16,6 +16,10 @@ bool is_anthropic(const std::string& model) { return model.find("claude") != std::string::npos; } +bool is_openai(const std::string& model) { + return model.find("gpt") != std::string::npos; +} + struct AnthropicChunk { std::string type; std::string id; @@ -78,6 +82,35 @@ struct AnthropicChunk { } }; +const std::string kOpenaiChatStreamTemplate = R"( +{ + "object": "chat.completion.chunk", + "model": "{{ model }}", + "choices": [ + { + "index": 0, + "delta": { + {% if type == "message_start" %} + "role": "assistant", + "content": "" + {% else if type == "content_block_delta" %} + "role": "assistant", + "content": "{{ delta.text }}" + {% else if type == "content_block_stop" %} + "role": "assistant", + "content": "" + {% endif %} + }, + {% if type == "content_block_stop" %} + "finish_reason": "stop" + {% else %} + "finish_reason": null + {% endif %} + } + ] +} +)"; + } // namespace size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, @@ -120,17 +153,18 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, // Parse the JSON Json::Value chunk_json; - if (is_anthropic(context->model)) { - AnthropicChunk ac(line); - if (ac.should_ignore) + if (!is_openai(context->model)) { + std::string s = line.substr(6); + try { + auto root = json_helper::ParseJsonString(s); + root["model"] = context->model; + root["id"] = context->id; + auto result = context->renderer.Render(context->stream_template, root); + chunk_json["data"] = "data: " + result + "\n\n"; + } catch (const std::exception& e) { + CTL_WRN("JSON parse error: " << e.what()); continue; - ac.model = context->model; - if (ac.type == "message_start") { - context->id = ac.id; - } else { - ac.id = context->id; } - chunk_json["data"] = ac.ToOpenAiFormatString() + "\n\n"; } else { chunk_json["data"] = line + "\n\n"; } @@ -178,10 +212,23 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( headers = curl_slist_append(headers, "Cache-Control: no-cache"); headers = curl_slist_append(headers, "Connection: keep-alive"); + std::string stream_template = kOpenaiChatStreamTemplate; + if (config.transform_resp["chat_completions"]["stream_template"]) { + stream_template = + config.transform_resp["chat_completions"]["stream_template"] + .as(); + } else { + CTL_WRN("stream_template does not exist"); + } + StreamContext context{ std::make_shared>( callback), - "", "", config.model}; + "", + "", + config.model, + renderer_, + stream_template}; curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 8ce6fa652..9132b4ada 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -18,12 +18,15 @@ inline bool IsRemoteEngine(std::string_view e) { return e == kAnthropicEngine || e == kOpenAiEngine; } + struct StreamContext { std::shared_ptr> callback; std::string buffer; // Cache value for Anthropic std::string id; std::string model; + TemplateRenderer& renderer; + std::string stream_template; }; struct CurlResponse { std::string body; @@ -50,7 +53,6 @@ class RemoteEngine : public RemoteEngineI { TemplateRenderer renderer_; Json::Value metadata_; std::string api_key_template_; - std::unique_ptr async_file_logger_; // Helper functions CurlResponse MakeChatCompletionRequest(const ModelConfig& config, From 80e6fdec8dbc7f4704c1072fe035a6c86262835b Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Wed, 11 Dec 2024 06:36:06 +0700 Subject: [PATCH 2/7] feat: improve remote engine --- engine/common/engine_servicei.h | 2 + engine/config/model_config.h | 39 ---- engine/controllers/models.cc | 10 +- .../remote-engine/anthropic_engine.cc | 1 + .../remote-engine/anthropic_engine.h | 8 +- .../extensions/remote-engine/remote_engine.cc | 179 +++++------------- .../extensions/remote-engine/remote_engine.h | 6 +- engine/services/engine_service.cc | 29 +-- engine/services/engine_service.h | 2 + engine/services/model_service.cc | 2 +- engine/utils/engine_constants.h | 3 + 11 files changed, 91 insertions(+), 190 deletions(-) diff --git a/engine/common/engine_servicei.h b/engine/common/engine_servicei.h index 85fa87d76..a4b0c8732 100644 --- a/engine/common/engine_servicei.h +++ b/engine/common/engine_servicei.h @@ -58,4 +58,6 @@ class EngineServiceI { GetEngineByNameAndVariant( const std::string& engine_name, const std::optional variant = std::nullopt) = 0; + + virtual bool IsRemoteEngine(const std::string& engine_name) = 0; }; diff --git a/engine/config/model_config.h b/engine/config/model_config.h index d319145ad..a6af81974 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -53,36 +53,6 @@ const std::string kAnthropicTransformRespTemplate = R"({ "system_fingerprint": "fp_6b68a8204b" })"; -const std::string kAnthropicTransformStreamRespTemplate = R"( -{ - "object": "chat.completion.chunk", - {% if type == "message_start" %} - "model": "{{ message.model }}", - {% endif %} - "choices": [ - { - "index": 0, - "delta": { - {% if type == "message_start" %} - "role": "assistant", - "content": "" - {% else if type == "content_block_delta" %} - "role": "assistant", - "content": "{{ delta.text }}" - {% else if type == "content_block_stop" %} - "role": "assistant", - "content": "" - {% endif %} - }, - {% if type == "content_block_stop" %} - "finish_reason": "stop" - {% else %} - "finish_reason": null - {% endif %} - } - ] -} -)"; } // namespace struct RemoteModelConfig { @@ -140,15 +110,6 @@ struct RemoteModelConfig { } } - if (TransformResp["chat_completions"]["stream_template"].isNull()) { - if (is_anthropic(model)) { - TransformResp["chat_completions"]["stream_template"] = - kAnthropicTransformStreamRespTemplate; - } else { - TransformResp["chat_completions"]["stream_template"] = - kOpenAITransformRespTemplate; - } - } metadata = json.get("metadata", metadata); } diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 3f91da848..ff612226c 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -178,7 +178,7 @@ void Models::ListModel( .string()); auto model_config = yaml_handler.GetModelConfig(); - if (!remote_engine::IsRemoteEngine(model_config.engine)) { + if (!engine_service_->IsRemoteEngine(model_config.engine)) { Json::Value obj = model_config.ToJson(); obj["id"] = model_entry.model; obj["model"] = model_entry.model; @@ -610,7 +610,7 @@ void Models::GetRemoteModels( const HttpRequestPtr& req, std::function&& callback, const std::string& engine_id) { - if (!remote_engine::IsRemoteEngine(engine_id)) { + if (!engine_service_->IsRemoteEngine(engine_id)) { Json::Value ret; ret["message"] = "Not a remote engine: " + engine_id; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); @@ -646,8 +646,7 @@ void Models::AddRemoteModel( auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); auto engine_name = (*(req->getJsonObject())).get("engine", "").asString(); - /* To do: uncomment when remote engine is ready - + auto engine_validate = engine_service_->IsEngineReady(engine_name); if (engine_validate.has_error()) { Json::Value ret; @@ -657,6 +656,7 @@ void Models::AddRemoteModel( callback(resp); return; } + if (!engine_validate.value()) { Json::Value ret; ret["message"] = "Engine is not ready! Please install first!"; @@ -665,7 +665,7 @@ void Models::AddRemoteModel( callback(resp); return; } - */ + config::RemoteModelConfig model_config; model_config.LoadFromJson(*(req->getJsonObject())); cortex::db::Models modellist_utils_obj; diff --git a/engine/extensions/remote-engine/anthropic_engine.cc b/engine/extensions/remote-engine/anthropic_engine.cc index 847cba566..78bbffc2c 100644 --- a/engine/extensions/remote-engine/anthropic_engine.cc +++ b/engine/extensions/remote-engine/anthropic_engine.cc @@ -10,6 +10,7 @@ constexpr const std::array kAnthropicModels = { "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"}; } + void AnthropicEngine::GetModels( std::shared_ptr json_body, std::function&& callback) { diff --git a/engine/extensions/remote-engine/anthropic_engine.h b/engine/extensions/remote-engine/anthropic_engine.h index bcd3dfaf7..44bf76046 100644 --- a/engine/extensions/remote-engine/anthropic_engine.h +++ b/engine/extensions/remote-engine/anthropic_engine.h @@ -2,12 +2,12 @@ #include "remote_engine.h" namespace remote_engine { - class AnthropicEngine: public RemoteEngine { -public: +class AnthropicEngine : public RemoteEngine { + public: void GetModels( std::shared_ptr json_body, std::function&& callback) override; Json::Value GetRemoteModels() override; - }; -} \ No newline at end of file +}; +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 8b9bc8200..4d9d90d0b 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -20,97 +20,6 @@ bool is_openai(const std::string& model) { return model.find("gpt") != std::string::npos; } -struct AnthropicChunk { - std::string type; - std::string id; - int index; - std::string msg; - std::string model; - std::string stop_reason; - bool should_ignore = false; - - AnthropicChunk(const std::string& str) { - if (str.size() > 6) { - std::string s = str.substr(6); - try { - auto root = json_helper::ParseJsonString(s); - type = root["type"].asString(); - if (type == "message_start") { - id = root["message"]["id"].asString(); - model = root["message"]["model"].asString(); - } else if (type == "content_block_delta") { - index = root["index"].asInt(); - if (root["delta"]["type"].asString() == "text_delta") { - msg = root["delta"]["text"].asString(); - } - } else if (type == "message_delta") { - stop_reason = root["delta"]["stop_reason"].asString(); - } else { - // ignore other messages - should_ignore = true; - } - } catch (const std::exception& e) { - should_ignore = true; - CTL_WRN("JSON parse error: " << e.what()); - } - } else { - should_ignore = true; - } - } - - std::string ToOpenAiFormatString() { - Json::Value root; - root["id"] = id; - root["object"] = "chat.completion.chunk"; - root["created"] = Json::Value(); - root["model"] = model; - root["system_fingerprint"] = "fp_e76890f0c3"; - Json::Value choices(Json::arrayValue); - Json::Value choice; - Json::Value content; - choice["index"] = 0; - content["content"] = msg; - if (type == "message_start") { - content["role"] = "assistant"; - content["refusal"] = Json::Value(); - } - choice["delta"] = content; - choice["finish_reason"] = stop_reason.empty() ? Json::Value() : stop_reason; - choices.append(choice); - root["choices"] = choices; - return "data: " + json_helper::DumpJsonString(root); - } -}; - -const std::string kOpenaiChatStreamTemplate = R"( -{ - "object": "chat.completion.chunk", - "model": "{{ model }}", - "choices": [ - { - "index": 0, - "delta": { - {% if type == "message_start" %} - "role": "assistant", - "content": "" - {% else if type == "content_block_delta" %} - "role": "assistant", - "content": "{{ delta.text }}" - {% else if type == "content_block_stop" %} - "role": "assistant", - "content": "" - {% endif %} - }, - {% if type == "content_block_stop" %} - "finish_reason": "stop" - {% else %} - "finish_reason": null - {% endif %} - } - ] -} -)"; - } // namespace size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, @@ -125,21 +34,16 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, while ((pos = context->buffer.find('\n')) != std::string::npos) { std::string line = context->buffer.substr(0, pos); context->buffer = context->buffer.substr(pos + 1); - CTL_TRC(line); + // CTL_INF(line); // Skip empty lines if (line.empty() || line == "\r" || line.find("event:") != std::string::npos) continue; - // Remove "data: " prefix if present - // if (line.substr(0, 6) == "data: ") - // { - // line = line.substr(6); - // } - // Skip [DONE] message // std::cout << line << std::endl; + CTL_DBG(line); if (line == "data: [DONE]" || line.find("message_stop") != std::string::npos) { Json::Value status; @@ -159,7 +63,9 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, auto root = json_helper::ParseJsonString(s); root["model"] = context->model; root["id"] = context->id; + root["stream"] = true; auto result = context->renderer.Render(context->stream_template, root); + CTL_DBG(result); chunk_json["data"] = "data: " + result + "\n\n"; } catch (const std::exception& e) { CTL_WRN("JSON parse error: " << e.what()); @@ -212,14 +118,7 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( headers = curl_slist_append(headers, "Cache-Control: no-cache"); headers = curl_slist_append(headers, "Connection: keep-alive"); - std::string stream_template = kOpenaiChatStreamTemplate; - if (config.transform_resp["chat_completions"]["stream_template"]) { - stream_template = - config.transform_resp["chat_completions"]["stream_template"] - .as(); - } else { - CTL_WRN("stream_template does not exist"); - } + std::string stream_template = chat_res_template_; StreamContext context{ std::make_shared>( @@ -478,6 +377,21 @@ void RemoteEngine::LoadModel( } if (json_body->isMember("metadata")) { metadata_ = (*json_body)["metadata"]; + if (!metadata_["TransformReq"].isNull() && + !metadata_["TransformReq"]["chat_completions"].isNull() && + !metadata_["TransformReq"]["chat_completions"]["template"].isNull()) { + chat_req_template_ = + metadata_["TransformReq"]["chat_completions"]["template"].asString(); + CTL_INF(chat_req_template_); + } + + if (!metadata_["TransformResp"].isNull() && + !metadata_["TransformResp"]["chat_completions"].isNull() && + !metadata_["TransformResp"]["chat_completions"]["template"].isNull()) { + chat_res_template_ = + metadata_["TransformResp"]["chat_completions"]["template"].asString(); + CTL_INF(chat_res_template_); + } } Json::Value response; @@ -648,33 +562,42 @@ void RemoteEngine::HandleChatCompletion( // Transform Response std::string response_str; try { - // Check if required YAML nodes exist - if (!model_config->transform_resp["chat_completions"]) { - throw std::runtime_error( - "Missing 'chat_completions' node in transform_resp"); - } - if (!model_config->transform_resp["chat_completions"]["template"]) { - throw std::runtime_error("Missing 'template' node in chat_completions"); - } + std::string template_str; + if (!chat_res_template_.empty()) { + CTL_DBG( + "Use engine transform response template: " << chat_res_template_); + template_str = chat_res_template_; + } else { + // Check if required YAML nodes exist + if (!model_config->transform_resp["chat_completions"]) { + throw std::runtime_error( + "Missing 'chat_completions' node in transform_resp"); + } + if (!model_config->transform_resp["chat_completions"]["template"]) { + throw std::runtime_error( + "Missing 'template' node in chat_completions"); + } - // Validate JSON body - if (!response_json || response_json.isNull()) { - throw std::runtime_error("Invalid or null JSON body"); - } + // Validate JSON body + if (!response_json || response_json.isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } - // Get template string with error check - std::string template_str; - try { - template_str = - model_config->transform_resp["chat_completions"]["template"] - .as(); - } catch (const YAML::BadConversion& e) { - throw std::runtime_error("Failed to convert template node to string: " + - std::string(e.what())); + // Get template string with error check + + try { + template_str = + model_config->transform_resp["chat_completions"]["template"] + .as(); + } catch (const YAML::BadConversion& e) { + throw std::runtime_error( + "Failed to convert template node to string: " + + std::string(e.what())); + } } - // Render with error handling try { + response_json["stream"] = false; response_str = renderer_.Render(template_str, response_json); } catch (const std::exception& e) { throw std::runtime_error("Template rendering error: " + diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 9132b4ada..d404fb363 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -14,11 +14,11 @@ // Helper for CURL response namespace remote_engine { +// TODO(sang) remove this after we have all engine saved in DB inline bool IsRemoteEngine(std::string_view e) { return e == kAnthropicEngine || e == kOpenAiEngine; } - struct StreamContext { std::shared_ptr> callback; std::string buffer; @@ -52,6 +52,8 @@ class RemoteEngine : public RemoteEngineI { std::unordered_map models_; TemplateRenderer renderer_; Json::Value metadata_; + std::string chat_req_template_; + std::string chat_res_template_; std::string api_key_template_; // Helper functions @@ -97,7 +99,7 @@ class RemoteEngine : public RemoteEngineI { void HandleEmbedding( std::shared_ptr json_body, std::function&& callback) override; - + Json::Value GetRemoteModels() override; }; diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index fe5317c7d..f3167fb3c 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -185,7 +185,7 @@ cpp::result EngineService::UninstallEngineVariant( // TODO: handle uninstall remote engine // only delete a remote engine if no model are using it auto exist_engine = GetEngineByNameAndVariant(engine); - if (exist_engine.has_value() && exist_engine.value().type == "remote") { + if (exist_engine.has_value() && exist_engine.value().type == kRemote) { auto result = DeleteEngine(exist_engine.value().id); if (!result.empty()) { // This mean no error when delete model CTL_ERR("Failed to delete engine: " << result); @@ -331,15 +331,9 @@ cpp::result EngineService::DownloadEngine( } else { CTL_INF("Set default engine variant: " << res.value().variant); } - auto create_res = - EngineService::UpsertEngine(engine, // engine_name - "local", // todo - luke - "", // todo - luke - "", // todo - luke - normalize_version, variant.value(), - "Default", // todo - luke - "" // todo - luke - ); + auto create_res = EngineService::UpsertEngine( + engine, // engine_name + kLocal, "", "", normalize_version, variant.value(), "Default", ""); if (create_res.has_value()) { CTL_ERR("Failed to create engine entry: " << create_res->engine_name); @@ -681,7 +675,7 @@ cpp::result EngineService::LoadEngine( } // Check for remote engine - if (remote_engine::IsRemoteEngine(engine_name)) { + if (IsRemoteEngine(engine_name)) { auto exist_engine = GetEngineByNameAndVariant(engine_name); if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine_name + "' is not installed"); @@ -1097,4 +1091,17 @@ cpp::result EngineService::GetRemoteModels( } else { return res; } +} + +bool EngineService::IsRemoteEngine(const std::string& engine_name) { + cortex::db::Engines e_db; + auto res = e_db.GetEngines(); + if (res) { + for (auto const& e : *res) { + if (e.engine_name == engine_name && e.type == kRemote) { + return true; + } + } + } + return false; } \ No newline at end of file diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index ab274825d..9a050fb70 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -149,6 +149,8 @@ class EngineService : public EngineServiceI { cpp::result GetRemoteModels( const std::string& engine_name); + bool IsRemoteEngine(const std::string& engine_name) override; + private: bool IsEngineLoaded(const std::string& engine); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 7f79ddaf7..95daf74d6 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -722,7 +722,7 @@ cpp::result ModelService::StartModel( auto mc = yaml_handler.GetModelConfig(); // Running remote model - if (remote_engine::IsRemoteEngine(mc.engine)) { + if (engine_svc_->IsRemoteEngine(mc.engine)) { config::RemoteModelConfig remote_mc; remote_mc.LoadFromYamlFile( diff --git a/engine/utils/engine_constants.h b/engine/utils/engine_constants.h index 020109fd8..dcdf6a443 100644 --- a/engine/utils/engine_constants.h +++ b/engine/utils/engine_constants.h @@ -6,6 +6,9 @@ constexpr const auto kTrtLlmEngine = "tensorrt-llm"; constexpr const auto kOpenAiEngine = "openai"; constexpr const auto kAnthropicEngine = "anthropic"; +constexpr const auto kRemote = "remote"; +constexpr const auto kLocal = "local"; + constexpr const auto kOnnxRepo = "cortex.onnx"; constexpr const auto kLlamaRepo = "cortex.llamacpp"; constexpr const auto kTrtLlmRepo = "cortex.tensorrt-llm"; From 28933474b09a8d0f9f99721372410e983fb62d76 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 12 Dec 2024 06:31:06 +0700 Subject: [PATCH 3/7] chore: cleanup --- engine/CMakeLists.txt | 1 - engine/cli/CMakeLists.txt | 1 - .../remote-engine/anthropic_engine.cc | 32 ----------- .../remote-engine/anthropic_engine.h | 3 -- .../extensions/remote-engine/openai_engine.cc | 54 ------------------- .../extensions/remote-engine/openai_engine.h | 14 ----- .../extensions/remote-engine/remote_engine.cc | 44 +++++++++++++-- engine/services/engine_service.cc | 9 ++-- 8 files changed, 45 insertions(+), 113 deletions(-) delete mode 100644 engine/extensions/remote-engine/openai_engine.cc delete mode 100644 engine/extensions/remote-engine/openai_engine.h diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 7cac3421c..23c8b78ac 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -143,7 +143,6 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/cpuid/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/remote_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/openai_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/anthropic_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/template_renderer.cc ) diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 51382dc13..7083dc0d9 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -83,7 +83,6 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/openai_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/anthropic_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/template_renderer.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc diff --git a/engine/extensions/remote-engine/anthropic_engine.cc b/engine/extensions/remote-engine/anthropic_engine.cc index 78bbffc2c..823ac4747 100644 --- a/engine/extensions/remote-engine/anthropic_engine.cc +++ b/engine/extensions/remote-engine/anthropic_engine.cc @@ -11,38 +11,6 @@ constexpr const std::array kAnthropicModels = { "claude-3-haiku-20240307"}; } -void AnthropicEngine::GetModels( - std::shared_ptr json_body, - std::function&& callback) { - Json::Value json_resp; - Json::Value model_array(Json::arrayValue); - { - std::shared_lock l(models_mtx_); - for (const auto& [m, _] : models_) { - Json::Value val; - val["id"] = m; - val["engine"] = "anthropic"; - val["start_time"] = "_"; - val["model_size"] = "_"; - val["vram"] = "_"; - val["ram"] = "_"; - val["object"] = "model"; - model_array.append(val); - } - } - - json_resp["object"] = "list"; - json_resp["data"] = model_array; - - Json::Value status; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = false; - status["status_code"] = 200; - callback(std::move(status), std::move(json_resp)); - CTL_INF("Running models responded"); -} - Json::Value AnthropicEngine::GetRemoteModels() { Json::Value json_resp; Json::Value model_array(Json::arrayValue); diff --git a/engine/extensions/remote-engine/anthropic_engine.h b/engine/extensions/remote-engine/anthropic_engine.h index 44bf76046..b4b98d2e4 100644 --- a/engine/extensions/remote-engine/anthropic_engine.h +++ b/engine/extensions/remote-engine/anthropic_engine.h @@ -4,9 +4,6 @@ namespace remote_engine { class AnthropicEngine : public RemoteEngine { public: - void GetModels( - std::shared_ptr json_body, - std::function&& callback) override; Json::Value GetRemoteModels() override; }; diff --git a/engine/extensions/remote-engine/openai_engine.cc b/engine/extensions/remote-engine/openai_engine.cc deleted file mode 100644 index 7c7d70385..000000000 --- a/engine/extensions/remote-engine/openai_engine.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include "openai_engine.h" -#include "utils/logging_utils.h" - -namespace remote_engine { - -void OpenAiEngine::GetModels( - std::shared_ptr json_body, - std::function&& callback) { - Json::Value json_resp; - Json::Value model_array(Json::arrayValue); - { - std::shared_lock l(models_mtx_); - for (const auto& [m, _] : models_) { - Json::Value val; - val["id"] = m; - val["engine"] = "openai"; - val["start_time"] = "_"; - val["model_size"] = "_"; - val["vram"] = "_"; - val["ram"] = "_"; - val["object"] = "model"; - model_array.append(val); - } - } - - json_resp["object"] = "list"; - json_resp["data"] = model_array; - - Json::Value status; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = false; - status["status_code"] = 200; - callback(std::move(status), std::move(json_resp)); - CTL_INF("Running models responded"); -} - -Json::Value OpenAiEngine::GetRemoteModels() { - auto response = MakeGetModelsRequest(); - if (response.error) { - Json::Value error; - error["error"] = response.error_message; - return error; - } - Json::Value response_json; - Json::Reader reader; - if (!reader.parse(response.body, response_json)) { - Json::Value error; - error["error"] = "Failed to parse response"; - return error; - } - return response_json; -} -} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/openai_engine.h b/engine/extensions/remote-engine/openai_engine.h deleted file mode 100644 index 61dc68f0c..000000000 --- a/engine/extensions/remote-engine/openai_engine.h +++ /dev/null @@ -1,14 +0,0 @@ -#pragma once - -#include "remote_engine.h" - -namespace remote_engine { -class OpenAiEngine : public RemoteEngine { - public: - void GetModels( - std::shared_ptr json_body, - std::function&& callback) override; - - Json::Value GetRemoteModels() override; -}; -} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 4d9d90d0b..4b3f0be79 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -341,7 +341,33 @@ bool RemoteEngine::LoadModelConfig(const std::string& model, void RemoteEngine::GetModels( std::shared_ptr json_body, std::function&& callback) { - CTL_WRN("Not implemented yet!"); + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + { + std::shared_lock l(models_mtx_); + for (const auto& [m, _] : models_) { + Json::Value val; + val["id"] = m; + val["engine"] = "openai"; + val["start_time"] = "_"; + val["model_size"] = "_"; + val["vram"] = "_"; + val["ram"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + callback(std::move(status), std::move(json_resp)); + CTL_INF("Running models responded"); } void RemoteEngine::LoadModel( @@ -675,8 +701,20 @@ void RemoteEngine::HandleEmbedding( } Json::Value RemoteEngine::GetRemoteModels() { - CTL_WRN("Not implemented yet!"); - return {}; + auto response = MakeGetModelsRequest(); + if (response.error) { + Json::Value error; + error["error"] = response.error_message; + return error; + } + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value error; + error["error"] = "Failed to parse response"; + return error; + } + return response_json; } } // namespace remote_engine \ No newline at end of file diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 063294216..2f7576f7f 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -7,7 +7,6 @@ #include "algorithm" #include "database/engines.h" #include "extensions/remote-engine/anthropic_engine.h" -#include "extensions/remote-engine/openai_engine.h" #include "utils/archive_utils.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" @@ -683,8 +682,8 @@ cpp::result EngineService::LoadEngine( return cpp::fail("Remote engine '" + engine_name + "' is not installed"); } - if (engine_name == kOpenAiEngine) { - engines_[engine_name].engine = new remote_engine::OpenAiEngine(); + if (engine_name != kAnthropicEngine) { + engines_[engine_name].engine = new remote_engine::RemoteEngine(); } else { engines_[engine_name].engine = new remote_engine::AnthropicEngine(); } @@ -1041,8 +1040,8 @@ cpp::result EngineService::GetRemoteModels( if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine_name + "' is not installed"); } - if (engine_name == kOpenAiEngine) { - engines_[engine_name].engine = new remote_engine::OpenAiEngine(); + if (engine_name != kAnthropicEngine) { + engines_[engine_name].engine = new remote_engine::RemoteEngine(); } else { engines_[engine_name].engine = new remote_engine::AnthropicEngine(); } From bad9f11f3889acd963d09604b2ca4cf3d1d4eb59 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 12 Dec 2024 06:40:16 +0700 Subject: [PATCH 4/7] fix: correct remote engine check --- engine/extensions/remote-engine/remote_engine.h | 4 ---- engine/services/engine_service.cc | 17 +++++++---------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index d404fb363..92d9e8126 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -14,10 +14,6 @@ // Helper for CURL response namespace remote_engine { -// TODO(sang) remove this after we have all engine saved in DB -inline bool IsRemoteEngine(std::string_view e) { - return e == kAnthropicEngine || e == kOpenAiEngine; -} struct StreamContext { std::shared_ptr> callback; diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 2f7576f7f..69f7a136b 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -864,7 +864,7 @@ cpp::result EngineService::IsEngineReady( auto ne = NormalizeEngine(engine); // Check for remote engine - if (remote_engine::IsRemoteEngine(engine)) { + if (IsRemoteEngine(engine)) { auto exist_engine = GetEngineByNameAndVariant(engine); if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine + "' is not installed"); @@ -1058,16 +1058,13 @@ cpp::result EngineService::GetRemoteModels( } bool EngineService::IsRemoteEngine(const std::string& engine_name) { - cortex::db::Engines e_db; - auto res = e_db.GetEngines(); - if (res) { - for (auto const& e : *res) { - if (e.engine_name == engine_name && e.type == kRemote) { - return true; - } - } + auto ne = Repo2Engine(engine_name); + auto local_engines = file_manager_utils::GetCortexConfig().supportedEngines; + for (auto const& le : local_engines) { + if (le == ne) + return false; } - return false; + return true; } cpp::result, std::string> From 223eef69670109825393985794a3a82282e8bdc1 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 12 Dec 2024 14:49:36 +0700 Subject: [PATCH 5/7] chore: add unit tests --- engine/test/components/CMakeLists.txt | 1 + engine/test/components/main.cc | 4 + engine/test/components/test_remote_engine.cc | 81 ++++++++++++++++++++ 3 files changed, 86 insertions(+) create mode 100644 engine/test/components/test_remote_engine.cc diff --git a/engine/test/components/CMakeLists.txt b/engine/test/components/CMakeLists.txt index 58c5d83d6..0df46cfc2 100644 --- a/engine/test/components/CMakeLists.txt +++ b/engine/test/components/CMakeLists.txt @@ -16,6 +16,7 @@ add_executable(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/file_manager_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/curl_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/system_info_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../extensions/remote-engine/template_renderer.cc ) find_package(Drogon CONFIG REQUIRED) diff --git a/engine/test/components/main.cc b/engine/test/components/main.cc index 08080680e..ba24a3e01 100644 --- a/engine/test/components/main.cc +++ b/engine/test/components/main.cc @@ -4,11 +4,15 @@ int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); +#if defined(NDEBUG) ::testing::GTEST_FLAG(filter) = "-FileManagerConfigTest.*"; int ret = RUN_ALL_TESTS(); if (ret != 0) return ret; ::testing::GTEST_FLAG(filter) = "FileManagerConfigTest.*"; ret = RUN_ALL_TESTS(); +#else + int ret = RUN_ALL_TESTS(); +#endif return ret; } diff --git a/engine/test/components/test_remote_engine.cc b/engine/test/components/test_remote_engine.cc new file mode 100644 index 000000000..bfac76f49 --- /dev/null +++ b/engine/test/components/test_remote_engine.cc @@ -0,0 +1,81 @@ +#include "extensions/remote-engine/template_renderer.h" +#include "gtest/gtest.h" +#include "utils/json_helper.h" + +class RemoteEngineTest : public ::testing::Test {}; + +TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) { + std::string tpl = + R"({ + {% for key, value in input_request %} + {% if key == "messages" %} + {% if input_request.messages.0.role == "system" %} + "system": "{{ input_request.messages.0.content }}", + "messages": [ + {% for message in input_request.messages %} + {% if not loop.is_first %} + {"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endif %} + {% endfor %} + ] + {% else %} + "messages": [ + {% for message in input_request.messages %} + {"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endfor %} + ] + {% endif %} + {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} + "{{ key }}": {{ tojson(value) }} + {% endif %} + {% if not loop.is_last %},{% endif %} + {% endfor %} })"; + { + std::string message_with_system = R"({ + "messages": [ + {"role": "system", "content": "You are a seasoned data scientist at a Fortune 500 company."}, + {"role": "user", "content": "Hello, world"} + ], + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, +})"; + + auto data = json_helper::ParseJsonString(message_with_system); + + remote_engine::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt()); + for (auto const& msg : data["messages"]) { + if (msg["role"].asString() == "system") { + EXPECT_EQ(msg["content"].asString(), res_json["system"].asString()); + } else if (msg["role"].asString() == "user") { + EXPECT_EQ(msg["content"].asString(), + res_json["messages"][0]["content"].asString()); + } + } + } + + { + std::string message_without_system = R"({ + "messages": [ + {"role": "user", "content": "Hello, world"} + ], + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, +})"; + + auto data = json_helper::ParseJsonString(message_without_system); + + remote_engine::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt()); + EXPECT_EQ(data["messages"][0]["content"].asString(), + res_json["messages"][0]["content"].asString()); + } +} \ No newline at end of file From 742cb027c2389552362093931d9b456abcce11c4 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 12 Dec 2024 14:55:45 +0700 Subject: [PATCH 6/7] fix: cleanup --- engine/CMakeLists.txt | 1 - engine/cli/CMakeLists.txt | 1 - engine/config/model_config.h | 29 ++++++- .../remote-engine/anthropic_engine.cc | 31 -------- .../remote-engine/anthropic_engine.h | 10 --- .../extensions/remote-engine/remote_engine.cc | 77 +++++++++++-------- .../extensions/remote-engine/remote_engine.h | 3 +- engine/services/engine_service.cc | 14 +--- 8 files changed, 74 insertions(+), 92 deletions(-) delete mode 100644 engine/extensions/remote-engine/anthropic_engine.cc delete mode 100644 engine/extensions/remote-engine/anthropic_engine.h diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 23c8b78ac..b7f577669 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -143,7 +143,6 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/cpuid/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/remote_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/anthropic_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/template_renderer.cc ) diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 7083dc0d9..a0a69592a 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -83,7 +83,6 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/anthropic_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/template_renderer.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc diff --git a/engine/config/model_config.h b/engine/config/model_config.h index a6af81974..084f4519d 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -15,11 +15,34 @@ namespace config { namespace { const std::string kOpenAITransformReqTemplate = - R"({ {% set first = true %} {% for key, value in input_request %} {% if key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; + R"({ {% set first = true %} {% for key, value in input_request %} {% if key == "messages" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} {% if not first %},{% endif %} "{{ key }}": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; const std::string kOpenAITransformRespTemplate = - R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == \"id\" or key == \"choices\" or key == \"created\" or key == \"model\" or key == \"service_tier\" or key == \"system_fingerprint\" or key == \"object\" or key == \"usage\" -%} {%- if not first -%},{%- endif -%} \"{{ key }}\": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; + R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == "id" or key == "choices" or key == "created" or key == "model" or key == "service_tier" or key == "system_fingerprint" or key == "object" or key == "usage" -%} {%- if not first -%},{%- endif -%} "{{ key }}": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; const std::string kAnthropicTransformReqTemplate = - R"({ {% set first = true %} {% for key, value in input_request %} {% if key == \"system\" or key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; + R"({ + {% for key, value in input_request %} + {% if key == "messages" %} + {% if input_request.messages.0.role == "system" %} + "system": "{{ input_request.messages.0.content }}", + "messages": [ + {% for message in input_request.messages %} + {% if not loop.is_first %} + {"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endif %} + {% endfor %} + ] + {% else %} + "messages": [ + {% for message in input_request.messages %} + {"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endfor %} + ] + {% endif %} + {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} + "{{ key }}": {{ tojson(value) }} + {% endif %} + {% if not loop.is_last %},{% endif %} + {% endfor %} })"; const std::string kAnthropicTransformRespTemplate = R"({ "id": "{{ input_request.id }}", "created": null, diff --git a/engine/extensions/remote-engine/anthropic_engine.cc b/engine/extensions/remote-engine/anthropic_engine.cc deleted file mode 100644 index 823ac4747..000000000 --- a/engine/extensions/remote-engine/anthropic_engine.cc +++ /dev/null @@ -1,31 +0,0 @@ -#include "anthropic_engine.h" -#include -#include -#include "utils/logging_utils.h" - -namespace remote_engine { -namespace { -constexpr const std::array kAnthropicModels = { - "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", - "claude-3-opus-20240229", "claude-3-sonnet-20240229", - "claude-3-haiku-20240307"}; -} - -Json::Value AnthropicEngine::GetRemoteModels() { - Json::Value json_resp; - Json::Value model_array(Json::arrayValue); - for (const auto& m : kAnthropicModels) { - Json::Value val; - val["id"] = std::string(m); - val["engine"] = "anthropic"; - val["created"] = "_"; - val["object"] = "model"; - model_array.append(val); - } - - json_resp["object"] = "list"; - json_resp["data"] = model_array; - CTL_INF("Remote models responded"); - return json_resp; -} -} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/anthropic_engine.h b/engine/extensions/remote-engine/anthropic_engine.h deleted file mode 100644 index b4b98d2e4..000000000 --- a/engine/extensions/remote-engine/anthropic_engine.h +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once -#include "remote_engine.h" - -namespace remote_engine { -class AnthropicEngine : public RemoteEngine { - public: - - Json::Value GetRemoteModels() override; -}; -} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 4b3f0be79..6361077dd 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -20,6 +20,11 @@ bool is_openai(const std::string& model) { return model.find("gpt") != std::string::npos; } +constexpr const std::array kAnthropicModels = { + "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", + "claude-3-opus-20240229", "claude-3-sonnet-20240229", + "claude-3-haiku-20240307"}; + } // namespace size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, @@ -34,15 +39,12 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, while ((pos = context->buffer.find('\n')) != std::string::npos) { std::string line = context->buffer.substr(0, pos); context->buffer = context->buffer.substr(pos + 1); - // CTL_INF(line); // Skip empty lines if (line.empty() || line == "\r" || line.find("event:") != std::string::npos) continue; - // Skip [DONE] message - // std::cout << line << std::endl; CTL_DBG(line); if (line == "data: [DONE]" || line.find("message_stop") != std::string::npos) { @@ -178,7 +180,8 @@ static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, return size * nmemb; } -RemoteEngine::RemoteEngine() { +RemoteEngine::RemoteEngine(const std::string& engine_name) + : engine_name_(engine_name) { curl_global_init(CURL_GLOBAL_ALL); } @@ -522,23 +525,6 @@ void RemoteEngine::HandleChatCompletion( std::string(e.what())); } - // Parse system for anthropic - if (is_anthropic(model)) { - bool has_system = false; - Json::Value msgs(Json::arrayValue); - for (auto& kv : (*json_body)["messages"]) { - if (kv["role"].asString() == "system") { - (*json_body)["system"] = kv["content"].asString(); - has_system = true; - } else { - msgs.append(kv); - } - } - if (has_system) { - (*json_body)["messages"] = msgs; - } - } - // Render with error handling try { result = renderer_.Render(template_str, *json_body); @@ -701,20 +687,43 @@ void RemoteEngine::HandleEmbedding( } Json::Value RemoteEngine::GetRemoteModels() { - auto response = MakeGetModelsRequest(); - if (response.error) { - Json::Value error; - error["error"] = response.error_message; - return error; - } - Json::Value response_json; - Json::Reader reader; - if (!reader.parse(response.body, response_json)) { - Json::Value error; - error["error"] = "Failed to parse response"; - return error; + if (metadata_["get_models_url"].isNull() || + metadata_["get_models_url"].asString().empty()) { + if (engine_name_ == kAnthropicEngine) { + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + for (const auto& m : kAnthropicModels) { + Json::Value val; + val["id"] = std::string(m); + val["engine"] = "anthropic"; + val["created"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + CTL_INF("Remote models responded"); + return json_resp; + } else { + return Json::Value(); + } + } else { + auto response = MakeGetModelsRequest(); + if (response.error) { + Json::Value error; + error["error"] = response.error_message; + return error; + } + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value error; + error["error"] = "Failed to parse response"; + return error; + } + return response_json; } - return response_json; } } // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 92d9e8126..d8dfbad61 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -51,6 +51,7 @@ class RemoteEngine : public RemoteEngineI { std::string chat_req_template_; std::string chat_res_template_; std::string api_key_template_; + std::string engine_name_; // Helper functions CurlResponse MakeChatCompletionRequest(const ModelConfig& config, @@ -67,7 +68,7 @@ class RemoteEngine : public RemoteEngineI { ModelConfig* GetModelConfig(const std::string& model); public: - RemoteEngine(); + explicit RemoteEngine(const std::string& engine_name); virtual ~RemoteEngine(); // Main interface implementations diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 69f7a136b..adf926db2 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -6,7 +6,7 @@ #include #include "algorithm" #include "database/engines.h" -#include "extensions/remote-engine/anthropic_engine.h" +#include "extensions/remote-engine/remote_engine.h" #include "utils/archive_utils.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" @@ -682,11 +682,7 @@ cpp::result EngineService::LoadEngine( return cpp::fail("Remote engine '" + engine_name + "' is not installed"); } - if (engine_name != kAnthropicEngine) { - engines_[engine_name].engine = new remote_engine::RemoteEngine(); - } else { - engines_[engine_name].engine = new remote_engine::AnthropicEngine(); - } + engines_[engine_name].engine = new remote_engine::RemoteEngine(engine_name); CTL_INF("Loaded engine: " << engine_name); return {}; @@ -1040,11 +1036,7 @@ cpp::result EngineService::GetRemoteModels( if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine_name + "' is not installed"); } - if (engine_name != kAnthropicEngine) { - engines_[engine_name].engine = new remote_engine::RemoteEngine(); - } else { - engines_[engine_name].engine = new remote_engine::AnthropicEngine(); - } + engines_[engine_name].engine = new remote_engine::RemoteEngine(engine_name); CTL_INF("Loaded engine: " << engine_name); } From 127b8fa0852fe4efadc7c6e85191000fb035c5c2 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 16 Dec 2024 09:11:44 +0700 Subject: [PATCH 7/7] chore: cleanup --- engine/config/model_config.h | 66 +-------------------------------- engine/config/remote_template.h | 66 +++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 65 deletions(-) create mode 100644 engine/config/remote_template.h diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 084f4519d..a799adb27 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -8,76 +8,12 @@ #include #include #include +#include "config/remote_template.h" #include "utils/format_utils.h" #include "utils/remote_models_utils.h" namespace config { -namespace { -const std::string kOpenAITransformReqTemplate = - R"({ {% set first = true %} {% for key, value in input_request %} {% if key == "messages" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} {% if not first %},{% endif %} "{{ key }}": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; -const std::string kOpenAITransformRespTemplate = - R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == "id" or key == "choices" or key == "created" or key == "model" or key == "service_tier" or key == "system_fingerprint" or key == "object" or key == "usage" -%} {%- if not first -%},{%- endif -%} "{{ key }}": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; -const std::string kAnthropicTransformReqTemplate = - R"({ - {% for key, value in input_request %} - {% if key == "messages" %} - {% if input_request.messages.0.role == "system" %} - "system": "{{ input_request.messages.0.content }}", - "messages": [ - {% for message in input_request.messages %} - {% if not loop.is_first %} - {"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} - {% endif %} - {% endfor %} - ] - {% else %} - "messages": [ - {% for message in input_request.messages %} - {"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} - {% endfor %} - ] - {% endif %} - {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} - "{{ key }}": {{ tojson(value) }} - {% endif %} - {% if not loop.is_last %},{% endif %} - {% endfor %} })"; -const std::string kAnthropicTransformRespTemplate = R"({ - "id": "{{ input_request.id }}", - "created": null, - "object": "chat.completion", - "model": "{{ input_request.model }}", - "choices": [ - { - "index": 0, - "message": { - "role": "{{ input_request.role }}", - "content": "{% if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% endif %}", - "refusal": null - }, - "logprobs": null, - "finish_reason": "{{ input_request.stop_reason }}" - } - ], - "usage": { - "prompt_tokens": {{ input_request.usage.input_tokens }}, - "completion_tokens": {{ input_request.usage.output_tokens }}, - "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, - "prompt_tokens_details": { - "cached_tokens": 0 - }, - "completion_tokens_details": { - "reasoning_tokens": 0, - "accepted_prediction_tokens": 0, - "rejected_prediction_tokens": 0 - } - }, - "system_fingerprint": "fp_6b68a8204b" - })"; - -} // namespace - struct RemoteModelConfig { std::string model; std::string api_key_template; diff --git a/engine/config/remote_template.h b/engine/config/remote_template.h new file mode 100644 index 000000000..8a17aaa9a --- /dev/null +++ b/engine/config/remote_template.h @@ -0,0 +1,66 @@ +#include + +namespace config { +const std::string kOpenAITransformReqTemplate = + R"({ {% set first = true %} {% for key, value in input_request %} {% if key == "messages" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} {% if not first %},{% endif %} "{{ key }}": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; +const std::string kOpenAITransformRespTemplate = + R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == "id" or key == "choices" or key == "created" or key == "model" or key == "service_tier" or key == "system_fingerprint" or key == "object" or key == "usage" -%} {%- if not first -%},{%- endif -%} "{{ key }}": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; +const std::string kAnthropicTransformReqTemplate = + R"({ + {% for key, value in input_request %} + {% if key == "messages" %} + {% if input_request.messages.0.role == "system" %} + "system": "{{ input_request.messages.0.content }}", + "messages": [ + {% for message in input_request.messages %} + {% if not loop.is_first %} + {"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endif %} + {% endfor %} + ] + {% else %} + "messages": [ + {% for message in input_request.messages %} + {"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endfor %} + ] + {% endif %} + {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} + "{{ key }}": {{ tojson(value) }} + {% endif %} + {% if not loop.is_last %},{% endif %} + {% endfor %} })"; +const std::string kAnthropicTransformRespTemplate = R"({ + "id": "{{ input_request.id }}", + "created": null, + "object": "chat.completion", + "model": "{{ input_request.model }}", + "choices": [ + { + "index": 0, + "message": { + "role": "{{ input_request.role }}", + "content": "{% if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% endif %}", + "refusal": null + }, + "logprobs": null, + "finish_reason": "{{ input_request.stop_reason }}" + } + ], + "usage": { + "prompt_tokens": {{ input_request.usage.input_tokens }}, + "completion_tokens": {{ input_request.usage.output_tokens }}, + "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "system_fingerprint": "fp_6b68a8204b" + })"; + +} // namespace config \ No newline at end of file