From 8d5618ea8d0ddd1235498fc3740bc5da3bcbfcae Mon Sep 17 00:00:00 2001 From: James Date: Tue, 3 Dec 2024 22:58:23 +0700 Subject: [PATCH 1/3] feat: add thread --- ...e_response.h => delete_success_response.h} | 2 +- engine/common/message.h | 14 +- engine/common/message_attachment.h | 4 +- engine/common/message_attachment_factory.h | 6 +- engine/common/message_content.h | 4 +- engine/common/message_content_factory.h | 4 +- engine/common/message_content_image_file.h | 4 +- engine/common/message_content_image_url.h | 4 +- engine/common/message_content_refusal.h | 4 +- engine/common/message_content_text.h | 4 +- engine/common/message_incomplete_detail.h | 4 +- engine/common/message_role.h | 4 +- engine/common/message_status.h | 4 +- engine/common/repository/message_repository.h | 19 +- engine/common/repository/thread_repository.h | 25 ++ engine/common/thread.h | 102 ++++++++ engine/common/thread_tool_resources.h | 50 ++++ engine/controllers/messages.cc | 22 +- engine/controllers/threads.cc | 220 ++++++++++++++++++ engine/controllers/threads.h | 57 +++++ engine/main.cc | 41 ++-- engine/repositories/message_fs_repository.cc | 142 ++++++----- engine/repositories/message_fs_repository.h | 46 +++- engine/repositories/thread_fs_repository.cc | 165 +++++++++++++ engine/repositories/thread_fs_repository.h | 62 +++++ engine/services/message_service.cc | 62 +++-- engine/services/message_service.h | 23 +- engine/services/thread_service.cc | 81 +++++++ engine/services/thread_service.h | 36 +++ 29 files changed, 1023 insertions(+), 192 deletions(-) rename engine/common/api-dto/{messages/delete_message_response.h => delete_success_response.h} (87%) create mode 100644 engine/common/repository/thread_repository.h create mode 100644 engine/common/thread.h create mode 100644 engine/common/thread_tool_resources.h create mode 100644 engine/controllers/threads.cc create mode 100644 engine/controllers/threads.h create mode 100644 engine/repositories/thread_fs_repository.cc create mode 100644 engine/repositories/thread_fs_repository.h create mode 100644 engine/services/thread_service.cc create mode 100644 engine/services/thread_service.h diff --git a/engine/common/api-dto/messages/delete_message_response.h b/engine/common/api-dto/delete_success_response.h similarity index 87% rename from engine/common/api-dto/messages/delete_message_response.h rename to engine/common/api-dto/delete_success_response.h index 79447c93a..ebb8f36f0 100644 --- a/engine/common/api-dto/messages/delete_message_response.h +++ b/engine/common/api-dto/delete_success_response.h @@ -3,7 +3,7 @@ #include "common/json_serializable.h" namespace api_response { -struct DeleteMessageResponse : JsonSerializable { +struct DeleteSuccessResponse : JsonSerializable { std::string id; std::string object; bool deleted; diff --git a/engine/common/message.h b/engine/common/message.h index e5685f3bb..cfd069515 100644 --- a/engine/common/message.h +++ b/engine/common/message.h @@ -17,20 +17,10 @@ #include "utils/logging_utils.h" #include "utils/result.hpp" -namespace ThreadMessage { +namespace OpenAi { // Represents a message within a thread. struct Message : JsonSerializable { - Message() = default; - - Message(Message&&) = default; - - Message& operator=(Message&&) = default; - - Message(const Message&) = delete; - - Message& operator=(const Message&) = delete; - // The identifier, which can be referenced in API endpoints. std::string id; @@ -210,4 +200,4 @@ struct Message : JsonSerializable { } } }; -}; // namespace ThreadMessage +}; // namespace OpenAi diff --git a/engine/common/message_attachment.h b/engine/common/message_attachment.h index ea809990e..767ec9bea 100644 --- a/engine/common/message_attachment.h +++ b/engine/common/message_attachment.h @@ -3,7 +3,7 @@ #include #include "common/json_serializable.h" -namespace ThreadMessage { +namespace OpenAi { // The tools to add this file to. struct Tool { @@ -47,4 +47,4 @@ struct Attachment : JsonSerializable { } } }; -}; // namespace ThreadMessage +}; // namespace OpenAi diff --git a/engine/common/message_attachment_factory.h b/engine/common/message_attachment_factory.h index d9f1b8d2e..ce4eef60b 100644 --- a/engine/common/message_attachment_factory.h +++ b/engine/common/message_attachment_factory.h @@ -1,8 +1,10 @@ +#pragma once + #include #include "common/message_attachment.h" #include "utils/result.hpp" -namespace ThreadMessage { +namespace OpenAi { inline cpp::result ParseAttachment( Json::Value&& json) { if (json.empty()) { @@ -45,4 +47,4 @@ ParseAttachments(Json::Value&& json) { return attachments; } -}; // namespace ThreadMessage +}; // namespace OpenAi diff --git a/engine/common/message_content.h b/engine/common/message_content.h index 6e76b01a8..a86dc58ed 100644 --- a/engine/common/message_content.h +++ b/engine/common/message_content.h @@ -3,7 +3,7 @@ #include #include "common/json_serializable.h" -namespace ThreadMessage { +namespace OpenAi { struct Content : JsonSerializable { std::string type; @@ -20,4 +20,4 @@ struct Content : JsonSerializable { virtual ~Content() = default; }; -}; // namespace ThreadMessage +}; // namespace OpenAi diff --git a/engine/common/message_content_factory.h b/engine/common/message_content_factory.h index 854f6efd8..6f8fcb4fe 100644 --- a/engine/common/message_content_factory.h +++ b/engine/common/message_content_factory.h @@ -8,7 +8,7 @@ #include "utils/logging_utils.h" #include "utils/result.hpp" -namespace ThreadMessage { +namespace OpenAi { inline cpp::result, std::string> ParseContent( Json::Value&& json) { if (json.empty()) { @@ -74,4 +74,4 @@ ParseContents(Json::Value&& json) { } return contents; } -} // namespace ThreadMessage +} // namespace OpenAi diff --git a/engine/common/message_content_image_file.h b/engine/common/message_content_image_file.h index 1807dec1e..c3ec57853 100644 --- a/engine/common/message_content_image_file.h +++ b/engine/common/message_content_image_file.h @@ -2,7 +2,7 @@ #include "common/message_content.h" -namespace ThreadMessage { +namespace OpenAi { struct ImageFile { // The File ID of the image in the message content. Set purpose="vision" when uploading the File if you need to later display the file content. std::string file_id; @@ -66,4 +66,4 @@ struct ImageFileContent : Content { } } }; -} // namespace ThreadMessage +} // namespace OpenAi diff --git a/engine/common/message_content_image_url.h b/engine/common/message_content_image_url.h index eae6a7aa6..b86544e38 100644 --- a/engine/common/message_content_image_url.h +++ b/engine/common/message_content_image_url.h @@ -2,7 +2,7 @@ #include "common/message_content.h" -namespace ThreadMessage { +namespace OpenAi { struct ImageUrl { // The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp. @@ -68,4 +68,4 @@ struct ImageUrlContent : Content { } } }; -} // namespace ThreadMessage +} // namespace OpenAi diff --git a/engine/common/message_content_refusal.h b/engine/common/message_content_refusal.h index 8353c3a85..c2537ccbf 100644 --- a/engine/common/message_content_refusal.h +++ b/engine/common/message_content_refusal.h @@ -2,7 +2,7 @@ #include "common/message_content.h" -namespace ThreadMessage { +namespace OpenAi { // The refusal content generated by the assistant. struct Refusal : Content { @@ -43,4 +43,4 @@ struct Refusal : Content { } } }; -} // namespace ThreadMessage +} // namespace OpenAi diff --git a/engine/common/message_content_text.h b/engine/common/message_content_text.h index 124d4a878..ea6aab1ab 100644 --- a/engine/common/message_content_text.h +++ b/engine/common/message_content_text.h @@ -3,7 +3,7 @@ #include "common/message_content.h" #include "utils/logging_utils.h" -namespace ThreadMessage { +namespace OpenAi { struct Annotation : JsonSerializable { std::string type; @@ -239,4 +239,4 @@ struct TextContent : Content { } } }; -} // namespace ThreadMessage +} // namespace OpenAi diff --git a/engine/common/message_incomplete_detail.h b/engine/common/message_incomplete_detail.h index 25e9c1169..98e6ff56b 100644 --- a/engine/common/message_incomplete_detail.h +++ b/engine/common/message_incomplete_detail.h @@ -2,7 +2,7 @@ #include "common/json_serializable.h" -namespace ThreadMessage { +namespace OpenAi { // On an incomplete message, details about why the message is incomplete. struct IncompleteDetail : JsonSerializable { @@ -29,4 +29,4 @@ struct IncompleteDetail : JsonSerializable { } } }; -} // namespace ThreadMessage +} // namespace OpenAi diff --git a/engine/common/message_role.h b/engine/common/message_role.h index 9d428eddc..504e2e5f6 100644 --- a/engine/common/message_role.h +++ b/engine/common/message_role.h @@ -3,7 +3,7 @@ #include #include "utils/string_utils.h" -namespace ThreadMessage { +namespace OpenAi { // The entity that produced the message. One of user or assistant. enum class Role { USER, ASSISTANT }; @@ -27,4 +27,4 @@ inline Role RoleFromString(const std::string& input) { return Role::ASSISTANT; } } -}; // namespace ThreadMessage +}; // namespace OpenAi diff --git a/engine/common/message_status.h b/engine/common/message_status.h index e8844ee13..453617363 100644 --- a/engine/common/message_status.h +++ b/engine/common/message_status.h @@ -3,7 +3,7 @@ #include #include "utils/string_utils.h" -namespace ThreadMessage { +namespace OpenAi { // The status of the message, which can be either in_progress, incomplete, or completed. enum class Status { IN_PROGRESS, INCOMPLETE, COMPLETED }; @@ -31,4 +31,4 @@ inline Status StatusFromString(const std::string& input) { return Status::COMPLETED; } } -}; // namespace ThreadMessage +}; // namespace OpenAi diff --git a/engine/common/repository/message_repository.h b/engine/common/repository/message_repository.h index cffc73675..a8a971fd8 100644 --- a/engine/common/repository/message_repository.h +++ b/engine/common/repository/message_repository.h @@ -6,22 +6,25 @@ class MessageRepository { public: virtual cpp::result CreateMessage( - ThreadMessage::Message& message) = 0; + OpenAi::Message& message) = 0; - virtual cpp::result, std::string> - ListMessages(const std::string& thread_id, uint8_t limit = 20, - const std::string& order = "desc", const std::string& after = "", - const std::string& before = "", - const std::string& run_id = "") const = 0; + virtual cpp::result, std::string> ListMessages( + const std::string& thread_id, uint8_t limit, const std::string& order, + const std::string& after, const std::string& before, + const std::string& run_id) const = 0; - virtual cpp::result RetrieveMessage( + virtual cpp::result RetrieveMessage( const std::string& thread_id, const std::string& message_id) const = 0; virtual cpp::result ModifyMessage( - ThreadMessage::Message& message) = 0; + OpenAi::Message& message) = 0; virtual cpp::result DeleteMessage( const std::string& thread_id, const std::string& message_id) = 0; + virtual cpp::result InitializeMessages( + const std::string& thread_id, + std::optional> messages) = 0; + virtual ~MessageRepository() = default; }; diff --git a/engine/common/repository/thread_repository.h b/engine/common/repository/thread_repository.h new file mode 100644 index 000000000..c7bb9e7cf --- /dev/null +++ b/engine/common/repository/thread_repository.h @@ -0,0 +1,25 @@ +#pragma once + +#include "common/thread.h" +#include "utils/result.hpp" + +class ThreadRepository { + public: + virtual cpp::result CreateThread( + OpenAi::Thread& thread) = 0; + + virtual cpp::result, std::string> ListThreads( + uint8_t limit, const std::string& order, const std::string&, + const std::string& before) const = 0; + + virtual cpp::result RetrieveThread( + const std::string& thread_id) const = 0; + + virtual cpp::result ModifyThread( + OpenAi::Thread& thread) = 0; + + virtual cpp::result DeleteThread( + const std::string& thread_id) = 0; + + virtual ~ThreadRepository() = default; +}; diff --git a/engine/common/thread.h b/engine/common/thread.h new file mode 100644 index 000000000..434ecd50e --- /dev/null +++ b/engine/common/thread.h @@ -0,0 +1,102 @@ +#pragma once + +#include +#include +#include +#include "common/thread_tool_resources.h" +#include "common/variant_map.h" +#include "json_serializable.h" +#include "utils/logging_utils.h" + +namespace OpenAi { + +/** + * Represents a thread that contains messages. + */ +struct Thread : JsonSerializable { + /** + * The identifier, which can be referenced in API endpoints. + */ + std::string id; + + /** + * The object type, which is always thread. + */ + std::string object = "thread"; + + /** + * The Unix timestamp (in seconds) for when the thread was created. + */ + uint64_t created_at; + + /** + * A set of resources that are made available to the assistant's + * tools in this thread. The resources are specific to the type + * of tool. For example, the code_interpreter tool requires a list of + * file IDs, while the file_search tool requires a list of vector store IDs. + */ + std::optional> tool_resources; + + /** + * Set of 16 key-value pairs that can be attached to an object. + * This can be useful for storing additional information about the object + * in a structured format. + * + * Keys can be a maximum of 64 characters long and values can be a maximum + * of 512 characters long. + */ + Cortex::VariantMap metadata; + + static cpp::result FromJson(const Json::Value& json) { + Thread thread; + + thread.id = json["id"].asString(); + thread.object = "thread"; + thread.created_at = json["created_at"].asUInt(); + if (thread.created_at == 0 && json["created"].asUInt64() != 0) { + thread.created_at = json["created"].asUInt64() / 1000; + } + // TODO: namh parse tool_resources + + if (json["metadata"].isObject() && !json["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(json["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + thread.metadata = res.value(); + } + } + + return thread; + } + + cpp::result ToJson() override { + try { + Json::Value json; + + json["id"] = id; + json["object"] = object; + json["created_at"] = created_at; + // TODO: namh handle tool_resources + + Json::Value metadata_json{Json::objectValue}; + for (const auto& [key, value] : metadata) { + if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else { + metadata_json[key] = std::get(value); + } + } + json["metadata"] = metadata_json; + + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace OpenAi diff --git a/engine/common/thread_tool_resources.h b/engine/common/thread_tool_resources.h new file mode 100644 index 000000000..3c22a4480 --- /dev/null +++ b/engine/common/thread_tool_resources.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include "common/json_serializable.h" + +namespace OpenAi { + +struct ThreadToolResources : JsonSerializable { + ~ThreadToolResources() = default; + + virtual cpp::result ToJson() override = 0; +}; + +struct ThreadCodeInterpreter : ThreadToolResources { + std::vector file_ids; + + cpp::result ToJson() override { + try { + Json::Value json; + Json::Value file_ids_json{Json::arrayValue}; + for (auto& file_id : file_ids) { + file_ids_json.append(file_id); + } + json["file_ids"] = file_ids_json; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; + +struct ThreadFileSearch : ThreadToolResources { + std::vector vector_store_ids; + + cpp::result ToJson() override { + try { + Json::Value json; + Json::Value vector_store_ids_json{Json::arrayValue}; + for (auto& vector_store_id : vector_store_ids) { + vector_store_ids_json.append(vector_store_id); + } + json["vector_store_ids"] = vector_store_ids_json; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace OpenAi diff --git a/engine/controllers/messages.cc b/engine/controllers/messages.cc index 55d9f6370..ef82b3412 100644 --- a/engine/controllers/messages.cc +++ b/engine/controllers/messages.cc @@ -1,5 +1,5 @@ #include "messages.h" -#include "common/api-dto/messages/delete_message_response.h" +#include "common/api-dto/delete_success_response.h" #include "common/message_content.h" #include "common/message_role.h" #include "common/variant_map.h" @@ -75,16 +75,13 @@ void Messages::CreateMessage( return; } - ThreadMessage::Role role = role_str == "user" - ? ThreadMessage::Role::USER - : ThreadMessage::Role::ASSISTANT; + auto role = role_str == "user" ? OpenAi::Role::USER : OpenAi::Role::ASSISTANT; - std::variant>> + std::variant>> content; if (json_body->get("content", "").isArray()) { - auto result = ThreadMessage::ParseContents(json_body->get("content", "")); + auto result = OpenAi::ParseContents(json_body->get("content", "")); if (result.has_error()) { Json::Value ret; ret["message"] = "Failed to parse content array: " + result.error(); @@ -128,12 +125,11 @@ void Messages::CreateMessage( } // attachments - std::optional> attachments = - std::nullopt; + std::optional> attachments = std::nullopt; if (json_body->get("attachments", "").isArray()) { - attachments = ThreadMessage::ParseAttachments( - std::move(json_body->get("attachments", ""))) - .value(); + attachments = + OpenAi::ParseAttachments(std::move(json_body->get("attachments", ""))) + .value(); } std::optional metadata = std::nullopt; @@ -287,7 +283,7 @@ void Messages::DeleteMessage( return; } - api_response::DeleteMessageResponse response; + api_response::DeleteSuccessResponse response; response.id = message_id; response.object = "thread.message.deleted"; response.deleted = true; diff --git a/engine/controllers/threads.cc b/engine/controllers/threads.cc new file mode 100644 index 000000000..b8aec1fd5 --- /dev/null +++ b/engine/controllers/threads.cc @@ -0,0 +1,220 @@ +#include "threads.h" +#include "common/api-dto/delete_success_response.h" +#include "common/variant_map.h" +#include "utils/cortex_utils.h" +#include "utils/logging_utils.h" + +void Threads::ListThreads( + const HttpRequestPtr& req, + std::function&& callback, + std::optional limit, std::optional order, + std::optional after, std::optional before) const { + CTL_INF("ListThreads"); + auto res = + thread_service_->ListThreads(limit.value_or(20), order.value_or("desc"), + after.value_or(""), before.value_or("")); + + if (res.has_error()) { + Json::Value root; + root["message"] = res.error(); + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k400BadRequest); + callback(response); + return; + } + Json::Value msg_arr(Json::arrayValue); + for (auto& msg : res.value()) { + if (auto it = msg.ToJson(); it.has_value()) { + msg_arr.append(it.value()); + } else { + CTL_WRN("Failed to convert message to json: " + it.error()); + } + } + + Json::Value root; + root["object"] = "list"; + root["data"] = msg_arr; + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k200OK); + callback(response); +} + +void Threads::CreateThread( + const HttpRequestPtr& req, + std::function&& callback) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // TODO: namh handle tool_resources + // TODO: namh handle messages + + std::optional metadata = std::nullopt; + if (json_body->get("metadata", "").isObject()) { + auto res = Cortex::ConvertJsonValueToMap(json_body->get("metadata", "")); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + metadata = res.value(); + } + } + + auto res = thread_service_->CreateThread(std::nullopt, metadata); + + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto init_msg_res = + message_service_->InitializeMessages(res->id, std::nullopt); + + if (res.has_error()) { + CTL_ERR("Failed to convert message to json: " + res.error()); + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Threads::RetrieveThread( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id) const { + auto res = thread_service_->RetrieveThread(thread_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto thread_to_json = res->ToJson(); + if (thread_to_json.has_error()) { + CTL_ERR("Failed to convert message to json: " + thread_to_json.error()); + Json::Value ret; + ret["message"] = thread_to_json.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Threads::ModifyThread( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + std::optional metadata = std::nullopt; + if (auto it = json_body->get("metadata", ""); it) { + if (it.empty()) { + Json::Value ret; + ret["message"] = "Metadata can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + auto convert_res = Cortex::ConvertJsonValueToMap(it); + if (convert_res.has_error()) { + Json::Value ret; + ret["message"] = + "Failed to convert metadata to map: " + convert_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + metadata = convert_res.value(); + } + + if (!metadata.has_value()) { + Json::Value ret; + ret["message"] = "Metadata is mandatory"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // TODO: namh handle tools + auto res = + thread_service_->ModifyThread(thread_id, std::nullopt, metadata.value()); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto message_to_json = res->ToJson(); + if (message_to_json.has_error()) { + CTL_ERR("Failed to convert message to json: " + message_to_json.error()); + Json::Value ret; + ret["message"] = message_to_json.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Threads::DeleteThread( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id) { + auto res = thread_service_->DeleteThread(thread_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + api_response::DeleteSuccessResponse response; + response.id = thread_id; + response.object = "thread.deleted"; + response.deleted = true; + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(response.ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} diff --git a/engine/controllers/threads.h b/engine/controllers/threads.h new file mode 100644 index 000000000..92c509525 --- /dev/null +++ b/engine/controllers/threads.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include "services/message_service.h" +#include "services/thread_service.h" + +using namespace drogon; + +class Threads : public drogon::HttpController { + public: + METHOD_LIST_BEGIN + ADD_METHOD_TO(Threads::CreateThread, "/v1/threads", Options, Post); + + ADD_METHOD_TO(Threads::ListThreads, + "/v1/" + "threads?limit={limit}&order={order}&after={after}&before={" + "before}", + Get); + + ADD_METHOD_TO(Threads::RetrieveThread, "/v1/threads/{thread_id}", Get); + ADD_METHOD_TO(Threads::ModifyThread, "/v1/threads/{thread_id}", Options, + Post); + ADD_METHOD_TO(Threads::DeleteThread, "/v1/threads/{thread_id}", Options, + Delete); + METHOD_LIST_END + + explicit Threads(std::shared_ptr thread_srv, + std::shared_ptr msg_srv) + : thread_service_{thread_srv}, message_service_{msg_srv} {} + + void CreateThread(const HttpRequestPtr& req, + std::function&& callback); + + void ListThreads(const HttpRequestPtr& req, + std::function&& callback, + std::optional limit, + std::optional order, + std::optional after, + std::optional before) const; + + void RetrieveThread(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id) const; + + void ModifyThread(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id); + + void DeleteThread(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id); + + private: + std::shared_ptr thread_service_; + std::shared_ptr message_service_; +}; diff --git a/engine/main.cc b/engine/main.cc index d076c02bd..0177a2143 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -1,7 +1,6 @@ #include #include #include -#include "common/repository/message_repository.h" #include "controllers/configs.h" #include "controllers/engines.h" #include "controllers/events.h" @@ -10,24 +9,23 @@ #include "controllers/models.h" #include "controllers/process_manager.h" #include "controllers/server.h" -#include "cortex-common/cortexpythoni.h" +#include "controllers/threads.h" #include "database/database.h" #include "migrations/migration_manager.h" #include "repositories/message_fs_repository.h" +#include "repositories/thread_fs_repository.h" #include "services/config_service.h" #include "services/file_watcher_service.h" #include "services/message_service.h" #include "services/model_service.h" +#include "services/thread_service.h" #include "utils/archive_utils.h" #include "utils/cortex_utils.h" -#include "utils/dylib.h" #include "utils/event_processor.h" #include "utils/file_logger.h" #include "utils/file_manager_utils.h" -#include "utils/hardware/gguf/gguf_file_estimate.h" #include "utils/logging_utils.h" #include "utils/system_info_utils.h" -#include "utils/widechar_conv.h" #if defined(__APPLE__) && defined(__MACH__) #include // for dirname() @@ -40,6 +38,7 @@ #include // for readlink() #elif defined(_WIN32) #include +#include "utils/widechar_conv.h" #undef max #else #error "Unsupported platform!" @@ -120,9 +119,14 @@ void RunServer(std::optional port, bool ignore_cout) { auto event_queue_ptr = std::make_shared(); cortex::event::EventProcessor event_processor(event_queue_ptr); - std::shared_ptr msg_repo = - std::make_shared(); + auto msg_repo = std::make_shared( + file_manager_utils::GetCortexDataPath()); + auto thread_repo = std::make_shared( + file_manager_utils::GetCortexDataPath()); + + auto thread_srv = std::make_shared(thread_repo); auto message_srv = std::make_shared(msg_repo); + auto model_dir_path = file_manager_utils::GetModelsContainerPath(); auto config_service = std::make_shared(); auto download_service = @@ -138,6 +142,7 @@ void RunServer(std::optional port, bool ignore_cout) { file_watcher_srv->start(); // initialize custom controllers + auto thread_ctl = std::make_shared(thread_srv, message_srv); auto message_ctl = std::make_shared(message_srv); auto engine_ctl = std::make_shared(engine_service); auto model_ctl = std::make_shared(model_service, engine_service); @@ -148,6 +153,7 @@ void RunServer(std::optional port, bool ignore_cout) { std::make_shared(inference_svc, engine_service); auto config_ctl = std::make_shared(config_service); + drogon::app().registerController(thread_ctl); drogon::app().registerController(message_ctl); drogon::app().registerController(engine_ctl); drogon::app().registerController(model_ctl); @@ -318,27 +324,6 @@ int main(int argc, char* argv[]) { } } - // // Check if this process is for python execution - // if (argc > 1) { - // if (strcmp(argv[1], "--run_python_file") == 0) { - // std::string py_home_path = (argc > 3) ? argv[3] : ""; - // std::unique_ptr dl; - // try { - // std::string abs_path = - // cortex_utils::GetCurrentPath() + kPythonRuntimeLibPath; - // dl = std::make_unique(abs_path, "engine"); - // } catch (const cortex_cpp::dylib::load_error& e) { - // LOG_ERROR << "Could not load engine: " << e.what(); - // return 1; - // } - - // auto func = dl->get_function("get_engine"); - // auto e = func(); - // e->ExecutePythonFile(argv[0], argv[2], py_home_path); - // return 0; - // } - // } - RunServer(server_port, ignore_cout_log); return 0; } diff --git a/engine/repositories/message_fs_repository.cc b/engine/repositories/message_fs_repository.cc index 60cc0b5bf..d0db0e400 100644 --- a/engine/repositories/message_fs_repository.cc +++ b/engine/repositories/message_fs_repository.cc @@ -1,32 +1,21 @@ #include "message_fs_repository.h" -#include "utils/file_manager_utils.h" +#include #include "utils/result.hpp" -namespace { -constexpr static const std::string_view kMessageFile = "messages.jsonl"; - -inline cpp::result GetMessageFileAbsPath( - const std::string& thread_id) { - auto path = - file_manager_utils::GetThreadsContainerPath() / thread_id / kMessageFile; - if (!std::filesystem::exists(path)) { - return cpp::fail("Message file not exist at path: " + path.string()); - } - return path; +std::filesystem::path MessageFsRepository::GetMessagePath( + const std::string& thread_id) const { + return data_folder_path_ / kThreadContainerFolderName / thread_id / + kMessageFile; } -} // namespace cpp::result MessageFsRepository::CreateMessage( - ThreadMessage::Message& message) { + OpenAi::Message& message) { CTL_INF("CreateMessage for thread " + message.thread_id); - auto path = GetMessageFileAbsPath(message.thread_id); - if (path.has_error()) { - return cpp::fail(path.error()); - } + auto path = GetMessagePath(message.thread_id); - std::ofstream file(path->string(), std::ios::app); + std::ofstream file(path, std::ios::app); if (!file) { - return cpp::fail("Failed to open file for writing: " + path->string()); + return cpp::fail("Failed to open file for writing: " + path.string()); } auto mutex = GrabMutex(message.thread_id); @@ -40,27 +29,24 @@ cpp::result MessageFsRepository::CreateMessage( file.flush(); if (file.fail()) { - return cpp::fail("Failed to write to file: " + path->string()); + return cpp::fail("Failed to write to file: " + path.string()); } file.close(); if (file.fail()) { - return cpp::fail("Failed to close file after writing: " + path->string()); + return cpp::fail("Failed to close file after writing: " + path.string()); } return {}; } -cpp::result, std::string> +cpp::result, std::string> MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, const std::string& order, const std::string& after, const std::string& before, const std::string& run_id) const { CTL_INF("Listing messages for thread " + thread_id); - auto path = GetMessageFileAbsPath(thread_id); - if (path.has_error()) { - return cpp::fail(path.error()); - } + auto path = GetMessagePath(thread_id); auto mutex = GrabMutex(thread_id); std::shared_lock lock(*mutex); @@ -68,13 +54,9 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, return ReadMessageFromFile(thread_id); } -cpp::result -MessageFsRepository::RetrieveMessage(const std::string& thread_id, - const std::string& message_id) const { - auto path = GetMessageFileAbsPath(thread_id); - if (path.has_error()) { - return cpp::fail(path.error()); - } +cpp::result MessageFsRepository::RetrieveMessage( + const std::string& thread_id, const std::string& message_id) const { + auto path = GetMessagePath(thread_id); auto mutex = GrabMutex(thread_id); std::unique_lock lock(*mutex); @@ -94,11 +76,8 @@ MessageFsRepository::RetrieveMessage(const std::string& thread_id, } cpp::result MessageFsRepository::ModifyMessage( - ThreadMessage::Message& message) { - auto path = GetMessageFileAbsPath(message.thread_id); - if (path.has_error()) { - return cpp::fail(path.error()); - } + OpenAi::Message& message) { + auto path = GetMessagePath(message.thread_id); auto mutex = GrabMutex(message.thread_id); std::unique_lock lock(*mutex); @@ -108,10 +87,9 @@ cpp::result MessageFsRepository::ModifyMessage( return cpp::fail(messages.error()); } - std::ofstream file(path.value().string(), std::ios::trunc); + std::ofstream file(path, std::ios::trunc); if (!file) { - return cpp::fail("Failed to open file for writing: " + - path.value().string()); + return cpp::fail("Failed to open file for writing: " + path.string()); } bool found = false; @@ -126,11 +104,11 @@ cpp::result MessageFsRepository::ModifyMessage( file.flush(); if (file.fail()) { - return cpp::fail("Failed to write to file: " + path->string()); + return cpp::fail("Failed to write to file: " + path.string()); } file.close(); if (file.fail()) { - return cpp::fail("Failed to close file after writing: " + path->string()); + return cpp::fail("Failed to close file after writing: " + path.string()); } if (!found) { @@ -141,10 +119,7 @@ cpp::result MessageFsRepository::ModifyMessage( cpp::result MessageFsRepository::DeleteMessage( const std::string& thread_id, const std::string& message_id) { - auto path = GetMessageFileAbsPath(thread_id); - if (path.has_error()) { - return cpp::fail(path.error()); - } + auto path = GetMessagePath(thread_id); auto mutex = GrabMutex(thread_id); std::unique_lock lock(*mutex); @@ -153,10 +128,9 @@ cpp::result MessageFsRepository::DeleteMessage( return cpp::fail(messages.error()); } - std::ofstream file(path.value().string(), std::ios::trunc); + std::ofstream file(path, std::ios::trunc); if (!file) { - return cpp::fail("Failed to open file for writing: " + - path.value().string()); + return cpp::fail("Failed to open file for writing: " + path.string()); } bool found = false; @@ -170,11 +144,11 @@ cpp::result MessageFsRepository::DeleteMessage( file.flush(); if (file.fail()) { - return cpp::fail("Failed to write to file: " + path->string()); + return cpp::fail("Failed to write to file: " + path.string()); } file.close(); if (file.fail()) { - return cpp::fail("Failed to close file after writing: " + path->string()); + return cpp::fail("Failed to close file after writing: " + path.string()); } if (!found) { @@ -184,26 +158,22 @@ cpp::result MessageFsRepository::DeleteMessage( return {}; } -cpp::result, std::string> +cpp::result, std::string> MessageFsRepository::ReadMessageFromFile(const std::string& thread_id) const { LOG_TRACE << "Reading messages from file for thread " << thread_id; - auto path = GetMessageFileAbsPath(thread_id); - if (path.has_error()) { - return cpp::fail(path.error()); - } + auto path = GetMessagePath(thread_id); - std::ifstream file(path.value()); + std::ifstream file(path); if (!file) { - return cpp::fail("Failed to open file: " + path->string()); + return cpp::fail("Failed to open file: " + path.string()); } - std::vector messages; + std::vector messages; std::string line; while (std::getline(file, line)) { if (line.empty()) continue; - auto msg_parse_result = - ThreadMessage::Message::FromJsonString(std::move(line)); + auto msg_parse_result = OpenAi::Message::FromJsonString(std::move(line)); if (msg_parse_result.has_error()) { CTL_WRN("Failed to parse message: " + msg_parse_result.error()); continue; @@ -224,3 +194,49 @@ std::shared_mutex* MessageFsRepository::GrabMutex( } return thread_mutex.get(); } + +cpp::result MessageFsRepository::InitializeMessages( + const std::string& thread_id, + std::optional> messages) { + CTL_INF("Initializing messages for thread " + thread_id); + + auto path = GetMessagePath(thread_id); + + if (!std::filesystem::exists(path.parent_path())) { + return cpp::fail( + "Failed to initialize messages, thread is not created yet! Path does " + "not exist: " + + path.parent_path().string()); + } + + auto mutex = GrabMutex(thread_id); + std::unique_lock lock(*mutex); + + std::ofstream file(path, std::ios::trunc); + if (!file) { + return cpp::fail("Failed to create message file: " + path.string()); + } + + if (messages.has_value()) { + for (auto& message : messages.value()) { + auto json_str = message.ToSingleLineJsonString(); + if (json_str.has_error()) { + CTL_WRN("Failed to serialize message: " + json_str.error()); + continue; + } + file << json_str.value(); + } + } + + file.flush(); + if (file.fail()) { + return cpp::fail("Failed to write to file: " + path.string()); + } + + file.close(); + if (file.fail()) { + return cpp::fail("Failed to close file after writing: " + path.string()); + } + + return {}; +} diff --git a/engine/repositories/message_fs_repository.h b/engine/repositories/message_fs_repository.h index d8bcd02a7..2146778bf 100644 --- a/engine/repositories/message_fs_repository.h +++ b/engine/repositories/message_fs_repository.h @@ -1,39 +1,63 @@ #pragma once +#include #include #include #include "common/repository/message_repository.h" class MessageFsRepository : public MessageRepository { + constexpr static auto kMessageFile = "messages.jsonl"; + constexpr static auto kThreadContainerFolderName = "threads"; + public: cpp::result CreateMessage( - ThreadMessage::Message& message) override; + OpenAi::Message& message) override; - cpp::result, std::string> ListMessages( - const std::string& thread_id, uint8_t limit = 20, - const std::string& order = "desc", const std::string& after = "", - const std::string& before = "", - const std::string& run_id = "") const override; + cpp::result, std::string> ListMessages( + const std::string& thread_id, uint8_t limit, const std::string& order, + const std::string& after, const std::string& before, + const std::string& run_id) const override; - cpp::result RetrieveMessage( + cpp::result RetrieveMessage( const std::string& thread_id, const std::string& message_id) const override; cpp::result ModifyMessage( - ThreadMessage::Message& message) override; + OpenAi::Message& message) override; cpp::result DeleteMessage( const std::string& thread_id, const std::string& message_id) override; + cpp::result InitializeMessages( + const std::string& thread_id, + std::optional> messages) override; + + explicit MessageFsRepository(std::filesystem::path data_folder_path) + : data_folder_path_{data_folder_path} { + CTL_INF("Constructing MessageFsRepository.."); + auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; + + if (!std::filesystem::exists(thread_container_path)) { + std::filesystem::create_directories(thread_container_path); + } + } + ~MessageFsRepository() = default; private: - cpp::result, std::string> - ReadMessageFromFile(const std::string& thread_id) const; + cpp::result, std::string> ReadMessageFromFile( + const std::string& thread_id) const; + + /** + * The path to the data folder. + */ + std::filesystem::path data_folder_path_; + + std::filesystem::path GetMessagePath(const std::string& thread_id) const; std::shared_mutex* GrabMutex(const std::string& thread_id) const; + mutable std::mutex mutex_map_mutex_; mutable std::unordered_map> thread_mutexes_; - mutable std::mutex mutex_map_mutex_; }; diff --git a/engine/repositories/thread_fs_repository.cc b/engine/repositories/thread_fs_repository.cc new file mode 100644 index 000000000..0f805dd7e --- /dev/null +++ b/engine/repositories/thread_fs_repository.cc @@ -0,0 +1,165 @@ +#include "thread_fs_repository.h" +#include + +cpp::result, std::string> +ThreadFsRepository::ListThreads(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + CTL_INF("ListThreads: limit=" + std::to_string(limit) + ", order=" + order + + ", after=" + after + ", before=" + before); + std::vector threads; + + try { + auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; + for (const auto& entry : + std::filesystem::directory_iterator(thread_container_path)) { + if (!entry.is_directory()) + continue; + + if (!std::filesystem::exists(entry.path() / kThreadFileName)) + continue; + + auto current_thread_id = entry.path().filename().string(); + CTL_INF("ListThreads: Found thread: " + current_thread_id); + std::shared_lock thread_lock(GrabThreadMutex(current_thread_id)); + + auto thread_result = LoadThread(current_thread_id); + if (thread_result.has_value()) { + threads.push_back(std::move(thread_result.value())); + } + + thread_lock.unlock(); + } + + return threads; + } catch (const std::exception& e) { + return cpp::fail(std::string("Failed to list threads: ") + e.what()); + } +} + +std::shared_mutex& ThreadFsRepository::GrabThreadMutex( + const std::string& thread_id) const { + std::shared_lock map_lock(map_mutex_); + auto it = thread_mutexes_.find(thread_id); + if (it != thread_mutexes_.end()) { + return *it->second; + } + + map_lock.unlock(); + std::unique_lock map_write_lock(map_mutex_); + return *thread_mutexes_ + .try_emplace(thread_id, std::make_unique()) + .first->second; +} + +std::filesystem::path ThreadFsRepository::GetThreadPath( + const std::string& thread_id) const { + return data_folder_path_ / kThreadContainerFolderName / thread_id; +} + +cpp::result ThreadFsRepository::LoadThread( + const std::string& thread_id) const { + auto path = GetThreadPath(thread_id) / kThreadFileName; + if (!std::filesystem::exists(path)) { + return cpp::fail("Path does not exist: " + path.string()); + } + + try { + std::ifstream file(path); + if (!file.is_open()) { + return cpp::fail("Failed to open file: " + path.string()); + } + + Json::Value root; + Json::CharReaderBuilder builder; + JSONCPP_STRING errs; + + if (!parseFromStream(builder, file, &root, &errs)) { + return cpp::fail("Failed to parse JSON: " + errs); + } + + return OpenAi::Thread::FromJson(root); + } catch (const std::exception& e) { + return cpp::fail("Failed to load thread: " + std::string(e.what())); + } +} + +cpp::result ThreadFsRepository::CreateThread( + OpenAi::Thread& thread) { + CTL_INF("CreateThread: " + thread.id); + std::unique_lock lock(GrabThreadMutex(thread.id)); + auto thread_path = GetThreadPath(thread.id); + + if (std::filesystem::exists(thread_path)) { + return cpp::fail("Thread exists: " + thread.id); + } + + std::filesystem::create_directories(thread_path); + auto thread_file_path = thread_path / kThreadFileName; + std::ofstream thread_file(thread_file_path); + thread_file.close(); + + return SaveThread(thread); +} + +cpp::result ThreadFsRepository::SaveThread( + OpenAi::Thread& thread) { + auto path = GetThreadPath(thread.id) / kThreadFileName; + if (!std::filesystem::exists(path)) { + return cpp::fail("Path does not exist: " + path.string()); + } + + std::ofstream file(path); + try { + if (!file) { + return cpp::fail("Failed to open file: " + path.string()); + } + file << thread.ToJson()->toStyledString(); + file.flush(); + file.close(); + return {}; + } catch (const std::exception& e) { + file.close(); + return cpp::fail("Failed to save thread: " + std::string(e.what())); + } +} + +cpp::result ThreadFsRepository::RetrieveThread( + const std::string& thread_id) const { + std::shared_lock lock(GrabThreadMutex(thread_id)); + return LoadThread(thread_id); +} + +cpp::result ThreadFsRepository::ModifyThread( + OpenAi::Thread& thread) { + std::unique_lock lock(GrabThreadMutex(thread.id)); + auto thread_path = GetThreadPath(thread.id); + + if (!std::filesystem::exists(thread_path)) { + return cpp::fail("Thread doesn't exist: " + thread.id); + } + + return SaveThread(thread); +} + +cpp::result ThreadFsRepository::DeleteThread( + const std::string& thread_id) { + CTL_INF("DeleteThread: " + thread_id); + + { + std::unique_lock thread_lock(GrabThreadMutex(thread_id)); + auto path = GetThreadPath(thread_id); + if (!std::filesystem::exists(path)) { + return cpp::fail("Thread doesn't exist: " + thread_id); + } + try { + std::filesystem::remove_all(path); + } catch (const std::exception& e) { + return cpp::failure(std::string("Failed to delete thread: ") + e.what()); + } + } + + std::unique_lock map_lock(map_mutex_); + thread_mutexes_.erase(thread_id); + return {}; +} diff --git a/engine/repositories/thread_fs_repository.h b/engine/repositories/thread_fs_repository.h new file mode 100644 index 000000000..230440153 --- /dev/null +++ b/engine/repositories/thread_fs_repository.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include +#include +#include "common/repository/thread_repository.h" +#include "common/thread.h" +#include "utils/logging_utils.h" + +class ThreadFsRepository : public ThreadRepository { + private: + constexpr static auto kThreadFileName = "thread.json"; + constexpr static auto kThreadContainerFolderName = "threads"; + + mutable std::shared_mutex map_mutex_; + mutable std::unordered_map> + thread_mutexes_; + + /** + * The path to the data folder. + */ + std::filesystem::path data_folder_path_; + + std::shared_mutex& GrabThreadMutex(const std::string& thread_id) const; + + std::filesystem::path GetThreadPath(const std::string& thread_id) const; + + /** + * Read the thread file and parse to Thread from the file system. + */ + cpp::result LoadThread( + const std::string& thread_id) const; + + cpp::result SaveThread(OpenAi::Thread& thread); + + public: + explicit ThreadFsRepository(std::filesystem::path data_folder_path) + : data_folder_path_{data_folder_path} { + CTL_INF("Constructing ThreadFsRepository.."); + auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; + + if (!std::filesystem::exists(thread_container_path)) { + std::filesystem::create_directories(thread_container_path); + } + } + + cpp::result CreateThread(OpenAi::Thread& thread) override; + + cpp::result, std::string> ListThreads( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const override; + + cpp::result RetrieveThread( + const std::string& thread_id) const override; + + cpp::result ModifyThread(OpenAi::Thread& thread) override; + + cpp::result DeleteThread( + const std::string& thread_id) override; + + ~ThreadFsRepository() = default; +}; diff --git a/engine/services/message_service.cc b/engine/services/message_service.cc index 31ae38420..dfad74236 100644 --- a/engine/services/message_service.cc +++ b/engine/services/message_service.cc @@ -3,40 +3,39 @@ #include "utils/result.hpp" #include "utils/ulid/ulid.hh" -cpp::result MessageService::CreateMessage( - const std::string& thread_id, const ThreadMessage::Role& role, - std::variant>>&& +cpp::result MessageService::CreateMessage( + const std::string& thread_id, const OpenAi::Role& role, + std::variant>>&& content, - std::optional> attachments, + std::optional> attachments, std::optional metadata) { LOG_TRACE << "CreateMessage for thread " << thread_id; - auto now = std::chrono::system_clock::now(); + auto seconds_since_epoch = - std::chrono::duration_cast(now.time_since_epoch()) + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) .count(); - std::vector> content_list{}; + std::vector> content_list{}; + // if content is string if (std::holds_alternative(content)) { - auto text_content = std::make_unique(); + auto text_content = std::make_unique(); text_content->text.value = std::get(content); content_list.push_back(std::move(text_content)); } else { content_list = std::move( - std::get>>( - content)); + std::get>>(content)); } - ulid::ULID ulid = ulid::Create(seconds_since_epoch, []() { return 4; }); - std::string str = ulid::Marshal(ulid); - LOG_TRACE << "Generated message ID: " << str; + auto ulid = ulid::CreateNowRand(); + auto msg_id = ulid::Marshal(ulid); - ThreadMessage::Message msg; - msg.id = str; + OpenAi::Message msg; + msg.id = msg_id; msg.object = "thread.message"; msg.created_at = 0; msg.thread_id = thread_id; - msg.status = ThreadMessage::Status::COMPLETED; + msg.status = OpenAi::Status::COMPLETED; msg.completed_at = seconds_since_epoch; msg.incomplete_at = std::nullopt; msg.incomplete_details = std::nullopt; @@ -54,23 +53,23 @@ cpp::result MessageService::CreateMessage( } } -cpp::result, std::string> +cpp::result, std::string> MessageService::ListMessages(const std::string& thread_id, uint8_t limit, const std::string& order, const std::string& after, const std::string& before, const std::string& run_id) const { CTL_INF("ListMessages for thread " + thread_id); - return message_repository_->ListMessages(thread_id); + return message_repository_->ListMessages(thread_id, limit, order, after, + before, run_id); } -cpp::result -MessageService::RetrieveMessage(const std::string& thread_id, - const std::string& message_id) const { +cpp::result MessageService::RetrieveMessage( + const std::string& thread_id, const std::string& message_id) const { CTL_INF("RetrieveMessage for thread " + thread_id); return message_repository_->RetrieveMessage(thread_id, message_id); } -cpp::result MessageService::ModifyMessage( +cpp::result MessageService::ModifyMessage( const std::string& thread_id, const std::string& message_id, std::optional metadata) { LOG_TRACE << "ModifyMessage for thread " << thread_id << ", message " @@ -103,3 +102,20 @@ cpp::result MessageService::DeleteMessage( return message_id; } } + +cpp::result MessageService::InitializeMessages( + const std::string& thread_id, + std::optional> messages) { + CTL_INF("InitializeMessages for thread " + thread_id); + + if (messages.has_value()) { + CTL_INF("Prepopulated messages length: " + + std::to_string(messages->size())); + } else { + + CTL_INF("Prepopulated with empty messages"); + } + + return message_repository_->InitializeMessages(thread_id, + std::move(messages)); +} diff --git a/engine/services/message_service.h b/engine/services/message_service.h index e62970b54..6c4880f32 100644 --- a/engine/services/message_service.h +++ b/engine/services/message_service.h @@ -9,27 +9,28 @@ class MessageService { explicit MessageService(std::shared_ptr message_repository) : message_repository_{message_repository} {} - cpp::result CreateMessage( - const std::string& thread_id, const ThreadMessage::Role& role, - std::variant>>&& + cpp::result CreateMessage( + const std::string& thread_id, const OpenAi::Role& role, + std::variant>>&& content, - std::optional> attachments, + std::optional> attachments, std::optional metadata); - cpp::result, std::string> ListMessages( + cpp::result InitializeMessages( + const std::string& thread_id, + std::optional> messages); + + cpp::result, std::string> ListMessages( const std::string& thread_id, uint8_t limit = 20, const std::string& order = "desc", const std::string& after = "", const std::string& before = "", const std::string& run_id = "") const; - cpp::result RetrieveMessage( + cpp::result RetrieveMessage( const std::string& thread_id, const std::string& message_id) const; - cpp::result ModifyMessage( + cpp::result ModifyMessage( const std::string& thread_id, const std::string& message_id, - std::optional>> - metadata); + std::optional metadata); cpp::result DeleteMessage( const std::string& thread_id, const std::string& message_id); diff --git a/engine/services/thread_service.cc b/engine/services/thread_service.cc new file mode 100644 index 000000000..685edbb97 --- /dev/null +++ b/engine/services/thread_service.cc @@ -0,0 +1,81 @@ +#include "thread_service.h" +#include "utils/logging_utils.h" +#include "utils/ulid/ulid.hh" + +cpp::result ThreadService::CreateThread( + std::optional> tool_resources, + std::optional metadata) { + LOG_TRACE << "CreateThread"; + + auto seconds_since_epoch = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + + auto ulid = ulid::CreateNowRand(); + auto thread_id = ulid::Marshal(ulid); + + OpenAi::Thread thread; + thread.id = thread_id; + thread.object = "thread"; + thread.created_at = seconds_since_epoch; + + if (tool_resources.has_value()) { + thread.tool_resources = std::move(tool_resources.value()); + } + thread.metadata = metadata.value_or(Cortex::VariantMap{}); + + if (auto res = thread_repository_->CreateThread(thread); res.has_error()) { + return cpp::fail("Failed to create message: " + res.error()); + } + + return thread; +} + +cpp::result, std::string> +ThreadService::ListThreads(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + CTL_INF("ListThreads"); + return thread_repository_->ListThreads(limit, order, after, before); +} + +cpp::result ThreadService::RetrieveThread( + const std::string& thread_id) const { + LOG_TRACE << "RetriveThread: " << thread_id; + return thread_repository_->RetrieveThread(thread_id); +} + +cpp::result ThreadService::ModifyThread( + const std::string& thread_id, + std::optional> tool_resources, + std::optional metadata) { + LOG_TRACE << "ModifyThread " << thread_id; + auto retrieve_res = RetrieveThread(thread_id); + if (retrieve_res.has_error()) { + return cpp::fail("Failed to retrieve thread: " + retrieve_res.error()); + } + + retrieve_res->tool_resources = std::move(tool_resources.value()); + retrieve_res->metadata = std::move(metadata.value()); + + auto res = thread_repository_->ModifyThread(retrieve_res.value()); + if (res.has_error()) { + CTL_ERR("Failed to modify thread: " + res.error()); + return cpp::fail("Failed to modify thread: " + res.error()); + } else { + return RetrieveThread(thread_id); + } +} + +cpp::result ThreadService::DeleteThread( + const std::string& thread_id) { + LOG_TRACE << "DeleteThread: " + thread_id; + auto res = thread_repository_->DeleteThread(thread_id); + if (res.has_error()) { + LOG_ERROR << "Failed to delete thread: " + res.error(); + return cpp::fail("Failed to delete thread: " + res.error()); + } else { + return thread_id; + } +} diff --git a/engine/services/thread_service.h b/engine/services/thread_service.h new file mode 100644 index 000000000..608fba22b --- /dev/null +++ b/engine/services/thread_service.h @@ -0,0 +1,36 @@ +#pragma once + +#include "common/repository/thread_repository.h" +#include "common/thread_tool_resources.h" +#include "common/variant_map.h" +#include "utils/result.hpp" + +class ThreadService { + public: + explicit ThreadService(std::shared_ptr thread_repository) + : thread_repository_{thread_repository} {} + + cpp::result CreateThread( + std::optional> + tool_resources, + std::optional metadata); + + cpp::result, std::string> ListThreads( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const; + + cpp::result RetrieveThread( + const std::string& thread_id) const; + + cpp::result ModifyThread( + const std::string& thread_id, + std::optional> + tool_resources, + std::optional metadata); + + cpp::result DeleteThread( + const std::string& thread_id); + + private: + std::shared_ptr thread_repository_; +}; From 263382d8fac5a9b66362385575a8cf0a2608f9b5 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 4 Dec 2024 19:39:50 +0700 Subject: [PATCH 2/3] fix build windows Signed-off-by: James --- engine/common/message.h | 8 ++++++++ engine/common/thread.h | 1 + engine/repositories/thread_fs_repository.cc | 2 +- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/engine/common/message.h b/engine/common/message.h index cfd069515..909a843ee 100644 --- a/engine/common/message.h +++ b/engine/common/message.h @@ -21,6 +21,14 @@ namespace OpenAi { // Represents a message within a thread. struct Message : JsonSerializable { + Message() = default; + // Delete copy operations + Message(const Message&) = delete; + Message& operator=(const Message&) = delete; + // Allow move operations + Message(Message&&) = default; + Message& operator=(Message&&) = default; + // The identifier, which can be referenced in API endpoints. std::string id; diff --git a/engine/common/thread.h b/engine/common/thread.h index 434ecd50e..d442179c2 100644 --- a/engine/common/thread.h +++ b/engine/common/thread.h @@ -3,6 +3,7 @@ #include #include #include +#include #include "common/thread_tool_resources.h" #include "common/variant_map.h" #include "json_serializable.h" diff --git a/engine/repositories/thread_fs_repository.cc b/engine/repositories/thread_fs_repository.cc index 0f805dd7e..824982e47 100644 --- a/engine/repositories/thread_fs_repository.cc +++ b/engine/repositories/thread_fs_repository.cc @@ -155,7 +155,7 @@ cpp::result ThreadFsRepository::DeleteThread( try { std::filesystem::remove_all(path); } catch (const std::exception& e) { - return cpp::failure(std::string("Failed to delete thread: ") + e.what()); + return cpp::fail(std::string("Failed to delete thread: ") + e.what()); } } From 0bd7178c196c1cde0179c7bf1d3ca2f5917c5020 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 4 Dec 2024 20:05:41 +0700 Subject: [PATCH 3/3] fix ci --- engine/common/thread.h | 47 ++++++++++++++++++-- engine/controllers/threads.cc | 4 +- engine/repositories/message_fs_repository.cc | 1 + engine/repositories/thread_fs_repository.cc | 1 + engine/repositories/thread_fs_repository.h | 2 +- engine/services/thread_service.cc | 14 +++--- engine/services/thread_service.h | 7 ++- 7 files changed, 59 insertions(+), 17 deletions(-) diff --git a/engine/common/thread.h b/engine/common/thread.h index d442179c2..20672ff72 100644 --- a/engine/common/thread.h +++ b/engine/common/thread.h @@ -3,7 +3,6 @@ #include #include #include -#include #include "common/thread_tool_resources.h" #include "common/variant_map.h" #include "json_serializable.h" @@ -36,7 +35,7 @@ struct Thread : JsonSerializable { * of tool. For example, the code_interpreter tool requires a list of * file IDs, while the file_search tool requires a list of vector store IDs. */ - std::optional> tool_resources; + std::unique_ptr tool_resources; /** * Set of 16 key-value pairs that can be attached to an object. @@ -57,7 +56,30 @@ struct Thread : JsonSerializable { if (thread.created_at == 0 && json["created"].asUInt64() != 0) { thread.created_at = json["created"].asUInt64() / 1000; } - // TODO: namh parse tool_resources + + if (json.isMember("tool_resources") && !json["tool_resources"].isNull()) { + const auto& tool_json = json["tool_resources"]; + + if (tool_json.isMember("code_interpreter")) { + auto code_interpreter = std::make_unique(); + const auto& file_ids = tool_json["code_interpreter"]["file_ids"]; + if (file_ids.isArray()) { + for (const auto& file_id : file_ids) { + code_interpreter->file_ids.push_back(file_id.asString()); + } + } + thread.tool_resources = std::move(code_interpreter); + } else if (tool_json.isMember("file_search")) { + auto file_search = std::make_unique(); + const auto& store_ids = tool_json["file_search"]["vector_store_ids"]; + if (store_ids.isArray()) { + for (const auto& store_id : store_ids) { + file_search->vector_store_ids.push_back(store_id.asString()); + } + } + thread.tool_resources = std::move(file_search); + } + } if (json["metadata"].isObject() && !json["metadata"].empty()) { auto res = Cortex::ConvertJsonValueToMap(json["metadata"]); @@ -78,7 +100,24 @@ struct Thread : JsonSerializable { json["id"] = id; json["object"] = object; json["created_at"] = created_at; - // TODO: namh handle tool_resources + + if (tool_resources) { + auto tool_result = tool_resources->ToJson(); + if (tool_result.has_error()) { + return cpp::fail("Failed to serialize tool_resources: " + + tool_result.error()); + } + + Json::Value tool_json; + if (auto code_interpreter = + dynamic_cast(tool_resources.get())) { + tool_json["code_interpreter"] = tool_result.value(); + } else if (auto file_search = + dynamic_cast(tool_resources.get())) { + tool_json["file_search"] = tool_result.value(); + } + json["tool_resources"] = tool_json; + } Json::Value metadata_json{Json::objectValue}; for (const auto& [key, value] : metadata) { diff --git a/engine/controllers/threads.cc b/engine/controllers/threads.cc index b8aec1fd5..a11c1071b 100644 --- a/engine/controllers/threads.cc +++ b/engine/controllers/threads.cc @@ -65,7 +65,7 @@ void Threads::CreateThread( } } - auto res = thread_service_->CreateThread(std::nullopt, metadata); + auto res = thread_service_->CreateThread(nullptr, metadata); if (res.has_error()) { Json::Value ret; @@ -170,7 +170,7 @@ void Threads::ModifyThread( // TODO: namh handle tools auto res = - thread_service_->ModifyThread(thread_id, std::nullopt, metadata.value()); + thread_service_->ModifyThread(thread_id, nullptr, metadata.value()); if (res.has_error()) { Json::Value ret; ret["message"] = res.error(); diff --git a/engine/repositories/message_fs_repository.cc b/engine/repositories/message_fs_repository.cc index d0db0e400..e576a7695 100644 --- a/engine/repositories/message_fs_repository.cc +++ b/engine/repositories/message_fs_repository.cc @@ -1,5 +1,6 @@ #include "message_fs_repository.h" #include +#include #include "utils/result.hpp" std::filesystem::path MessageFsRepository::GetMessagePath( diff --git a/engine/repositories/thread_fs_repository.cc b/engine/repositories/thread_fs_repository.cc index 824982e47..64dad6ea5 100644 --- a/engine/repositories/thread_fs_repository.cc +++ b/engine/repositories/thread_fs_repository.cc @@ -1,5 +1,6 @@ #include "thread_fs_repository.h" #include +#include cpp::result, std::string> ThreadFsRepository::ListThreads(uint8_t limit, const std::string& order, diff --git a/engine/repositories/thread_fs_repository.h b/engine/repositories/thread_fs_repository.h index 230440153..d834b8e44 100644 --- a/engine/repositories/thread_fs_repository.h +++ b/engine/repositories/thread_fs_repository.h @@ -34,7 +34,7 @@ class ThreadFsRepository : public ThreadRepository { cpp::result SaveThread(OpenAi::Thread& thread); public: - explicit ThreadFsRepository(std::filesystem::path data_folder_path) + explicit ThreadFsRepository(const std::filesystem::path& data_folder_path) : data_folder_path_{data_folder_path} { CTL_INF("Constructing ThreadFsRepository.."); auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; diff --git a/engine/services/thread_service.cc b/engine/services/thread_service.cc index 685edbb97..25784c2ee 100644 --- a/engine/services/thread_service.cc +++ b/engine/services/thread_service.cc @@ -3,7 +3,7 @@ #include "utils/ulid/ulid.hh" cpp::result ThreadService::CreateThread( - std::optional> tool_resources, + std::unique_ptr tool_resources, std::optional metadata) { LOG_TRACE << "CreateThread"; @@ -20,8 +20,8 @@ cpp::result ThreadService::CreateThread( thread.object = "thread"; thread.created_at = seconds_since_epoch; - if (tool_resources.has_value()) { - thread.tool_resources = std::move(tool_resources.value()); + if (tool_resources) { + thread.tool_resources = std::move(tool_resources); } thread.metadata = metadata.value_or(Cortex::VariantMap{}); @@ -42,13 +42,13 @@ ThreadService::ListThreads(uint8_t limit, const std::string& order, cpp::result ThreadService::RetrieveThread( const std::string& thread_id) const { - LOG_TRACE << "RetriveThread: " << thread_id; + CTL_INF("RetrieveThread: " + thread_id); return thread_repository_->RetrieveThread(thread_id); } cpp::result ThreadService::ModifyThread( const std::string& thread_id, - std::optional> tool_resources, + std::unique_ptr tool_resources, std::optional metadata) { LOG_TRACE << "ModifyThread " << thread_id; auto retrieve_res = RetrieveThread(thread_id); @@ -56,7 +56,9 @@ cpp::result ThreadService::ModifyThread( return cpp::fail("Failed to retrieve thread: " + retrieve_res.error()); } - retrieve_res->tool_resources = std::move(tool_resources.value()); + if (tool_resources) { + retrieve_res->tool_resources = std::move(tool_resources); + } retrieve_res->metadata = std::move(metadata.value()); auto res = thread_repository_->ModifyThread(retrieve_res.value()); diff --git a/engine/services/thread_service.h b/engine/services/thread_service.h index 608fba22b..966b0ab01 100644 --- a/engine/services/thread_service.h +++ b/engine/services/thread_service.h @@ -1,5 +1,6 @@ #pragma once +#include #include "common/repository/thread_repository.h" #include "common/thread_tool_resources.h" #include "common/variant_map.h" @@ -11,8 +12,7 @@ class ThreadService { : thread_repository_{thread_repository} {} cpp::result CreateThread( - std::optional> - tool_resources, + std::unique_ptr tool_resources, std::optional metadata); cpp::result, std::string> ListThreads( @@ -24,8 +24,7 @@ class ThreadService { cpp::result ModifyThread( const std::string& thread_id, - std::optional> - tool_resources, + std::unique_ptr tool_resources, std::optional metadata); cpp::result DeleteThread(