From 98bb46732c518a50885962d595c7ee21ccf2044e Mon Sep 17 00:00:00 2001 From: zdevito Date: Thu, 3 Jul 2025 19:46:43 -0700 Subject: [PATCH] [10/n] restore debugger support to mesh_controller moves the debugger support directly into the controller actor so there is no need to shuttle messages back/forth from the client. Differential Revision: [D77771081](https://our.internmc.facebook.com/intern/diff/D77771081/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D77771081/)! [ghstack-poisoned] --- monarch_extension/src/mesh_controller.rs | 83 ++++++++++++++++++- .../monarch_extension/mesh_controller.pyi | 2 - python/monarch/mesh_controller.py | 6 +- python/tests/test_remote_functions.py | 4 +- 4 files changed, 88 insertions(+), 7 deletions(-) diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index 3cef61d81..b8253eea6 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -9,6 +9,7 @@ use std::collections::BTreeMap; use std::collections::HashMap; use std::collections::HashSet; +use std::collections::VecDeque; use std::error::Error; use std::fmt::Debug; use std::fmt::Formatter; @@ -21,6 +22,7 @@ use std::sync::atomic::AtomicUsize; use async_trait::async_trait; use hyperactor::Actor; use hyperactor::ActorHandle; +use hyperactor::ActorId; use hyperactor::ActorRef; use hyperactor::Context; use hyperactor::HandleClient; @@ -30,7 +32,6 @@ use hyperactor::PortRef; use hyperactor::cap::CanSend; use hyperactor::mailbox::MailboxSenderError; use hyperactor_mesh::Mesh; -use hyperactor_mesh::ProcMesh; use hyperactor_mesh::actor_mesh::RootActorMesh; use hyperactor_mesh::shared_cell::SharedCell; use hyperactor_mesh::shared_cell::SharedCellRef; @@ -44,6 +45,9 @@ use monarch_messages::controller::ControllerActor; use monarch_messages::controller::ControllerMessage; use monarch_messages::controller::Seq; use monarch_messages::controller::WorkerError; +use monarch_messages::debugger::DebuggerAction; +use monarch_messages::debugger::DebuggerActor; +use monarch_messages::debugger::DebuggerMessage; use monarch_messages::worker::Ref; use monarch_messages::worker::WorkerMessage; use monarch_messages::worker::WorkerParams; @@ -611,12 +615,84 @@ struct MeshControllerActor { workers: Option>>, history: History, id: usize, + debugger_active: Option>, + debugger_paused: VecDeque>, } impl MeshControllerActor { fn workers(&self) -> SharedCellRef> { self.workers.as_ref().unwrap().borrow().unwrap() } + fn handle_debug( + &mut self, + this: &Context, + debugger_actor_id: ActorId, + action: DebuggerAction, + ) -> anyhow::Result<()> { + if matches!(action, DebuggerAction::Paused()) { + self.debugger_paused + .push_back(ActorRef::attest(debugger_actor_id)); + } else { + let debugger_actor = self + .debugger_active + .as_ref() + .ok_or_else(|| anyhow::anyhow!("no active debugger"))?; + if debugger_actor_id != *debugger_actor.actor_id() { + anyhow::bail!("debugger action for wrong actor"); + } + match action { + DebuggerAction::Detach() => { + self.debugger_active = None; + } + DebuggerAction::Read { requested_size } => { + Python::with_gil(|py| { + let read = py + .import("monarch.controller.debugger") + .unwrap() + .getattr("read") + .unwrap(); + let bytes: Vec = + read.call1((requested_size,)).unwrap().extract().unwrap(); + + debugger_actor.send( + this, + DebuggerMessage::Action { + action: DebuggerAction::Write { bytes }, + }, + ) + })?; + } + DebuggerAction::Write { bytes } => { + Python::with_gil(|py| -> Result<(), anyhow::Error> { + let write = py + .import("monarch.controller.debugger") + .unwrap() + .getattr("write") + .unwrap(); + write.call1((String::from_utf8(bytes)?,)).unwrap(); + Ok(()) + })?; + } + _ => { + anyhow::bail!("unexpected action: {:?}", action); + } + } + } + if self.debugger_active.is_none() { + self.debugger_active = self.debugger_paused.pop_front().and_then(|pdb_actor| { + pdb_actor + .send( + this, + DebuggerMessage::Action { + action: DebuggerAction::Attach(), + }, + ) + .map(|_| pdb_actor) + .ok() + }); + } + Ok(()) + } } impl Debug for MeshControllerActor { @@ -642,6 +718,8 @@ impl Actor for MeshControllerActor { workers: None, history: History::new(world_size), id, + debugger_active: None, + debugger_paused: VecDeque::new(), }) } async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { @@ -681,8 +759,7 @@ impl Handler for MeshControllerActor { debugger_actor_id, action, } => { - let dm = crate::client::DebuggerMessage::new(debugger_actor_id.into(), action)?; - panic!("NYI: debugger message handling"); + self.handle_debug(this, debugger_actor_id, action)?; } ControllerMessage::Status { seq, diff --git a/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi b/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi index dbdba70e7..a256c556a 100644 --- a/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi @@ -30,8 +30,6 @@ class _Controller: ranks: Union[NDSlice, List[NDSlice]], msg: NamedTuple, ) -> None: ... - def _debugger_attach(self, debugger_actor_id: ActorId) -> None: ... - def _debugger_write(self, debugger_actor_id: ActorId, data: bytes) -> None: ... def _drain_and_stop( self, ) -> List[client.LogMessage | client.WorkerResponse | client.DebuggerMessage]: ... diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 622ddd762..26d0d1106 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -7,6 +7,8 @@ import atexit import logging import os + +import pdb import traceback from collections import deque from logging import Logger @@ -23,7 +25,6 @@ ) import torch.utils._python_dispatch - from monarch._rust_bindings.monarch_extension import client from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension WorldState, @@ -41,6 +42,8 @@ from monarch.common.stream import StreamRef from monarch.common.tensor import Tensor +from monarch.tensor_worker_main import _set_trace + if TYPE_CHECKING: from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ( ProcMesh as HyProcMesh, @@ -120,6 +123,7 @@ def _initialize_env(worker_point: Point, proc_id: str) -> None: "LOCAL_WORLD_SIZE": str(gpus_per_host), } os.environ.update(process_env) + pdb.set_trace = _set_trace except Exception: traceback.print_exc() raise diff --git a/python/tests/test_remote_functions.py b/python/tests/test_remote_functions.py index 7b1b33983..bf27588ab 100644 --- a/python/tests/test_remote_functions.py +++ b/python/tests/test_remote_functions.py @@ -185,7 +185,9 @@ def local_device_mesh( # out is not counted as a failure, so we set a more restrictive timeout to # ensure we see a hard failure in CI. @pytest.mark.timeout(120) -@pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS]) +@pytest.mark.parametrize( + "backend_type", [BackendType.PY, BackendType.RS, BackendType.MESH] +) class TestRemoteFunctions(RemoteFunctionsTestBase): @classmethod def do_test_reduce_scatter_tensor(cls, backend_type, reduce_op, expected_tensor):