diff --git a/engine/config/gguf_parser.cc b/engine/config/gguf_parser.cc
index f2f058336..160734468 100644
--- a/engine/config/gguf_parser.cc
+++ b/engine/config/gguf_parser.cc
@@ -409,7 +409,23 @@ void GGUFHandler::ModelConfigFromMetadata() {
model_config_.created = std::time(nullptr);
model_config_.model = "model";
model_config_.owned_by = "";
- model_config_.version;
+ model_config_.seed = -1;
+ model_config_.dynatemp_range = 0.0f;
+ model_config_.dynatemp_exponent = 1.0f;
+ model_config_.top_k = 40;
+ model_config_.min_p = 0.05f;
+ model_config_.tfs_z = 1.0f;
+ model_config_.typ_p = 1.0f;
+ model_config_.repeat_last_n = 64;
+ model_config_.repeat_penalty = 1.0f;
+ model_config_.mirostat = false;
+ model_config_.mirostat_tau = 5.0f;
+ model_config_.mirostat_eta = 0.1f;
+ model_config_.penalize_nl = false;
+ model_config_.ignore_eos = false;
+ model_config_.n_probs = 0;
+ model_config_.min_keep = 0;
+ model_config_.grammar = "";
// Get version, bos, eos id, contex_len, ngl from meta data
for (const auto& [key, value] : metadata_uint8_) {
@@ -522,7 +538,7 @@ void GGUFHandler::ModelConfigFromMetadata() {
for (const auto& [key, value] : metadata_string_) {
if (key.compare("general.name") == 0) {
name = std::regex_replace(value, std::regex(" "), "-");
- } else if (key.compare("tokenizer.chat_template") == 0) {
+ } else if (key.find("chat_template") != std::string::npos) {
if (value.compare(ZEPHYR_JINJA) == 0) {
chat_template =
"<|system|>\n{system_message}\n<|user|>\n{prompt}"
@@ -564,12 +580,21 @@ void GGUFHandler::ModelConfigFromMetadata() {
}
}
- eos_string = tokens[eos_token];
- bos_string = tokens[bos_token];
- stop.push_back(std::move(eos_string));
+ try {
+ if (tokens.size() > eos_token) {
+ eos_string = tokens[eos_token];
+ stop.push_back(std::move(eos_string));
+ } else {
+ LOG_ERROR << "Can't find stop token";
+ }
+ } catch (const std::exception& e) {
+ LOG_ERROR << "Can't find stop token";
+ }
model_config_.stop = std::move(stop);
-
+ if (chat_template.empty())
+ chat_template =
+ "[INST] <>\n{system_message}\n<>\n{prompt}[/INST]";
model_config_.prompt_template = std::move(chat_template);
model_config_.name = name;
model_config_.model = name;
diff --git a/engine/config/model_config.h b/engine/config/model_config.h
index a8d7f4fda..f61f9e9ba 100644
--- a/engine/config/model_config.h
+++ b/engine/config/model_config.h
@@ -36,5 +36,23 @@ struct ModelConfig {
std::size_t created;
std::string object;
std::string owned_by = "";
+
+ int seed = -1;
+ float dynatemp_range = 0.0f;
+ float dynatemp_exponent = 1.0f;
+ int top_k = 40;
+ float min_p = 0.05f;
+ float tfs_z = 1.0f;
+ float typ_p = 1.0f;
+ int repeat_last_n = 64;
+ float repeat_penalty = 1.0f;
+ bool mirostat = false;
+ float mirostat_tau = 5.0f;
+ float mirostat_eta = 0.1f;
+ bool penalize_nl = false;
+ bool ignore_eos = false;
+ int n_probs = 0;
+ int min_keep = 0;
+ std::string grammar;
};
} // namespace config
diff --git a/engine/config/yaml_config.cc b/engine/config/yaml_config.cc
index 81f62df9d..85f086d46 100644
--- a/engine/config/yaml_config.cc
+++ b/engine/config/yaml_config.cc
@@ -24,6 +24,7 @@ void YamlHandler::ReadYamlFile(const std::string& file_path) {
std::vector v;
if (yaml_node_["engine"] &&
yaml_node_["engine"].as() == "cortex.llamacpp") {
+ // TODO: change prefix to models:// with source from cortexso
v.emplace_back(s.substr(0, s.find_last_of('/')) + "/model.gguf");
} else {
v.emplace_back(s.substr(0, s.find_last_of('/')));
@@ -36,7 +37,6 @@ void YamlHandler::ReadYamlFile(const std::string& file_path) {
std::cerr << "Failed to read file: " << e.what() << std::endl;
throw;
}
- ModelConfigFromYaml();
}
void YamlHandler::SplitPromptTemplate(ModelConfig& mc) {
if (mc.prompt_template.size() > 0) {
@@ -119,6 +119,41 @@ void YamlHandler::ModelConfigFromYaml() {
tmp.files = yaml_node_["files"].as>();
if (yaml_node_["created"])
tmp.created = yaml_node_["created"].as();
+
+ if (yaml_node_["seed"])
+ tmp.seed = yaml_node_["seed"].as();
+ if (yaml_node_["dynatemp_range"])
+ tmp.dynatemp_range = yaml_node_["dynatemp_range"].as();
+ if (yaml_node_["dynatemp_exponent"])
+ tmp.dynatemp_exponent = yaml_node_["dynatemp_exponent"].as();
+ if (yaml_node_["top_k"])
+ tmp.top_k = yaml_node_["top_k"].as();
+ if (yaml_node_["min_p"])
+ tmp.min_p = yaml_node_["min_p"].as();
+ if (yaml_node_["tfs_z"])
+ tmp.tfs_z = yaml_node_["tfs_z"].as();
+ if (yaml_node_["typ_p"])
+ tmp.typ_p = yaml_node_["typ_p"].as();
+ if (yaml_node_["repeat_last_n"])
+ tmp.repeat_last_n = yaml_node_["repeat_last_n"].as();
+ if (yaml_node_["repeat_penalty"])
+ tmp.repeat_penalty = yaml_node_["repeat_penalty"].as();
+ if (yaml_node_["mirostat"])
+ tmp.mirostat = yaml_node_["mirostat"].as();
+ if (yaml_node_["mirostat_tau"])
+ tmp.mirostat_tau = yaml_node_["mirostat_tau"].as();
+ if (yaml_node_["mirostat_eta"])
+ tmp.mirostat_eta = yaml_node_["mirostat_eta"].as();
+ if (yaml_node_["penalize_nl"])
+ tmp.penalize_nl = yaml_node_["penalize_nl"].as();
+ if (yaml_node_["ignore_eos"])
+ tmp.ignore_eos = yaml_node_["ignore_eos"].as();
+ if (yaml_node_["n_probs"])
+ tmp.n_probs = yaml_node_["n_probs"].as();
+ if (yaml_node_["min_keep"])
+ tmp.min_keep = yaml_node_["min_keep"].as();
+ if (yaml_node_["grammar"])
+ tmp.grammar = yaml_node_["grammar"].as();
} catch (const std::exception& e) {
std::cerr << "Error when load model config : " << e.what() << std::endl;
std::cerr << "Revert ..." << std::endl;
@@ -185,6 +220,42 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) {
yaml_node_["stop"] = model_config_.stop;
if (model_config_.files.size() > 0)
yaml_node_["files"] = model_config_.files;
+
+ if (!std::isnan(static_cast(model_config_.seed)))
+ yaml_node_["seed"] = model_config_.seed;
+ if (!std::isnan(model_config_.dynatemp_range))
+ yaml_node_["dynatemp_range"] = model_config_.dynatemp_range;
+ if (!std::isnan(model_config_.dynatemp_exponent))
+ yaml_node_["dynatemp_exponent"] = model_config_.dynatemp_exponent;
+ if (!std::isnan(static_cast(model_config_.top_k)))
+ yaml_node_["top_k"] = model_config_.top_k;
+ if (!std::isnan(model_config_.min_p))
+ yaml_node_["min_p"] = model_config_.min_p;
+ if (!std::isnan(model_config_.tfs_z))
+ yaml_node_["tfs_z"] = model_config_.tfs_z;
+ if (!std::isnan(model_config_.typ_p))
+ yaml_node_["typ_p"] = model_config_.typ_p;
+ if (!std::isnan(static_cast(model_config_.repeat_last_n)))
+ yaml_node_["repeat_last_n"] = model_config_.repeat_last_n;
+ if (!std::isnan(model_config_.repeat_penalty))
+ yaml_node_["repeat_penalty"] = model_config_.repeat_penalty;
+ if (!std::isnan(static_cast(model_config_.mirostat)))
+ yaml_node_["mirostat"] = model_config_.mirostat;
+ if (!std::isnan(model_config_.mirostat_tau))
+ yaml_node_["mirostat_tau"] = model_config_.mirostat_tau;
+ if (!std::isnan(model_config_.mirostat_eta))
+ yaml_node_["mirostat_eta"] = model_config_.mirostat_eta;
+ if (!std::isnan(static_cast(model_config_.penalize_nl)))
+ yaml_node_["penalize_nl"] = model_config_.penalize_nl;
+ if (!std::isnan(static_cast(model_config_.ignore_eos)))
+ yaml_node_["ignore_eos"] = model_config_.ignore_eos;
+ if (!std::isnan(static_cast(model_config_.n_probs)))
+ yaml_node_["n_probs"] = model_config_.n_probs;
+ if (!std::isnan(static_cast(model_config_.min_keep)))
+ yaml_node_["min_keep"] = model_config_.min_keep;
+ if (!model_config_.grammar.empty())
+ yaml_node_["grammar"] = model_config_.grammar;
+
yaml_node_["created"] = std::time(nullptr);
} catch (const std::exception& e) {
std::cerr << "Error when update model config : " << e.what() << std::endl;
@@ -200,7 +271,97 @@ void YamlHandler::WriteYamlFile(const std::string& file_path) const {
if (!outFile) {
throw std::runtime_error("Failed to open output file.");
}
- outFile << yaml_node_;
+ // Helper function to write a key-value pair with an optional comment
+ auto writeKeyValue = [&](const std::string& key, const YAML::Node& value,
+ const std::string& comment = "") {
+ if (!value)
+ return;
+ outFile << key << ": " << value;
+ if (!comment.empty()) {
+ outFile << " # " << comment;
+ }
+ outFile << "\n";
+ };
+
+ // Write GENERAL GGUF METADATA
+ outFile << "# BEGIN GENERAL GGUF METADATA\n";
+ writeKeyValue("id", yaml_node_["id"],
+ "Model ID unique between models (author / quantization)");
+ writeKeyValue("model", yaml_node_["model"],
+ "Model ID which is used for request construct - should be "
+ "unique between models (author / quantization)");
+ writeKeyValue("name", yaml_node_["name"], "metadata.general.name");
+ writeKeyValue("version", yaml_node_["version"], "metadata.version");
+ if (yaml_node_["files"] && yaml_node_["files"].size()) {
+ outFile << "files: # can be universal protocol (models://) "
+ "OR absolute local file path (file://) OR https remote URL "
+ "(https://)\n";
+ for (const auto& source : yaml_node_["files"]) {
+ outFile << " - " << source << "\n";
+ }
+ }
+
+ outFile << "# END GENERAL GGUF METADATA\n";
+ outFile << "\n";
+ // Write INFERENCE PARAMETERS
+ outFile << "# BEGIN INFERENCE PARAMETERS\n";
+ outFile << "# BEGIN REQUIRED\n";
+ if (yaml_node_["stop"] && yaml_node_["stop"].size()) {
+ outFile << "stop: # tokenizer.ggml.eos_token_id\n";
+ for (const auto& stop : yaml_node_["stop"]) {
+ outFile << " - " << stop << "\n";
+ }
+ }
+
+ outFile << "# END REQUIRED\n";
+ outFile << "\n";
+ outFile << "# BEGIN OPTIONAL\n";
+ writeKeyValue("stream", yaml_node_["stream"], "Default true?");
+ writeKeyValue("top_p", yaml_node_["top_p"], "Ranges: 0 to 1");
+ writeKeyValue("temperature", yaml_node_["temperature"], "Ranges: 0 to 1");
+ writeKeyValue("frequency_penalty", yaml_node_["frequency_penalty"],
+ "Ranges: 0 to 1");
+ writeKeyValue("presence_penalty", yaml_node_["presence_penalty"],
+ "Ranges: 0 to 1");
+ writeKeyValue("max_tokens", yaml_node_["max_tokens"],
+ "Should be default to context length");
+ writeKeyValue("seed", yaml_node_["seed"]);
+ writeKeyValue("dynatemp_range", yaml_node_["dynatemp_range"]);
+ writeKeyValue("dynatemp_exponent", yaml_node_["dynatemp_exponent"]);
+ writeKeyValue("top_k", yaml_node_["top_k"]);
+ writeKeyValue("min_p", yaml_node_["min_p"]);
+ writeKeyValue("tfs_z", yaml_node_["tfs_z"]);
+ writeKeyValue("typ_p", yaml_node_["typ_p"]);
+ writeKeyValue("repeat_last_n", yaml_node_["repeat_last_n"]);
+ writeKeyValue("repeat_penalty", yaml_node_["repeat_penalty"]);
+ writeKeyValue("mirostat", yaml_node_["mirostat"]);
+ writeKeyValue("mirostat_tau", yaml_node_["mirostat_tau"]);
+ writeKeyValue("mirostat_eta", yaml_node_["mirostat_eta"]);
+ writeKeyValue("penalize_nl", yaml_node_["penalize_nl"]);
+ writeKeyValue("ignore_eos", yaml_node_["ignore_eos"]);
+ writeKeyValue("n_probs", yaml_node_["n_probs"]);
+ writeKeyValue("min_keep", yaml_node_["min_keep"]);
+ writeKeyValue("grammar", yaml_node_["grammar"]);
+ outFile << "# END OPTIONAL\n";
+ outFile << "# END INFERENCE PARAMETERS\n";
+ outFile << "\n";
+ // Write MODEL LOAD PARAMETERS
+ outFile << "# BEGIN MODEL LOAD PARAMETERS\n";
+ outFile << "# BEGIN REQUIRED\n";
+ writeKeyValue("engine", yaml_node_["engine"], "engine to run model");
+ outFile << "prompt_template:";
+ outFile << " " << yaml_node_["prompt_template"] << "\n";
+ outFile << "# END REQUIRED\n";
+ outFile << "\n";
+ outFile << "# BEGIN OPTIONAL\n";
+ writeKeyValue("ctx_len", yaml_node_["ctx_len"],
+ "llama.context_length | 0 or undefined = loaded from model");
+ writeKeyValue("ngl", yaml_node_["ngl"], "Undefined = loaded from model");
+ outFile << "# END OPTIONAL\n";
+ outFile << "# END MODEL LOAD PARAMETERS\n";
+
+ // Write new configuration parameters
+
outFile.close();
} catch (const std::exception& e) {
std::cerr << "Error writing to file: " << e.what() << std::endl;
diff --git a/engine/controllers/command_line_parser.cc b/engine/controllers/command_line_parser.cc
index 9c4b5713f..b55887ebd 100644
--- a/engine/controllers/command_line_parser.cc
+++ b/engine/controllers/command_line_parser.cc
@@ -294,7 +294,6 @@ void CommandLineParser::EngineGet(CLI::App* parent) {
std::string desc = "Get " + engine_name + " status";
auto engine_get_cmd = get_cmd->add_subcommand(engine_name, desc);
- engine_get_cmd->require_option();
engine_get_cmd->callback(
[engine_name] { commands::EngineGetCmd().Exec(engine_name); });
}
diff --git a/engine/e2e-test/test_cli_engine_get.py b/engine/e2e-test/test_cli_engine_get.py
index 38c235b30..6b5270eba 100644
--- a/engine/e2e-test/test_cli_engine_get.py
+++ b/engine/e2e-test/test_cli_engine_get.py
@@ -52,5 +52,6 @@ def test_engines_get_onnx_should_be_incompatible_on_macos(self):
@pytest.mark.skipif(platform.system() != "Linux", reason="Linux-specific test")
def test_engines_get_onnx_should_be_incompatible_on_linux(self):
exit_code, output, error = run("Get engine", ["engines", "get", "cortex.onnx"])
+ print(output)
assert exit_code == 0, f"Get engine failed with error: {error}"
assert "Incompatible" in output, "cortex.onnx should be Incompatible on Linux"
diff --git a/engine/e2e-test/test_cli_engine_install.py b/engine/e2e-test/test_cli_engine_install.py
index 15cd8deb3..89d49401d 100644
--- a/engine/e2e-test/test_cli_engine_install.py
+++ b/engine/e2e-test/test_cli_engine_install.py
@@ -31,7 +31,7 @@ def test_engines_install_onnx_on_tensorrt_should_be_failed(self):
def test_engines_install_pre_release_llamacpp(self):
exit_code, output, error = run(
- "Install Engine", ["engines", "install", "cortex.llamacpp", "-v", "v0.1.29"], timeout=60
+ "Install Engine", ["engines", "install", "cortex.llamacpp", "-v", "v0.1.29"], timeout=None
)
assert "Start downloading" in output, "Should display downloading message"
assert exit_code == 0, f"Install engine failed with error: {error}"
diff --git a/engine/e2e-test/test_cli_engine_uninstall.py b/engine/e2e-test/test_cli_engine_uninstall.py
index 03078a1e6..525c0ad63 100644
--- a/engine/e2e-test/test_cli_engine_uninstall.py
+++ b/engine/e2e-test/test_cli_engine_uninstall.py
@@ -8,7 +8,7 @@ class TestCliEngineUninstall:
def setup_and_teardown(self):
# Setup
# Preinstall llamacpp engine
- run("Install Engine", ["engines", "install", "cortex.llamacpp"])
+ run("Install Engine", ["engines", "install", "cortex.llamacpp"],timeout=None)
yield
diff --git a/engine/test/components/CMakeLists.txt b/engine/test/components/CMakeLists.txt
index 32ee36f09..db810ad26 100644
--- a/engine/test/components/CMakeLists.txt
+++ b/engine/test/components/CMakeLists.txt
@@ -3,13 +3,14 @@ project(test-components)
enable_testing()
-add_executable(${PROJECT_NAME} ${SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/modellist_utils.cc)
+add_executable(${PROJECT_NAME} ${SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/modellist_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../config/yaml_config.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../config/gguf_parser.cc)
find_package(Drogon CONFIG REQUIRED)
find_package(GTest CONFIG REQUIRED)
find_package(yaml-cpp CONFIG REQUIRED)
+find_package(jinja2cpp CONFIG REQUIRED)
-target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp
+target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp jinja2cpp
${CMAKE_THREAD_LIBS_INIT})
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)
diff --git a/engine/test/components/test_gguf_parser.cc b/engine/test/components/test_gguf_parser.cc
new file mode 100644
index 000000000..6c5c61486
--- /dev/null
+++ b/engine/test/components/test_gguf_parser.cc
@@ -0,0 +1,158 @@
+#include "gtest/gtest.h"
+#include "config/gguf_parser.h"
+#include "config/yaml_config.h"
+#include
+#include
+#include
+#include
+#include
+#include
+
+#ifdef _WIN32
+#include
+#else
+#include
+#endif
+
+class GGUFParserTest : public ::testing::Test {
+protected:
+ void SetUp() override {
+ gguf_handler = std::make_unique();
+ yaml_handler = std::make_unique< config::YamlHandler>();
+ }
+
+ void TearDown() override {
+ }
+
+ std::unique_ptr gguf_handler;
+ std::unique_ptr yaml_handler;
+
+ std::string getTempFilePath(const std::string& prefix, const std::string& extension) {
+ #ifdef _WIN32
+ char temp_path[MAX_PATH];
+ char file_name[MAX_PATH];
+ GetTempPathA(MAX_PATH, temp_path);
+ GetTempFileNameA(temp_path, prefix.c_str(), 0, file_name);
+ std::string path(file_name);
+ DeleteFileA(file_name); // Delete the file created by GetTempFileNameA
+ return path + extension;
+ #else
+ std::string path = "/tmp/" + prefix + "XXXXXX" + extension;
+ char* temp = strdup(path.c_str());
+ int fd = mkstemps(temp, extension.length());
+ if (fd == -1) {
+ free(temp);
+ throw std::runtime_error("Failed to create temporary file");
+ }
+ close(fd);
+ std::string result(temp);
+ free(temp);
+ return result;
+ #endif
+ }
+
+ std::string createMockGGUFFile() {
+ std::string gguf_path = getTempFilePath("mock_tinyllama-model", ".gguf");
+ std::ofstream file(gguf_path, std::ios::binary);
+
+ if (!file.is_open()) {
+ throw std::runtime_error("Failed to create mock GGUF file");
+ }
+
+ try {
+ // GGUF magic number
+ uint32_t magic = 0x46554747;
+ file.write(reinterpret_cast(&magic), sizeof(magic));
+
+ // Version
+ uint32_t version = 2;
+ file.write(reinterpret_cast(&version), sizeof(version));
+
+ // Tensor count (not important for this test)
+ uint64_t tensor_count = 0;
+ file.write(reinterpret_cast(&tensor_count), sizeof(tensor_count));
+
+ // Metadata key-value count
+ uint64_t kv_count = 2;
+ file.write(reinterpret_cast(&kv_count), sizeof(kv_count));
+
+ // Helper function to write a string
+ auto writeString = [&file](const std::string& str) {
+ uint64_t length = str.length();
+ file.write(reinterpret_cast(&length), sizeof(length));
+ file.write(str.c_str(), length);
+ };
+
+ // Helper function to write a key-value pair
+ auto writeKV = [&](const std::string& key, uint32_t type, const auto& value) {
+ writeString(key);
+ file.write(reinterpret_cast(&type), sizeof(type));
+ if constexpr (std::is_same_v) {
+ writeString(value);
+ } else {
+ file.write(reinterpret_cast(&value), sizeof(value));
+ }
+ };
+
+ // Write metadata
+ writeKV("general.name", 8, std::string("tinyllama 1B"));
+ writeKV("llama.context_length", 4, uint32_t(4096));
+
+ file.close();
+
+ } catch (const std::exception& e) {
+ file.close();
+ std::remove(gguf_path.c_str());
+ throw std::runtime_error(std::string("Failed to write mock GGUF file: ") + e.what());
+ }
+
+ return gguf_path;
+ }
+};
+
+TEST_F(GGUFParserTest, ParseMockTinyLlamaModel) {
+ std::string gguf_path;
+ std::string yaml_path;
+ try {
+ // Create a mock GGUF file
+ gguf_path = createMockGGUFFile();
+
+ // Parse the GGUF file
+ gguf_handler->Parse(gguf_path);
+
+ const config::ModelConfig& gguf_config = gguf_handler->GetModelConfig();
+
+ // Load the expected configuration from YAML
+ std::string yaml_content = R"(
+name: tinyllama-1B
+ctx_len: 4096
+ )";
+
+ yaml_path = getTempFilePath("expected_config", ".yaml");
+ std::ofstream yaml_file(yaml_path);
+ yaml_file << yaml_content;
+ yaml_file.close();
+
+ yaml_handler->ModelConfigFromFile(yaml_path);
+
+ const config::ModelConfig& yaml_config = yaml_handler->GetModelConfig();
+
+ // Compare GGUF parsed config with YAML config
+ EXPECT_EQ(gguf_config.name, yaml_config.name);
+ EXPECT_EQ(gguf_config.ctx_len, yaml_config.ctx_len);
+
+ // Clean up
+ std::remove(gguf_path.c_str());
+ std::remove(yaml_path.c_str());
+ }
+ catch (const std::exception& e) {
+ // If an exception was thrown, make sure to clean up the files
+ if (!gguf_path.empty()) {
+ std::remove(gguf_path.c_str());
+ }
+ if (!yaml_path.empty()) {
+ std::remove(yaml_path.c_str());
+ }
+ FAIL() << "Exception thrown: " << e.what();
+ }
+}
\ No newline at end of file
diff --git a/engine/test/components/test_yaml_handler.cc b/engine/test/components/test_yaml_handler.cc
new file mode 100644
index 000000000..d65ab6b49
--- /dev/null
+++ b/engine/test/components/test_yaml_handler.cc
@@ -0,0 +1,199 @@
+#include
+#include
+#include
+#include
+#include "config/yaml_config.h"
+#include "gtest/gtest.h"
+
+#ifdef _WIN32
+#include
+#include
+#else
+#include
+#endif
+
+class YamlHandlerTest : public ::testing::Test {
+ protected:
+ void SetUp() override { handler = new config::YamlHandler(); }
+
+ void TearDown() override { delete handler; }
+
+ config::YamlHandler* handler;
+
+ // Helper function to create a temporary YAML file
+ std::string createTempYamlFile(const std::string& content) {
+ std::string filename;
+#ifdef _WIN32
+ char tempPath[MAX_PATH];
+ char tempFileName[MAX_PATH];
+ GetTempPathA(MAX_PATH, tempPath);
+ GetTempFileNameA(tempPath, "yaml", 0, tempFileName);
+ filename = tempFileName;
+#else
+ char tempFileName[] = "/tmp/yaml_test_XXXXXX";
+ int fd = mkstemp(tempFileName);
+ if (fd == -1) {
+ throw std::runtime_error("Failed to create temporary file");
+ }
+ close(fd);
+ filename = tempFileName;
+#endif
+
+ std::ofstream file(filename);
+ file << content;
+ file.close();
+ return filename;
+ }
+
+ // Helper function to remove a file
+ void removeFile(const std::string& filename) {
+ std::remove(filename.c_str());
+ }
+};
+
+TEST_F(YamlHandlerTest, ModelConfigFromFile) {
+ std::string yaml_content = R"(
+name: test_model
+model: test_model_v1
+version: 1.0
+engine: test_engine
+prompt_template: "Test prompt {system_message} {prompt}"
+top_p: 0.9
+temperature: 0.7
+max_tokens: 100
+stream: true
+stop:
+ - "END"
+files:
+ - "test_file.gguf"
+ )";
+
+ std::string filename = createTempYamlFile(yaml_content);
+
+ handler->ModelConfigFromFile(filename);
+ const config::ModelConfig& config = handler->GetModelConfig();
+
+ EXPECT_EQ(config.name, "test_model");
+ EXPECT_EQ(config.model, "test_model_v1");
+ EXPECT_EQ(config.version, "1.0");
+ EXPECT_EQ(config.engine, "test_engine");
+ EXPECT_EQ(config.prompt_template, "Test prompt {system_message} {prompt}");
+ EXPECT_FLOAT_EQ(config.top_p, 0.9f);
+ EXPECT_FLOAT_EQ(config.temperature, 0.7f);
+ EXPECT_EQ(config.max_tokens, 100);
+ EXPECT_TRUE(config.stream);
+ EXPECT_EQ(config.stop.size(), 1);
+ EXPECT_EQ(config.stop[0], "END");
+ EXPECT_EQ(config.files.size(), 1);
+ EXPECT_EQ(config.files[0], "test_file.gguf");
+
+ removeFile(filename);
+}
+
+TEST_F(YamlHandlerTest, UpdateModelConfig) {
+ config::ModelConfig new_config;
+ new_config.name = "updated_model";
+ new_config.model = "updated_model_v2";
+ new_config.version = "2.0";
+ new_config.engine = "updated_engine";
+ new_config.prompt_template = "Updated prompt {system_message} {prompt}";
+ new_config.top_p = 0.95f;
+ new_config.temperature = 0.8f;
+ new_config.max_tokens = 200;
+ new_config.stream = false;
+ new_config.stop = {"STOP", "END"};
+ new_config.files = {"updated_file1.gguf", "updated_file2.gguf"};
+
+ handler->UpdateModelConfig(new_config);
+ const config::ModelConfig& config = handler->GetModelConfig();
+
+ EXPECT_EQ(config.name, "updated_model");
+ EXPECT_EQ(config.model, "updated_model_v2");
+ EXPECT_EQ(config.version, "2.0");
+ EXPECT_EQ(config.engine, "updated_engine");
+ EXPECT_EQ(config.prompt_template, "Updated prompt {system_message} {prompt}");
+ EXPECT_FLOAT_EQ(config.top_p, 0.95f);
+ EXPECT_FLOAT_EQ(config.temperature, 0.8f);
+ EXPECT_EQ(config.max_tokens, 200);
+ EXPECT_FALSE(config.stream);
+ EXPECT_EQ(config.stop.size(), 2);
+ EXPECT_EQ(config.stop[0], "STOP");
+ EXPECT_EQ(config.stop[1], "END");
+ EXPECT_EQ(config.files.size(), 2);
+ EXPECT_EQ(config.files[0], "updated_file1.gguf");
+ EXPECT_EQ(config.files[1], "updated_file2.gguf");
+}
+
+TEST_F(YamlHandlerTest, WriteYamlFile) {
+ config::ModelConfig new_config;
+ new_config.name = "write_test_model";
+ new_config.model = "write_test_model_v1";
+ new_config.version = "1.0";
+ new_config.engine = "write_test_engine";
+ new_config.prompt_template = "Write test prompt {system_message} {prompt}";
+ new_config.top_p = 0.85f;
+ new_config.temperature = 0.6f;
+ new_config.max_tokens = 150;
+ new_config.stream = true;
+ new_config.stop = {"HALT"};
+ new_config.files = {"write_test_file.gguf"};
+
+ handler->UpdateModelConfig(new_config);
+
+ std::string filename = createTempYamlFile(""); // Create empty file
+ handler->WriteYamlFile(filename);
+
+ // Read the file back and verify its contents
+ config::YamlHandler read_handler;
+ read_handler.ModelConfigFromFile(filename);
+ const config::ModelConfig& read_config = read_handler.GetModelConfig();
+
+ EXPECT_EQ(read_config.name, "write_test_model");
+ EXPECT_EQ(read_config.model, "write_test_model_v1");
+ EXPECT_EQ(read_config.version, "1.0");
+ EXPECT_EQ(read_config.engine, "write_test_engine");
+ EXPECT_EQ(read_config.prompt_template,
+ "Write test prompt {system_message} {prompt}");
+ EXPECT_FLOAT_EQ(read_config.top_p, 0.85f);
+ EXPECT_FLOAT_EQ(read_config.temperature, 0.6f);
+ EXPECT_EQ(read_config.max_tokens, 150);
+ EXPECT_TRUE(read_config.stream);
+ EXPECT_EQ(read_config.stop.size(), 1);
+ EXPECT_EQ(read_config.stop[0], "HALT");
+ EXPECT_EQ(read_config.files.size(), 1);
+ EXPECT_EQ(read_config.files[0], "write_test_file.gguf");
+
+ removeFile(filename);
+}
+
+TEST_F(YamlHandlerTest, Reset) {
+ config::ModelConfig new_config;
+ new_config.name = "test_reset_model";
+ new_config.model = "test_reset_model_v1";
+ handler->UpdateModelConfig(new_config);
+
+ handler->Reset();
+ const config::ModelConfig& config = handler->GetModelConfig();
+
+ EXPECT_TRUE(config.name.empty());
+ EXPECT_TRUE(config.model.empty());
+}
+
+TEST_F(YamlHandlerTest, InvalidYamlFile) {
+ std::string invalid_yaml_content = R"(
+name: test_model
+model: test_model_v1
+version: 1.0
+engine: test_engine
+prompt_template: "Test prompt {system_message} {prompt}"
+top_p: not_a_float
+seed: also_not_an_int
+ )";
+
+ std::string filename = createTempYamlFile(invalid_yaml_content);
+ handler->ModelConfigFromFile(filename);
+ config::ModelConfig new_config = handler->GetModelConfig();
+ EXPECT_EQ(new_config.seed, -1);
+
+ removeFile(filename);
+}
\ No newline at end of file