diff --git a/engine/commands/chat_completion_cmd.cc b/engine/commands/chat_completion_cmd.cc index ef12e575d..be7b9d170 100644 --- a/engine/commands/chat_completion_cmd.cc +++ b/engine/commands/chat_completion_cmd.cc @@ -7,6 +7,7 @@ #include "server_start_cmd.h" #include "utils/engine_constants.h" #include "utils/logging_utils.h" +#include "utils/string_utils.h" namespace commands { namespace { @@ -100,6 +101,8 @@ void ChatCompletionCmd::Exec(const std::string& host, int port, break; } } + + string_utils::Trim(user_input); if (user_input == kExitChat) { break; } diff --git a/engine/test/components/test_string_utils.cc b/engine/test/components/test_string_utils.cc index 3d6abeddf..1b16858c4 100644 --- a/engine/test/components/test_string_utils.cc +++ b/engine/test/components/test_string_utils.cc @@ -1,8 +1,8 @@ -#include -#include #include "gtest/gtest.h" #include "utils/string_utils.h" + class StringUtilsTestSuite : public ::testing::Test {}; + TEST_F(StringUtilsTestSuite, ParsePrompt) { { std::string prompt = @@ -94,3 +94,82 @@ TEST_F(StringUtilsTestSuite, TestEndsWithWithEmptySuffix) { auto suffix = ""; EXPECT_TRUE(string_utils::EndsWith(input, suffix)); } + +TEST_F(StringUtilsTestSuite, EmptyString) { + std::string s = ""; + string_utils::Trim(s); + EXPECT_EQ(s, ""); +} + +TEST_F(StringUtilsTestSuite, NoWhitespace) { + std::string s = "hello"; + string_utils::Trim(s); + EXPECT_EQ(s, "hello"); +} + +TEST_F(StringUtilsTestSuite, LeadingWhitespace) { + std::string s = " hello"; + string_utils::Trim(s); + EXPECT_EQ(s, "hello"); +} + +TEST_F(StringUtilsTestSuite, TrailingWhitespace) { + std::string s = "hello "; + string_utils::Trim(s); + EXPECT_EQ(s, "hello"); +} + +TEST_F(StringUtilsTestSuite, BothEndsWhitespace) { + std::string s = " hello "; + string_utils::Trim(s); + EXPECT_EQ(s, "hello"); +} + +TEST_F(StringUtilsTestSuite, ExitString) { + std::string s = "exit() "; + string_utils::Trim(s); + EXPECT_EQ(s, "exit()"); +} + +TEST_F(StringUtilsTestSuite, AllWhitespace) { + std::string s = " "; + string_utils::Trim(s); + EXPECT_EQ(s, ""); +} + +TEST_F(StringUtilsTestSuite, MixedWhitespace) { + std::string s = " \t\n hello world \r\n "; + string_utils::Trim(s); + EXPECT_EQ(s, "hello world"); +} + +TEST_F(StringUtilsTestSuite, EqualStrings) { + EXPECT_TRUE(string_utils::EqualsIgnoreCase("hello", "hello")); + EXPECT_TRUE(string_utils::EqualsIgnoreCase("WORLD", "WORLD")); +} + +TEST_F(StringUtilsTestSuite, DifferentCaseStrings) { + EXPECT_TRUE(string_utils::EqualsIgnoreCase("Hello", "hElLo")); + EXPECT_TRUE(string_utils::EqualsIgnoreCase("WORLD", "world")); + EXPECT_TRUE(string_utils::EqualsIgnoreCase("MiXeD", "mIxEd")); +} + +TEST_F(StringUtilsTestSuite, EmptyStrings) { + EXPECT_TRUE(string_utils::EqualsIgnoreCase("", "")); +} + +TEST_F(StringUtilsTestSuite, DifferentStrings) { + EXPECT_FALSE(string_utils::EqualsIgnoreCase("hello", "world")); + EXPECT_FALSE(string_utils::EqualsIgnoreCase("HELLO", "hello world")); +} + +TEST_F(StringUtilsTestSuite, DifferentLengthStrings) { + EXPECT_FALSE(string_utils::EqualsIgnoreCase("short", "longer string")); + EXPECT_FALSE(string_utils::EqualsIgnoreCase("LONG STRING", "long")); +} + +TEST_F(StringUtilsTestSuite, SpecialCharacters) { + EXPECT_TRUE(string_utils::EqualsIgnoreCase("Hello!", "hElLo!")); + EXPECT_TRUE(string_utils::EqualsIgnoreCase("123 ABC", "123 abc")); + EXPECT_FALSE(string_utils::EqualsIgnoreCase("Hello!", "Hello")); +} diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h index 1be4584d1..3af6dda82 100644 --- a/engine/utils/string_utils.h +++ b/engine/utils/string_utils.h @@ -13,6 +13,16 @@ struct ParsePromptResult { std::string ai_prompt; }; +inline void Trim(std::string& s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { + return !std::isspace(ch); + })); + s.erase(std::find_if(s.rbegin(), s.rend(), + [](unsigned char ch) { return !std::isspace(ch); }) + .base(), + s.end()); +} + inline bool EqualsIgnoreCase(const std::string& a, const std::string& b) { return std::equal(a.begin(), a.end(), b.begin(), b.end(), [](char a, char b) { return tolower(a) == tolower(b); });