diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index 382f9a4731dd4..2ac05880de86b 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -9,8 +9,6 @@ #ifndef LLDB_PROTOCOL_MCP_SERVER_H #define LLDB_PROTOCOL_MCP_SERVER_H -#include "lldb/Host/JSONTransport.h" -#include "lldb/Host/MainLoop.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Resource.h" #include "lldb/Protocol/MCP/Tool.h" @@ -20,52 +18,26 @@ namespace lldb_protocol::mcp { -class MCPTransport final - : public lldb_private::JSONRPCTransport { +class Server { public: - using LogCallback = std::function; - - MCPTransport(lldb::IOObjectSP in, lldb::IOObjectSP out, - std::string client_name, LogCallback log_callback = {}) - : JSONRPCTransport(in, out), m_client_name(std::move(client_name)), - m_log_callback(log_callback) {} - virtual ~MCPTransport() = default; - - void Log(llvm::StringRef message) override { - if (m_log_callback) - m_log_callback(llvm::formatv("{0}: {1}", m_client_name, message).str()); - } - -private: - std::string m_client_name; - LogCallback m_log_callback; -}; - -class Server : public MCPTransport::MessageHandler { -public: - Server(std::string name, std::string version, - std::unique_ptr transport_up, - lldb_private::MainLoop &loop); - ~Server() = default; - - using NotificationHandler = std::function; + Server(std::string name, std::string version); + virtual ~Server() = default; void AddTool(std::unique_ptr tool); void AddResourceProvider(std::unique_ptr resource_provider); - void AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler); - - llvm::Error Run(); protected: - Capabilities GetCapabilities(); + virtual Capabilities GetCapabilities() = 0; using RequestHandler = std::function(const Request &)>; + using NotificationHandler = std::function; void AddRequestHandlers(); void AddRequestHandler(llvm::StringRef method, RequestHandler handler); + void AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler); llvm::Expected> HandleData(llvm::StringRef data); @@ -80,23 +52,12 @@ class Server : public MCPTransport::MessageHandler { llvm::Expected ResourcesListHandler(const Request &); llvm::Expected ResourcesReadHandler(const Request &); - void Received(const Request &) override; - void Received(const Response &) override; - void Received(const Notification &) override; - void OnError(llvm::Error) override; - void OnClosed() override; - - void TerminateLoop(); - std::mutex m_mutex; private: const std::string m_name; const std::string m_version; - std::unique_ptr m_transport_up; - lldb_private::MainLoop &m_loop; - llvm::StringMap> m_tools; std::vector> m_resource_providers; diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index 57132534cf680..c359663239dcc 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -26,10 +26,24 @@ using namespace llvm; LLDB_PLUGIN_DEFINE(ProtocolServerMCP) +static constexpr size_t kChunkSize = 1024; static constexpr llvm::StringLiteral kName = "lldb-mcp"; static constexpr llvm::StringLiteral kVersion = "0.1.0"; -ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {} +ProtocolServerMCP::ProtocolServerMCP() + : ProtocolServer(), + lldb_protocol::mcp::Server(std::string(kName), std::string(kVersion)) { + AddNotificationHandler("notifications/initialized", + [](const lldb_protocol::mcp::Notification &) { + LLDB_LOG(GetLog(LLDBLog::Host), + "MCP initialization complete"); + }); + + AddTool( + std::make_unique("lldb_command", "Run an lldb command.")); + + AddResourceProvider(std::make_unique()); +} ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } @@ -50,37 +64,57 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { return "MCP Server."; } -void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { - server.AddNotificationHandler("notifications/initialized", - [](const lldb_protocol::mcp::Notification &) { - LLDB_LOG(GetLog(LLDBLog::Host), - "MCP initialization complete"); - }); - server.AddTool( - std::make_unique("lldb_command", "Run an lldb command.")); - server.AddResourceProvider(std::make_unique()); -} - void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { - Log *log = GetLog(LLDBLog::Host); - std::string client_name = llvm::formatv("client_{0}", m_instances.size() + 1); - LLDB_LOG(log, "New MCP client connected: {0}", client_name); + LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected", + m_clients.size() + 1); lldb::IOObjectSP io_sp = std::move(socket); - auto transport_up = std::make_unique( - io_sp, io_sp, std::move(client_name), [&](llvm::StringRef message) { - LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message); - }); - auto instance_up = std::make_unique( - std::string(kName), std::string(kVersion), std::move(transport_up), - m_loop); - Extend(*instance_up); - llvm::Error error = instance_up->Run(); - if (error) { - LLDB_LOG_ERROR(log, std::move(error), "Failed to run MCP server: {0}"); + auto client_up = std::make_unique(); + client_up->io_sp = io_sp; + Client *client = client_up.get(); + + Status status; + auto read_handle_up = m_loop.RegisterReadObject( + io_sp, + [this, client](MainLoopBase &loop) { + if (llvm::Error error = ReadCallback(*client)) { + LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}"); + client->read_handle_up.reset(); + } + }, + status); + if (status.Fail()) return; + + client_up->read_handle_up = std::move(read_handle_up); + m_clients.emplace_back(std::move(client_up)); +} + +llvm::Error ProtocolServerMCP::ReadCallback(Client &client) { + char chunk[kChunkSize]; + size_t bytes_read = sizeof(chunk); + if (Status status = client.io_sp->Read(chunk, bytes_read); status.Fail()) + return status.takeError(); + client.buffer.append(chunk, bytes_read); + + for (std::string::size_type pos; + (pos = client.buffer.find('\n')) != std::string::npos;) { + llvm::Expected> message = + HandleData(StringRef(client.buffer.data(), pos)); + client.buffer = client.buffer.erase(0, pos + 1); + if (!message) + return message.takeError(); + + if (*message) { + std::string Output; + llvm::raw_string_ostream OS(Output); + OS << llvm::formatv("{0}", toJSON(**message)) << '\n'; + size_t num_bytes = Output.size(); + return client.io_sp->Write(Output.data(), num_bytes).takeError(); + } } - m_instances.push_back(std::move(instance_up)); + + return llvm::Error::success(); } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { @@ -124,11 +158,27 @@ llvm::Error ProtocolServerMCP::Stop() { // Stop the main loop. m_loop.AddPendingCallback( - [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + [](MainLoopBase &loop) { loop.RequestTermination(); }); // Wait for the main loop to exit. if (m_loop_thread.joinable()) m_loop_thread.join(); + { + std::lock_guard guard(m_mutex); + m_listener.reset(); + m_listen_handlers.clear(); + m_clients.clear(); + } + return llvm::Error::success(); } + +lldb_protocol::mcp::Capabilities ProtocolServerMCP::GetCapabilities() { + lldb_protocol::mcp::Capabilities capabilities; + capabilities.tools.listChanged = true; + // FIXME: Support sending notifications when a debugger/target are + // added/removed. + capabilities.resources.listChanged = false; + return capabilities; +} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index fc650ffe0dfa7..7fe909a728b85 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -18,7 +18,8 @@ namespace lldb_private::mcp { -class ProtocolServerMCP : public ProtocolServer { +class ProtocolServerMCP : public ProtocolServer, + public lldb_protocol::mcp::Server { public: ProtocolServerMCP(); virtual ~ProtocolServerMCP() override; @@ -38,24 +39,26 @@ class ProtocolServerMCP : public ProtocolServer { Socket *GetSocket() const override { return m_listener.get(); } -protected: - // This adds tools and resource providers that - // are specific to this server. Overridable by the unit tests. - virtual void Extend(lldb_protocol::mcp::Server &server) const; - private: void AcceptCallback(std::unique_ptr socket); + lldb_protocol::mcp::Capabilities GetCapabilities() override; + bool m_running = false; - lldb_private::MainLoop m_loop; + MainLoop m_loop; std::thread m_loop_thread; - std::mutex m_mutex; std::unique_ptr m_listener; - std::vector m_listen_handlers; - std::vector> m_instances; + + struct Client { + lldb::IOObjectSP io_sp; + MainLoopBase::ReadHandleUP read_handle_up; + std::string buffer; + }; + llvm::Error ReadCallback(Client &client); + std::vector> m_clients; }; } // namespace lldb_private::mcp diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index 3713e8e46c5d6..a9c1482e3e378 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -12,11 +12,8 @@ using namespace lldb_protocol::mcp; using namespace llvm; -Server::Server(std::string name, std::string version, - std::unique_ptr transport_up, - lldb_private::MainLoop &loop) - : m_name(std::move(name)), m_version(std::move(version)), - m_transport_up(std::move(transport_up)), m_loop(loop) { +Server::Server(std::string name, std::string version) + : m_name(std::move(name)), m_version(std::move(version)) { AddRequestHandlers(); } @@ -235,71 +232,3 @@ llvm::Expected Server::ResourcesReadHandler(const Request &request) { llvm::formatv("no resource handler for uri: {0}", uri_str).str(), MCPError::kResourceNotFound); } - -Capabilities Server::GetCapabilities() { - lldb_protocol::mcp::Capabilities capabilities; - capabilities.tools.listChanged = true; - // FIXME: Support sending notifications when a debugger/target are - // added/removed. - capabilities.resources.listChanged = false; - return capabilities; -} - -llvm::Error Server::Run() { - auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this); - if (!handle) - return handle.takeError(); - - lldb_private::Status status = m_loop.Run(); - if (status.Fail()) - return status.takeError(); - - return llvm::Error::success(); -} - -void Server::Received(const Request &request) { - auto SendResponse = [this](const Response &response) { - if (llvm::Error error = m_transport_up->Send(response)) - m_transport_up->Log(llvm::toString(std::move(error))); - }; - - llvm::Expected response = Handle(request); - if (response) - return SendResponse(*response); - - lldb_protocol::mcp::Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.code = MCPError::kInternalError; - protocol_error.message = err.message(); - }); - Response error_response; - error_response.id = request.id; - error_response.result = std::move(protocol_error); - SendResponse(error_response); -} - -void Server::Received(const Response &response) { - m_transport_up->Log("unexpected MCP message: response"); -} - -void Server::Received(const Notification ¬ification) { - Handle(notification); -} - -void Server::OnError(llvm::Error error) { - m_transport_up->Log(llvm::toString(std::move(error))); - TerminateLoop(); -} - -void Server::OnClosed() { - m_transport_up->Log("EOF"); - TerminateLoop(); -} - -void Server::TerminateLoop() { - m_loop.AddPendingCallback( - [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); -} diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp index 83a42bfb6970c..18112428950ce 100644 --- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp @@ -39,20 +39,12 @@ using testing::_; namespace { class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { public: + using ProtocolServerMCP::AddNotificationHandler; + using ProtocolServerMCP::AddRequestHandler; + using ProtocolServerMCP::AddResourceProvider; + using ProtocolServerMCP::AddTool; using ProtocolServerMCP::GetSocket; using ProtocolServerMCP::ProtocolServerMCP; - - using ExtendCallback = - std::function; - - virtual void Extend(lldb_protocol::mcp::Server &server) const override { - if (m_extend_callback) - m_extend_callback(server); - }; - - void Extend(ExtendCallback callback) { m_extend_callback = callback; } - - ExtendCallback m_extend_callback; }; using Message = typename Transport::Message; @@ -191,10 +183,8 @@ class ProtocolServerMCPTest : public ::testing::Test { connection.protocol = Socket::SocketProtocol::ProtocolTcp; connection.name = llvm::formatv("{0}:0", k_localhost).str(); m_server_up = std::make_unique(); - m_server_up->Extend([&](auto &server) { - server.AddTool(std::make_unique("test", "test tool")); - server.AddResourceProvider(std::make_unique()); - }); + m_server_up->AddTool(std::make_unique("test", "test tool")); + m_server_up->AddResourceProvider(std::make_unique()); ASSERT_THAT_ERROR(m_server_up->Start(connection), llvm::Succeeded()); // Connect to the server over a TCP socket. @@ -243,10 +233,20 @@ TEST_F(ProtocolServerMCPTest, ToolsList) { test_tool.description = "test tool"; test_tool.inputSchema = json::Object{{"type", "object"}}; + ToolDefinition lldb_command_tool; + lldb_command_tool.description = "Run an lldb command."; + lldb_command_tool.name = "lldb_command"; + lldb_command_tool.inputSchema = json::Object{ + {"type", "object"}, + {"properties", + json::Object{{"arguments", json::Object{{"type", "string"}}}, + {"debugger_id", json::Object{{"type", "number"}}}}}, + {"required", json::Array{"debugger_id"}}}; Response response; response.id = "one"; response.result = json::Object{ - {"tools", json::Array{std::move(test_tool)}}, + {"tools", + json::Array{std::move(test_tool), std::move(lldb_command_tool)}}, }; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); @@ -281,9 +281,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCall) { } TEST_F(ProtocolServerMCPTest, ToolsCallError) { - m_server_up->Extend([&](auto &server) { - server.AddTool(std::make_unique("error", "error tool")); - }); + m_server_up->AddTool(std::make_unique("error", "error tool")); llvm::StringLiteral request = R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; @@ -298,9 +296,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) { } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { - m_server_up->Extend([&](auto &server) { - server.AddTool(std::make_unique("fail", "fail tool")); - }); + m_server_up->AddTool(std::make_unique("fail", "fail tool")); llvm::StringLiteral request = R"json({"method":"tools/call","params":{"name":"fail","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; @@ -319,16 +315,14 @@ TEST_F(ProtocolServerMCPTest, NotificationInitialized) { std::condition_variable cv; std::mutex mutex; - m_server_up->Extend([&](auto &server) { - server.AddNotificationHandler("notifications/initialized", - [&](const Notification ¬ification) { - { - std::lock_guard lock(mutex); - handler_called = true; - } - cv.notify_all(); - }); - }); + m_server_up->AddNotificationHandler( + "notifications/initialized", [&](const Notification ¬ification) { + { + std::lock_guard lock(mutex); + handler_called = true; + } + cv.notify_all(); + }); llvm::StringLiteral request = R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json";