Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit d4d0d13

Browse files
Feat/support custom prompt template (#1495)
* chore: change update to patch * fix: swagger * fix: pull api * feat: support custom prompt template * fix comment and add unitests * fix: comment * chore: format code --------- Co-authored-by: vansangpfiev <vansangpfiev@gmail.com>
1 parent 2800e79 commit d4d0d13

File tree

6 files changed

+59
-11
lines changed

6 files changed

+59
-11
lines changed

engine/controllers/models.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,11 @@ void Models::StartModel(
402402
return;
403403
auto config = file_manager_utils::GetCortexConfig();
404404
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
405+
auto custom_prompt_template =
406+
(*(req->getJsonObject())).get("prompt_template", "").asString();
405407
auto result = model_service_->StartModel(
406-
config.apiServerHost, std::stoi(config.apiServerPort), model_handle);
408+
config.apiServerHost, std::stoi(config.apiServerPort), model_handle,
409+
custom_prompt_template);
407410
if (result.has_error()) {
408411
Json::Value ret;
409412
ret["message"] = result.error();

engine/controllers/swagger.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ Json::Value SwaggerController::generateOpenAPISpec() {
202202
responses["200"]["description"] = "Model details retrieved successfully";
203203
Json::Value& schema =
204204
responses["200"]["content"]["application/json"]["schema"];
205-
responses["responses"]["400"]["description"] = "Failed to get model information";
205+
responses["responses"]["400"]["description"] =
206+
"Failed to get model information";
206207

207208
responses["400"]["description"] = "Failed to get model information";
208209
responses["400"]["content"]["application/json"]["schema"]["type"] =
@@ -450,6 +451,8 @@ Json::Value SwaggerController::generateOpenAPISpec() {
450451
"object";
451452
start["requestBody"]["content"]["application/json"]["schema"]["properties"]
452453
["model"]["type"] = "string";
454+
start["requestBody"]["content"]["application/json"]["schema"]["properties"]
455+
["prompt_template"]["type"] = "string";
453456
start["requestBody"]["content"]["application/json"]["schema"]["required"] =
454457
Json::Value(Json::arrayValue);
455458
start["requestBody"]["content"]["application/json"]["schema"]["required"]
@@ -458,12 +461,12 @@ Json::Value SwaggerController::generateOpenAPISpec() {
458461
start["responses"]["400"]["description"] = "Failed to start model";
459462

460463
// Stop Model
461-
Json::Value& stop = spec["paths"]["/v1/models/stop"]["post"];
464+
Json::Value& stop = spec["paths"]["/v1/models/stop"]["post"];
462465
stop["summary"] = "Stop model";
463466
stop["requestBody"]["content"]["application/json"]["schema"]["type"] =
464467
"object";
465468
stop["requestBody"]["content"]["application/json"]["schema"]["properties"]
466-
["model"]["type"] = "string";
469+
["model"]["type"] = "string";
467470
stop["requestBody"]["content"]["application/json"]["schema"]["required"] =
468471
Json::Value(Json::arrayValue);
469472
stop["requestBody"]["content"]["application/json"]["schema"]["required"]

engine/services/model_service.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,8 @@ cpp::result<void, std::string> ModelService::DeleteModel(
550550
}
551551

552552
cpp::result<bool, std::string> ModelService::StartModel(
553-
const std::string& host, int port, const std::string& model_handle) {
553+
const std::string& host, int port, const std::string& model_handle,
554+
std::optional<std::string> custom_prompt_template) {
554555
namespace fs = std::filesystem;
555556
namespace fmu = file_manager_utils;
556557
cortex::db::Models modellist_handler;
@@ -580,9 +581,17 @@ cpp::result<bool, std::string> ModelService::StartModel(
580581
return false;
581582
}
582583
json_data["model"] = model_handle;
583-
json_data["system_prompt"] = mc.system_template;
584-
json_data["user_prompt"] = mc.user_template;
585-
json_data["ai_prompt"] = mc.ai_template;
584+
if (!custom_prompt_template.value_or("").empty()) {
585+
auto parse_prompt_result =
586+
string_utils::ParsePrompt(custom_prompt_template.value());
587+
json_data["system_prompt"] = parse_prompt_result.system_prompt;
588+
json_data["user_prompt"] = parse_prompt_result.user_prompt;
589+
json_data["ai_prompt"] = parse_prompt_result.ai_prompt;
590+
} else {
591+
json_data["system_prompt"] = mc.system_template;
592+
json_data["user_prompt"] = mc.user_template;
593+
json_data["ai_prompt"] = mc.ai_template;
594+
}
586595

587596
auto data_str = json_data.toStyledString();
588597
CTL_INF(data_str);

engine/services/model_service.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#pragma once
22

33
#include <memory>
4+
#include <optional>
45
#include <string>
56
#include "config/model_config.h"
67
#include "services/download_service.h"
7-
88
class ModelService {
99
public:
1010
constexpr auto static kHuggingFaceHost = "huggingface.co";
@@ -34,8 +34,9 @@ class ModelService {
3434
*/
3535
cpp::result<void, std::string> DeleteModel(const std::string& model_handle);
3636

37-
cpp::result<bool, std::string> StartModel(const std::string& host, int port,
38-
const std::string& model_handle);
37+
cpp::result<bool, std::string> StartModel(
38+
const std::string& host, int port, const std::string& model_handle,
39+
std::optional<std::string> custom_prompt_template = std::nullopt);
3940

4041
cpp::result<bool, std::string> StopModel(const std::string& host, int port,
4142
const std::string& model_handle);

engine/test/components/test_string_utils.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,22 @@
33
#include "gtest/gtest.h"
44
#include "utils/string_utils.h"
55
class StringUtilsTestSuite : public ::testing::Test {};
6+
TEST_F(StringUtilsTestSuite, ParsePrompt) {
7+
{
8+
std::string prompt =
9+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{"
10+
"system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{"
11+
"prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n";
12+
auto result = string_utils::ParsePrompt(prompt);
13+
EXPECT_EQ(result.user_prompt,
14+
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n");
15+
EXPECT_EQ(result.ai_prompt,
16+
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
17+
EXPECT_EQ(
18+
result.system_prompt,
19+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n");
20+
}
21+
}
622

723
TEST_F(StringUtilsTestSuite, TestSplitBy) {
824
auto input = "this is a test";

engine/utils/string_utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,22 @@
77

88
namespace string_utils {
99

10+
struct ParsePromptResult {
11+
std::string user_prompt;
12+
std::string system_prompt;
13+
std::string ai_prompt;
14+
};
15+
16+
inline ParsePromptResult ParsePrompt(const std::string& prompt) {
17+
auto& pt = prompt;
18+
ParsePromptResult result;
19+
result.user_prompt =
20+
pt.substr(pt.find_first_of('}') + 1,
21+
pt.find_last_of('{') - pt.find_first_of('}') - 1);
22+
result.ai_prompt = pt.substr(pt.find_last_of('}') + 1);
23+
result.system_prompt = pt.substr(0, pt.find_first_of('{'));
24+
return result;
25+
}
1026
inline bool StartsWith(const std::string& str, const std::string& prefix) {
1127
return str.rfind(prefix, 0) == 0;
1228
}

0 commit comments

Comments
 (0)