From f966eec37a1ed064fa4835b9064affe9561f3ebf Mon Sep 17 00:00:00 2001 From: James Date: Mon, 9 Dec 2024 13:15:45 +0700 Subject: [PATCH] fix: sort messages by its ulid instead of created_at --- engine/repositories/message_fs_repository.cc | 73 +++++++++++--------- 1 file changed, 42 insertions(+), 31 deletions(-) diff --git a/engine/repositories/message_fs_repository.cc b/engine/repositories/message_fs_repository.cc index 388409390..422242e3a 100644 --- a/engine/repositories/message_fs_repository.cc +++ b/engine/repositories/message_fs_repository.cc @@ -48,7 +48,14 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, const std::string& before, const std::string& run_id) const { CTL_INF("Listing messages for thread " + thread_id); - auto path = GetMessagePath(thread_id); + + // Early validation + if (limit == 0) { + return std::vector(); + } + if (!after.empty() && !before.empty() && after >= before) { + return cpp::fail("Invalid range: 'after' must be less than 'before'"); + } auto mutex = GrabMutex(thread_id); std::shared_lock lock(*mutex); @@ -60,6 +67,11 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, std::vector messages = std::move(read_result.value()); + if (messages.empty()) { + return messages; + } + + // Filter by run_id if (!run_id.empty()) { messages.erase(std::remove_if(messages.begin(), messages.end(), [&run_id](const OpenAi::Message& msg) { @@ -68,52 +80,52 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, messages.end()); } - std::sort(messages.begin(), messages.end(), - [&order](const OpenAi::Message& a, const OpenAi::Message& b) { - if (order == "desc") { - return a.created_at > b.created_at; - } - return a.created_at < b.created_at; - }); + const bool is_descending = (order == "desc"); + std::sort( + messages.begin(), messages.end(), + [is_descending](const OpenAi::Message& a, const OpenAi::Message& b) { + return is_descending ? (a.id > b.id) : (a.id < b.id); + }); auto start_it = messages.begin(); auto end_it = messages.end(); if (!after.empty()) { - start_it = std::find_if( - messages.begin(), messages.end(), - [&after](const OpenAi::Message& msg) { return msg.id == after; }); - if (start_it != messages.end()) { - ++start_it; // Start from the message after the 'after' message - } else { - start_it = messages.begin(); + start_it = std::lower_bound( + messages.begin(), messages.end(), after, + [is_descending](const OpenAi::Message& msg, const std::string& value) { + return is_descending ? (msg.id > value) : (msg.id < value); + }); + + if (start_it != messages.end() && start_it->id == after) { + ++start_it; } } if (!before.empty()) { - end_it = std::find_if( - messages.begin(), messages.end(), - [&before](const OpenAi::Message& msg) { return msg.id == before; }); + end_it = std::upper_bound( + start_it, messages.end(), before, + [is_descending](const std::string& value, const OpenAi::Message& msg) { + return is_descending ? (value > msg.id) : (value < msg.id); + }); } - std::vector result; - size_t distance = std::distance(start_it, end_it); - size_t limit_size = static_cast(limit); - CTL_INF("Distance: " + std::to_string(distance) + - ", limit_size: " + std::to_string(limit_size)); - result.reserve(distance < limit_size ? distance : limit_size); + const size_t available_messages = std::distance(start_it, end_it); + const size_t result_size = + std::min(static_cast(limit), available_messages); - for (auto it = start_it; it != end_it && result.size() < limit_size; ++it) { - result.push_back(std::move(*it)); - } + CTL_INF("Available messages: " + std::to_string(available_messages) + + ", result size: " + std::to_string(result_size)); + + std::vector result; + result.reserve(result_size); + std::move(start_it, start_it + result_size, std::back_inserter(result)); return result; } cpp::result MessageFsRepository::RetrieveMessage( const std::string& thread_id, const std::string& message_id) const { - auto path = GetMessagePath(thread_id); - auto mutex = GrabMutex(thread_id); std::unique_lock lock(*mutex); @@ -133,8 +145,6 @@ cpp::result MessageFsRepository::RetrieveMessage( cpp::result MessageFsRepository::ModifyMessage( OpenAi::Message& message) { - auto path = GetMessagePath(message.thread_id); - auto mutex = GrabMutex(message.thread_id); std::unique_lock lock(*mutex); @@ -143,6 +153,7 @@ cpp::result MessageFsRepository::ModifyMessage( return cpp::fail(messages.error()); } + auto path = GetMessagePath(message.thread_id); std::ofstream file(path, std::ios::trunc); if (!file) { return cpp::fail("Failed to open file for writing: " + path.string());