diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 41ebb3dd6..25c0783b1 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -144,8 +144,6 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/dylib_path_manager.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/remote_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/openai_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/anthropic_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/template_renderer.cc ) diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 237596f21..df4f1a76b 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -84,8 +84,6 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/openai_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/anthropic_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/template_renderer.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc diff --git a/engine/common/engine_servicei.h b/engine/common/engine_servicei.h index 85fa87d76..a4b0c8732 100644 --- a/engine/common/engine_servicei.h +++ b/engine/common/engine_servicei.h @@ -58,4 +58,6 @@ class EngineServiceI { GetEngineByNameAndVariant( const std::string& engine_name, const std::optional variant = std::nullopt) = 0; + + virtual bool IsRemoteEngine(const std::string& engine_name) = 0; }; diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 84e175d54..a799adb27 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -8,52 +8,12 @@ #include #include #include +#include "config/remote_template.h" #include "utils/format_utils.h" #include "utils/remote_models_utils.h" namespace config { -namespace { -const std::string kOpenAITransformReqTemplate = - R"({ {% set first = true %} {% for key, value in input_request %} {% if key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; -const std::string kOpenAITransformRespTemplate = - R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == \"id\" or key == \"choices\" or key == \"created\" or key == \"model\" or key == \"service_tier\" or key == \"system_fingerprint\" or key == \"object\" or key == \"usage\" -%} {%- if not first -%},{%- endif -%} \"{{ key }}\": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; -const std::string kAnthropicTransformReqTemplate = - R"({ {% set first = true %} {% for key, value in input_request %} {% if key == \"system\" or key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; -const std::string kAnthropicTransformRespTemplate = R"({ - "id": "{{ input_request.id }}", - "created": null, - "object": "chat.completion", - "model": "{{ input_request.model }}", - "choices": [ - { - "index": 0, - "message": { - "role": "{{ input_request.role }}", - "content": "{% if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% endif %}", - "refusal": null - }, - "logprobs": null, - "finish_reason": "{{ input_request.stop_reason }}" - } - ], - "usage": { - "prompt_tokens": {{ input_request.usage.input_tokens }}, - "completion_tokens": {{ input_request.usage.output_tokens }}, - "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, - "prompt_tokens_details": { - "cached_tokens": 0 - }, - "completion_tokens_details": { - "reasoning_tokens": 0, - "accepted_prediction_tokens": 0, - "rejected_prediction_tokens": 0 - } - }, - "system_fingerprint": "fp_6b68a8204b" - })"; -} // namespace - struct RemoteModelConfig { std::string model; std::string api_key_template; @@ -108,6 +68,7 @@ struct RemoteModelConfig { kOpenAITransformRespTemplate; } } + metadata = json.get("metadata", metadata); } diff --git a/engine/config/remote_template.h b/engine/config/remote_template.h new file mode 100644 index 000000000..8a17aaa9a --- /dev/null +++ b/engine/config/remote_template.h @@ -0,0 +1,66 @@ +#include + +namespace config { +const std::string kOpenAITransformReqTemplate = + R"({ {% set first = true %} {% for key, value in input_request %} {% if key == "messages" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} {% if not first %},{% endif %} "{{ key }}": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; +const std::string kOpenAITransformRespTemplate = + R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == "id" or key == "choices" or key == "created" or key == "model" or key == "service_tier" or key == "system_fingerprint" or key == "object" or key == "usage" -%} {%- if not first -%},{%- endif -%} "{{ key }}": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; +const std::string kAnthropicTransformReqTemplate = + R"({ + {% for key, value in input_request %} + {% if key == "messages" %} + {% if input_request.messages.0.role == "system" %} + "system": "{{ input_request.messages.0.content }}", + "messages": [ + {% for message in input_request.messages %} + {% if not loop.is_first %} + {"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endif %} + {% endfor %} + ] + {% else %} + "messages": [ + {% for message in input_request.messages %} + {"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endfor %} + ] + {% endif %} + {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} + "{{ key }}": {{ tojson(value) }} + {% endif %} + {% if not loop.is_last %},{% endif %} + {% endfor %} })"; +const std::string kAnthropicTransformRespTemplate = R"({ + "id": "{{ input_request.id }}", + "created": null, + "object": "chat.completion", + "model": "{{ input_request.model }}", + "choices": [ + { + "index": 0, + "message": { + "role": "{{ input_request.role }}", + "content": "{% if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% endif %}", + "refusal": null + }, + "logprobs": null, + "finish_reason": "{{ input_request.stop_reason }}" + } + ], + "usage": { + "prompt_tokens": {{ input_request.usage.input_tokens }}, + "completion_tokens": {{ input_request.usage.output_tokens }}, + "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "system_fingerprint": "fp_6b68a8204b" + })"; + +} // namespace config \ No newline at end of file diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index affa45d52..59793b2a6 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -200,7 +200,7 @@ void Models::ListModel( .string()); auto model_config = yaml_handler.GetModelConfig(); - if (!remote_engine::IsRemoteEngine(model_config.engine)) { + if (!engine_service_->IsRemoteEngine(model_config.engine)) { Json::Value obj = model_config.ToJson(); obj["id"] = model_entry.model; obj["model"] = model_entry.model; @@ -632,7 +632,7 @@ void Models::GetRemoteModels( const HttpRequestPtr& req, std::function&& callback, const std::string& engine_id) { - if (!remote_engine::IsRemoteEngine(engine_id)) { + if (!engine_service_->IsRemoteEngine(engine_id)) { Json::Value ret; ret["message"] = "Not a remote engine: " + engine_id; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); @@ -668,8 +668,7 @@ void Models::AddRemoteModel( auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); auto engine_name = (*(req->getJsonObject())).get("engine", "").asString(); - /* To do: uncomment when remote engine is ready - + auto engine_validate = engine_service_->IsEngineReady(engine_name); if (engine_validate.has_error()) { Json::Value ret; @@ -679,6 +678,7 @@ void Models::AddRemoteModel( callback(resp); return; } + if (!engine_validate.value()) { Json::Value ret; ret["message"] = "Engine is not ready! Please install first!"; @@ -687,7 +687,7 @@ void Models::AddRemoteModel( callback(resp); return; } - */ + config::RemoteModelConfig model_config; model_config.LoadFromJson(*(req->getJsonObject())); cortex::db::Models modellist_utils_obj; diff --git a/engine/extensions/remote-engine/anthropic_engine.cc b/engine/extensions/remote-engine/anthropic_engine.cc deleted file mode 100644 index 847cba566..000000000 --- a/engine/extensions/remote-engine/anthropic_engine.cc +++ /dev/null @@ -1,62 +0,0 @@ -#include "anthropic_engine.h" -#include -#include -#include "utils/logging_utils.h" - -namespace remote_engine { -namespace { -constexpr const std::array kAnthropicModels = { - "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", - "claude-3-opus-20240229", "claude-3-sonnet-20240229", - "claude-3-haiku-20240307"}; -} -void AnthropicEngine::GetModels( - std::shared_ptr json_body, - std::function&& callback) { - Json::Value json_resp; - Json::Value model_array(Json::arrayValue); - { - std::shared_lock l(models_mtx_); - for (const auto& [m, _] : models_) { - Json::Value val; - val["id"] = m; - val["engine"] = "anthropic"; - val["start_time"] = "_"; - val["model_size"] = "_"; - val["vram"] = "_"; - val["ram"] = "_"; - val["object"] = "model"; - model_array.append(val); - } - } - - json_resp["object"] = "list"; - json_resp["data"] = model_array; - - Json::Value status; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = false; - status["status_code"] = 200; - callback(std::move(status), std::move(json_resp)); - CTL_INF("Running models responded"); -} - -Json::Value AnthropicEngine::GetRemoteModels() { - Json::Value json_resp; - Json::Value model_array(Json::arrayValue); - for (const auto& m : kAnthropicModels) { - Json::Value val; - val["id"] = std::string(m); - val["engine"] = "anthropic"; - val["created"] = "_"; - val["object"] = "model"; - model_array.append(val); - } - - json_resp["object"] = "list"; - json_resp["data"] = model_array; - CTL_INF("Remote models responded"); - return json_resp; -} -} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/anthropic_engine.h b/engine/extensions/remote-engine/anthropic_engine.h deleted file mode 100644 index bcd3dfaf7..000000000 --- a/engine/extensions/remote-engine/anthropic_engine.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once -#include "remote_engine.h" - -namespace remote_engine { - class AnthropicEngine: public RemoteEngine { -public: - void GetModels( - std::shared_ptr json_body, - std::function&& callback) override; - - Json::Value GetRemoteModels() override; - }; -} \ No newline at end of file diff --git a/engine/extensions/remote-engine/openai_engine.cc b/engine/extensions/remote-engine/openai_engine.cc deleted file mode 100644 index 7c7d70385..000000000 --- a/engine/extensions/remote-engine/openai_engine.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include "openai_engine.h" -#include "utils/logging_utils.h" - -namespace remote_engine { - -void OpenAiEngine::GetModels( - std::shared_ptr json_body, - std::function&& callback) { - Json::Value json_resp; - Json::Value model_array(Json::arrayValue); - { - std::shared_lock l(models_mtx_); - for (const auto& [m, _] : models_) { - Json::Value val; - val["id"] = m; - val["engine"] = "openai"; - val["start_time"] = "_"; - val["model_size"] = "_"; - val["vram"] = "_"; - val["ram"] = "_"; - val["object"] = "model"; - model_array.append(val); - } - } - - json_resp["object"] = "list"; - json_resp["data"] = model_array; - - Json::Value status; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = false; - status["status_code"] = 200; - callback(std::move(status), std::move(json_resp)); - CTL_INF("Running models responded"); -} - -Json::Value OpenAiEngine::GetRemoteModels() { - auto response = MakeGetModelsRequest(); - if (response.error) { - Json::Value error; - error["error"] = response.error_message; - return error; - } - Json::Value response_json; - Json::Reader reader; - if (!reader.parse(response.body, response_json)) { - Json::Value error; - error["error"] = "Failed to parse response"; - return error; - } - return response_json; -} -} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/openai_engine.h b/engine/extensions/remote-engine/openai_engine.h deleted file mode 100644 index 61dc68f0c..000000000 --- a/engine/extensions/remote-engine/openai_engine.h +++ /dev/null @@ -1,14 +0,0 @@ -#pragma once - -#include "remote_engine.h" - -namespace remote_engine { -class OpenAiEngine : public RemoteEngine { - public: - void GetModels( - std::shared_ptr json_body, - std::function&& callback) override; - - Json::Value GetRemoteModels() override; -}; -} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 04effb457..6361077dd 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -16,67 +16,14 @@ bool is_anthropic(const std::string& model) { return model.find("claude") != std::string::npos; } -struct AnthropicChunk { - std::string type; - std::string id; - int index; - std::string msg; - std::string model; - std::string stop_reason; - bool should_ignore = false; - - AnthropicChunk(const std::string& str) { - if (str.size() > 6) { - std::string s = str.substr(6); - try { - auto root = json_helper::ParseJsonString(s); - type = root["type"].asString(); - if (type == "message_start") { - id = root["message"]["id"].asString(); - model = root["message"]["model"].asString(); - } else if (type == "content_block_delta") { - index = root["index"].asInt(); - if (root["delta"]["type"].asString() == "text_delta") { - msg = root["delta"]["text"].asString(); - } - } else if (type == "message_delta") { - stop_reason = root["delta"]["stop_reason"].asString(); - } else { - // ignore other messages - should_ignore = true; - } - } catch (const std::exception& e) { - should_ignore = true; - CTL_WRN("JSON parse error: " << e.what()); - } - } else { - should_ignore = true; - } - } +bool is_openai(const std::string& model) { + return model.find("gpt") != std::string::npos; +} - std::string ToOpenAiFormatString() { - Json::Value root; - root["id"] = id; - root["object"] = "chat.completion.chunk"; - root["created"] = Json::Value(); - root["model"] = model; - root["system_fingerprint"] = "fp_e76890f0c3"; - Json::Value choices(Json::arrayValue); - Json::Value choice; - Json::Value content; - choice["index"] = 0; - content["content"] = msg; - if (type == "message_start") { - content["role"] = "assistant"; - content["refusal"] = Json::Value(); - } - choice["delta"] = content; - choice["finish_reason"] = stop_reason.empty() ? Json::Value() : stop_reason; - choices.append(choice); - root["choices"] = choices; - return "data: " + json_helper::DumpJsonString(root); - } -}; +constexpr const std::array kAnthropicModels = { + "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", + "claude-3-opus-20240229", "claude-3-sonnet-20240229", + "claude-3-haiku-20240307"}; } // namespace @@ -92,21 +39,13 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, while ((pos = context->buffer.find('\n')) != std::string::npos) { std::string line = context->buffer.substr(0, pos); context->buffer = context->buffer.substr(pos + 1); - CTL_TRC(line); // Skip empty lines if (line.empty() || line == "\r" || line.find("event:") != std::string::npos) continue; - // Remove "data: " prefix if present - // if (line.substr(0, 6) == "data: ") - // { - // line = line.substr(6); - // } - - // Skip [DONE] message - // std::cout << line << std::endl; + CTL_DBG(line); if (line == "data: [DONE]" || line.find("message_stop") != std::string::npos) { Json::Value status; @@ -120,17 +59,20 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, // Parse the JSON Json::Value chunk_json; - if (is_anthropic(context->model)) { - AnthropicChunk ac(line); - if (ac.should_ignore) + if (!is_openai(context->model)) { + std::string s = line.substr(6); + try { + auto root = json_helper::ParseJsonString(s); + root["model"] = context->model; + root["id"] = context->id; + root["stream"] = true; + auto result = context->renderer.Render(context->stream_template, root); + CTL_DBG(result); + chunk_json["data"] = "data: " + result + "\n\n"; + } catch (const std::exception& e) { + CTL_WRN("JSON parse error: " << e.what()); continue; - ac.model = context->model; - if (ac.type == "message_start") { - context->id = ac.id; - } else { - ac.id = context->id; } - chunk_json["data"] = ac.ToOpenAiFormatString() + "\n\n"; } else { chunk_json["data"] = line + "\n\n"; } @@ -178,10 +120,16 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( headers = curl_slist_append(headers, "Cache-Control: no-cache"); headers = curl_slist_append(headers, "Connection: keep-alive"); + std::string stream_template = chat_res_template_; + StreamContext context{ std::make_shared>( callback), - "", "", config.model}; + "", + "", + config.model, + renderer_, + stream_template}; curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); @@ -232,7 +180,8 @@ static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, return size * nmemb; } -RemoteEngine::RemoteEngine() { +RemoteEngine::RemoteEngine(const std::string& engine_name) + : engine_name_(engine_name) { curl_global_init(CURL_GLOBAL_ALL); } @@ -395,7 +344,33 @@ bool RemoteEngine::LoadModelConfig(const std::string& model, void RemoteEngine::GetModels( std::shared_ptr json_body, std::function&& callback) { - CTL_WRN("Not implemented yet!"); + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + { + std::shared_lock l(models_mtx_); + for (const auto& [m, _] : models_) { + Json::Value val; + val["id"] = m; + val["engine"] = "openai"; + val["start_time"] = "_"; + val["model_size"] = "_"; + val["vram"] = "_"; + val["ram"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + callback(std::move(status), std::move(json_resp)); + CTL_INF("Running models responded"); } void RemoteEngine::LoadModel( @@ -431,6 +406,21 @@ void RemoteEngine::LoadModel( } if (json_body->isMember("metadata")) { metadata_ = (*json_body)["metadata"]; + if (!metadata_["TransformReq"].isNull() && + !metadata_["TransformReq"]["chat_completions"].isNull() && + !metadata_["TransformReq"]["chat_completions"]["template"].isNull()) { + chat_req_template_ = + metadata_["TransformReq"]["chat_completions"]["template"].asString(); + CTL_INF(chat_req_template_); + } + + if (!metadata_["TransformResp"].isNull() && + !metadata_["TransformResp"]["chat_completions"].isNull() && + !metadata_["TransformResp"]["chat_completions"]["template"].isNull()) { + chat_res_template_ = + metadata_["TransformResp"]["chat_completions"]["template"].asString(); + CTL_INF(chat_res_template_); + } } Json::Value response; @@ -535,23 +525,6 @@ void RemoteEngine::HandleChatCompletion( std::string(e.what())); } - // Parse system for anthropic - if (is_anthropic(model)) { - bool has_system = false; - Json::Value msgs(Json::arrayValue); - for (auto& kv : (*json_body)["messages"]) { - if (kv["role"].asString() == "system") { - (*json_body)["system"] = kv["content"].asString(); - has_system = true; - } else { - msgs.append(kv); - } - } - if (has_system) { - (*json_body)["messages"] = msgs; - } - } - // Render with error handling try { result = renderer_.Render(template_str, *json_body); @@ -601,33 +574,42 @@ void RemoteEngine::HandleChatCompletion( // Transform Response std::string response_str; try { - // Check if required YAML nodes exist - if (!model_config->transform_resp["chat_completions"]) { - throw std::runtime_error( - "Missing 'chat_completions' node in transform_resp"); - } - if (!model_config->transform_resp["chat_completions"]["template"]) { - throw std::runtime_error("Missing 'template' node in chat_completions"); - } + std::string template_str; + if (!chat_res_template_.empty()) { + CTL_DBG( + "Use engine transform response template: " << chat_res_template_); + template_str = chat_res_template_; + } else { + // Check if required YAML nodes exist + if (!model_config->transform_resp["chat_completions"]) { + throw std::runtime_error( + "Missing 'chat_completions' node in transform_resp"); + } + if (!model_config->transform_resp["chat_completions"]["template"]) { + throw std::runtime_error( + "Missing 'template' node in chat_completions"); + } - // Validate JSON body - if (!response_json || response_json.isNull()) { - throw std::runtime_error("Invalid or null JSON body"); - } + // Validate JSON body + if (!response_json || response_json.isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } - // Get template string with error check - std::string template_str; - try { - template_str = - model_config->transform_resp["chat_completions"]["template"] - .as(); - } catch (const YAML::BadConversion& e) { - throw std::runtime_error("Failed to convert template node to string: " + - std::string(e.what())); + // Get template string with error check + + try { + template_str = + model_config->transform_resp["chat_completions"]["template"] + .as(); + } catch (const YAML::BadConversion& e) { + throw std::runtime_error( + "Failed to convert template node to string: " + + std::string(e.what())); + } } - // Render with error handling try { + response_json["stream"] = false; response_str = renderer_.Render(template_str, response_json); } catch (const std::exception& e) { throw std::runtime_error("Template rendering error: " + @@ -705,8 +687,43 @@ void RemoteEngine::HandleEmbedding( } Json::Value RemoteEngine::GetRemoteModels() { - CTL_WRN("Not implemented yet!"); - return {}; + if (metadata_["get_models_url"].isNull() || + metadata_["get_models_url"].asString().empty()) { + if (engine_name_ == kAnthropicEngine) { + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + for (const auto& m : kAnthropicModels) { + Json::Value val; + val["id"] = std::string(m); + val["engine"] = "anthropic"; + val["created"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + CTL_INF("Remote models responded"); + return json_resp; + } else { + return Json::Value(); + } + } else { + auto response = MakeGetModelsRequest(); + if (response.error) { + Json::Value error; + error["error"] = response.error_message; + return error; + } + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value error; + error["error"] = "Failed to parse response"; + return error; + } + return response_json; + } } } // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 8ce6fa652..d8dfbad61 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -14,9 +14,6 @@ // Helper for CURL response namespace remote_engine { -inline bool IsRemoteEngine(std::string_view e) { - return e == kAnthropicEngine || e == kOpenAiEngine; -} struct StreamContext { std::shared_ptr> callback; @@ -24,6 +21,8 @@ struct StreamContext { // Cache value for Anthropic std::string id; std::string model; + TemplateRenderer& renderer; + std::string stream_template; }; struct CurlResponse { std::string body; @@ -49,8 +48,10 @@ class RemoteEngine : public RemoteEngineI { std::unordered_map models_; TemplateRenderer renderer_; Json::Value metadata_; + std::string chat_req_template_; + std::string chat_res_template_; std::string api_key_template_; - std::unique_ptr async_file_logger_; + std::string engine_name_; // Helper functions CurlResponse MakeChatCompletionRequest(const ModelConfig& config, @@ -67,7 +68,7 @@ class RemoteEngine : public RemoteEngineI { ModelConfig* GetModelConfig(const std::string& model); public: - RemoteEngine(); + explicit RemoteEngine(const std::string& engine_name); virtual ~RemoteEngine(); // Main interface implementations @@ -95,7 +96,7 @@ class RemoteEngine : public RemoteEngineI { void HandleEmbedding( std::shared_ptr json_body, std::function&& callback) override; - + Json::Value GetRemoteModels() override; }; diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 035ef4a4e..1f3e4d81c 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -6,8 +6,7 @@ #include #include "algorithm" #include "database/engines.h" -#include "extensions/remote-engine/anthropic_engine.h" -#include "extensions/remote-engine/openai_engine.h" +#include "extensions/remote-engine/remote_engine.h" #include "utils/archive_utils.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" @@ -187,7 +186,7 @@ cpp::result EngineService::UninstallEngineVariant( // TODO: handle uninstall remote engine // only delete a remote engine if no model are using it auto exist_engine = GetEngineByNameAndVariant(engine); - if (exist_engine.has_value() && exist_engine.value().type == "remote") { + if (exist_engine.has_value() && exist_engine.value().type == kRemote) { auto result = DeleteEngine(exist_engine.value().id); if (!result.empty()) { // This mean no error when delete model CTL_ERR("Failed to delete engine: " << result); @@ -333,15 +332,9 @@ cpp::result EngineService::DownloadEngine( } else { CTL_INF("Set default engine variant: " << res.value().variant); } - auto create_res = - EngineService::UpsertEngine(engine, // engine_name - "local", // todo - luke - "", // todo - luke - "", // todo - luke - normalize_version, variant.value(), - "Default", // todo - luke - "" // todo - luke - ); + auto create_res = EngineService::UpsertEngine( + engine, // engine_name + kLocal, "", "", normalize_version, variant.value(), "Default", ""); if (create_res.has_value()) { CTL_ERR("Failed to create engine entry: " << create_res->engine_name); @@ -683,17 +676,13 @@ cpp::result EngineService::LoadEngine( } // Check for remote engine - if (remote_engine::IsRemoteEngine(engine_name)) { + if (IsRemoteEngine(engine_name)) { auto exist_engine = GetEngineByNameAndVariant(engine_name); if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine_name + "' is not installed"); } - if (engine_name == kOpenAiEngine) { - engines_[engine_name].engine = new remote_engine::OpenAiEngine(); - } else { - engines_[engine_name].engine = new remote_engine::AnthropicEngine(); - } + engines_[engine_name].engine = new remote_engine::RemoteEngine(engine_name); CTL_INF("Loaded engine: " << engine_name); return {}; @@ -899,7 +888,7 @@ cpp::result EngineService::IsEngineReady( auto ne = NormalizeEngine(engine); // Check for remote engine - if (remote_engine::IsRemoteEngine(engine)) { + if (IsRemoteEngine(engine)) { auto exist_engine = GetEngineByNameAndVariant(engine); if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine + "' is not installed"); @@ -1075,11 +1064,7 @@ cpp::result EngineService::GetRemoteModels( if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine_name + "' is not installed"); } - if (engine_name == kOpenAiEngine) { - engines_[engine_name].engine = new remote_engine::OpenAiEngine(); - } else { - engines_[engine_name].engine = new remote_engine::AnthropicEngine(); - } + engines_[engine_name].engine = new remote_engine::RemoteEngine(engine_name); CTL_INF("Loaded engine: " << engine_name); } @@ -1092,6 +1077,16 @@ cpp::result EngineService::GetRemoteModels( } } +bool EngineService::IsRemoteEngine(const std::string& engine_name) { + auto ne = Repo2Engine(engine_name); + auto local_engines = file_manager_utils::GetCortexConfig().supportedEngines; + for (auto const& le : local_engines) { + if (le == ne) + return false; + } + return true; +} + cpp::result, std::string> EngineService::GetSupportedEngineNames() { return file_manager_utils::GetCortexConfig().supportedEngines; diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 9253eccf1..527123cb5 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -153,6 +153,8 @@ class EngineService : public EngineServiceI { void RegisterEngineLibPath(); + bool IsRemoteEngine(const std::string& engine_name) override; + private: bool IsEngineLoaded(const std::string& engine); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 6a45733d3..ce83152c4 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -773,7 +773,7 @@ cpp::result ModelService::StartModel( auto mc = yaml_handler.GetModelConfig(); // Running remote model - if (remote_engine::IsRemoteEngine(mc.engine)) { + if (engine_svc_->IsRemoteEngine(mc.engine)) { config::RemoteModelConfig remote_mc; remote_mc.LoadFromYamlFile( diff --git a/engine/test/components/CMakeLists.txt b/engine/test/components/CMakeLists.txt index 58c5d83d6..0df46cfc2 100644 --- a/engine/test/components/CMakeLists.txt +++ b/engine/test/components/CMakeLists.txt @@ -16,6 +16,7 @@ add_executable(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/file_manager_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/curl_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/system_info_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../extensions/remote-engine/template_renderer.cc ) find_package(Drogon CONFIG REQUIRED) diff --git a/engine/test/components/main.cc b/engine/test/components/main.cc index 08080680e..ba24a3e01 100644 --- a/engine/test/components/main.cc +++ b/engine/test/components/main.cc @@ -4,11 +4,15 @@ int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); +#if defined(NDEBUG) ::testing::GTEST_FLAG(filter) = "-FileManagerConfigTest.*"; int ret = RUN_ALL_TESTS(); if (ret != 0) return ret; ::testing::GTEST_FLAG(filter) = "FileManagerConfigTest.*"; ret = RUN_ALL_TESTS(); +#else + int ret = RUN_ALL_TESTS(); +#endif return ret; } diff --git a/engine/test/components/test_remote_engine.cc b/engine/test/components/test_remote_engine.cc new file mode 100644 index 000000000..bfac76f49 --- /dev/null +++ b/engine/test/components/test_remote_engine.cc @@ -0,0 +1,81 @@ +#include "extensions/remote-engine/template_renderer.h" +#include "gtest/gtest.h" +#include "utils/json_helper.h" + +class RemoteEngineTest : public ::testing::Test {}; + +TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) { + std::string tpl = + R"({ + {% for key, value in input_request %} + {% if key == "messages" %} + {% if input_request.messages.0.role == "system" %} + "system": "{{ input_request.messages.0.content }}", + "messages": [ + {% for message in input_request.messages %} + {% if not loop.is_first %} + {"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endif %} + {% endfor %} + ] + {% else %} + "messages": [ + {% for message in input_request.messages %} + {"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endfor %} + ] + {% endif %} + {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} + "{{ key }}": {{ tojson(value) }} + {% endif %} + {% if not loop.is_last %},{% endif %} + {% endfor %} })"; + { + std::string message_with_system = R"({ + "messages": [ + {"role": "system", "content": "You are a seasoned data scientist at a Fortune 500 company."}, + {"role": "user", "content": "Hello, world"} + ], + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, +})"; + + auto data = json_helper::ParseJsonString(message_with_system); + + remote_engine::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt()); + for (auto const& msg : data["messages"]) { + if (msg["role"].asString() == "system") { + EXPECT_EQ(msg["content"].asString(), res_json["system"].asString()); + } else if (msg["role"].asString() == "user") { + EXPECT_EQ(msg["content"].asString(), + res_json["messages"][0]["content"].asString()); + } + } + } + + { + std::string message_without_system = R"({ + "messages": [ + {"role": "user", "content": "Hello, world"} + ], + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, +})"; + + auto data = json_helper::ParseJsonString(message_without_system); + + remote_engine::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt()); + EXPECT_EQ(data["messages"][0]["content"].asString(), + res_json["messages"][0]["content"].asString()); + } +} \ No newline at end of file diff --git a/engine/utils/engine_constants.h b/engine/utils/engine_constants.h index 020109fd8..dcdf6a443 100644 --- a/engine/utils/engine_constants.h +++ b/engine/utils/engine_constants.h @@ -6,6 +6,9 @@ constexpr const auto kTrtLlmEngine = "tensorrt-llm"; constexpr const auto kOpenAiEngine = "openai"; constexpr const auto kAnthropicEngine = "anthropic"; +constexpr const auto kRemote = "remote"; +constexpr const auto kLocal = "local"; + constexpr const auto kOnnxRepo = "cortex.onnx"; constexpr const auto kLlamaRepo = "cortex.llamacpp"; constexpr const auto kTrtLlmRepo = "cortex.tensorrt-llm";