diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index 254b7d9680cd8..b674d58159550 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -15,9 +15,12 @@ #include "lldb/Protocol/MCP/Resource.h" #include "lldb/Protocol/MCP/Tool.h" #include "lldb/Protocol/MCP/Transport.h" +#include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" #include "llvm/Support/JSON.h" +#include "llvm/Support/Signals.h" #include #include #include @@ -25,15 +28,6 @@ namespace lldb_protocol::mcp { -/// Information about this instance of lldb's MCP server for lldb-mcp to use to -/// coordinate connecting an lldb-mcp client. -struct ServerInfo { - std::string connection_uri; - lldb::pid_t pid; -}; -llvm::json::Value toJSON(const ServerInfo &); -bool fromJSON(const llvm::json::Value &, ServerInfo &, llvm::json::Path); - class Server : public MCPTransport::MessageHandler { public: Server(std::string name, std::string version, @@ -95,6 +89,42 @@ class Server : public MCPTransport::MessageHandler { llvm::StringMap m_notification_handlers; }; +class ServerInfoHandle; + +/// Information about this instance of lldb's MCP server for lldb-mcp to use to +/// coordinate connecting an lldb-mcp client. +struct ServerInfo { + std::string connection_uri; + + /// Writes the server info into a unique file in `~/.lldb`. + static llvm::Expected Write(const ServerInfo &); + /// Loads any server info saved in `~/.lldb`. + static llvm::Expected> Load(); +}; +llvm::json::Value toJSON(const ServerInfo &); +bool fromJSON(const llvm::json::Value &, ServerInfo &, llvm::json::Path); + +/// A handle that tracks the server info on disk and cleans up the disk record +/// once it is no longer referenced. +class ServerInfoHandle { +public: + ServerInfoHandle(); + explicit ServerInfoHandle(llvm::StringRef filename); + ~ServerInfoHandle(); + + ServerInfoHandle(ServerInfoHandle &&other); + ServerInfoHandle &operator=(ServerInfoHandle &&other) noexcept; + + /// ServerIinfoHandle is not copyable. + /// @{ + ServerInfoHandle(const ServerInfoHandle &) = delete; + ServerInfoHandle &operator=(const ServerInfoHandle &) = delete; + /// @} + +private: + llvm::SmallString<128> m_filename; +}; + } // namespace lldb_protocol::mcp #endif diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index 2b004c19e88a6..dc18c8e06803a 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -113,34 +113,13 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { std::string address = llvm::join(m_listener->GetListeningConnectionURI(), ", "); - FileSpec user_lldb_dir = HostInfo::GetUserLLDBDir(); - - Status error(llvm::sys::fs::create_directory(user_lldb_dir.GetPath())); - if (error.Fail()) - return error.takeError(); - - m_mcp_registry_entry_path = user_lldb_dir.CopyByAppendingPathComponent( - formatv("lldb-mcp-{0}.json", getpid()).str()); - - ServerInfo info; - info.connection_uri = listening_uris[0]; - info.pid = getpid(); - - std::string buf = formatv("{0}", toJSON(info)).str(); - size_t num_bytes = buf.size(); - - const File::OpenOptions flags = File::eOpenOptionWriteOnly | - File::eOpenOptionCanCreate | - File::eOpenOptionTruncate; - llvm::Expected file = - FileSystem::Instance().Open(m_mcp_registry_entry_path, flags, - lldb::eFilePermissionsFileDefault, false); - if (!file) - return file.takeError(); - if (llvm::Error error = (*file)->Write(buf.data(), num_bytes).takeError()) - return error; + ServerInfo info{listening_uris[0]}; + llvm::Expected handle = ServerInfo::Write(info); + if (!handle) + return handle.takeError(); m_running = true; + m_server_info_handle = std::move(*handle); m_listen_handlers = std::move(*handles); m_loop_thread = std::thread([=] { llvm::set_thread_name("protocol-server.mcp"); @@ -158,10 +137,6 @@ llvm::Error ProtocolServerMCP::Stop() { m_running = false; } - if (!m_mcp_registry_entry_path.GetPath().empty()) - FileSystem::Instance().RemoveFile(m_mcp_registry_entry_path); - m_mcp_registry_entry_path.Clear(); - // Stop the main loop. m_loop.AddPendingCallback( [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); @@ -170,5 +145,9 @@ llvm::Error ProtocolServerMCP::Stop() { if (m_loop_thread.joinable()) m_loop_thread.join(); + m_listen_handlers.clear(); + m_server_info_handle = ServerInfoHandle(); + m_instances.clear(); + return llvm::Error::success(); } diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index 004fa3c2d05a8..0251664a2acc4 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -48,7 +48,7 @@ class ProtocolServerMCP : public ProtocolServer { bool m_running = false; - FileSpec m_mcp_registry_entry_path; + lldb_protocol::mcp::ServerInfoHandle m_server_info_handle; lldb_private::MainLoop m_loop; std::thread m_loop_thread; std::mutex m_mutex; diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index 0381b7f745e98..f3489c620832f 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -7,23 +7,108 @@ //===----------------------------------------------------------------------===// #include "lldb/Protocol/MCP/Server.h" +#include "lldb/Host/File.h" +#include "lldb/Host/FileSystem.h" +#include "lldb/Host/HostInfo.h" +#include "lldb/Host/JSONTransport.h" #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" +#include "llvm/Support/Signals.h" -using namespace lldb_protocol::mcp; using namespace llvm; +using namespace lldb_private; +using namespace lldb_protocol::mcp; + +ServerInfoHandle::ServerInfoHandle() : ServerInfoHandle("") {} + +ServerInfoHandle::ServerInfoHandle(StringRef filename) : m_filename(filename) { + if (!m_filename.empty()) + sys::RemoveFileOnSignal(m_filename); +} + +ServerInfoHandle::~ServerInfoHandle() { + if (m_filename.empty()) + return; + + sys::fs::remove(m_filename); + sys::DontRemoveFileOnSignal(m_filename); + m_filename.clear(); +} + +ServerInfoHandle::ServerInfoHandle(ServerInfoHandle &&other) + : m_filename(other.m_filename) { + *this = std::move(other); +} -llvm::json::Value lldb_protocol::mcp::toJSON(const ServerInfo &SM) { - return llvm::json::Object{{"connection_uri", SM.connection_uri}, - {"pid", SM.pid}}; +ServerInfoHandle & +ServerInfoHandle::operator=(ServerInfoHandle &&other) noexcept { + m_filename = other.m_filename; + other.m_filename.clear(); + return *this; } -bool lldb_protocol::mcp::fromJSON(const llvm::json::Value &V, ServerInfo &SM, - llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("connection_uri", SM.connection_uri) && - O.map("pid", SM.pid); +json::Value lldb_protocol::mcp::toJSON(const ServerInfo &SM) { + return json::Object{{"connection_uri", SM.connection_uri}}; +} + +bool lldb_protocol::mcp::fromJSON(const json::Value &V, ServerInfo &SM, + json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("connection_uri", SM.connection_uri); +} + +Expected ServerInfo::Write(const ServerInfo &info) { + std::string buf = formatv("{0}", toJSON(info)).str(); + size_t num_bytes = buf.size(); + + FileSpec user_lldb_dir = HostInfo::GetUserLLDBDir(); + + Status error(sys::fs::create_directory(user_lldb_dir.GetPath())); + if (error.Fail()) + return error.takeError(); + + FileSpec mcp_registry_entry_path = user_lldb_dir.CopyByAppendingPathComponent( + formatv("lldb-mcp-{0}.json", getpid()).str()); + + const File::OpenOptions flags = File::eOpenOptionWriteOnly | + File::eOpenOptionCanCreate | + File::eOpenOptionTruncate; + Expected file = + FileSystem::Instance().Open(mcp_registry_entry_path, flags); + if (!file) + return file.takeError(); + if (llvm::Error error = (*file)->Write(buf.data(), num_bytes).takeError()) + return error; + return ServerInfoHandle{mcp_registry_entry_path.GetPath()}; +} + +Expected> ServerInfo::Load() { + namespace path = llvm::sys::path; + FileSpec user_lldb_dir = HostInfo::GetUserLLDBDir(); + FileSystem &fs = FileSystem::Instance(); + std::error_code EC; + vfs::directory_iterator it = fs.DirBegin(user_lldb_dir, EC); + vfs::directory_iterator end; + std::vector infos; + for (; it != end && !EC; it.increment(EC)) { + auto &entry = *it; + auto path = entry.path(); + auto name = path::filename(path); + if (!name.starts_with("lldb-mcp-") || !name.ends_with(".json")) + continue; + + auto buffer = fs.CreateDataBuffer(path); + auto info = json::parse(toStringRef(buffer->GetData())); + if (!info) + return info.takeError(); + + infos.emplace_back(std::move(*info)); + } + + return infos; } Server::Server(std::string name, std::string version, diff --git a/lldb/tools/lldb-mcp/CMakeLists.txt b/lldb/tools/lldb-mcp/CMakeLists.txt index 7fe3301ab3081..5f61a1993cea3 100644 --- a/lldb/tools/lldb-mcp/CMakeLists.txt +++ b/lldb/tools/lldb-mcp/CMakeLists.txt @@ -6,6 +6,7 @@ add_lldb_tool(lldb-mcp Support LINK_LIBS liblldb + lldbInitialization lldbHost lldbProtocolMCP ) diff --git a/lldb/tools/lldb-mcp/lldb-mcp.cpp b/lldb/tools/lldb-mcp/lldb-mcp.cpp index 6c4ebbaa5f230..12545dcf3a3cc 100644 --- a/lldb/tools/lldb-mcp/lldb-mcp.cpp +++ b/lldb/tools/lldb-mcp/lldb-mcp.cpp @@ -10,17 +10,29 @@ #include "lldb/Host/File.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" -#include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Host/Socket.h" +#include "lldb/Initialization/SystemInitializerCommon.h" +#include "lldb/Initialization/SystemLifetimeManager.h" #include "lldb/Protocol/MCP/Server.h" +#include "lldb/Utility/Status.h" +#include "lldb/Utility/UriParser.h" +#include "lldb/lldb-forward.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" #include "llvm/Support/InitLLVM.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Signals.h" #include "llvm/Support/WithColor.h" +#include +#include #if defined(_WIN32) #include #endif +using namespace llvm; +using namespace lldb; using namespace lldb_protocol::mcp; using lldb_private::File; @@ -28,8 +40,78 @@ using lldb_private::MainLoop; using lldb_private::MainLoopBase; using lldb_private::NativeFile; -static constexpr llvm::StringLiteral kName = "lldb-mcp"; -static constexpr llvm::StringLiteral kVersion = "0.1.0"; +namespace { + +inline void exitWithError(llvm::Error Err, StringRef Prefix = "") { + handleAllErrors(std::move(Err), [&](ErrorInfoBase &Info) { + WithColor::error(errs(), Prefix) << Info.message() << '\n'; + }); + std::exit(EXIT_FAILURE); +} + +constexpr size_t kForwardIOBufferSize = 1024; + +void forwardIO(lldb_private::MainLoopBase &loop, lldb::IOObjectSP &from, + lldb::IOObjectSP &to) { + char buf[kForwardIOBufferSize]; + size_t num_bytes = sizeof(buf); + + if (llvm::Error err = from->Read(buf, num_bytes).takeError()) + exitWithError(std::move(err)); + + // EOF reached. + if (num_bytes == 0) + return loop.RequestTermination(); + + if (llvm::Error err = to->Write(buf, num_bytes).takeError()) + exitWithError(std::move(err)); +} + +void connectAndForwardIO(lldb_private::MainLoop &loop, ServerInfo &info, + IOObjectSP &input_sp, IOObjectSP &output_sp) { + auto uri = lldb_private::URI::Parse(info.connection_uri); + if (!uri) + exitWithError(createStringError("invalid connection_uri")); + + std::optional protocol_and_mode = + lldb_private::Socket::GetProtocolAndMode(uri->scheme); + + lldb_private::Status status; + std::unique_ptr sock = + lldb_private::Socket::Create(protocol_and_mode->first, status); + + if (status.Fail()) + exitWithError(status.takeError()); + + if (uri->port && !uri->hostname.empty()) + status = sock->Connect( + llvm::formatv("[{0}]:{1}", uri->hostname, *uri->port).str()); + else + status = sock->Connect(uri->path); + if (status.Fail()) + exitWithError(status.takeError()); + + IOObjectSP sock_sp = std::move(sock); + auto input_handle = loop.RegisterReadObject( + input_sp, std::bind(forwardIO, std::placeholders::_1, input_sp, sock_sp), + status); + if (status.Fail()) + exitWithError(status.takeError()); + + auto socket_handle = loop.RegisterReadObject( + sock_sp, std::bind(forwardIO, std::placeholders::_1, sock_sp, output_sp), + status); + if (status.Fail()) + exitWithError(status.takeError()); + + status = loop.Run(); + if (status.Fail()) + exitWithError(status.takeError()); +} + +llvm::ManagedStatic g_debugger_lifetime; + +} // namespace int main(int argc, char *argv[]) { llvm::InitLLVM IL(argc, argv, /*InstallPipeSignalExitHandler=*/false); @@ -53,33 +135,42 @@ int main(int argc, char *argv[]) { assert(result); #endif - lldb::IOObjectSP input = std::make_shared( + if (llvm::Error err = g_debugger_lifetime->Initialize( + std::make_unique(nullptr))) + exitWithError(std::move(err)); + + auto cleanup = make_scope_exit([] { g_debugger_lifetime->Terminate(); }); + + IOObjectSP input_sp = std::make_shared( fileno(stdin), File::eOpenOptionReadOnly, NativeFile::Unowned); - lldb::IOObjectSP output = std::make_shared( + IOObjectSP output_sp = std::make_shared( fileno(stdout), File::eOpenOptionWriteOnly, NativeFile::Unowned); - constexpr llvm::StringLiteral client_name = "stdio"; static MainLoop loop; - llvm::sys::SetInterruptFunction([]() { + sys::SetInterruptFunction([]() { loop.AddPendingCallback( [](MainLoopBase &loop) { loop.RequestTermination(); }); }); - auto transport_up = std::make_unique( - input, output, [&](llvm::StringRef message) { - llvm::errs() << formatv("{0}: {1}", client_name, message) << '\n'; - }); + auto existing_servers = ServerInfo::Load(); + + if (!existing_servers) + exitWithError(existing_servers.takeError()); + + // FIXME: Launch `lldb -o 'protocol start MCP'`. + if (existing_servers->empty()) + exitWithError(createStringError("No MCP servers running")); - auto instance_up = std::make_unique( - std::string(kName), std::string(kVersion), std::move(transport_up), loop); + // FIXME: Support selecting a specific server. + if (existing_servers->size() != 1) + exitWithError( + createStringError("To many MCP servers running, picking a specific " + "one is not yet implemented.")); - if (llvm::Error error = instance_up->Run()) { - llvm::logAllUnhandledErrors(std::move(error), llvm::WithColor::error(), - "MCP error: "); - return EXIT_FAILURE; - } + ServerInfo &info = existing_servers->front(); + connectAndForwardIO(loop, info, input_sp, output_sp); return EXIT_SUCCESS; }