From e45c57b77bb847d2e7729a75017cdc4e380223f1 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 29 Oct 2024 16:23:04 +0700 Subject: [PATCH 1/2] fix: add n_parallel to model yaml config --- engine/config/model_config.h | 6 ++++++ engine/config/yaml_config.cc | 5 +++++ engine/test/components/test_yaml_handler.cc | 6 ++++++ 3 files changed, 17 insertions(+) diff --git a/engine/config/model_config.h b/engine/config/model_config.h index ffc158b88..bc3a7ec25 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -22,6 +22,7 @@ struct ModelConfig { bool stream = std::numeric_limits::quiet_NaN(); int ngl = std::numeric_limits::quiet_NaN(); int ctx_len = std::numeric_limits::quiet_NaN(); + int n_parallel = 1; std::string engine; std::string prompt_template; std::string system_template; @@ -125,6 +126,8 @@ struct ModelConfig { ngl = json["ngl"].asInt(); if (json.isMember("ctx_len")) ctx_len = json["ctx_len"].asInt(); + if (json.isMember("n_parallel")) + n_parallel = json["n_parallel"].asInt(); if (json.isMember("engine")) engine = json["engine"].asString(); if (json.isMember("prompt_template")) @@ -204,6 +207,7 @@ struct ModelConfig { obj["min_keep"] = min_keep; obj["ngl"] = ngl; obj["ctx_len"] = ctx_len; + obj["n_parallel"] = n_parallel; obj["engine"] = engine; obj["prompt_template"] = prompt_template; obj["system_template"] = system_template; @@ -313,6 +317,8 @@ struct ModelConfig { if (ctx_len != std::numeric_limits::quiet_NaN()) oss << format_utils::print_kv("ctx_len", std::to_string(ctx_len), format_utils::MAGENTA); + oss << format_utils::print_kv("n_parallel", std::to_string(n_parallel), + format_utils::MAGENTA); if (ngl != std::numeric_limits::quiet_NaN()) oss << format_utils::print_kv("ngl", std::to_string(ngl), format_utils::MAGENTA); diff --git a/engine/config/yaml_config.cc b/engine/config/yaml_config.cc index 8bd34c109..99f8103d8 100644 --- a/engine/config/yaml_config.cc +++ b/engine/config/yaml_config.cc @@ -113,6 +113,8 @@ void YamlHandler::ModelConfigFromYaml() { tmp.ngl = yaml_node_["ngl"].as(); if (yaml_node_["ctx_len"]) tmp.ctx_len = yaml_node_["ctx_len"].as(); + if (yaml_node_["n_parallel"]) + tmp.n_parallel = yaml_node_["n_parallel"].as(); if (yaml_node_["tp"]) tmp.tp = yaml_node_["tp"].as(); if (yaml_node_["stream"]) @@ -216,6 +218,8 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) { yaml_node_["ngl"] = model_config_.ngl; if (!std::isnan(static_cast(model_config_.ctx_len))) yaml_node_["ctx_len"] = model_config_.ctx_len; + if (!std::isnan(static_cast(model_config_.n_parallel))) + yaml_node_["n_parallel"] = model_config_.n_parallel; if (!std::isnan(static_cast(model_config_.tp))) yaml_node_["tp"] = model_config_.tp; if (!std::isnan(static_cast(model_config_.stream))) @@ -368,6 +372,7 @@ void YamlHandler::WriteYamlFile(const std::string& file_path) const { outFile << format_utils::writeKeyValue( "ctx_len", yaml_node_["ctx_len"], "llama.context_length | 0 or undefined = loaded from model"); + outFile << format_utils::writeKeyValue("n_parallel", yaml_node_["n_parallel"]); outFile << format_utils::writeKeyValue("ngl", yaml_node_["ngl"], "Undefined = loaded from model"); outFile << "# END OPTIONAL\n"; diff --git a/engine/test/components/test_yaml_handler.cc b/engine/test/components/test_yaml_handler.cc index d65ab6b49..f699e0c6a 100644 --- a/engine/test/components/test_yaml_handler.cc +++ b/engine/test/components/test_yaml_handler.cc @@ -62,6 +62,7 @@ top_p: 0.9 temperature: 0.7 max_tokens: 100 stream: true +n_parallel: 2 stop: - "END" files: @@ -82,6 +83,7 @@ stream: true EXPECT_FLOAT_EQ(config.temperature, 0.7f); EXPECT_EQ(config.max_tokens, 100); EXPECT_TRUE(config.stream); + EXPECT_EQ(config.n_parallel, 2); EXPECT_EQ(config.stop.size(), 1); EXPECT_EQ(config.stop[0], "END"); EXPECT_EQ(config.files.size(), 1); @@ -101,6 +103,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) { new_config.temperature = 0.8f; new_config.max_tokens = 200; new_config.stream = false; + new_config.n_parallel = 2; new_config.stop = {"STOP", "END"}; new_config.files = {"updated_file1.gguf", "updated_file2.gguf"}; @@ -116,6 +119,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) { EXPECT_FLOAT_EQ(config.temperature, 0.8f); EXPECT_EQ(config.max_tokens, 200); EXPECT_FALSE(config.stream); + EXPECT_EQ(config.n_parallel, 2); EXPECT_EQ(config.stop.size(), 2); EXPECT_EQ(config.stop[0], "STOP"); EXPECT_EQ(config.stop[1], "END"); @@ -135,6 +139,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) { new_config.temperature = 0.6f; new_config.max_tokens = 150; new_config.stream = true; + new_config.n_parallel = 2; new_config.stop = {"HALT"}; new_config.files = {"write_test_file.gguf"}; @@ -158,6 +163,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) { EXPECT_FLOAT_EQ(read_config.temperature, 0.6f); EXPECT_EQ(read_config.max_tokens, 150); EXPECT_TRUE(read_config.stream); + EXPECT_EQ(read_config.n_parallel, 2); EXPECT_EQ(read_config.stop.size(), 1); EXPECT_EQ(read_config.stop[0], "HALT"); EXPECT_EQ(read_config.files.size(), 1); From a97194aeb492c39b9a9e60c76dc2d6bca304be40 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 29 Oct 2024 20:44:06 +0700 Subject: [PATCH 2/2] fix: models update --- engine/cli/command_line_parser.cc | 1 + engine/cli/commands/model_upd_cmd.cc | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index f75625e36..1eea35196 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -546,6 +546,7 @@ void CommandLineParser::ModelUpdate(CLI::App* parent) { "stream", "ngl", "ctx_len", + "n_parallel", "engine", "prompt_template", "system_template", diff --git a/engine/cli/commands/model_upd_cmd.cc b/engine/cli/commands/model_upd_cmd.cc index 0d907357f..af37efd5f 100644 --- a/engine/cli/commands/model_upd_cmd.cc +++ b/engine/cli/commands/model_upd_cmd.cc @@ -223,6 +223,12 @@ void ModelUpdCmd::UpdateConfig(Json::Value& data, const std::string& key, data["ctx_len"] = static_cast(f); }); }}, + {"n_parallel", + [this](Json::Value &data, const std::string& k, const std::string& v) { + UpdateNumericField(k, v, [&data](float f) { + data["n_parallel"] = static_cast(f); + }); + }}, {"tp", [this](Json::Value &data, const std::string& k, const std::string& v) { UpdateNumericField(k, v, [&data](float f) {