diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 29ec8c0ed..eb4a325c9 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -402,8 +402,11 @@ void Models::StartModel( return; auto config = file_manager_utils::GetCortexConfig(); auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); + auto custom_prompt_template = + (*(req->getJsonObject())).get("prompt_template", "").asString(); auto result = model_service_->StartModel( - config.apiServerHost, std::stoi(config.apiServerPort), model_handle); + config.apiServerHost, std::stoi(config.apiServerPort), model_handle, + custom_prompt_template); if (result.has_error()) { Json::Value ret; ret["message"] = result.error(); diff --git a/engine/controllers/swagger.cc b/engine/controllers/swagger.cc index 27d9b7b15..527ae60b8 100644 --- a/engine/controllers/swagger.cc +++ b/engine/controllers/swagger.cc @@ -202,7 +202,8 @@ Json::Value SwaggerController::generateOpenAPISpec() { responses["200"]["description"] = "Model details retrieved successfully"; Json::Value& schema = responses["200"]["content"]["application/json"]["schema"]; - responses["responses"]["400"]["description"] = "Failed to get model information"; + responses["responses"]["400"]["description"] = + "Failed to get model information"; responses["400"]["description"] = "Failed to get model information"; responses["400"]["content"]["application/json"]["schema"]["type"] = @@ -450,6 +451,8 @@ Json::Value SwaggerController::generateOpenAPISpec() { "object"; start["requestBody"]["content"]["application/json"]["schema"]["properties"] ["model"]["type"] = "string"; + start["requestBody"]["content"]["application/json"]["schema"]["properties"] + ["prompt_template"]["type"] = "string"; start["requestBody"]["content"]["application/json"]["schema"]["required"] = Json::Value(Json::arrayValue); start["requestBody"]["content"]["application/json"]["schema"]["required"] @@ -458,12 +461,12 @@ Json::Value SwaggerController::generateOpenAPISpec() { start["responses"]["400"]["description"] = "Failed to start model"; // Stop Model - Json::Value& stop = spec["paths"]["/v1/models/stop"]["post"]; + Json::Value& stop = spec["paths"]["/v1/models/stop"]["post"]; stop["summary"] = "Stop model"; stop["requestBody"]["content"]["application/json"]["schema"]["type"] = "object"; stop["requestBody"]["content"]["application/json"]["schema"]["properties"] - ["model"]["type"] = "string"; + ["model"]["type"] = "string"; stop["requestBody"]["content"]["application/json"]["schema"]["required"] = Json::Value(Json::arrayValue); stop["requestBody"]["content"]["application/json"]["schema"]["required"] diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index df159c84f..bf6f5e9af 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -550,7 +550,8 @@ cpp::result ModelService::DeleteModel( } cpp::result ModelService::StartModel( - const std::string& host, int port, const std::string& model_handle) { + const std::string& host, int port, const std::string& model_handle, + std::optional custom_prompt_template) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; cortex::db::Models modellist_handler; @@ -580,9 +581,17 @@ cpp::result ModelService::StartModel( return false; } json_data["model"] = model_handle; - json_data["system_prompt"] = mc.system_template; - json_data["user_prompt"] = mc.user_template; - json_data["ai_prompt"] = mc.ai_template; + if (!custom_prompt_template.value_or("").empty()) { + auto parse_prompt_result = + string_utils::ParsePrompt(custom_prompt_template.value()); + json_data["system_prompt"] = parse_prompt_result.system_prompt; + json_data["user_prompt"] = parse_prompt_result.user_prompt; + json_data["ai_prompt"] = parse_prompt_result.ai_prompt; + } else { + json_data["system_prompt"] = mc.system_template; + json_data["user_prompt"] = mc.user_template; + json_data["ai_prompt"] = mc.ai_template; + } auto data_str = json_data.toStyledString(); CTL_INF(data_str); diff --git a/engine/services/model_service.h b/engine/services/model_service.h index ca33a7796..c791c7fd7 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -1,10 +1,10 @@ #pragma once #include +#include #include #include "config/model_config.h" #include "services/download_service.h" - class ModelService { public: constexpr auto static kHuggingFaceHost = "huggingface.co"; @@ -34,8 +34,9 @@ class ModelService { */ cpp::result DeleteModel(const std::string& model_handle); - cpp::result StartModel(const std::string& host, int port, - const std::string& model_handle); + cpp::result StartModel( + const std::string& host, int port, const std::string& model_handle, + std::optional custom_prompt_template = std::nullopt); cpp::result StopModel(const std::string& host, int port, const std::string& model_handle); diff --git a/engine/test/components/test_string_utils.cc b/engine/test/components/test_string_utils.cc index a4f484e36..3d6abeddf 100644 --- a/engine/test/components/test_string_utils.cc +++ b/engine/test/components/test_string_utils.cc @@ -3,6 +3,22 @@ #include "gtest/gtest.h" #include "utils/string_utils.h" class StringUtilsTestSuite : public ::testing::Test {}; +TEST_F(StringUtilsTestSuite, ParsePrompt) { + { + std::string prompt = + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{" + "system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{" + "prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"; + auto result = string_utils::ParsePrompt(prompt); + EXPECT_EQ(result.user_prompt, + "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"); + EXPECT_EQ(result.ai_prompt, + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); + EXPECT_EQ( + result.system_prompt, + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"); + } +} TEST_F(StringUtilsTestSuite, TestSplitBy) { auto input = "this is a test"; diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h index 5295745b0..4682623c1 100644 --- a/engine/utils/string_utils.h +++ b/engine/utils/string_utils.h @@ -7,6 +7,22 @@ namespace string_utils { +struct ParsePromptResult { + std::string user_prompt; + std::string system_prompt; + std::string ai_prompt; +}; + +inline ParsePromptResult ParsePrompt(const std::string& prompt) { + auto& pt = prompt; + ParsePromptResult result; + result.user_prompt = + pt.substr(pt.find_first_of('}') + 1, + pt.find_last_of('{') - pt.find_first_of('}') - 1); + result.ai_prompt = pt.substr(pt.find_last_of('}') + 1); + result.system_prompt = pt.substr(0, pt.find_first_of('{')); + return result; +} inline bool StartsWith(const std::string& str, const std::string& prefix) { return str.rfind(prefix, 0) == 0; }