From cb8f25f8fa2ad7cd1c69f89c17a9151aae985187 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 7 Oct 2024 15:25:57 +0700 Subject: [PATCH 1/7] chore: change update to patch --- engine/controllers/models.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engine/controllers/models.h b/engine/controllers/models.h index c2c804170..167d4bb36 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -14,7 +14,7 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::PullModel, "/pull", Post); METHOD_ADD(Models::ListModel, "", Get); METHOD_ADD(Models::GetModel, "/{1}", Get); - METHOD_ADD(Models::UpdateModel, "/{1}", Post); + METHOD_ADD(Models::UpdateModel, "/{1}", Patch); METHOD_ADD(Models::ImportModel, "/import", Post); METHOD_ADD(Models::DeleteModel, "/{1}", Delete); METHOD_ADD(Models::SetModelAlias, "/alias", Post); @@ -24,7 +24,7 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::PullModel, "/v1/models/pull", Post); ADD_METHOD_TO(Models::ListModel, "/v1/models", Get); ADD_METHOD_TO(Models::GetModel, "/v1/models/{1}", Get); - ADD_METHOD_TO(Models::UpdateModel, "/v1/models/{1}", Post); + ADD_METHOD_TO(Models::UpdateModel, "/v1/models/{1}", Patch); ADD_METHOD_TO(Models::ImportModel, "/v1/models/import", Post); ADD_METHOD_TO(Models::DeleteModel, "/v1/models/{1}", Delete); ADD_METHOD_TO(Models::SetModelAlias, "/v1/models/alias", Post); From 7f0349ce36e2e42fb51c9a9631185a9ba3ee77f1 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 7 Oct 2024 15:28:25 +0700 Subject: [PATCH 2/7] fix: swagger --- engine/controllers/swagger.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/controllers/swagger.cc b/engine/controllers/swagger.cc index 89733b65c..27d9b7b15 100644 --- a/engine/controllers/swagger.cc +++ b/engine/controllers/swagger.cc @@ -211,7 +211,7 @@ Json::Value SwaggerController::generateOpenAPISpec() { ["message"]["type"] = "string"; // UpdateModel Endpoint - Json::Value& update = spec["paths"]["/v1/models/{model}"]["post"]; + Json::Value& update = spec["paths"]["/v1/models/{model}"]["patch"]; update["summary"] = "Update model details"; update["description"] = "Update various attributes of a model based on the ModelConfig " From 8d8846d1ed41600f5e20cfb990e9c6daa5d9a75f Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 7 Oct 2024 16:57:38 +0700 Subject: [PATCH 3/7] fix: pull api --- engine/services/model_service.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index d9c2aa48f..1b0946133 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -310,7 +310,7 @@ cpp::result ModelService::DownloadModelFromCortexso( } std::string model_id{name + ":" + branch}; - auto on_finished = [&](const DownloadTask& finishedTask) { + auto on_finished = [&, model_id](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; From 054eef631979f21563b73c3ec9dc268719b18f93 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Wed, 16 Oct 2024 10:54:13 +0700 Subject: [PATCH 4/7] feat: support custom prompt template --- engine/controllers/models.cc | 5 ++++- engine/controllers/swagger.cc | 9 ++++++--- engine/services/model_service.cc | 19 +++++++++++++++---- engine/services/model_service.h | 7 ++++--- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index eefc0a941..ef219fbdc 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -369,8 +369,11 @@ void Models::StartModel( return; auto config = file_manager_utils::GetCortexConfig(); auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); + auto custom_prompt_template = + (*(req->getJsonObject())).get("prompt_template", "").asString(); auto result = model_service_->StartModel( - config.apiServerHost, std::stoi(config.apiServerPort), model_handle); + config.apiServerHost, std::stoi(config.apiServerPort), model_handle, + custom_prompt_template); if (result.has_error()) { Json::Value ret; ret["message"] = result.error(); diff --git a/engine/controllers/swagger.cc b/engine/controllers/swagger.cc index 27d9b7b15..527ae60b8 100644 --- a/engine/controllers/swagger.cc +++ b/engine/controllers/swagger.cc @@ -202,7 +202,8 @@ Json::Value SwaggerController::generateOpenAPISpec() { responses["200"]["description"] = "Model details retrieved successfully"; Json::Value& schema = responses["200"]["content"]["application/json"]["schema"]; - responses["responses"]["400"]["description"] = "Failed to get model information"; + responses["responses"]["400"]["description"] = + "Failed to get model information"; responses["400"]["description"] = "Failed to get model information"; responses["400"]["content"]["application/json"]["schema"]["type"] = @@ -450,6 +451,8 @@ Json::Value SwaggerController::generateOpenAPISpec() { "object"; start["requestBody"]["content"]["application/json"]["schema"]["properties"] ["model"]["type"] = "string"; + start["requestBody"]["content"]["application/json"]["schema"]["properties"] + ["prompt_template"]["type"] = "string"; start["requestBody"]["content"]["application/json"]["schema"]["required"] = Json::Value(Json::arrayValue); start["requestBody"]["content"]["application/json"]["schema"]["required"] @@ -458,12 +461,12 @@ Json::Value SwaggerController::generateOpenAPISpec() { start["responses"]["400"]["description"] = "Failed to start model"; // Stop Model - Json::Value& stop = spec["paths"]["/v1/models/stop"]["post"]; + Json::Value& stop = spec["paths"]["/v1/models/stop"]["post"]; stop["summary"] = "Stop model"; stop["requestBody"]["content"]["application/json"]["schema"]["type"] = "object"; stop["requestBody"]["content"]["application/json"]["schema"]["properties"] - ["model"]["type"] = "string"; + ["model"]["type"] = "string"; stop["requestBody"]["content"]["application/json"]["schema"]["required"] = Json::Value(Json::arrayValue); stop["requestBody"]["content"]["application/json"]["schema"]["required"] diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 1d970ae89..374f3242c 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -446,7 +446,8 @@ cpp::result ModelService::DeleteModel( } cpp::result ModelService::StartModel( - const std::string& host, int port, const std::string& model_handle) { + const std::string& host, int port, const std::string& model_handle, + std::optional custom_prompt_template) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; cortex::db::Models modellist_handler; @@ -476,9 +477,19 @@ cpp::result ModelService::StartModel( return false; } json_data["model"] = model_handle; - json_data["system_prompt"] = mc.system_template; - json_data["user_prompt"] = mc.user_template; - json_data["ai_prompt"] = mc.ai_template; + if (custom_prompt_template.has_value() && + !custom_prompt_template.value_or("").empty()) { + auto& pt = custom_prompt_template.value(); + json_data["system_prompt"] = pt.substr(0, pt.find_first_of('{')); + json_data["user_prompt"] = + pt.substr(pt.find_first_of('}') + 1, + pt.find_last_of('{') - pt.find_first_of('}') - 1); + json_data["ai_prompt"] = pt.substr(pt.find_last_of('}') + 1); + } else { + json_data["system_prompt"] = mc.system_template; + json_data["user_prompt"] = mc.user_template; + json_data["ai_prompt"] = mc.ai_template; + } auto data_str = json_data.toStyledString(); CTL_INF(data_str); diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 319260584..cebc494cd 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -1,10 +1,10 @@ #pragma once #include +#include #include #include "config/model_config.h" #include "services/download_service.h" - class ModelService { public: constexpr auto static kHuggingFaceHost = "huggingface.co"; @@ -31,8 +31,9 @@ class ModelService { */ cpp::result DeleteModel(const std::string& model_handle); - cpp::result StartModel(const std::string& host, int port, - const std::string& model_handle); + cpp::result StartModel( + const std::string& host, int port, const std::string& model_handle, + std::optional custom_prompt_template = std::nullopt); cpp::result StopModel(const std::string& host, int port, const std::string& model_handle); From c18b3b1c33ae26e246b6f921d5143f7c90f059ba Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Wed, 16 Oct 2024 11:50:05 +0700 Subject: [PATCH 5/7] fix comment and add unitests --- engine/services/model_service.cc | 14 +++++------ engine/test/components/test_string_utils.cc | 28 ++++++++++++++++++++- engine/utils/string_utils.h | 13 ++++++++++ 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 374f3242c..58b8852b3 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -477,14 +477,14 @@ cpp::result ModelService::StartModel( return false; } json_data["model"] = model_handle; - if (custom_prompt_template.has_value() && - !custom_prompt_template.value_or("").empty()) { - auto& pt = custom_prompt_template.value(); - json_data["system_prompt"] = pt.substr(0, pt.find_first_of('{')); + if (!custom_prompt_template.value_or("").empty()) { + + json_data["system_prompt"] = + string_utils::ParseSystemPrompt(custom_prompt_template.value()); json_data["user_prompt"] = - pt.substr(pt.find_first_of('}') + 1, - pt.find_last_of('{') - pt.find_first_of('}') - 1); - json_data["ai_prompt"] = pt.substr(pt.find_last_of('}') + 1); + string_utils::ParseUserPrompt(custom_prompt_template.value()); + json_data["ai_prompt"] = + string_utils::ParseAIPrompt(custom_prompt_template.value()); } else { json_data["system_prompt"] = mc.system_template; json_data["user_prompt"] = mc.user_template; diff --git a/engine/test/components/test_string_utils.cc b/engine/test/components/test_string_utils.cc index a4f484e36..708d408ec 100644 --- a/engine/test/components/test_string_utils.cc +++ b/engine/test/components/test_string_utils.cc @@ -2,8 +2,34 @@ #include #include "gtest/gtest.h" #include "utils/string_utils.h" -class StringUtilsTestSuite : public ::testing::Test {}; +class StringUtilsTestSuite : public ::testing::Test { + protected: + std::string prompt = + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_" + "message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|" + "eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"; +}; +TEST_F(StringUtilsTestSuite, ParseUserPrompt) { + { + EXPECT_EQ(string_utils::ParseUserPrompt(prompt), + "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"); + } +} + +TEST_F(StringUtilsTestSuite, ParseSystemPrompt) { + { + EXPECT_EQ( + string_utils::ParseSystemPrompt(prompt), + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"); + } +} +TEST_F(StringUtilsTestSuite, ParseAIPrompt) { + { + EXPECT_EQ(string_utils::ParseAIPrompt(prompt), + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); + } +} TEST_F(StringUtilsTestSuite, TestSplitBy) { auto input = "this is a test"; std::string delimiter{' '}; diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h index 5295745b0..9337b4f32 100644 --- a/engine/utils/string_utils.h +++ b/engine/utils/string_utils.h @@ -7,6 +7,19 @@ namespace string_utils { +inline std::string ParseUserPrompt(const std::string& prompt) { + auto& pt = prompt; + return pt.substr(pt.find_first_of('}') + 1, + pt.find_last_of('{') - pt.find_first_of('}') - 1); +} +inline std::string ParseSystemPrompt(const std::string& prompt) { + auto& pt = prompt; + return pt.substr(0, pt.find_first_of('{')); +} +inline std::string ParseAIPrompt(const std::string& prompt) { + auto& pt = prompt; + return pt.substr(pt.find_last_of('}') + 1); +} inline bool StartsWith(const std::string& str, const std::string& prefix) { return str.rfind(prefix, 0) == 0; } From 58dcf7b443f6f9339e8d4d0e058dec226790df26 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Wed, 16 Oct 2024 12:00:44 +0700 Subject: [PATCH 6/7] fix: comment --- engine/services/model_service.cc | 12 +++++----- engine/test/components/test_string_utils.cc | 20 +++++------------ engine/utils/string_utils.h | 25 ++++++++++++--------- 3 files changed, 25 insertions(+), 32 deletions(-) diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index e8e49ff94..bf6f5e9af 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -582,13 +582,11 @@ cpp::result ModelService::StartModel( } json_data["model"] = model_handle; if (!custom_prompt_template.value_or("").empty()) { - - json_data["system_prompt"] = - string_utils::ParseSystemPrompt(custom_prompt_template.value()); - json_data["user_prompt"] = - string_utils::ParseUserPrompt(custom_prompt_template.value()); - json_data["ai_prompt"] = - string_utils::ParseAIPrompt(custom_prompt_template.value()); + auto parse_prompt_result = + string_utils::ParsePrompt(custom_prompt_template.value()); + json_data["system_prompt"] = parse_prompt_result.system_prompt; + json_data["user_prompt"] = parse_prompt_result.user_prompt; + json_data["ai_prompt"] = parse_prompt_result.ai_prompt; } else { json_data["system_prompt"] = mc.system_template; json_data["user_prompt"] = mc.user_template; diff --git a/engine/test/components/test_string_utils.cc b/engine/test/components/test_string_utils.cc index 708d408ec..e0d403489 100644 --- a/engine/test/components/test_string_utils.cc +++ b/engine/test/components/test_string_utils.cc @@ -9,27 +9,19 @@ class StringUtilsTestSuite : public ::testing::Test { "message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|" "eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"; }; -TEST_F(StringUtilsTestSuite, ParseUserPrompt) { +TEST_F(StringUtilsTestSuite, ParsePrompt) { { - EXPECT_EQ(string_utils::ParseUserPrompt(prompt), + auto result = string_utils::ParsePrompt(prompt); + EXPECT_EQ(result.user_prompt, "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"); - } -} - -TEST_F(StringUtilsTestSuite, ParseSystemPrompt) { - { + EXPECT_EQ(result.ai_prompt, + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); EXPECT_EQ( - string_utils::ParseSystemPrompt(prompt), + result.system_prompt, "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"); } } -TEST_F(StringUtilsTestSuite, ParseAIPrompt) { - { - EXPECT_EQ(string_utils::ParseAIPrompt(prompt), - "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); - } -} TEST_F(StringUtilsTestSuite, TestSplitBy) { auto input = "this is a test"; std::string delimiter{' '}; diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h index 9337b4f32..4682623c1 100644 --- a/engine/utils/string_utils.h +++ b/engine/utils/string_utils.h @@ -7,18 +7,21 @@ namespace string_utils { -inline std::string ParseUserPrompt(const std::string& prompt) { - auto& pt = prompt; - return pt.substr(pt.find_first_of('}') + 1, - pt.find_last_of('{') - pt.find_first_of('}') - 1); -} -inline std::string ParseSystemPrompt(const std::string& prompt) { - auto& pt = prompt; - return pt.substr(0, pt.find_first_of('{')); -} -inline std::string ParseAIPrompt(const std::string& prompt) { +struct ParsePromptResult { + std::string user_prompt; + std::string system_prompt; + std::string ai_prompt; +}; + +inline ParsePromptResult ParsePrompt(const std::string& prompt) { auto& pt = prompt; - return pt.substr(pt.find_last_of('}') + 1); + ParsePromptResult result; + result.user_prompt = + pt.substr(pt.find_first_of('}') + 1, + pt.find_last_of('{') - pt.find_first_of('}') - 1); + result.ai_prompt = pt.substr(pt.find_last_of('}') + 1); + result.system_prompt = pt.substr(0, pt.find_first_of('{')); + return result; } inline bool StartsWith(const std::string& str, const std::string& prefix) { return str.rfind(prefix, 0) == 0; From c60ec19dddb856446dc54d6392242ba37e885be9 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Wed, 16 Oct 2024 12:09:31 +0700 Subject: [PATCH 7/7] chore: format code --- engine/test/components/test_string_utils.cc | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/engine/test/components/test_string_utils.cc b/engine/test/components/test_string_utils.cc index e0d403489..3d6abeddf 100644 --- a/engine/test/components/test_string_utils.cc +++ b/engine/test/components/test_string_utils.cc @@ -2,15 +2,13 @@ #include #include "gtest/gtest.h" #include "utils/string_utils.h" -class StringUtilsTestSuite : public ::testing::Test { - protected: - std::string prompt = - "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_" - "message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|" - "eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"; -}; +class StringUtilsTestSuite : public ::testing::Test {}; TEST_F(StringUtilsTestSuite, ParsePrompt) { { + std::string prompt = + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{" + "system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{" + "prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"; auto result = string_utils::ParsePrompt(prompt); EXPECT_EQ(result.user_prompt, "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n");