diff --git a/engine/config/gguf_parser.cc b/engine/config/gguf_parser.cc index f2f058336..160734468 100644 --- a/engine/config/gguf_parser.cc +++ b/engine/config/gguf_parser.cc @@ -409,7 +409,23 @@ void GGUFHandler::ModelConfigFromMetadata() { model_config_.created = std::time(nullptr); model_config_.model = "model"; model_config_.owned_by = ""; - model_config_.version; + model_config_.seed = -1; + model_config_.dynatemp_range = 0.0f; + model_config_.dynatemp_exponent = 1.0f; + model_config_.top_k = 40; + model_config_.min_p = 0.05f; + model_config_.tfs_z = 1.0f; + model_config_.typ_p = 1.0f; + model_config_.repeat_last_n = 64; + model_config_.repeat_penalty = 1.0f; + model_config_.mirostat = false; + model_config_.mirostat_tau = 5.0f; + model_config_.mirostat_eta = 0.1f; + model_config_.penalize_nl = false; + model_config_.ignore_eos = false; + model_config_.n_probs = 0; + model_config_.min_keep = 0; + model_config_.grammar = ""; // Get version, bos, eos id, contex_len, ngl from meta data for (const auto& [key, value] : metadata_uint8_) { @@ -522,7 +538,7 @@ void GGUFHandler::ModelConfigFromMetadata() { for (const auto& [key, value] : metadata_string_) { if (key.compare("general.name") == 0) { name = std::regex_replace(value, std::regex(" "), "-"); - } else if (key.compare("tokenizer.chat_template") == 0) { + } else if (key.find("chat_template") != std::string::npos) { if (value.compare(ZEPHYR_JINJA) == 0) { chat_template = "<|system|>\n{system_message}\n<|user|>\n{prompt} eos_token) { + eos_string = tokens[eos_token]; + stop.push_back(std::move(eos_string)); + } else { + LOG_ERROR << "Can't find stop token"; + } + } catch (const std::exception& e) { + LOG_ERROR << "Can't find stop token"; + } model_config_.stop = std::move(stop); - + if (chat_template.empty()) + chat_template = + "[INST] <>\n{system_message}\n<>\n{prompt}[/INST]"; model_config_.prompt_template = std::move(chat_template); model_config_.name = name; model_config_.model = name; diff --git a/engine/config/model_config.h b/engine/config/model_config.h index a8d7f4fda..f61f9e9ba 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -36,5 +36,23 @@ struct ModelConfig { std::size_t created; std::string object; std::string owned_by = ""; + + int seed = -1; + float dynatemp_range = 0.0f; + float dynatemp_exponent = 1.0f; + int top_k = 40; + float min_p = 0.05f; + float tfs_z = 1.0f; + float typ_p = 1.0f; + int repeat_last_n = 64; + float repeat_penalty = 1.0f; + bool mirostat = false; + float mirostat_tau = 5.0f; + float mirostat_eta = 0.1f; + bool penalize_nl = false; + bool ignore_eos = false; + int n_probs = 0; + int min_keep = 0; + std::string grammar; }; } // namespace config diff --git a/engine/config/yaml_config.cc b/engine/config/yaml_config.cc index 81f62df9d..85f086d46 100644 --- a/engine/config/yaml_config.cc +++ b/engine/config/yaml_config.cc @@ -24,6 +24,7 @@ void YamlHandler::ReadYamlFile(const std::string& file_path) { std::vector v; if (yaml_node_["engine"] && yaml_node_["engine"].as() == "cortex.llamacpp") { + // TODO: change prefix to models:// with source from cortexso v.emplace_back(s.substr(0, s.find_last_of('/')) + "/model.gguf"); } else { v.emplace_back(s.substr(0, s.find_last_of('/'))); @@ -36,7 +37,6 @@ void YamlHandler::ReadYamlFile(const std::string& file_path) { std::cerr << "Failed to read file: " << e.what() << std::endl; throw; } - ModelConfigFromYaml(); } void YamlHandler::SplitPromptTemplate(ModelConfig& mc) { if (mc.prompt_template.size() > 0) { @@ -119,6 +119,41 @@ void YamlHandler::ModelConfigFromYaml() { tmp.files = yaml_node_["files"].as>(); if (yaml_node_["created"]) tmp.created = yaml_node_["created"].as(); + + if (yaml_node_["seed"]) + tmp.seed = yaml_node_["seed"].as(); + if (yaml_node_["dynatemp_range"]) + tmp.dynatemp_range = yaml_node_["dynatemp_range"].as(); + if (yaml_node_["dynatemp_exponent"]) + tmp.dynatemp_exponent = yaml_node_["dynatemp_exponent"].as(); + if (yaml_node_["top_k"]) + tmp.top_k = yaml_node_["top_k"].as(); + if (yaml_node_["min_p"]) + tmp.min_p = yaml_node_["min_p"].as(); + if (yaml_node_["tfs_z"]) + tmp.tfs_z = yaml_node_["tfs_z"].as(); + if (yaml_node_["typ_p"]) + tmp.typ_p = yaml_node_["typ_p"].as(); + if (yaml_node_["repeat_last_n"]) + tmp.repeat_last_n = yaml_node_["repeat_last_n"].as(); + if (yaml_node_["repeat_penalty"]) + tmp.repeat_penalty = yaml_node_["repeat_penalty"].as(); + if (yaml_node_["mirostat"]) + tmp.mirostat = yaml_node_["mirostat"].as(); + if (yaml_node_["mirostat_tau"]) + tmp.mirostat_tau = yaml_node_["mirostat_tau"].as(); + if (yaml_node_["mirostat_eta"]) + tmp.mirostat_eta = yaml_node_["mirostat_eta"].as(); + if (yaml_node_["penalize_nl"]) + tmp.penalize_nl = yaml_node_["penalize_nl"].as(); + if (yaml_node_["ignore_eos"]) + tmp.ignore_eos = yaml_node_["ignore_eos"].as(); + if (yaml_node_["n_probs"]) + tmp.n_probs = yaml_node_["n_probs"].as(); + if (yaml_node_["min_keep"]) + tmp.min_keep = yaml_node_["min_keep"].as(); + if (yaml_node_["grammar"]) + tmp.grammar = yaml_node_["grammar"].as(); } catch (const std::exception& e) { std::cerr << "Error when load model config : " << e.what() << std::endl; std::cerr << "Revert ..." << std::endl; @@ -185,6 +220,42 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) { yaml_node_["stop"] = model_config_.stop; if (model_config_.files.size() > 0) yaml_node_["files"] = model_config_.files; + + if (!std::isnan(static_cast(model_config_.seed))) + yaml_node_["seed"] = model_config_.seed; + if (!std::isnan(model_config_.dynatemp_range)) + yaml_node_["dynatemp_range"] = model_config_.dynatemp_range; + if (!std::isnan(model_config_.dynatemp_exponent)) + yaml_node_["dynatemp_exponent"] = model_config_.dynatemp_exponent; + if (!std::isnan(static_cast(model_config_.top_k))) + yaml_node_["top_k"] = model_config_.top_k; + if (!std::isnan(model_config_.min_p)) + yaml_node_["min_p"] = model_config_.min_p; + if (!std::isnan(model_config_.tfs_z)) + yaml_node_["tfs_z"] = model_config_.tfs_z; + if (!std::isnan(model_config_.typ_p)) + yaml_node_["typ_p"] = model_config_.typ_p; + if (!std::isnan(static_cast(model_config_.repeat_last_n))) + yaml_node_["repeat_last_n"] = model_config_.repeat_last_n; + if (!std::isnan(model_config_.repeat_penalty)) + yaml_node_["repeat_penalty"] = model_config_.repeat_penalty; + if (!std::isnan(static_cast(model_config_.mirostat))) + yaml_node_["mirostat"] = model_config_.mirostat; + if (!std::isnan(model_config_.mirostat_tau)) + yaml_node_["mirostat_tau"] = model_config_.mirostat_tau; + if (!std::isnan(model_config_.mirostat_eta)) + yaml_node_["mirostat_eta"] = model_config_.mirostat_eta; + if (!std::isnan(static_cast(model_config_.penalize_nl))) + yaml_node_["penalize_nl"] = model_config_.penalize_nl; + if (!std::isnan(static_cast(model_config_.ignore_eos))) + yaml_node_["ignore_eos"] = model_config_.ignore_eos; + if (!std::isnan(static_cast(model_config_.n_probs))) + yaml_node_["n_probs"] = model_config_.n_probs; + if (!std::isnan(static_cast(model_config_.min_keep))) + yaml_node_["min_keep"] = model_config_.min_keep; + if (!model_config_.grammar.empty()) + yaml_node_["grammar"] = model_config_.grammar; + yaml_node_["created"] = std::time(nullptr); } catch (const std::exception& e) { std::cerr << "Error when update model config : " << e.what() << std::endl; @@ -200,7 +271,97 @@ void YamlHandler::WriteYamlFile(const std::string& file_path) const { if (!outFile) { throw std::runtime_error("Failed to open output file."); } - outFile << yaml_node_; + // Helper function to write a key-value pair with an optional comment + auto writeKeyValue = [&](const std::string& key, const YAML::Node& value, + const std::string& comment = "") { + if (!value) + return; + outFile << key << ": " << value; + if (!comment.empty()) { + outFile << " # " << comment; + } + outFile << "\n"; + }; + + // Write GENERAL GGUF METADATA + outFile << "# BEGIN GENERAL GGUF METADATA\n"; + writeKeyValue("id", yaml_node_["id"], + "Model ID unique between models (author / quantization)"); + writeKeyValue("model", yaml_node_["model"], + "Model ID which is used for request construct - should be " + "unique between models (author / quantization)"); + writeKeyValue("name", yaml_node_["name"], "metadata.general.name"); + writeKeyValue("version", yaml_node_["version"], "metadata.version"); + if (yaml_node_["files"] && yaml_node_["files"].size()) { + outFile << "files: # can be universal protocol (models://) " + "OR absolute local file path (file://) OR https remote URL " + "(https://)\n"; + for (const auto& source : yaml_node_["files"]) { + outFile << " - " << source << "\n"; + } + } + + outFile << "# END GENERAL GGUF METADATA\n"; + outFile << "\n"; + // Write INFERENCE PARAMETERS + outFile << "# BEGIN INFERENCE PARAMETERS\n"; + outFile << "# BEGIN REQUIRED\n"; + if (yaml_node_["stop"] && yaml_node_["stop"].size()) { + outFile << "stop: # tokenizer.ggml.eos_token_id\n"; + for (const auto& stop : yaml_node_["stop"]) { + outFile << " - " << stop << "\n"; + } + } + + outFile << "# END REQUIRED\n"; + outFile << "\n"; + outFile << "# BEGIN OPTIONAL\n"; + writeKeyValue("stream", yaml_node_["stream"], "Default true?"); + writeKeyValue("top_p", yaml_node_["top_p"], "Ranges: 0 to 1"); + writeKeyValue("temperature", yaml_node_["temperature"], "Ranges: 0 to 1"); + writeKeyValue("frequency_penalty", yaml_node_["frequency_penalty"], + "Ranges: 0 to 1"); + writeKeyValue("presence_penalty", yaml_node_["presence_penalty"], + "Ranges: 0 to 1"); + writeKeyValue("max_tokens", yaml_node_["max_tokens"], + "Should be default to context length"); + writeKeyValue("seed", yaml_node_["seed"]); + writeKeyValue("dynatemp_range", yaml_node_["dynatemp_range"]); + writeKeyValue("dynatemp_exponent", yaml_node_["dynatemp_exponent"]); + writeKeyValue("top_k", yaml_node_["top_k"]); + writeKeyValue("min_p", yaml_node_["min_p"]); + writeKeyValue("tfs_z", yaml_node_["tfs_z"]); + writeKeyValue("typ_p", yaml_node_["typ_p"]); + writeKeyValue("repeat_last_n", yaml_node_["repeat_last_n"]); + writeKeyValue("repeat_penalty", yaml_node_["repeat_penalty"]); + writeKeyValue("mirostat", yaml_node_["mirostat"]); + writeKeyValue("mirostat_tau", yaml_node_["mirostat_tau"]); + writeKeyValue("mirostat_eta", yaml_node_["mirostat_eta"]); + writeKeyValue("penalize_nl", yaml_node_["penalize_nl"]); + writeKeyValue("ignore_eos", yaml_node_["ignore_eos"]); + writeKeyValue("n_probs", yaml_node_["n_probs"]); + writeKeyValue("min_keep", yaml_node_["min_keep"]); + writeKeyValue("grammar", yaml_node_["grammar"]); + outFile << "# END OPTIONAL\n"; + outFile << "# END INFERENCE PARAMETERS\n"; + outFile << "\n"; + // Write MODEL LOAD PARAMETERS + outFile << "# BEGIN MODEL LOAD PARAMETERS\n"; + outFile << "# BEGIN REQUIRED\n"; + writeKeyValue("engine", yaml_node_["engine"], "engine to run model"); + outFile << "prompt_template:"; + outFile << " " << yaml_node_["prompt_template"] << "\n"; + outFile << "# END REQUIRED\n"; + outFile << "\n"; + outFile << "# BEGIN OPTIONAL\n"; + writeKeyValue("ctx_len", yaml_node_["ctx_len"], + "llama.context_length | 0 or undefined = loaded from model"); + writeKeyValue("ngl", yaml_node_["ngl"], "Undefined = loaded from model"); + outFile << "# END OPTIONAL\n"; + outFile << "# END MODEL LOAD PARAMETERS\n"; + + // Write new configuration parameters + outFile.close(); } catch (const std::exception& e) { std::cerr << "Error writing to file: " << e.what() << std::endl; diff --git a/engine/controllers/command_line_parser.cc b/engine/controllers/command_line_parser.cc index 9c4b5713f..b55887ebd 100644 --- a/engine/controllers/command_line_parser.cc +++ b/engine/controllers/command_line_parser.cc @@ -294,7 +294,6 @@ void CommandLineParser::EngineGet(CLI::App* parent) { std::string desc = "Get " + engine_name + " status"; auto engine_get_cmd = get_cmd->add_subcommand(engine_name, desc); - engine_get_cmd->require_option(); engine_get_cmd->callback( [engine_name] { commands::EngineGetCmd().Exec(engine_name); }); } diff --git a/engine/e2e-test/test_cli_engine_get.py b/engine/e2e-test/test_cli_engine_get.py index 38c235b30..6b5270eba 100644 --- a/engine/e2e-test/test_cli_engine_get.py +++ b/engine/e2e-test/test_cli_engine_get.py @@ -52,5 +52,6 @@ def test_engines_get_onnx_should_be_incompatible_on_macos(self): @pytest.mark.skipif(platform.system() != "Linux", reason="Linux-specific test") def test_engines_get_onnx_should_be_incompatible_on_linux(self): exit_code, output, error = run("Get engine", ["engines", "get", "cortex.onnx"]) + print(output) assert exit_code == 0, f"Get engine failed with error: {error}" assert "Incompatible" in output, "cortex.onnx should be Incompatible on Linux" diff --git a/engine/e2e-test/test_cli_engine_install.py b/engine/e2e-test/test_cli_engine_install.py index 15cd8deb3..89d49401d 100644 --- a/engine/e2e-test/test_cli_engine_install.py +++ b/engine/e2e-test/test_cli_engine_install.py @@ -31,7 +31,7 @@ def test_engines_install_onnx_on_tensorrt_should_be_failed(self): def test_engines_install_pre_release_llamacpp(self): exit_code, output, error = run( - "Install Engine", ["engines", "install", "cortex.llamacpp", "-v", "v0.1.29"], timeout=60 + "Install Engine", ["engines", "install", "cortex.llamacpp", "-v", "v0.1.29"], timeout=None ) assert "Start downloading" in output, "Should display downloading message" assert exit_code == 0, f"Install engine failed with error: {error}" diff --git a/engine/e2e-test/test_cli_engine_uninstall.py b/engine/e2e-test/test_cli_engine_uninstall.py index 03078a1e6..525c0ad63 100644 --- a/engine/e2e-test/test_cli_engine_uninstall.py +++ b/engine/e2e-test/test_cli_engine_uninstall.py @@ -8,7 +8,7 @@ class TestCliEngineUninstall: def setup_and_teardown(self): # Setup # Preinstall llamacpp engine - run("Install Engine", ["engines", "install", "cortex.llamacpp"]) + run("Install Engine", ["engines", "install", "cortex.llamacpp"],timeout=None) yield diff --git a/engine/test/components/CMakeLists.txt b/engine/test/components/CMakeLists.txt index 32ee36f09..db810ad26 100644 --- a/engine/test/components/CMakeLists.txt +++ b/engine/test/components/CMakeLists.txt @@ -3,13 +3,14 @@ project(test-components) enable_testing() -add_executable(${PROJECT_NAME} ${SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/modellist_utils.cc) +add_executable(${PROJECT_NAME} ${SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/modellist_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../config/yaml_config.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../config/gguf_parser.cc) find_package(Drogon CONFIG REQUIRED) find_package(GTest CONFIG REQUIRED) find_package(yaml-cpp CONFIG REQUIRED) +find_package(jinja2cpp CONFIG REQUIRED) -target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp +target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp jinja2cpp ${CMAKE_THREAD_LIBS_INIT}) target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../) diff --git a/engine/test/components/test_gguf_parser.cc b/engine/test/components/test_gguf_parser.cc new file mode 100644 index 000000000..6c5c61486 --- /dev/null +++ b/engine/test/components/test_gguf_parser.cc @@ -0,0 +1,158 @@ +#include "gtest/gtest.h" +#include "config/gguf_parser.h" +#include "config/yaml_config.h" +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +class GGUFParserTest : public ::testing::Test { +protected: + void SetUp() override { + gguf_handler = std::make_unique(); + yaml_handler = std::make_unique< config::YamlHandler>(); + } + + void TearDown() override { + } + + std::unique_ptr gguf_handler; + std::unique_ptr yaml_handler; + + std::string getTempFilePath(const std::string& prefix, const std::string& extension) { + #ifdef _WIN32 + char temp_path[MAX_PATH]; + char file_name[MAX_PATH]; + GetTempPathA(MAX_PATH, temp_path); + GetTempFileNameA(temp_path, prefix.c_str(), 0, file_name); + std::string path(file_name); + DeleteFileA(file_name); // Delete the file created by GetTempFileNameA + return path + extension; + #else + std::string path = "/tmp/" + prefix + "XXXXXX" + extension; + char* temp = strdup(path.c_str()); + int fd = mkstemps(temp, extension.length()); + if (fd == -1) { + free(temp); + throw std::runtime_error("Failed to create temporary file"); + } + close(fd); + std::string result(temp); + free(temp); + return result; + #endif + } + + std::string createMockGGUFFile() { + std::string gguf_path = getTempFilePath("mock_tinyllama-model", ".gguf"); + std::ofstream file(gguf_path, std::ios::binary); + + if (!file.is_open()) { + throw std::runtime_error("Failed to create mock GGUF file"); + } + + try { + // GGUF magic number + uint32_t magic = 0x46554747; + file.write(reinterpret_cast(&magic), sizeof(magic)); + + // Version + uint32_t version = 2; + file.write(reinterpret_cast(&version), sizeof(version)); + + // Tensor count (not important for this test) + uint64_t tensor_count = 0; + file.write(reinterpret_cast(&tensor_count), sizeof(tensor_count)); + + // Metadata key-value count + uint64_t kv_count = 2; + file.write(reinterpret_cast(&kv_count), sizeof(kv_count)); + + // Helper function to write a string + auto writeString = [&file](const std::string& str) { + uint64_t length = str.length(); + file.write(reinterpret_cast(&length), sizeof(length)); + file.write(str.c_str(), length); + }; + + // Helper function to write a key-value pair + auto writeKV = [&](const std::string& key, uint32_t type, const auto& value) { + writeString(key); + file.write(reinterpret_cast(&type), sizeof(type)); + if constexpr (std::is_same_v) { + writeString(value); + } else { + file.write(reinterpret_cast(&value), sizeof(value)); + } + }; + + // Write metadata + writeKV("general.name", 8, std::string("tinyllama 1B")); + writeKV("llama.context_length", 4, uint32_t(4096)); + + file.close(); + + } catch (const std::exception& e) { + file.close(); + std::remove(gguf_path.c_str()); + throw std::runtime_error(std::string("Failed to write mock GGUF file: ") + e.what()); + } + + return gguf_path; + } +}; + +TEST_F(GGUFParserTest, ParseMockTinyLlamaModel) { + std::string gguf_path; + std::string yaml_path; + try { + // Create a mock GGUF file + gguf_path = createMockGGUFFile(); + + // Parse the GGUF file + gguf_handler->Parse(gguf_path); + + const config::ModelConfig& gguf_config = gguf_handler->GetModelConfig(); + + // Load the expected configuration from YAML + std::string yaml_content = R"( +name: tinyllama-1B +ctx_len: 4096 + )"; + + yaml_path = getTempFilePath("expected_config", ".yaml"); + std::ofstream yaml_file(yaml_path); + yaml_file << yaml_content; + yaml_file.close(); + + yaml_handler->ModelConfigFromFile(yaml_path); + + const config::ModelConfig& yaml_config = yaml_handler->GetModelConfig(); + + // Compare GGUF parsed config with YAML config + EXPECT_EQ(gguf_config.name, yaml_config.name); + EXPECT_EQ(gguf_config.ctx_len, yaml_config.ctx_len); + + // Clean up + std::remove(gguf_path.c_str()); + std::remove(yaml_path.c_str()); + } + catch (const std::exception& e) { + // If an exception was thrown, make sure to clean up the files + if (!gguf_path.empty()) { + std::remove(gguf_path.c_str()); + } + if (!yaml_path.empty()) { + std::remove(yaml_path.c_str()); + } + FAIL() << "Exception thrown: " << e.what(); + } +} \ No newline at end of file diff --git a/engine/test/components/test_yaml_handler.cc b/engine/test/components/test_yaml_handler.cc new file mode 100644 index 000000000..d65ab6b49 --- /dev/null +++ b/engine/test/components/test_yaml_handler.cc @@ -0,0 +1,199 @@ +#include +#include +#include +#include +#include "config/yaml_config.h" +#include "gtest/gtest.h" + +#ifdef _WIN32 +#include +#include +#else +#include +#endif + +class YamlHandlerTest : public ::testing::Test { + protected: + void SetUp() override { handler = new config::YamlHandler(); } + + void TearDown() override { delete handler; } + + config::YamlHandler* handler; + + // Helper function to create a temporary YAML file + std::string createTempYamlFile(const std::string& content) { + std::string filename; +#ifdef _WIN32 + char tempPath[MAX_PATH]; + char tempFileName[MAX_PATH]; + GetTempPathA(MAX_PATH, tempPath); + GetTempFileNameA(tempPath, "yaml", 0, tempFileName); + filename = tempFileName; +#else + char tempFileName[] = "/tmp/yaml_test_XXXXXX"; + int fd = mkstemp(tempFileName); + if (fd == -1) { + throw std::runtime_error("Failed to create temporary file"); + } + close(fd); + filename = tempFileName; +#endif + + std::ofstream file(filename); + file << content; + file.close(); + return filename; + } + + // Helper function to remove a file + void removeFile(const std::string& filename) { + std::remove(filename.c_str()); + } +}; + +TEST_F(YamlHandlerTest, ModelConfigFromFile) { + std::string yaml_content = R"( +name: test_model +model: test_model_v1 +version: 1.0 +engine: test_engine +prompt_template: "Test prompt {system_message} {prompt}" +top_p: 0.9 +temperature: 0.7 +max_tokens: 100 +stream: true +stop: + - "END" +files: + - "test_file.gguf" + )"; + + std::string filename = createTempYamlFile(yaml_content); + + handler->ModelConfigFromFile(filename); + const config::ModelConfig& config = handler->GetModelConfig(); + + EXPECT_EQ(config.name, "test_model"); + EXPECT_EQ(config.model, "test_model_v1"); + EXPECT_EQ(config.version, "1.0"); + EXPECT_EQ(config.engine, "test_engine"); + EXPECT_EQ(config.prompt_template, "Test prompt {system_message} {prompt}"); + EXPECT_FLOAT_EQ(config.top_p, 0.9f); + EXPECT_FLOAT_EQ(config.temperature, 0.7f); + EXPECT_EQ(config.max_tokens, 100); + EXPECT_TRUE(config.stream); + EXPECT_EQ(config.stop.size(), 1); + EXPECT_EQ(config.stop[0], "END"); + EXPECT_EQ(config.files.size(), 1); + EXPECT_EQ(config.files[0], "test_file.gguf"); + + removeFile(filename); +} + +TEST_F(YamlHandlerTest, UpdateModelConfig) { + config::ModelConfig new_config; + new_config.name = "updated_model"; + new_config.model = "updated_model_v2"; + new_config.version = "2.0"; + new_config.engine = "updated_engine"; + new_config.prompt_template = "Updated prompt {system_message} {prompt}"; + new_config.top_p = 0.95f; + new_config.temperature = 0.8f; + new_config.max_tokens = 200; + new_config.stream = false; + new_config.stop = {"STOP", "END"}; + new_config.files = {"updated_file1.gguf", "updated_file2.gguf"}; + + handler->UpdateModelConfig(new_config); + const config::ModelConfig& config = handler->GetModelConfig(); + + EXPECT_EQ(config.name, "updated_model"); + EXPECT_EQ(config.model, "updated_model_v2"); + EXPECT_EQ(config.version, "2.0"); + EXPECT_EQ(config.engine, "updated_engine"); + EXPECT_EQ(config.prompt_template, "Updated prompt {system_message} {prompt}"); + EXPECT_FLOAT_EQ(config.top_p, 0.95f); + EXPECT_FLOAT_EQ(config.temperature, 0.8f); + EXPECT_EQ(config.max_tokens, 200); + EXPECT_FALSE(config.stream); + EXPECT_EQ(config.stop.size(), 2); + EXPECT_EQ(config.stop[0], "STOP"); + EXPECT_EQ(config.stop[1], "END"); + EXPECT_EQ(config.files.size(), 2); + EXPECT_EQ(config.files[0], "updated_file1.gguf"); + EXPECT_EQ(config.files[1], "updated_file2.gguf"); +} + +TEST_F(YamlHandlerTest, WriteYamlFile) { + config::ModelConfig new_config; + new_config.name = "write_test_model"; + new_config.model = "write_test_model_v1"; + new_config.version = "1.0"; + new_config.engine = "write_test_engine"; + new_config.prompt_template = "Write test prompt {system_message} {prompt}"; + new_config.top_p = 0.85f; + new_config.temperature = 0.6f; + new_config.max_tokens = 150; + new_config.stream = true; + new_config.stop = {"HALT"}; + new_config.files = {"write_test_file.gguf"}; + + handler->UpdateModelConfig(new_config); + + std::string filename = createTempYamlFile(""); // Create empty file + handler->WriteYamlFile(filename); + + // Read the file back and verify its contents + config::YamlHandler read_handler; + read_handler.ModelConfigFromFile(filename); + const config::ModelConfig& read_config = read_handler.GetModelConfig(); + + EXPECT_EQ(read_config.name, "write_test_model"); + EXPECT_EQ(read_config.model, "write_test_model_v1"); + EXPECT_EQ(read_config.version, "1.0"); + EXPECT_EQ(read_config.engine, "write_test_engine"); + EXPECT_EQ(read_config.prompt_template, + "Write test prompt {system_message} {prompt}"); + EXPECT_FLOAT_EQ(read_config.top_p, 0.85f); + EXPECT_FLOAT_EQ(read_config.temperature, 0.6f); + EXPECT_EQ(read_config.max_tokens, 150); + EXPECT_TRUE(read_config.stream); + EXPECT_EQ(read_config.stop.size(), 1); + EXPECT_EQ(read_config.stop[0], "HALT"); + EXPECT_EQ(read_config.files.size(), 1); + EXPECT_EQ(read_config.files[0], "write_test_file.gguf"); + + removeFile(filename); +} + +TEST_F(YamlHandlerTest, Reset) { + config::ModelConfig new_config; + new_config.name = "test_reset_model"; + new_config.model = "test_reset_model_v1"; + handler->UpdateModelConfig(new_config); + + handler->Reset(); + const config::ModelConfig& config = handler->GetModelConfig(); + + EXPECT_TRUE(config.name.empty()); + EXPECT_TRUE(config.model.empty()); +} + +TEST_F(YamlHandlerTest, InvalidYamlFile) { + std::string invalid_yaml_content = R"( +name: test_model +model: test_model_v1 +version: 1.0 +engine: test_engine +prompt_template: "Test prompt {system_message} {prompt}" +top_p: not_a_float +seed: also_not_an_int + )"; + + std::string filename = createTempYamlFile(invalid_yaml_content); + handler->ModelConfigFromFile(filename); + config::ModelConfig new_config = handler->GetModelConfig(); + EXPECT_EQ(new_config.seed, -1); + + removeFile(filename); +} \ No newline at end of file