Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,11 @@ void Models::StartModel(
return;
auto config = file_manager_utils::GetCortexConfig();
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
auto custom_prompt_template =
(*(req->getJsonObject())).get("prompt_template", "").asString();
auto result = model_service_->StartModel(
config.apiServerHost, std::stoi(config.apiServerPort), model_handle);
config.apiServerHost, std::stoi(config.apiServerPort), model_handle,
custom_prompt_template);
if (result.has_error()) {
Json::Value ret;
ret["message"] = result.error();
Expand Down
9 changes: 6 additions & 3 deletions engine/controllers/swagger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ Json::Value SwaggerController::generateOpenAPISpec() {
responses["200"]["description"] = "Model details retrieved successfully";
Json::Value& schema =
responses["200"]["content"]["application/json"]["schema"];
responses["responses"]["400"]["description"] = "Failed to get model information";
responses["responses"]["400"]["description"] =
"Failed to get model information";

responses["400"]["description"] = "Failed to get model information";
responses["400"]["content"]["application/json"]["schema"]["type"] =
Expand Down Expand Up @@ -450,6 +451,8 @@ Json::Value SwaggerController::generateOpenAPISpec() {
"object";
start["requestBody"]["content"]["application/json"]["schema"]["properties"]
["model"]["type"] = "string";
start["requestBody"]["content"]["application/json"]["schema"]["properties"]
["prompt_template"]["type"] = "string";
start["requestBody"]["content"]["application/json"]["schema"]["required"] =
Json::Value(Json::arrayValue);
start["requestBody"]["content"]["application/json"]["schema"]["required"]
Expand All @@ -458,12 +461,12 @@ Json::Value SwaggerController::generateOpenAPISpec() {
start["responses"]["400"]["description"] = "Failed to start model";

// Stop Model
Json::Value& stop = spec["paths"]["/v1/models/stop"]["post"];
Json::Value& stop = spec["paths"]["/v1/models/stop"]["post"];
stop["summary"] = "Stop model";
stop["requestBody"]["content"]["application/json"]["schema"]["type"] =
"object";
stop["requestBody"]["content"]["application/json"]["schema"]["properties"]
["model"]["type"] = "string";
["model"]["type"] = "string";
stop["requestBody"]["content"]["application/json"]["schema"]["required"] =
Json::Value(Json::arrayValue);
stop["requestBody"]["content"]["application/json"]["schema"]["required"]
Expand Down
17 changes: 13 additions & 4 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,8 @@ cpp::result<void, std::string> ModelService::DeleteModel(
}

cpp::result<bool, std::string> ModelService::StartModel(
const std::string& host, int port, const std::string& model_handle) {
const std::string& host, int port, const std::string& model_handle,
std::optional<std::string> custom_prompt_template) {
namespace fs = std::filesystem;
namespace fmu = file_manager_utils;
cortex::db::Models modellist_handler;
Expand Down Expand Up @@ -580,9 +581,17 @@ cpp::result<bool, std::string> ModelService::StartModel(
return false;
}
json_data["model"] = model_handle;
json_data["system_prompt"] = mc.system_template;
json_data["user_prompt"] = mc.user_template;
json_data["ai_prompt"] = mc.ai_template;
if (!custom_prompt_template.value_or("").empty()) {
auto parse_prompt_result =
string_utils::ParsePrompt(custom_prompt_template.value());
json_data["system_prompt"] = parse_prompt_result.system_prompt;
json_data["user_prompt"] = parse_prompt_result.user_prompt;
json_data["ai_prompt"] = parse_prompt_result.ai_prompt;
} else {
json_data["system_prompt"] = mc.system_template;
json_data["user_prompt"] = mc.user_template;
json_data["ai_prompt"] = mc.ai_template;
}

auto data_str = json_data.toStyledString();
CTL_INF(data_str);
Expand Down
7 changes: 4 additions & 3 deletions engine/services/model_service.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#pragma once

#include <memory>
#include <optional>
#include <string>
#include "config/model_config.h"
#include "services/download_service.h"

class ModelService {
public:
constexpr auto static kHuggingFaceHost = "huggingface.co";
Expand Down Expand Up @@ -34,8 +34,9 @@ class ModelService {
*/
cpp::result<void, std::string> DeleteModel(const std::string& model_handle);

cpp::result<bool, std::string> StartModel(const std::string& host, int port,
const std::string& model_handle);
cpp::result<bool, std::string> StartModel(
const std::string& host, int port, const std::string& model_handle,
std::optional<std::string> custom_prompt_template = std::nullopt);

cpp::result<bool, std::string> StopModel(const std::string& host, int port,
const std::string& model_handle);
Expand Down
16 changes: 16 additions & 0 deletions engine/test/components/test_string_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,22 @@
#include "gtest/gtest.h"
#include "utils/string_utils.h"
class StringUtilsTestSuite : public ::testing::Test {};
TEST_F(StringUtilsTestSuite, ParsePrompt) {
{
std::string prompt =
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{"
"system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{"
"prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n";
auto result = string_utils::ParsePrompt(prompt);
EXPECT_EQ(result.user_prompt,
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n");
EXPECT_EQ(result.ai_prompt,
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
EXPECT_EQ(
result.system_prompt,
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n");
}
}

TEST_F(StringUtilsTestSuite, TestSplitBy) {
auto input = "this is a test";
Expand Down
16 changes: 16 additions & 0 deletions engine/utils/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@

namespace string_utils {

struct ParsePromptResult {
std::string user_prompt;
std::string system_prompt;
std::string ai_prompt;
};

inline ParsePromptResult ParsePrompt(const std::string& prompt) {
auto& pt = prompt;
ParsePromptResult result;
result.user_prompt =
pt.substr(pt.find_first_of('}') + 1,
pt.find_last_of('{') - pt.find_first_of('}') - 1);
result.ai_prompt = pt.substr(pt.find_last_of('}') + 1);
result.system_prompt = pt.substr(0, pt.find_first_of('{'));
return result;
}
inline bool StartsWith(const std::string& str, const std::string& prefix) {
return str.rfind(prefix, 0) == 0;
}
Expand Down
Loading