diff --git a/lldb/examples/mcp/README.md b/lldb/examples/mcp/README.md new file mode 100644 index 0000000000000..2b8977c95951e --- /dev/null +++ b/lldb/examples/mcp/README.md @@ -0,0 +1,46 @@ +# lldb-mcp backport + +A backport of the lldb-mcp protocol for older releases of lldb. + +To load the backport use: + +``` +(lldb) command script import --allow-reload server.py +(lldb) start_mcp +``` + +Then you can use the `./lldb-mcp` script in this directory to launch a client +for the running server. + +For example, + +```json +{ + "mcpServers": { + "lldb": { + "command": "/lldb-mcp", + "args": ["--log-file=/tmp/lldb-mcp.log", "--timeout=30.0"] + } + } +} +``` + +## Development + +For getting started with making changes to this backport, use the +[MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) to run the +binary. + +In one terminal, start the lldb server: + +``` +$ lldb +(lldb) command script import --allow-reload server.py +(lldb) start_mcp --log-file=/tmp/lldb-mcp-server.log +``` + +Then launch the inspector to run specific operations. + +```sh +$ npx @modelcontextprotocol/inspector ./lldb-mcp --log-file=/tmp/lldb-mcp.log +``` diff --git a/lldb/examples/mcp/lldb-mcp b/lldb/examples/mcp/lldb-mcp new file mode 100755 index 0000000000000..d91b48be6092a --- /dev/null +++ b/lldb/examples/mcp/lldb-mcp @@ -0,0 +1,12 @@ +#!/bin/sh + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + +PYTHONPATH="$(lldb -P)" +export PYTHONPATH + +if [ "$(uname)" == "Darwin" ]; then + exec xcrun python3 $SCRIPT_DIR/lldb-mcp.py $@ +else + exec python3 $SCRIPT_DIR/lldb-mcp.py $@ +fi diff --git a/lldb/examples/mcp/lldb-mcp.py b/lldb/examples/mcp/lldb-mcp.py new file mode 100644 index 0000000000000..808155b179858 --- /dev/null +++ b/lldb/examples/mcp/lldb-mcp.py @@ -0,0 +1,162 @@ +import atexit +import logging +import argparse +import pathlib +import asyncio +import os +import sys +import signal +import transport +import protocol +from typing import Optional + +logger = logging.getLogger("lldb-mcp") + + +class MCPClient(transport.MessageHandler): + initialize = protocol.initialize.invoker + initialized = protocol.initialized.invoker + toolsList = protocol.toolsList.invoker + toolsCall = protocol.toolsCall.invoker + + +def parse(uri: str) -> tuple[str, int]: + assert uri.startswith("connection://") + uri = uri.removeprefix("connection://") + host, port = uri.rsplit(":", maxsplit=1) + if host != "[::1]": + host = host.removeprefix("[").removesuffix("]") + return (host, int(port)) + + +async def test_client(uri: str): + host, port = parse(uri) + print("connecting to", host, port) + reader, writer = await asyncio.open_connection(host, int(port)) + with transport.Transport(reader, writer) as conn: + async with MCPClient(conn) as client: + _ = await client.initialize() + client.initialized() + + tools_list_result = await client.toolsList() + for tool in tools_list_result["tools"]: + print("tool", tool) + + await client.toolsCall( + name="command", + arguments={ + "command": "bt", + "debugger": "lldb://debugger/1", + }, + ) + await client.toolsCall( + name="debugger_list", + arguments=None, + ) + + +async def launchLLDB(log_file: Optional[str] = None): + dir = os.path.dirname(os.path.abspath(__file__)) + server_script = os.path.join(dir, "server.py") + args = [ + "lldb", + "-O", + f"command script import --allow-reload {server_script}", + "-O", + "start_mcp" + " --log-file=" + str(log_file) if log_file else "", + ] + process = await asyncio.subprocess.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + + def shutdown(): + try: + if process.returncode is None: + process.send_signal(signal.SIGHUP) + os.waitpid(process.pid, 0) + except: + pass + + atexit.register(shutdown) + + +async def main() -> None: + parser = argparse.ArgumentParser("lldb-mcp") + parser.add_argument("-l", "--log-file", type=pathlib.Path) + parser.add_argument("-t", "--timeout", type=float, default=30.0) + parser.add_argument("--test", action="store_true") + opts = parser.parse_args() + if opts.log_file or opts.test: + logging.basicConfig( + filename=opts.log_file, + format="%(created)f:%(process)d:%(levelname)s:%(name)s:%(message)s", + level=logging.DEBUG, + ) + logger.info("Loading lldb-mcp server configurations...") + loop = asyncio.get_running_loop() + + launched = False + deadline: float = loop.time() + opts.timeout + servers: list[protocol.ServerInfo] = [] + while not servers and loop.time() < deadline: + logger.info("loading host server details") + servers = protocol.load() + + if not servers and not launched: + launched = True + logger.info("Starting lldb with server loaded...") + await launchLLDB(log_file=opts.log_file) + continue + + if not servers: + logger.info("Waiting for server to start...") + await asyncio.sleep(1.0) + continue + + if len(servers) != 1: + logger.error("to many lldb-mcp servers detected, exiting...") + sys.exit( + "Multiple servers detected, selecting a single server is not yet supported." + ) + + break + + assert servers + + if opts.test: + for server in servers: + await test_client(server["connection_uri"]) + return + + logger.info("Forwarding stdio to first server %r", servers[0]) + try: + server_info = servers[0] + host, port = parse(server_info["connection_uri"]) + cr, cw = await asyncio.open_connection(host, port) + loop = asyncio.get_event_loop() + + def forward(): + buf = sys.stdin.buffer.read(4096) + if not buf: # eof detected + cr.feed_eof() + loop.remove_reader(sys.stdin) + return + logger.info("--> %s", buf.decode().strip()) + cw.write(buf) + + os.set_blocking(sys.stdin.fileno(), False) + loop.add_reader(sys.stdin, forward) + async for f in cr: + logger.info("<-- %s", f.decode().strip()) + sys.stdout.buffer.write(f) + sys.stdout.buffer.flush() + except: + logger.exception("forwarding client failed") + finally: + logger.info("lldb-mcp client shut down") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/lldb/examples/mcp/protocol.py b/lldb/examples/mcp/protocol.py new file mode 100644 index 0000000000000..067b03af7221b --- /dev/null +++ b/lldb/examples/mcp/protocol.py @@ -0,0 +1,218 @@ +import os +import io +import sys +import json +import logging +import ctypes +import ctypes.util +from typing import TypedDict, Any, Literal, Optional +from transport import RequestDescriptor, EventDescriptor + +logger = logging.getLogger(__name__) + +PROC_PIDPATHINFO_MAXSIZE = 4 * 1024 + + +def _is_valid_lldb_process(pid: int) -> bool: + logger.info("checking if process %d is alive and is an lldb process", pid) + try: + # raises ProcessLookupError if pid does not exist. + os.kill(pid, 0) + if sys.platform == "darwin": + libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True) + assert libc + proc_pidpath = libc.proc_pidpath + proc_pidpath.restype = ctypes.c_int + proc_pidpath.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_uint32] + buf = ctypes.create_string_buffer(PROC_PIDPATHINFO_MAXSIZE) + if proc_pidpath(pid, buf, PROC_PIDPATHINFO_MAXSIZE) <= 0: + raise OSError(ctypes.get_errno()) + path = bytes(buf.value).decode() + logger.info("path=%r", path) + if "lldb" not in os.path.basename(path): + logger.info("pid %d is invalid", pid) + return False + logger.info("pid %d is valid", pid) + return True + except ProcessLookupError: + logger.info("pid %d is not alive", pid) + return False + except: + logger.exception("failed to validate pid %d", pid) + return False + + +class ServerInfo(TypedDict): + connection_uri: str + + +def load() -> list[ServerInfo]: + dir = os.path.expanduser("~/.lldb") + contents = os.listdir(dir) + server_infos = [] + for file in contents: + if not file.startswith("lldb-mcp-") or not file.endswith(".json"): + continue + + filename = os.path.join(dir, file) + pid = int(file.removeprefix("lldb-mcp-").removesuffix(".json")) + if not _is_valid_lldb_process(pid): + # Process is dead, clean up the stale file. + os.remove(filename) + continue + + with open(filename) as f: + server_infos.append(json.load(f)) + return server_infos + + +def cleanup(): + server_info_config = os.path.expanduser(f"~/.lldb/lldb-mcp-{os.getpid()}.json") + if os.path.exists(server_info_config): + os.remove(server_info_config) + + +def save(uri: str): + server_info: ServerInfo = {"connection_uri": uri} + with open(os.path.expanduser(f"~/.lldb/lldb-mcp-{os.getpid()}.json"), "w+") as f: + json.dump(server_info, f) + + +class URI: + scheme: str + host: Optional[str] + port: Optional[int] + path: str + + def __init__( + self, + *, + scheme="", + host: Optional[str] = None, + port: Optional[int] = None, + path="", + ): + self.scheme = scheme + self.host = host + self.port = port + self.path = path + + @classmethod + def parse(cls, input: str) -> "URI": + assert ":" in input + uri = URI() + uri.scheme, rest = input.split(":", maxsplit=1) + assert uri.scheme.isascii() + if rest.startswith("//"): + rest = rest.removeprefix("//") + if "/" in rest: + uri.host, rest = rest.split("/", maxsplit=1) + else: + uri.host = rest + rest = "" + uri.path = rest + if uri.host is not None and ":" in uri.host: + uri.host, raw_port = uri.host.rsplit(":", maxsplit=1) + assert raw_port.isdigit() + uri.port = int(raw_port) + return uri + + def append(self, path: str) -> "URI": + return URI( + scheme=self.scheme, + host=self.host, + port=self.port, + path=os.path.join(self.path, path), + ) + + def __str__(self): + os = io.StringIO() + os.write(self.scheme) + os.write(":") + if self.host or self.port: + os.write("//") + if self.host: + os.write(self.host) + if self.port: + os.write(":") + os.write(self.port) + if self.path and self.path != "/": + os.write(self.path) + return os.getvalue() + + +class ImplementationVersion(TypedDict): + name: str + version: str + + +class Tool(TypedDict): + name: str + title: str + description: str + inputSchema: dict + + +class Resource(TypedDict): + uri: str + name: str + + +class ListToolsResult(TypedDict): + tools: list[Tool] + + +class CallToolParams(TypedDict): + name: str + arguments: Any + + +class TextContent(TypedDict): + type: Literal["text"] + text: str + + +class CallToolResult(TypedDict): + content: list[TextContent] + isError: bool + + +class ComponentCapabilities(TypedDict, total=False): + listChanged: bool + subscribe: bool + + +class ServerCapabilities(TypedDict): + tools: ComponentCapabilities + + +class InitializeParams(TypedDict): + capabilities: dict + clientInfo: ImplementationVersion + protocolVersion: str + + +class InitializeResult(TypedDict): + capabilities: ServerCapabilities + protocolVersion: str + serverInfo: ImplementationVersion + + +initialize = RequestDescriptor[InitializeParams, InitializeResult]( + "initialize", + defaults={ + "protocolVersion": "2024-11-05", + "clientInfo": { + "name": "lldb-mcp", + "version": "0.0.1", + }, + "capabilities": { + "roots": {"listChanged": True}, + "sampling": {}, + "elicitation": {}, + }, + }, +) +initialized = EventDescriptor[None](name="initialized") +toolsList = RequestDescriptor[None, ListToolsResult](name="tools/list") +toolsCall = RequestDescriptor[CallToolParams, CallToolResult](name="tools/call") diff --git a/lldb/examples/mcp/server.py b/lldb/examples/mcp/server.py new file mode 100644 index 0000000000000..af05800a87e8f --- /dev/null +++ b/lldb/examples/mcp/server.py @@ -0,0 +1,277 @@ +""" +An implementation of the lldb-mcp server. +""" + +from typing import Any, Optional +import argparse +import asyncio +import lldb +import logging +import protocol +import queue +import pathlib +import traceback +import shlex +import threading +import transport + +logger = logging.getLogger(__name__) + + +SCHEME = "lldb-mcp" +DEBUGGER_HOST = "debugger" +BASE_DEBUGGER_URI = protocol.URI(scheme=SCHEME, host=DEBUGGER_HOST, path="/") + + +class Tool: + name: str + title: str + description: str + inputSchema: dict + + def to_protocol(self) -> protocol.Tool: + return { + "name": self.name, + "title": self.title, + "description": self.description, + "inputSchema": self.inputSchema, + } + + async def call(self, **kwargs) -> protocol.CallToolResult: + assert False, "Implement in a subclass." + + +class CommandTool(Tool): + name = "command" + title = "LLDB Command" + description = "Evaluates an lldb command." + inputSchema = { + "type": "object", + "properties": { + "debugger": { + "type": "string", + "description": "The debugger ID or URI to a specific debug session. If not specified, the first debugger will be used.", + }, + "command": { + "type": "string", + "description": "An lldb command to run.", + }, + }, + } + + async def call( + self, *, command: Optional[str] = None, debugger: Optional[str] = None, **kwargs + ) -> protocol.CallToolResult: + if debugger: + if debugger.isdigit(): + id = int(debugger) + else: + logger.info("Parsing %s", debugger) + uri = protocol.URI.parse(debugger) + logger.info("Parsed URI: %s", uri) + assert uri.scheme == SCHEME + assert uri.host == DEBUGGER_HOST + raw_id = uri.path.removeprefix("/") + assert raw_id.isdigit() + id = int(raw_id) + dbg_inst = lldb.SBDebugger.FindDebuggerWithID(id) + else: + for i in range(100): + dbg_inst = lldb.SBDebugger.FindDebuggerWithID(i) + if dbg_inst.IsValid(): + break + assert dbg_inst.IsValid() + result = lldb.SBCommandReturnObject() + dbg_inst.GetCommandInterpreter().HandleCommand(command, result) + contents: list[protocol.TextContent] = [] + if result.GetOutputSize(): + contents.append({"type": "text", "text": str(result.GetOutput())}) + if result.GetErrorSize(): + contents.append({"type": "text", "text": str(result.GetError())}) + return { + "content": contents, + "isError": not result.Succeeded(), + } + + +class DebuggerList(Tool): + name = "debugger_list" + title = "List Debuggers" + description = "List debuggers associated with this server." + inputSchema = {"type": "object"} + + async def call(self, **_kwargs) -> protocol.CallToolResult: + out = "" + + for i in range(100): + debugger = lldb.SBDebugger.FindDebuggerWithID(i) + if debugger.IsValid(): + uri = BASE_DEBUGGER_URI.append(str(i)) + out += f"- {uri}\n" + + return { + "content": [ + {"type": "text", "text": out}, + ], + "isError": False, + } + + +class MCPServer(transport.MessageHandler): + tools: dict[str, Tool] + + def __init__( + self, transport: transport.Transport, tools=[CommandTool(), DebuggerList()] + ): + super().__init__(transport) + self.tools = {tool.name: tool for tool in tools} + + @protocol.initialize.handler() + async def initialize( + self, **params: protocol.InitializeParams + ) -> protocol.InitializeResult: + return protocol.InitializeResult( + capabilities={"tools": {"listChanged": True}}, + protocolVersion="2024-11-05", + serverInfo={"name": "lldb-mcp", "version": "0.0.1"}, + ) + + @protocol.initialized.handler() + def initialized(self): + print("Client initialized...") + + @protocol.toolsList.handler() + async def listTools(self) -> protocol.ListToolsResult: + return {"tools": [tool.to_protocol() for tool in self.tools.values()]} + + @protocol.toolsCall.handler() + async def callTool( + self, name: str, arguments: Optional[Any] = None + ) -> protocol.CallToolResult: + tool = self.tools[name] + if arguments is None: + arguments = {} + return await tool.call(**arguments) + + +server: Optional[asyncio.AbstractServer] = None + + +def get_parser(): + parser = argparse.ArgumentParser("lldb-mcp") + parser.add_argument("-l", "--log-file", type=pathlib.Path) + parser.add_argument("-t", "--timeout", type=float, default=30.0) + parser.add_argument("connection", nargs="?", default="listen://[127.0.0.1]:0") + return parser + + +async def run(opts: argparse.Namespace, notify: Optional[queue.Queue] = None): + global server + conn: str = opts.connection + assert conn.startswith("listen://"), "Invalid connection specifier" + hostname, port = conn.removeprefix("listen://").split(":") + hostname = hostname.removeprefix("[").removesuffix("]") + + logging.basicConfig(filename=opts.log_file, level=logging.DEBUG, force=True) + + server = await asyncio.start_server(MCPServer.acceptClient, hostname, int(port)) + addrs = ", ".join(str(sock.getsockname()) for sock in server.sockets) + if notify: + notify.put(addrs) + else: + print(f"Serving on {addrs}") + + sock_name = server.sockets[0].getsockname() + (h, p) = sock_name[0], sock_name[1] + protocol.save(f"connection://[{h}]:{p}") + + async with server: + await server.serve_forever() + + +# A registration count, if this module is loaded for multiple lldb.SBDebugger +# instances then we should only stop the global server if all registrations have +# been removed. This could happen with lldb-rpc-server or lldb-dap in server +# mode. +registration_count = 0 + + +def stop(): + """Stop the server, if one exists.""" + global server + protocol.cleanup() + if not server: + return + server.close() # Stop accepting new connections + server = None + + +class CommandStart: + # The CommandStart is being used to track when the interpreter exits. lldb + # does not call `Py_Finalize()`, so `atexit` calls are never invoked. In + # order to ensure we shutdown the server and clean up the server info + # records we use the `__del__` method to trigger the clean up as a best + # effort attempt at a clean shutdown. + def __init__(self, debugger, internal_dict): + global registration_count + registration_count += 1 + + def __del__(self): + global registration_count + registration_count -= 1 + if registration_count == 0: + stop() + + def __call__(self, debugger, command, exe_ctx, result): + """Start an MCP server in a background thread.""" + global server + + if server is not None: + print("Server already running.", file=result) + return + + command_args = shlex.split(command) + opts = get_parser().parse_args(command_args) + + print("Starting LLDB MCP Server...", file=result) + + notify = queue.Queue() + + def start_server(): + asyncio.run(run(opts, notify)) + + thr = threading.Thread(target=start_server) + thr.start() + + addrs = notify.get() + print(f"Serving on {addrs}", file=result) + result.SetStatus(lldb.eReturnStatusSuccessFinishNoResult) + + +def lldb_stop(debugger, command, exe_ctx, result, internal_dict): + """Stop an MCP server.""" + global server + try: + if server is None: + print("Server is stopped.", file=result) + result.SetStatus(lldb.eReturnStatusSuccessFinishNoResult) + return + + print("Server stopping...", file=result) + stop() + print("Server stopped.", file=result) + + result.SetStatus(lldb.eReturnStatusSuccessFinishNoResult) + except: + logging.exception("failed to stop MCP server") + traceback.print_exc(file=result) + result.SetStatus(lldb.eReturnStatusFailed) + + +def __lldb_init_module( + debugger: lldb.SBDebugger, + internal_dict: dict[Any, Any], +) -> None: + debugger.HandleCommand("command script add -o -c server.CommandStart start_mcp") + debugger.HandleCommand("command script add -o -f server.lldb_stop stop_mcp") + print("Registered command 'start_mcp' and 'stop_mcp'.") diff --git a/lldb/examples/mcp/transport.py b/lldb/examples/mcp/transport.py new file mode 100644 index 0000000000000..1fdb3f4d494b2 --- /dev/null +++ b/lldb/examples/mcp/transport.py @@ -0,0 +1,337 @@ +import asyncio +import dataclasses +import enum +import functools +import json +import logging +import pprint +import sys +import traceback +from typing import ( + Any, + Awaitable, + Generic, + TypeVar, + Callable, + Union, + Optional, +) + +logger = logging.getLogger(__name__) + + +@enum.unique +class MessageType(enum.Enum): + REQ = enum.auto() + RESP = enum.auto() + NOTE = enum.auto() + + +@dataclasses.dataclass(frozen=True, repr=False) +class Message: + """Wrapper around the JSON payload of a MCP message.""" + + payload: dict[str, Any] + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> "Message": + # Ensure the jsonrpc field is always set. + payload["jsonrpc"] = "2.0" + return cls(payload=payload) + + @classmethod + def decode(cls, bytes: Union[str, bytes, bytearray]) -> "Message": + # Ensure the jsonrpc field is always set. + return cls(payload=json.loads(bytes)) + + def __str__(self): + return json.dumps(self.payload, indent=None, separators=(",", ":")) + + def __repr__(self): + return "{}: {}".format( + self.message_type.name.title(), + pprint.pformat(self.payload, sort_dicts=False), + ) + + @functools.cached_property + def message_type(self) -> MessageType: + if "id" in self.payload and "method" in self.payload: + return MessageType.REQ + elif "id" in self.payload: + return MessageType.RESP + elif "method" in self.payload: + return MessageType.NOTE + assert False, f"Unknown message type: {self.payload}" + + def encode(self) -> bytes: + msg = json.dumps(self.payload, indent=None, separators=(",", ":")) + return f"{msg}\n".encode() + + def matches(self, other: Optional["Message"]) -> bool: + """Returns true iff other is a subset of this message.""" + if not other: + return True + + # The other payload must be a subset of this payload, meaning if we were to + # add its payload to ours, the payload is the same. + return self.payload | other.payload == self.payload + + # Various typed wrappers around self.payload['field'] + + @property + def method(self) -> str: + return self.payload["method"] + + @property + def id(self) -> int: + return self.payload["id"] + + @property + def params(self) -> dict[str, Any]: + return self.payload.get("params", {}) + + @property + def result(self) -> dict[str, Any]: + return self.payload.get("result", {}) + + @property + def error(self) -> dict[str, Any]: + return self.payload.get("error", {}) + + @property + def success(self) -> bool: + return not hasattr(self.payload, "error") + + +class Transport: + r: asyncio.StreamReader + w: asyncio.StreamWriter + + def __init__(self, r: asyncio.StreamReader, w: asyncio.StreamWriter): + self.r = r + self.w = w + + def write(self, message: Message): + logger.info("--> %s", message) + self.w.write(message.encode()) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.w.close() + + def __aiter__(self): + return self + + async def __anext__(self): + line = await self.r.readline() + if line == b"": + raise StopAsyncIteration + return Message.decode(line) + + +class Invoker: + name: str + message_type: MessageType + defaults: dict + transport: Optional[Transport] = None + handler: Optional["MessageHandler"] = None + + def __init__(self, name, message_type: MessageType, defaults: dict = {}): + self.name = name + self.message_type = message_type + self.defaults = defaults + + def __call__(self, **kwargs): + assert self.transport and self.handler + if self.message_type == MessageType.REQ: + return self.handler.request(self.name, params=self.defaults | kwargs) + elif self.message_type == MessageType.NOTE: + return self.handler.event(self.name, params=self.defaults | kwargs) + + +class Handler: + name: str + message_type: MessageType + + def __init__(self, name: str, message_type: MessageType): + self.name = name + self.message_type = message_type + + def __call__(self): + def wrap(fn): + return RequestWrapper(self.name, fn) + + return wrap + + +Params = TypeVar("Params", bound=dict) +Result = TypeVar("Result") + + +class EventDescriptor(Generic[Params]): + invoker: Callable[Params, None] + handler: Callable[Params, None] + + def __init__(self, name: str): + self.name = name + + self.invoker = Invoker(name=name, message_type=MessageType.NOTE) + self.handler = Handler(name, MessageType.NOTE) + + +class RequestDescriptor(Generic[Params, Result]): + invoker: Callable[Params, Awaitable[Result]] + handler: Callable[Params, Awaitable[Result]] + + def __init__(self, name: str, defaults: Params = {}): + self.name = name + + self.invoker = Invoker(name, MessageType.REQ, defaults) + self.handler = Handler(name, MessageType.REQ) + + +class RequestWrapper: + name: str + fn: Callable + handler: "MessageHandler" + + def __init__(self, name, fn): + self.name = name + self.fn = fn + + def __call__(self, *args, **kwargs): + assert self.handler is not None + _ = kwargs.pop("_meta", None) + return self.fn(self.handler, *args, **kwargs) + + +class MessageHandler: + seq: int = 0 + handlers: dict[str, RequestWrapper] = {} + invokers: dict[str, Invoker] = {} + inflight: dict[int, asyncio.Future] = {} + transport: Transport + + def __init_subclass__(cls): + super().__init_subclass__() + for i in dir(cls): + attr = getattr(cls, i) + if isinstance(attr, RequestWrapper): + cls.handlers[attr.name] = attr + if isinstance(attr, Invoker): + cls.invokers[attr.name] = attr + + def __init__(self, transport: Transport): + self.transport = transport + for invoker in self.invokers.values(): + invoker.transport = transport + invoker.handler = self + for handlers in self.handlers.values(): + handlers.handler = self + + _handler: Optional[asyncio.Task] = None + + async def __aenter__(self): + self._handler = asyncio.create_task(self.run()) + return self + + async def run(self): + async for message in self.transport: + logger.info("<-- %s", message) + if message.message_type == MessageType.REQ: + handler = self.handlers.get(message.method) + if not handler: + self.transport.write( + Message.from_dict( + { + "id": message.id, + "error": { + "code": -32601, + "message": "Method not found", + }, + } + ) + ) + continue + try: + result = await handler(**message.params) + self.transport.write( + Message.from_dict( + { + "id": message.id, + "result": result, + } + ) + ) + except Exception as e: + print("Internal error:", file=sys.stderr) + traceback.print_exc(file=sys.stderr) + self.transport.write( + Message.from_dict( + { + "id": message.id, + "error": { + "code": -32603, + "message": "Internal error", + }, + } + ) + ) + elif message.message_type == MessageType.RESP: + future = self.inflight.pop(message.id, None) + if not future: + continue + future.set_result(message.result) + elif message.message_type == MessageType.NOTE: + fn = self.handlers.get(message.method) + if fn: + fn(**message.params) + else: + logger.info("no handler for %s", message.method) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self._handler: + self._handler.cancel() + try: + await self._handler + except asyncio.CancelledError: + pass + self._handler = None + + async def request(self, name: str, params: dict): + self.seq += 1 + msg = Message.from_dict( + { + "id": self.seq, + "method": name, + "params": params, + } + ) + resp_future = asyncio.get_running_loop().create_future() + self.inflight[self.seq] = resp_future + self.transport.write(msg) + return await resp_future + + def event(self, name: str, params: dict): + msg = Message.from_dict( + { + "method": name, + "params": params, + } + ) + self.transport.write(msg) + + @classmethod + async def acceptClient( + cls, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + try: + with Transport(reader, writer) as client: + server = cls(client) + await server.run() + except: + logger.exception("mcp client failed", exc_info=True)