diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 210f33edace6e..c73021d204258 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -100,22 +100,21 @@ template class Transport { virtual llvm::Expected RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) = 0; - // FIXME: Refactor mcp::Server to not directly access log on the transport. - // protected: +protected: template inline auto Logv(const char *Fmt, Ts &&...Vals) { Log(llvm::formatv(Fmt, std::forward(Vals)...).str()); } virtual void Log(llvm::StringRef message) = 0; }; -/// A JSONTransport will encode and decode messages using JSON. +/// An IOTransport sends and receives messages using an IOObject. template -class JSONTransport : public Transport { +class IOTransport : public Transport { public: using Transport::Transport; using MessageHandler = typename Transport::MessageHandler; - JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) + IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) : m_in(in), m_out(out) {} llvm::Error Send(const Evt &evt) override { return Write(evt); } @@ -127,7 +126,7 @@ class JSONTransport : public Transport { Status status; MainLoop::ReadHandleUP read_handle = loop.RegisterReadObject( m_in, - std::bind(&JSONTransport::OnRead, this, std::placeholders::_1, + std::bind(&IOTransport::OnRead, this, std::placeholders::_1, std::ref(handler)), status); if (status.Fail()) { @@ -203,9 +202,9 @@ class JSONTransport : public Transport { /// A transport class for JSON with a HTTP header. template -class HTTPDelimitedJSONTransport : public JSONTransport { +class HTTPDelimitedJSONTransport : public IOTransport { public: - using JSONTransport::JSONTransport; + using IOTransport::IOTransport; protected: /// Encodes messages based on @@ -270,9 +269,9 @@ class HTTPDelimitedJSONTransport : public JSONTransport { /// A transport class for JSON RPC. template -class JSONRPCTransport : public JSONTransport { +class JSONRPCTransport : public IOTransport { public: - using JSONTransport::JSONTransport; + using IOTransport::IOTransport; protected: std::string Encode(const llvm::json::Value &message) override { diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index 1f916ae525b5c..970980d075ea6 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -29,10 +29,11 @@ namespace lldb_protocol::mcp { class Server : public MCPTransport::MessageHandler { + using ClosedCallback = llvm::unique_function; + public: - Server(std::string name, std::string version, - std::unique_ptr transport_up, - lldb_private::MainLoop &loop); + Server(std::string name, std::string version, MCPTransport &client, + LogCallback log_callback = {}, ClosedCallback closed_callback = {}); ~Server() = default; using NotificationHandler = std::function; @@ -42,8 +43,6 @@ class Server : public MCPTransport::MessageHandler { void AddNotificationHandler(llvm::StringRef method, NotificationHandler handler); - llvm::Error Run(); - protected: ServerCapabilities GetCapabilities(); @@ -73,14 +72,16 @@ class Server : public MCPTransport::MessageHandler { void OnError(llvm::Error) override; void OnClosed() override; - void TerminateLoop(); +protected: + void Log(llvm::StringRef); private: const std::string m_name; const std::string m_version; - std::unique_ptr m_transport_up; - lldb_private::MainLoop &m_loop; + MCPTransport &m_client; + LogCallback m_log_callback; + ClosedCallback m_closed_callback; llvm::StringMap> m_tools; std::vector> m_resource_providers; diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index d3af3cf25c4a1..d7293fc28c524 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -66,7 +66,7 @@ void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { Log *log = GetLog(LLDBLog::Host); - std::string client_name = llvm::formatv("client_{0}", m_instances.size() + 1); + std::string client_name = llvm::formatv("client_{0}", ++m_client_count); LLDB_LOG(log, "New MCP client connected: {0}", client_name); lldb::IOObjectSP io_sp = std::move(socket); @@ -74,16 +74,26 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { io_sp, io_sp, [client_name](llvm::StringRef message) { LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message); }); + MCPTransport *transport_ptr = transport_up.get(); auto instance_up = std::make_unique( - std::string(kName), std::string(kVersion), std::move(transport_up), - m_loop); + std::string(kName), std::string(kVersion), *transport_up, + /*log_callback=*/ + [client_name](llvm::StringRef message) { + LLDB_LOG(GetLog(LLDBLog::Host), "{0} Server: {1}", client_name, + message); + }, + /*closed_callback=*/ + [this, transport_ptr]() { m_instances.erase(transport_ptr); }); Extend(*instance_up); - llvm::Error error = instance_up->Run(); - if (error) { - LLDB_LOG_ERROR(log, std::move(error), "Failed to run MCP server: {0}"); + llvm::Expected handle = + transport_up->RegisterMessageHandler(m_loop, *instance_up); + if (!handle) { + LLDB_LOG_ERROR(log, handle.takeError(), "Failed to run MCP server: {0}"); return; } - m_instances.push_back(std::move(instance_up)); + m_instances[transport_ptr] = + std::make_tuple( + std::move(instance_up), std::move(*handle), std::move(transport_up)); } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index 0251664a2acc4..b325a3681bccb 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -12,13 +12,21 @@ #include "lldb/Core/ProtocolServer.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/Socket.h" -#include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Server.h" +#include "lldb/Protocol/MCP/Transport.h" +#include +#include #include +#include +#include namespace lldb_private::mcp { class ProtocolServerMCP : public ProtocolServer { + using ReadHandleUP = MainLoopBase::ReadHandleUP; + using TransportUP = std::unique_ptr; + using ServerUP = std::unique_ptr; + public: ProtocolServerMCP(); virtual ~ProtocolServerMCP() override; @@ -52,11 +60,14 @@ class ProtocolServerMCP : public ProtocolServer { lldb_private::MainLoop m_loop; std::thread m_loop_thread; std::mutex m_mutex; + size_t m_client_count = 0; std::unique_ptr m_listener; - std::vector m_listen_handlers; - std::vector> m_instances; + std::vector m_listen_handlers; + std::map> + m_instances; }; } // namespace lldb_private::mcp diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index a08874e7321af..19030a3a4e5d6 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -10,7 +10,6 @@ #include "lldb/Host/File.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" -#include "lldb/Host/JSONTransport.h" #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" #include "llvm/ADT/SmallString.h" @@ -109,11 +108,11 @@ Expected> ServerInfo::Load() { return infos; } -Server::Server(std::string name, std::string version, - std::unique_ptr transport_up, - lldb_private::MainLoop &loop) - : m_name(std::move(name)), m_version(std::move(version)), - m_transport_up(std::move(transport_up)), m_loop(loop) { +Server::Server(std::string name, std::string version, MCPTransport &client, + LogCallback log_callback, ClosedCallback closed_callback) + : m_name(std::move(name)), m_version(std::move(version)), m_client(client), + m_log_callback(std::move(log_callback)), + m_closed_callback(std::move(closed_callback)) { AddRequestHandlers(); } @@ -287,22 +286,15 @@ ServerCapabilities Server::GetCapabilities() { return capabilities; } -llvm::Error Server::Run() { - auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this); - if (!handle) - return handle.takeError(); - - lldb_private::Status status = m_loop.Run(); - if (status.Fail()) - return status.takeError(); - - return llvm::Error::success(); +void Server::Log(llvm::StringRef message) { + if (m_log_callback) + m_log_callback(message); } void Server::Received(const Request &request) { auto SendResponse = [this](const Response &response) { - if (llvm::Error error = m_transport_up->Send(response)) - m_transport_up->Log(llvm::toString(std::move(error))); + if (llvm::Error error = m_client.Send(response)) + Log(llvm::toString(std::move(error))); }; llvm::Expected response = Handle(request); @@ -324,7 +316,7 @@ void Server::Received(const Request &request) { } void Server::Received(const Response &response) { - m_transport_up->Log("unexpected MCP message: response"); + Log("unexpected MCP message: response"); } void Server::Received(const Notification ¬ification) { @@ -332,16 +324,11 @@ void Server::Received(const Notification ¬ification) { } void Server::OnError(llvm::Error error) { - m_transport_up->Log(llvm::toString(std::move(error))); - TerminateLoop(); + Log(llvm::toString(std::move(error))); } void Server::OnClosed() { - m_transport_up->Log("EOF"); - TerminateLoop(); -} - -void Server::TerminateLoop() { - m_loop.AddPendingCallback( - [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + Log("EOF"); + if (m_closed_callback) + m_closed_callback(); } diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 445674f402252..3a36bf21f07ff 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -413,7 +413,7 @@ TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) { // Use a string longer than the chunk size to ensure we split the message // across the chunk boundary. std::string long_str = - std::string(JSONTransport::kReadBufferSize * 2, 'x'); + std::string(IOTransport::kReadBufferSize * 2, 'x'); Write(Req{long_str}); EXPECT_CALL(message_handler, Received(Req{long_str})); ASSERT_THAT_ERROR(Run(), Succeeded()); diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index f686255c6d41d..f3ca4cfc01788 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -122,53 +122,73 @@ class ProtocolServerMCPTest : public PipePairTest { public: SubsystemRAII subsystems; - std::unique_ptr transport_up; - std::unique_ptr server_up; MainLoop loop; + + std::unique_ptr from_client; + std::unique_ptr to_client; + MainLoopBase::ReadHandleUP handles[2]; + + std::unique_ptr server_up; MockMessageHandler message_handler; llvm::Error Write(llvm::StringRef message) { llvm::Expected value = json::parse(message); if (!value) return value.takeError(); - return transport_up->Write(*value); + return from_client->Write(*value); } - llvm::Error Write(json::Value value) { return transport_up->Write(value); } + llvm::Error Write(json::Value value) { return from_client->Write(value); } /// Run the transport MainLoop and return any messages received. - llvm::Error - Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(200)) { + llvm::Error Run() { loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, - timeout); - auto handle = transport_up->RegisterMessageHandler(loop, message_handler); - if (!handle) - return handle.takeError(); - - return server_up->Run(); + std::chrono::milliseconds(10)); + return loop.Run().takeError(); } void SetUp() override { PipePairTest::SetUp(); - transport_up = std::make_unique( + from_client = std::make_unique( std::make_shared(input.GetReadFileDescriptor(), File::eOpenOptionReadOnly, NativeFile::Unowned), std::make_shared(output.GetWriteFileDescriptor(), File::eOpenOptionWriteOnly, - NativeFile::Unowned)); - - server_up = std::make_unique( - "lldb-mcp", "0.1.0", - std::make_unique( - std::make_shared(output.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared(input.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned)), - loop); + NativeFile::Unowned), + [](StringRef message) { + // Uncomment for debugging + // llvm::errs() << "from_client: " << message << '\n'; + }); + to_client = std::make_unique( + std::make_shared(output.GetReadFileDescriptor(), + File::eOpenOptionReadOnly, + NativeFile::Unowned), + std::make_shared(input.GetWriteFileDescriptor(), + File::eOpenOptionWriteOnly, + NativeFile::Unowned), + [](StringRef message) { + // Uncomment for debugging + // llvm::errs() << "to_client: " << message << '\n'; + }); + + server_up = std::make_unique("lldb-mcp", "0.1.0", *to_client, + [](StringRef message) { + // Uncomment for debugging + // llvm::errs() << "server: " << + // message << '\n'; + }); + + auto maybe_from_client_handle = + from_client->RegisterMessageHandler(loop, message_handler); + EXPECT_THAT_EXPECTED(maybe_from_client_handle, Succeeded()); + handles[0] = std::move(*maybe_from_client_handle); + + auto maybe_to_client_handle = + to_client->RegisterMessageHandler(loop, *server_up); + EXPECT_THAT_EXPECTED(maybe_to_client_handle, Succeeded()); + handles[1] = std::move(*maybe_to_client_handle); } };