From 30e70dafd713c2410eed90591bd4c340db1a0380 Mon Sep 17 00:00:00 2001 From: zdevito Date: Wed, 18 Jun 2025 12:12:46 -0700 Subject: [PATCH] [9/n] Unify Tensor engine's future with ActorFuture This replaces the "get_next_message()" polling that tensor engine client does with ActorFutures and Ports. The major improvement is that the futures between actors and tensor engine are now the same, meaning they can be used together to wait for different actions. This re-introduces a true controller actor that is reponsible for ensuring results are finished and that some things do not have an error before sending results to the ports. The client is now no longer responsible for polling for anything, the workers and controller will make progress on their own. The controller actor will directly send result values to Port given to it when it is ready. `fetch_shard` family is now implemented by creating a Port and returning port.recv() as the future. To unify further, the next PR will implement all the adverbs for free-standing remote functions by implementing `monarch.actor_mesh.send()` for free functions, and re-implement fetch_shard family of functions in terms of those primitives. This is possible because the controller can now route multiple results for the same seq to the result port. Differential Revision: [D76918699](https://our.internmc.facebook.com/intern/diff/D76918699/) [ghstack-poisoned] --- monarch_extension/Cargo.toml | 1 + monarch_extension/src/mesh_controller.rs | 684 +++++++++++------- monarch_hyperactor/src/actor.rs | 13 +- monarch_messages/src/worker.rs | 26 +- monarch_tensor_worker/src/lib.rs | 4 + monarch_tensor_worker/src/stream.rs | 392 ++++++---- .../monarch_extension/mesh_controller.pyi | 17 +- python/monarch/actor_mesh.py | 32 +- python/monarch/common/pickle_flatten.py | 6 +- python/monarch/common/remote.py | 10 +- python/monarch/mesh_controller.py | 130 +++- python/tests/test_python_actors.py | 9 + 12 files changed, 862 insertions(+), 462 deletions(-) diff --git a/monarch_extension/Cargo.toml b/monarch_extension/Cargo.toml index a8ed29df9..67aa95ff7 100644 --- a/monarch_extension/Cargo.toml +++ b/monarch_extension/Cargo.toml @@ -15,6 +15,7 @@ crate-type = ["cdylib"] [dependencies] anyhow = "1.0.95" +async-trait = "0.1.86" clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "wrap_help"] } controller = { version = "0.0.0", path = "../controller", optional = true } hyperactor = { version = "0.0.0", path = "../hyperactor" } diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index 3cab9f030..cd127f818 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -9,31 +9,36 @@ use std::collections::BTreeMap; use std::collections::HashMap; use std::collections::HashSet; -use std::collections::VecDeque; -use std::iter::repeat_n; +use std::error::Error; +use std::fmt::Debug; +use std::fmt::Formatter; use std::sync; use std::sync::Arc; use std::sync::atomic; use std::sync::atomic::AtomicUsize; +use async_trait::async_trait; +use hyperactor::Actor; +use hyperactor::ActorHandle; use hyperactor::ActorRef; -use hyperactor::data::Serialized; -use hyperactor_mesh::actor_mesh::ActorMesh; +use hyperactor::HandleClient; +use hyperactor::Handler; +use hyperactor::Instance; +use hyperactor::PortRef; +use hyperactor::cap::CanSend; +use hyperactor::mailbox::MailboxSenderError; +use hyperactor_mesh::ProcMesh; use hyperactor_mesh::actor_mesh::RootActorMesh; use hyperactor_mesh::proc_mesh::SharedSpawnable; +use monarch_hyperactor::actor::PythonMessage; +use monarch_hyperactor::mailbox::PyPortId; use monarch_hyperactor::ndslice::PySlice; -use monarch_hyperactor::proc::InstanceWrapper; -use monarch_hyperactor::proc::PyActorId; -use monarch_hyperactor::proc::PyProc; use monarch_hyperactor::proc_mesh::PyProcMesh; use monarch_hyperactor::runtime::signal_safe_block_on; -use monarch_messages::client::Exception; use monarch_messages::controller::ControllerActor; use monarch_messages::controller::ControllerMessage; use monarch_messages::controller::Seq; -use monarch_messages::debugger::DebuggerAction; -use monarch_messages::debugger::DebuggerActor; -use monarch_messages::debugger::DebuggerMessage; +use monarch_messages::controller::WorkerError; use monarch_messages::worker::Ref; use monarch_messages::worker::WorkerMessage; use monarch_messages::worker::WorkerParams; @@ -46,121 +51,36 @@ use tokio::sync::Mutex; use crate::convert::convert; +pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { + module.add_class::<_Controller>()?; + Ok(()) +} + +/// The rust-side implementation of monarch.mesh_controller.Controller +/// It exports the API that interacts with the controller actor (MeshControllerActor) #[pyclass( subclass, module = "monarch._rust_bindings.monarch_extension.mesh_controller" )] struct _Controller { - controller_instance: Arc>>, - workers: RootActorMesh<'static, WorkerActor>, - pending_messages: VecDeque, - history: History, + controller_handle: Arc>>, + all_ranks: Slice, } -impl _Controller { - fn add_responses( - &mut self, - py: Python<'_>, - responses: Vec<( - monarch_messages::controller::Seq, - Option>, - )>, - ) -> PyResult<()> { - for (seq, response) in responses { - let message = crate::client::WorkerResponse::new(seq, response); - self.pending_messages.push_back(message.into_py(py)); - } - Ok(()) - } - fn fill_messages<'py>(&mut self, py: Python<'py>, timeout_msec: Option) -> PyResult<()> { - let instance = self.controller_instance.clone(); - let result = signal_safe_block_on(py, async move { - instance.lock().await.next_message(timeout_msec).await - })??; - result.map(|m| self.add_message(m)).transpose()?; - Ok(()) - } +static NEXT_ID: AtomicUsize = AtomicUsize::new(0); - fn add_message(&mut self, message: ControllerMessage) -> PyResult<()> { - Python::with_gil(|py| -> PyResult<()> { - match message { - ControllerMessage::DebuggerMessage { - debugger_actor_id, - action, - } => { - let dm = crate::client::DebuggerMessage::new(debugger_actor_id.into(), action)? - .into_py(py); - self.pending_messages.push_back(dm); - } - ControllerMessage::Status { - seq, - worker_actor_id, - controller: false, - } => { - let rank = worker_actor_id.rank(); - let responses = self.history.rank_completed(rank, seq); - self.add_responses(py, responses)?; - } - ControllerMessage::RemoteFunctionFailed { seq, error } => { - let responses = self - .history - .propagate_exception(seq, Exception::Error(seq, seq, error)); - self.add_responses(py, responses)?; - } - ControllerMessage::FetchResult { - seq, - value: Ok(value), - } => { - self.history.set_result(seq, value); - } - ControllerMessage::FetchResult { - seq, - value: Err(error), - } => { - let responses = self - .history - .propagate_exception(seq, Exception::Error(seq, seq, error)); - self.add_responses(py, responses)?; - } - message => { - panic!("unexpected message: {:?}", message); - } - }; - Ok(()) - }) - } - fn send_slice(&mut self, slice: Slice, message: WorkerMessage) -> PyResult<()> { - self.workers - .cast_slices(vec![slice], message) - .map_err(|err| PyErr::new::(err.to_string())) - // let shape = Shape::new( - // (0..slice.sizes().len()).map(|i| format!("d{i}")).collect(), - // slice, - // ) - // .unwrap(); - // println!("SENDING TO {:?} {:?}", &shape, &message); - // let worker_slice = SlicedActorMesh::new(&self.workers, shape); - // worker_slice - // .cast(ndslice::Selection::True, message) - // .map_err(|err| PyErr::new::(err.to_string())) - } +fn to_py_error(e: T) -> PyErr +where + T: Error, +{ + return PyErr::new::(e.to_string()); } -static NEXT_ID: AtomicUsize = AtomicUsize::new(0); - #[pymethods] impl _Controller { #[new] fn new(py: Python, py_proc_mesh: &PyProcMesh) -> PyResult { - let proc_mesh = py_proc_mesh.inner.as_ref(); - let id = NEXT_ID.fetch_add(1, atomic::Ordering::Relaxed); - let controller_instance: InstanceWrapper = InstanceWrapper::new( - &PyProc::new_from_proc(proc_mesh.client_proc().clone()), - &format!("tensor_engine_controller_{}", id), - )?; - - let controller_actor_ref = - ActorRef::::attest(controller_instance.actor_id().clone()); + let proc_mesh = py_proc_mesh.inner.clone(); let slice = proc_mesh.shape().slice(); if !slice.is_contiguous() || slice.offset() != 0 { @@ -168,125 +88,106 @@ impl _Controller { "NYI: proc mesh for workers must be contiguous and start at offset 0", )); } - let world_size = slice.len(); - let param = WorkerParams { - world_size, - // Rank assignment is consistent with proc indices. - rank: 0, - device_index: Some(0), - controller_actor: controller_actor_ref, - }; + let all_ranks = proc_mesh.shape().slice().clone(); + let id = NEXT_ID.fetch_add(1, atomic::Ordering::Relaxed); - let py_proc_mesh = Arc::clone(&py_proc_mesh.inner); - let workers: anyhow::Result> = + let controller_handle: Arc>> = signal_safe_block_on(py, async move { - let workers = py_proc_mesh - .spawn(&format!("tensor_engine_workers_{}", id), ¶m) + let controller_handle = proc_mesh + .client_proc() + .spawn( + &format!("tensor_engine_controller_{}", id), + MeshControllerActorParams { + proc_mesh: proc_mesh.clone(), + id, + }, + ) .await?; - //workers.cast(ndslice::Selection::True, )?; - workers.cast_slices( - vec![py_proc_mesh.shape().slice().clone()], - AssignRankMessage::AssignRank(), - )?; - Ok(workers) - })?; + let r: Result>>, anyhow::Error> = + Ok(Arc::new(Mutex::new(controller_handle))); + r + })??; + Ok(Self { - workers: workers?, - controller_instance: Arc::new(Mutex::new(controller_instance)), - pending_messages: VecDeque::new(), - history: History::new(world_size), + controller_handle, + all_ranks, }) } + #[pyo3(signature = (seq, defs, uses, response_port, tracebacks))] fn node<'py>( &mut self, seq: u64, defs: Bound<'py, PyAny>, uses: Bound<'py, PyAny>, + response_port: Option<(PyPortId, PySlice)>, + tracebacks: Py, ) -> PyResult<()> { - let failures = self.history.add_invocation( - seq.into(), - uses.try_iter()? + let response_port: Option = response_port.map(|(port, ranks)| PortInfo { + port: PortRef::attest(port.into()), + ranks: ranks.into(), + }); + let msg = ClientToControllerMessage::Node { + seq: seq.into(), + defs: defs + .try_iter()? .map(|x| Ref::from_py_object(&x?)) .collect::>>()?, - defs.try_iter()? + uses: uses + .try_iter()? .map(|x| Ref::from_py_object(&x?)) .collect::>>()?, - ); - self.add_responses(defs.py(), failures)?; - Ok(()) + tracebacks, + response_port, + }; + self.controller_handle + .blocking_lock() + .send(msg) + .map_err(to_py_error) + } + + fn drop_refs(&mut self, refs: Vec) -> PyResult<()> { + self.controller_handle + .blocking_lock() + .send(ClientToControllerMessage::DropRefs { refs }) + .map_err(to_py_error) } - fn drop_refs(&mut self, refs: Vec) { - self.history.drop_refs(refs); + fn exit(&mut self, seq: Seq) -> PyResult<()> { + self.controller_handle + .blocking_lock() + .send(ClientToControllerMessage::Exit { seq }) + .map_err(to_py_error) } fn send<'py>(&mut self, ranks: Bound<'py, PyAny>, message: Bound<'py, PyAny>) -> PyResult<()> { - let message: WorkerMessage = convert(message)?; - if let Ok(slice) = ranks.extract::() { - self.send_slice(slice.into(), message)?; + let slices = if let Ok(slice) = ranks.extract::() { + vec![slice.into()] } else { let slices = ranks.extract::>()?; - for (slice, message) in slices.iter().zip(repeat_n(message, slices.len())) { - self.send_slice(slice.into(), message)?; - } + slices.iter().map(|x| x.into()).collect::>() }; - Ok(()) - } - - #[pyo3(signature = (*, timeout_msec = None))] - fn _get_next_message<'py>( - &mut self, - py: Python<'py>, - timeout_msec: Option, - ) -> PyResult> { - if self.pending_messages.is_empty() { - self.fill_messages(py, timeout_msec)?; - } - Ok(self.pending_messages.pop_front()) - } - - fn _debugger_attach(&mut self, pdb_actor: PyActorId) -> PyResult<()> { - let pdb_actor: ActorRef = ActorRef::attest(pdb_actor.into()); - pdb_actor - .send( - self.controller_instance.blocking_lock().mailbox(), - DebuggerMessage::Action { - action: DebuggerAction::Attach(), - }, - ) - .map_err(|err| PyErr::new::(err.to_string()))?; - Ok(()) - } - - fn _debugger_write(&mut self, pdb_actor: PyActorId, bytes: Vec) -> PyResult<()> { - let pdb_actor: ActorRef = ActorRef::attest(pdb_actor.into()); - pdb_actor - .send( - self.controller_instance.blocking_lock().mailbox(), - DebuggerMessage::Action { - action: DebuggerAction::Write { bytes }, - }, - ) - .map_err(|err| PyErr::new::(err.to_string()))?; - Ok(()) + let message: WorkerMessage = convert(message)?; + self.controller_handle + .blocking_lock() + .send(ClientToControllerMessage::Send { slices, message }) + .map_err(to_py_error) } - fn _drain_and_stop(&mut self, py: Python<'_>) -> PyResult<()> { - self.send_slice( - self.workers.proc_mesh().shape().slice().clone(), - WorkerMessage::Exit { error: None }, - )?; - let instance = self.controller_instance.clone(); - let _ = signal_safe_block_on(py, async move { instance.lock().await.drain_and_stop() })??; - Ok(()) + fn _drain_and_stop(&mut self) -> PyResult<()> { + self.controller_handle + .blocking_lock() + .send(ClientToControllerMessage::Send { + slices: vec![self.all_ranks.clone()], + message: WorkerMessage::Exit { error: None }, + }) + .map_err(to_py_error)?; + self.controller_handle + .blocking_lock() + .drain_and_stop() + .map_err(to_py_error) } } -pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { - module.add_class::<_Controller>()?; - Ok(()) -} - /// An invocation tracks a discrete node in the graph of operations executed by /// the worker based on instructions from the client. /// It is useful for tracking the dependencies of an operation and propagating @@ -295,12 +196,26 @@ pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult #[derive(Debug)] enum Status { - Errored(Exception), - Complete(), + Errored { + exception: Arc, + }, + Complete {}, /// When incomplete this holds this list of users of this invocation, /// so a future error can be propagated to them., - Incomplete(HashMap>>), + Incomplete { + users: HashMap>>, + results: Vec, + }, +} + +impl Status { + fn incomplete() -> Status { + Self::Incomplete { + users: HashMap::new(), + results: vec![], + } + } } #[derive(Debug)] struct Invocation { @@ -311,85 +226,101 @@ struct Invocation { /// Result reported to a future if this invocation was a fetch /// Not all Invocations will be fetched so sometimes a Invocation will complete with /// both result and error == None - result: Option, + response_port: Option, + tracebacks: Py, } impl Invocation { - fn new(seq: Seq) -> Self { + fn new(seq: Seq, tracebacks: Py, response_port: Option) -> Self { Self { seq, - status: Status::Incomplete(HashMap::new()), - result: None, + status: Status::incomplete(), + response_port, + tracebacks, } } - fn add_user(&mut self, user: Arc>) { + fn add_user( + &mut self, + sender: &impl CanSend, + user: Arc>, + ) -> Result<(), MailboxSenderError> { match &mut self.status { - Status::Complete() => {} - Status::Incomplete(users) => { + Status::Complete {} => {} + Status::Incomplete { users, .. } => { let seq = user.lock().unwrap().seq; users.insert(seq, user); } - Status::Errored(err) => { - user.lock().unwrap().set_exception(err.clone()); + Status::Errored { exception } => { + user.lock() + .unwrap() + .set_exception(sender, exception.clone())?; } } + Ok(()) } /// Invocation results can only go from valid to failed, or be /// set if the invocation result is empty. - fn set_result(&mut self, result: Serialized) { - if self.result.is_none() { - self.result = Some(result); - } - } - - fn succeed(&mut self) { - match self.status { - Status::Incomplete(_) => self.status = Status::Complete(), - _ => {} + fn set_result(&mut self, result: PythonMessage) { + match &mut self.status { + Status::Incomplete { results, .. } => { + results.push(result); + } + Status::Errored { .. } => {} + Status::Complete {} => { + panic!("setting result on a complete seq"); + } } } - fn set_exception(&mut self, exception: Exception) -> Vec>> { - match exception { - Exception::Error(_, caused_by_new, error) => { - let err = Status::Errored(Exception::Error(self.seq, caused_by_new, error)); - match &self.status { - Status::Errored(Exception::Error(_, caused_by_current, _)) - if caused_by_new < *caused_by_current => - { - self.status = err; - } - Status::Incomplete(users) => { - let users = users.values().cloned().collect(); - self.status = err; - return users; - } - Status::Complete() => { - panic!("Complete invocation getting an exception set") + fn complete(&mut self, sender: &impl CanSend) -> Result<(), MailboxSenderError> { + let old_status = std::mem::replace(&mut self.status, Status::Complete {}); + match old_status { + Status::Incomplete { results, .. } => match &self.response_port { + Some(PortInfo { port, ranks }) => { + assert!(ranks.len() == results.iter().len()); + for result in results.into_iter() { + port.send(sender, result)?; } - _ => {} } - } - Exception::Failure(_) => { - tracing::error!( - "system failures {:?} can never be assigned for an invocation", - exception - ); + None => {} + }, + _ => { + self.status = old_status; } } - return vec![]; + Ok(()) } - fn msg_result(&self) -> Option> { - match &self.status { - Status::Complete() => self.result.clone().map(|x| Ok(x)), - Status::Errored(err) => Some(Err(err.clone())), - Status::Incomplete(_) => { - panic!("Incomplete invocation doesn't have a result yet") + fn set_exception( + &mut self, + sender: &impl CanSend, + exception: Arc, + ) -> Result>>, MailboxSenderError> { + let err = Status::Errored { + exception: exception.clone(), + }; + let old_status = std::mem::replace(&mut self.status, err); + match old_status { + Status::Incomplete { users, .. } => { + match &self.response_port { + Some(PortInfo { port, ranks }) => { + for rank in ranks.iter() { + let msg = exception.as_ref().clone().with_rank(rank); + port.send(sender, msg)?; + } + } + None => {} + }; + return Ok(users.into_values().collect()); } + Status::Complete {} => { + panic!("Complete invocation getting an exception set") + } + Status::Errored { .. } => self.status = old_status, } + Ok(vec![]) } } @@ -481,10 +412,13 @@ impl History { /// Add an invocation to the history. pub fn add_invocation( &mut self, + sender: &impl CanSend, seq: Seq, uses: Vec, defs: Vec, - ) -> Vec<(Seq, Option>)> { + tracebacks: Py, + response_port: Option, + ) -> Result<(), MailboxSenderError> { assert!( seq >= self.seq_lower_bound, "nonmonotonic seq: {:?}; current lower bound: {:?}", @@ -492,34 +426,64 @@ impl History { self.seq_lower_bound, ); self.seq_lower_bound = seq; - let invocation = Arc::new(sync::Mutex::new(Invocation::new(seq))); + let invocation = Arc::new(sync::Mutex::new(Invocation::new( + seq, + tracebacks, + response_port, + ))); self.inflight_invocations.insert(seq, invocation.clone()); for ref use_ in uses { let producer = self.invocation_for_ref.get(use_).unwrap(); - producer.lock().unwrap().add_user(invocation.clone()); + producer + .lock() + .unwrap() + .add_user(sender, invocation.clone())?; } for def in defs { self.invocation_for_ref.insert(def, invocation.clone()); } - let invocation = invocation.lock().unwrap(); - if matches!(invocation.status, Status::Errored(_)) { - vec![(seq, invocation.msg_result())] - } else { - vec![] - } + Ok(()) } /// Propagate worker error to the invocation with the given Seq. This will also propagate /// to all seqs that depend on this seq directly or indirectly. pub fn propagate_exception( &mut self, + sender: &impl CanSend, seq: Seq, - exception: Exception, - ) -> Vec<(Seq, Option>)> { - let mut results = Vec::new(); + exception: WorkerError, + ) -> Result<(), MailboxSenderError> { + // TODO: supplement PythonMessage with the stack trace we have in invocation + let rank = exception.worker_actor_id.rank(); + let invocation = self.inflight_invocations.get(&seq).unwrap().clone(); + let python_message = Arc::new(Python::with_gil(|py| { + let traceback = invocation + .lock() + .unwrap() + .tracebacks + .bind(py) + .get_item(0) + .unwrap(); + let remote_exception = py + .import("monarch.mesh_controller") + .unwrap() + .getattr("RemoteException") + .unwrap(); + let pickle = py + .import("monarch.actor_mesh") + .unwrap() + .getattr("_pickle") + .unwrap(); + let exe = remote_exception + .call1((exception.backtrace, traceback, rank)) + .unwrap(); + let data: Vec = pickle.call1((exe,)).unwrap().extract().unwrap(); + PythonMessage::new("exception".to_string(), data, None, Some(rank)) + })); + let mut queue: Vec>> = vec![invocation]; let mut visited = HashSet::new(); @@ -528,40 +492,200 @@ impl History { if !visited.insert(invocation.seq) { continue; }; - queue.extend(invocation.set_exception(exception.clone())); - results.push((seq, invocation.msg_result())); + queue.extend(invocation.set_exception(sender, python_message.clone())?); } - results + Ok(()) } /// Mark the given rank as completed up to but excluding the given Seq. This will also purge history for /// any Seqs that are no longer relevant (completed on all ranks). pub fn rank_completed( &mut self, + sender: &impl CanSend, rank: usize, seq: Seq, - ) -> Vec<(Seq, Option>)> { + ) -> Result<(), MailboxSenderError> { self.first_incomplete_seqs.set(rank, seq); let prev = self.min_incomplete_seq; self.min_incomplete_seq = self.first_incomplete_seqs.min(); - let mut results: Vec<(Seq, Option>)> = Vec::new(); for i in Seq::iter_between(prev, self.min_incomplete_seq) { let invocation = self.inflight_invocations.remove(&i).unwrap(); let mut invocation = invocation.lock().unwrap(); - if matches!(invocation.status, Status::Errored(_)) { - // we already reported output early when it errored - continue; - } - invocation.succeed(); - results.push((i, invocation.msg_result())); + invocation.complete(sender)?; } - results + Ok(()) } - pub fn set_result(&mut self, seq: Seq, result: Serialized) { + pub fn set_result(&mut self, seq: Seq, result: PythonMessage) { let invocation = self.inflight_invocations.get(&seq).unwrap(); invocation.lock().unwrap().set_result(result); } } + +#[derive(Debug)] +struct PortInfo { + port: PortRef, + // the slice of ranks expected to respond + // to the port. used for error reporting. + ranks: Slice, +} + +#[derive(Debug, Handler, HandleClient)] +enum ClientToControllerMessage { + Send { + slices: Vec, + message: WorkerMessage, + }, + Node { + seq: Seq, + defs: Vec, + uses: Vec, + tracebacks: Py, + response_port: Option, + }, + DropRefs { + refs: Vec, + }, + Exit { + seq: Seq, + }, +} + +struct MeshControllerActor { + proc_mesh: Arc, + workers: Option>, + history: History, + id: usize, +} + +impl MeshControllerActor { + fn workers(&self) -> &RootActorMesh<'static, WorkerActor> { + self.workers.as_ref().unwrap() + } +} + +impl Debug for MeshControllerActor { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MeshControllerActor").finish() + } +} + +struct MeshControllerActorParams { + proc_mesh: Arc, + id: usize, +} + +#[async_trait] +impl Actor for MeshControllerActor { + type Params = MeshControllerActorParams; + async fn new( + MeshControllerActorParams { proc_mesh, id }: Self::Params, + ) -> Result { + let world_size = proc_mesh.shape().slice().len(); + Ok(MeshControllerActor { + proc_mesh, + workers: None, + history: History::new(world_size), + id, + }) + } + async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { + let controller_actor_ref: ActorRef = this.bind(); + let slice = self.proc_mesh.shape().slice(); + let world_size = slice.len(); + let param = WorkerParams { + world_size, + // Rank assignment is consistent with proc indices. + rank: 0, + device_index: Some(0), + controller_actor: controller_actor_ref, + }; + + let workers = self + .proc_mesh + .spawn(&format!("tensor_engine_workers_{}", self.id), ¶m) + .await?; + workers.cast_slices(vec![slice.clone()], AssignRankMessage::AssignRank())?; + self.workers = Some(workers); + Ok(()) + } +} + +#[async_trait] +impl Handler for MeshControllerActor { + async fn handle( + &mut self, + this: &Instance, + message: ControllerMessage, + ) -> anyhow::Result<()> { + match message { + ControllerMessage::DebuggerMessage { + debugger_actor_id, + action, + } => { + let dm = crate::client::DebuggerMessage::new(debugger_actor_id.into(), action)?; + panic!("NYI: debugger message handling"); + } + ControllerMessage::Status { + seq, + worker_actor_id, + controller: false, + } => { + let rank = worker_actor_id.rank(); + self.history.rank_completed(this, rank, seq)?; + } + ControllerMessage::FetchResult { + seq, + value: Ok(value), + } => { + let msg: PythonMessage = value.deserialized().unwrap(); + self.history.set_result(seq, msg); + } + ControllerMessage::RemoteFunctionFailed { seq, error } => { + self.history.propagate_exception(this, seq, error)?; + } + message => { + panic!("unexpected message: {:?}", message); + } + }; + Ok(()) + } +} + +#[async_trait] +impl Handler for MeshControllerActor { + async fn handle( + &mut self, + this: &Instance, + message: ClientToControllerMessage, + ) -> anyhow::Result<()> { + match message { + ClientToControllerMessage::Send { slices, message } => { + self.workers().cast_slices(slices, message)?; + } + ClientToControllerMessage::Node { + seq, + defs, + uses, + tracebacks, + response_port, + } => { + self.history + .add_invocation(this, seq, uses, defs, tracebacks, response_port)?; + } + ClientToControllerMessage::DropRefs { refs } => { + self.history.drop_refs(refs); + } + ClientToControllerMessage::Exit { seq } => { + // the byte string is just a Python None + let result = + PythonMessage::new("result".to_string(), b"\x80\x04N.".to_vec(), None, None); + + self.history.set_result(seq, result); + } + } + Ok(()) + } +} diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 305a46925..bbc295b59 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -173,12 +173,21 @@ impl PickledMessageClientActor { #[pyclass(frozen, module = "monarch._rust_bindings.monarch_hyperactor.actor")] #[derive(Clone, Serialize, Deserialize, Named, PartialEq)] pub struct PythonMessage { - method: String, + pub method: String, message: ByteBuf, response_port: Option, rank: Option, } +impl PythonMessage { + pub fn with_rank(self, rank: usize) -> PythonMessage { + PythonMessage { + rank: Some(rank), + ..self + } + } +} + impl std::fmt::Debug for PythonMessage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PythonMessage") @@ -214,7 +223,7 @@ impl Bind for PythonMessage { impl PythonMessage { #[new] #[pyo3(signature = (method, message, response_port, rank))] - fn new( + pub fn new( method: String, message: Vec, response_port: Option, diff --git a/monarch_messages/src/worker.rs b/monarch_messages/src/worker.rs index ede54cf09..ad57b7317 100644 --- a/monarch_messages/src/worker.rs +++ b/monarch_messages/src/worker.rs @@ -236,18 +236,32 @@ impl FunctionPath { } pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult> { - let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| { + let (start, rest) = self.path.split_once(".").with_context(|| { format!( "invalid function path {}: paths must be fully qualified", self.path ) })?; - let module = PyModule::import(py, module_fqn)?; - let mut function = module.getattr(function_name)?; - if function.hasattr("_remote_impl")? { - function = function.getattr("_remote_impl")?; + if start == "torch" { + let mut cur = py.import("torch")?.into_any(); + for p in rest.split(".") { + cur = cur.getattr(p)?; + } + Ok(cur) + } else { + let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| { + format!( + "invalid function path {}: paths must be fully qualified", + self.path + ) + })?; + let module = PyModule::import(py, module_fqn)?; + let mut function = module.getattr(function_name)?; + if function.hasattr("_remote_impl")? { + function = function.getattr("_remote_impl")?; + } + Ok(function.downcast_into()?) } - Ok(function.downcast_into()?) } } diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index 37ba83376..b9292a94b 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -183,6 +183,7 @@ pub struct WorkerActor { send_recv_comms: HashMap<(StreamRef, StreamRef), Arc>>, recordings: HashMap, defining_recording: Option, + respond_with_python_message: bool, } impl WorkerActor { @@ -248,6 +249,7 @@ impl Actor for WorkerActor { send_recv_comms: HashMap::new(), recordings: HashMap::new(), defining_recording: None, + respond_with_python_message: false, }) } @@ -262,6 +264,7 @@ impl Handler> for WorkerActor { message: Cast, ) -> anyhow::Result<()> { self.rank = message.rank.0; + self.respond_with_python_message = true; Python::with_gil(|py| { let mesh_controller = py.import_bound("monarch.mesh_controller").unwrap(); let shape: PyShape = message.shape.into(); @@ -469,6 +472,7 @@ impl WorkerMessageHandler for WorkerActor { id: result, device: self.device, controller_actor: self.controller_actor.clone(), + respond_with_python_message: self.respond_with_python_message, }, ) .await?; diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index db60d9d95..1fd1a2c31 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -35,6 +35,7 @@ use hyperactor::mailbox::Mailbox; use hyperactor::mailbox::OncePortHandle; use hyperactor::mailbox::PortReceiver; use hyperactor::proc::Proc; +use monarch_hyperactor::actor::PythonMessage; use monarch_messages::controller::ControllerMessageClient; use monarch_messages::controller::Seq; use monarch_messages::controller::WorkerError; @@ -423,6 +424,7 @@ pub struct StreamActor { remote_process_groups: HashMap, recordings: HashMap, active_recording: Option, + respond_with_python_message: bool, } /// Parameters for creating a [`Stream`]. @@ -439,6 +441,7 @@ pub struct StreamParams { pub device: Option, /// Actor ref of the controller that created this stream. pub controller_actor: ActorRef, + pub respond_with_python_message: bool, } #[async_trait] @@ -452,6 +455,7 @@ impl Actor for StreamActor { device, controller_actor, creation_mode, + respond_with_python_message, }: Self::Params, ) -> Result { Ok(Self { @@ -466,6 +470,7 @@ impl Actor for StreamActor { remote_process_groups: HashMap::new(), recordings: HashMap::new(), active_recording: None, + respond_with_python_message, }) } @@ -709,10 +714,11 @@ impl StreamActor { }) } - fn call_python_fn( + fn call_python_fn<'py>( &mut self, + py: Python<'py>, this: &Instance, - function: ResolvableFunction, + function: Option, args: Vec, kwargs: HashMap, mutates: &[Ref], @@ -721,149 +727,176 @@ impl StreamActor { Ref, (DeviceMesh, Vec, Arc>), >, - ) -> Result, CallFunctionError> { - Python::with_gil(|py| { - let function = function.resolve(py).map_err(|e| { - CallFunctionError::InvalidRemoteFunction(format!( - "failed to resolve function {}: {}", - function, - SerializablePyErr::from(py, &e) - )) - })?; - - let remote_process_groups = remote_process_groups - .into_iter() - .map(|(gref, (mesh, dims, comm))| { - let group = match self.remote_process_groups.entry(gref) { - Entry::Occupied(ent) => ent.get().clone_ref(py), - Entry::Vacant(ent) => { - // We need to run `init_process_group` before any - // remote process groups can get created. - torch_sys::backend::ensure_init_process_group( - py, - self.world_size, - self.rank, - )?; - - // Create a backend object to wrap the comm and use - // it to create a new torch group. - let ranks = mesh.get_ranks_for_dim_slice(&dims)?; - let group_size = ranks.len(); - let backend = CommBackend::new( - comm, - Mailbox::new_detached(this.self_id().clone()), - self.rank, - group_size, - self.world_size, - ); - ent.insert(torch_sys::backend::new_group(py, ranks, backend)?.unbind()) - .clone_ref(py) - } - }; - PyResult::Ok((gref, group)) + ) -> Result, CallFunctionError> { + let function = function + .map(|function| { + function.resolve(py).map_err(|e| { + CallFunctionError::InvalidRemoteFunction(format!( + "failed to resolve function {}: {}", + function, + SerializablePyErr::from(py, &e) + )) }) - .collect::, _>>() - .map_err(SerializablePyErr::from_fn(py))?; - - // SAFETY: We will be making an unchecked clone of each tensor to pass to to - // C++, so we need to hold a borrow of each input tensor for the duration of - // this function. - let mut multiborrow = MultiBorrow::new(); - - let resolve = |val: WireValue| { - val.into_py_object() - .map_err(|e| { - CallFunctionError::UnsupportedArgType( - format!("{:?}", function), - format!("{:?}", e), - ) - })? - .unpickle(py) - .map_err(SerializablePyErr::from_fn(py))? - .extract::>() - .map_err(SerializablePyErr::from_fn(py))? - .try_into_map(|obj| { - Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) { - if let Some(mesh) = device_meshes.get(&ref_) { - PyArg::DeviceMesh(mesh) - } else if let Some(pg) = remote_process_groups.get(&ref_) { - PyArg::PyObject(pg.clone_ref(py)) - } else { - let rval = self.ref_to_rvalue(&ref_)?; - PyArg::RValue(rval) - } - } else { - PyArg::PyObject(obj) - }) - }) - }; + }) + .transpose()?; - // Resolve refs - let py_args: Vec> = args - .into_iter() - .map(resolve) - .collect::>()?; - let py_kwargs: HashMap<_, PyTree> = kwargs - .into_iter() - .map(|(k, object)| Ok((k, resolve(object)?))) - .collect::>()?; - - // Add a shared-borrow for each rvalue reference. - py_args - .iter() - .chain(py_kwargs.values()) - .flat_map(|o| o.iter()) - .for_each(|arg| { - if let PyArg::RValue(rval) = arg { - multiborrow.add(rval, BorrowType::Shared); + let remote_process_groups = remote_process_groups + .into_iter() + .map(|(gref, (mesh, dims, comm))| { + let group = match self.remote_process_groups.entry(gref) { + Entry::Occupied(ent) => ent.get().clone_ref(py), + Entry::Vacant(ent) => { + // We need to run `init_process_group` before any + // remote process groups can get created. + torch_sys::backend::ensure_init_process_group( + py, + self.world_size, + self.rank, + )?; + + // Create a backend object to wrap the comm and use + // it to create a new torch group. + let ranks = mesh.get_ranks_for_dim_slice(&dims)?; + let group_size = ranks.len(); + let backend = CommBackend::new( + comm, + Mailbox::new_detached(this.self_id().clone()), + self.rank, + group_size, + self.world_size, + ); + ent.insert(torch_sys::backend::new_group(py, ranks, backend)?.unbind()) + .clone_ref(py) } - }); - - // Add mutable borrows for params we're mutating. - let mutates: Vec<_> = mutates - .iter() - .map(|r| self.ref_to_rvalue(r)) - .collect::>()?; - mutates - .iter() - .for_each(|rval| multiborrow.add(rval, BorrowType::Mutable)); + }; + PyResult::Ok((gref, group)) + }) + .collect::, _>>() + .map_err(SerializablePyErr::from_fn(py))?; + + // SAFETY: We will be making an unchecked clone of each tensor to pass to to + // C++, so we need to hold a borrow of each input tensor for the duration of + // this function. + let mut multiborrow = MultiBorrow::new(); + + let resolve = |val: WireValue| { + val.into_py_object() + .map_err(|e| { + CallFunctionError::UnsupportedArgType( + format!("{:?}", function), + format!("{:?}", e), + ) + })? + .unpickle(py) + .map_err(SerializablePyErr::from_fn(py))? + .extract::>() + .map_err(SerializablePyErr::from_fn(py))? + .try_into_map(|obj| { + Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) { + if let Some(mesh) = device_meshes.get(&ref_) { + PyArg::DeviceMesh(mesh) + } else if let Some(pg) = remote_process_groups.get(&ref_) { + PyArg::PyObject(pg.clone_ref(py)) + } else { + let rval = self.ref_to_rvalue(&ref_)?; + PyArg::RValue(rval) + } + } else { + PyArg::PyObject(obj) + }) + }) + }; - // Execute the borrow. - let _borrow = multiborrow.borrow()?; + // Resolve refs + let py_args: Vec> = args + .into_iter() + .map(resolve) + .collect::>()?; + let py_kwargs: HashMap<_, PyTree> = kwargs + .into_iter() + .map(|(k, object)| Ok((k, resolve(object)?))) + .collect::>()?; + + // Add a shared-borrow for each rvalue reference. + py_args + .iter() + .chain(py_kwargs.values()) + .flat_map(|o| o.iter()) + .for_each(|arg| { + if let PyArg::RValue(rval) = arg { + multiborrow.add(rval, BorrowType::Shared); + } + }); - // Call function. - // Use custom subscriber to route Worker messages to stdout. - let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish(); - let result: Bound<'_, PyAny> = - tracing::subscriber::with_default(scoped_subscriber, || { + // Add mutable borrows for params we're mutating. + let mutates: Vec<_> = mutates + .iter() + .map(|r| self.ref_to_rvalue(r)) + .collect::>()?; + mutates + .iter() + .for_each(|rval| multiborrow.add(rval, BorrowType::Mutable)); + + // Execute the borrow. + let _borrow = multiborrow.borrow()?; + + // Call function. + // Use custom subscriber to route Worker messages to stdout. + let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish(); + let result: Bound<'_, PyAny> = + tracing::subscriber::with_default(scoped_subscriber, || { + // SAFETY: The borrows above guard the unchecked clones done by + // `rvalue_to_ivalue`. This may result in multiple mutable + // references to tensor data, but the Python side is responsible + // for making sure that is safe + // TODO(agallagher): The args/kwargs conversion traits generate + // the appropriate types here, but they get casted to `PyAny`. + // It'd be nice to make `TryToPyObjectUnsafe` take a template + // arg for the converted py object to avoid this downcast. + let args = unsafe { py_args.try_to_object_unsafe(py) } + .map_err(SerializablePyErr::from_fn(py))?; + // SAFETY: above + let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) } + .map_err(SerializablePyErr::from_fn(py))?; + + if let Some(function) = function { function - .call( - // SAFETY: The borrows above guard the unchecked clones done by - // `rvalue_to_ivalue`. This may result in multiple mutable - // references to tensor data, but the Python side is responsible - // for making sure that is safe - // TODO(agallagher): The args/kwargs conversion traits generate - // the appropriate types here, but they get casted to `PyAny`. - // It'd be nice to make `TryToPyObjectUnsafe` take a template - // arg for the converted py object to avoid this downcast. - unsafe { py_args.try_to_object_unsafe(py) } - .map_err(SerializablePyErr::from_fn(py))?, - Some( - // SAFETY: Same. - &unsafe { py_kwargs.try_to_object_unsafe(py) } - .map_err(SerializablePyErr::from_fn(py))?, - ), - ) + .call(args, Some(kwargs)) .map_err(SerializablePyErr::from_fn(py)) - })?; + } else { + Ok(args.get_item(0).unwrap()) + } + })?; + Ok(result) + } - // Parse the python result as an `Object`, which should preserve the - // original Python object structure, while providing access to the - // leaves as `RValue`s. + fn call_python_fn_pytree( + &mut self, + this: &Instance, + function: ResolvableFunction, + args: Vec, + kwargs: HashMap, + mutates: &[Ref], + device_meshes: HashMap, + remote_process_groups: HashMap< + Ref, + (DeviceMesh, Vec, Arc>), + >, + ) -> Result, CallFunctionError> { + Python::with_gil(|py| { + let result = self.call_python_fn( + py, + this, + Some(function), + args, + kwargs, + mutates, + device_meshes, + remote_process_groups, + )?; Ok(PyTree::::extract_bound(&result).map_err(SerializablePyErr::from_fn(py))?) }) } - /// Retrieve `ref_` or create a fake value with the provided factory if it /// is an error. We use this for collective calls, where even if there was /// an upstream failure, we still have participate in the collective to @@ -918,6 +951,79 @@ impl StreamActor { } Ok(None) } + async fn send_value_python_message( + &mut self, + this: &Instance, + seq: Seq, + worker_actor_id: ActorId, + mutates: Vec, + function: Option, + args: Vec, + kwargs: HashMap, + device_meshes: HashMap, + ) -> Result<()> { + let result = Python::with_gil(|py| { + let result = tokio::task::block_in_place(|| { + self.call_python_fn( + py, + this, + function, + args, + kwargs, + &mutates, + device_meshes, + HashMap::new(), + ) + }); + result + .map_err(|err| { + let err = Arc::new(err); + for ref_ in mutates { + self.env.insert(ref_, Err(err.clone())); + } + let err = err.unwrap_dependent_error().unwrap_or(err); + WorkerError { + backtrace: format!("{:?}", err), + worker_actor_id: worker_actor_id.clone(), + } + }) + .and_then(|result| -> Result { + let pickle = py + .import("monarch.actor_mesh") + .unwrap() + .getattr("_pickle") + .unwrap(); + let data: Vec = pickle + .call1((result,)) + .map_err(|pyerr| WorkerError { + backtrace: SerializablePyErr::from(py, &pyerr).to_string(), + worker_actor_id: worker_actor_id.clone(), + })? + .extract() + .unwrap(); + Ok(PythonMessage::new( + "result".to_string(), + data, + None, + Some(worker_actor_id.rank()), + )) + }) + }); + match result { + Ok(value) => { + let ser = Serialized::serialize(&value).unwrap(); + self.controller_actor + .fetch_result(this, seq, Ok(ser)) + .await?; + } + Err(e) => { + self.controller_actor + .remote_function_failed(this, seq, e) + .await?; + } + } + Ok(()) + } } #[async_trait] @@ -954,7 +1060,7 @@ impl StreamMessageHandler for StreamActor { // Use block-in-place to allow nested callbacks to re-enter the // runtime to run async code. tokio::task::block_in_place(|| { - self.call_python_fn( + self.call_python_fn_pytree( this, params.function, params.args, @@ -1427,6 +1533,20 @@ impl StreamMessageHandler for StreamActor { device_meshes: HashMap, pipe: Option>, ) -> Result<()> { + if self.respond_with_python_message && pipe.is_none() { + return self + .send_value_python_message( + this, + seq, + worker_actor_id, + mutates, + function, + args, + kwargs, + device_meshes, + ) + .await; + } let result = if let Some(function) = function { // If a function was provided, use that to resolve the value. match function.as_torch_op() { @@ -1456,7 +1576,7 @@ impl StreamMessageHandler for StreamActor { // Use block-in-place to allow nested callbacks to re-enter the // runtime to run async code. _ => tokio::task::block_in_place(|| { - self.call_python_fn( + self.call_python_fn_pytree( this, function, args, diff --git a/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi b/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi index f50b7da28..29c975bf0 100644 --- a/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi @@ -4,9 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, NamedTuple, Sequence, Union +from traceback import FrameSummary +from typing import List, NamedTuple, Sequence, Tuple, Union from monarch._rust_bindings.monarch_extension import client +from monarch._rust_bindings.monarch_hyperactor.mailbox import PortId from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh @@ -15,7 +17,12 @@ from monarch._rust_bindings.monarch_hyperactor.shape import Slice as NDSlice class _Controller: def __init__(self) -> None: ... def node( - self, seq: int, defs: Sequence[object], uses: Sequence[object] + self, + seq: int, + defs: Sequence[object], + uses: Sequence[object], + port: Tuple[PortId, NDSlice] | None, + tracebacks: List[List[FrameSummary]], ) -> None: ... def drop_refs(self, refs: Sequence[object]) -> None: ... def send( @@ -31,3 +38,9 @@ class _Controller: def _drain_and_stop( self, ) -> List[client.LogMessage | client.WorkerResponse | client.DebuggerMessage]: ... + def exit(self, seq: Seq) -> None: + """ + Treat seq as a barrier for exit. It will recieve None on succesfully reaching + seq, and throw an exception if there remote failures that were never reported to a future. + """ + ... diff --git a/python/monarch/actor_mesh.py b/python/monarch/actor_mesh.py index e8cec3a87..f855857be 100644 --- a/python/monarch/actor_mesh.py +++ b/python/monarch/actor_mesh.py @@ -398,19 +398,31 @@ def send(self, method: str, obj: R) -> None: ) +from typing import NamedTuple + +R = TypeVar("R") + +T = TypeVar("T") + + +class PortTuple(NamedTuple, Generic[R]): + sender: "Port[R]" + receiver: "PortReceiver[R]" + + @staticmethod + def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]": + handle, receiver = mailbox.open_once_port() if once else mailbox.open_port() + port_id: PortId = handle.bind() + return PortTuple( + Port(port_id, mailbox, rank=None), PortReceiver(mailbox, receiver) + ) + + # advance lower-level API for sending messages. This is intentially # not part of the Endpoint API because they way it accepts arguments # and handles concerns is different. -def port( - endpoint: Endpoint[P, R], once: bool = False -) -> Tuple["Port[R]", "PortReceiver[R]"]: - handle, receiver = ( - endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port() - ) - port_id: PortId = handle.bind() - return Port(port_id, endpoint._mailbox, rank=None), PortReceiver( - endpoint._mailbox, receiver - ) +def port(endpoint: Endpoint[P, R], once: bool = False) -> "PortTuple[R]": + return PortTuple.create(endpoint._mailbox, once) def ranked_port( diff --git a/python/monarch/common/pickle_flatten.py b/python/monarch/common/pickle_flatten.py index 8033387d1..557b5399d 100644 --- a/python/monarch/common/pickle_flatten.py +++ b/python/monarch/common/pickle_flatten.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Iterable, List, Tuple import cloudpickle +import torch class _Pickler(cloudpickle.Pickler): @@ -44,5 +45,6 @@ def flatten(obj: Any, filter: Callable[[Any], bool]) -> Tuple[List[Any], bytes]: def unflatten(data: bytes, values: Iterable[Any]) -> Any: - up = _Unpickler(data, values) - return up.load() + with torch.utils._python_dispatch._disable_current_modes(): + up = _Unpickler(data, values) + return up.load() diff --git a/python/monarch/common/remote.py b/python/monarch/common/remote.py index bdd0c1d23..5e3b9278b 100644 --- a/python/monarch/common/remote.py +++ b/python/monarch/common/remote.py @@ -180,15 +180,9 @@ def _call_on_shard_and_fetch( client: "Client" = mesh.client if _coalescing.is_active(client): raise NotImplementedError("NYI: fetching results during a coalescing block") + stream_ref = stream._active._to_ref(client) return client.fetch( - mesh, - stream._active._to_ref(client), - shard, - preprocess_message, - args, - kwargs, - mutates, - dtensors, + mesh, stream_ref, shard, preprocess_message, args, kwargs, mutates, dtensors ) diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 499f45a59..4086f5f65 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -11,19 +11,35 @@ import traceback from collections import deque from logging import Logger -from typing import List, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import ( + Any, + cast, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) import torch.utils._python_dispatch -from monarch import NDSlice +from monarch import NDSlice, Stream from monarch._rust_bindings.monarch_extension import client, debugger from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension WorldState, ) from monarch._rust_bindings.monarch_extension.mesh_controller import _Controller +from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) +from monarch.actor_mesh import Port, PortTuple +from monarch.common import messages +from monarch.common.invocation import Seq +from monarch.common.stream import StreamRef +from monarch.common.tensor import Tensor if TYPE_CHECKING: from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ( @@ -37,8 +53,10 @@ from monarch.common.client import Client from monarch.common.controller_api import LogMessage, MessageResult from monarch.common.device_mesh import DeviceMesh, no_mesh +from monarch.common.future import Future as OldFuture from monarch.common.invocation import DeviceException, RemoteException from monarch.controller.debugger import read as debugger_read, write as debugger_write +from monarch.future import ActorFuture from monarch.rust_local_mesh import _get_worker_exec_info from pyre_extensions import none_throws @@ -48,6 +66,7 @@ class Controller(_Controller): def __init__(self, workers: "HyProcMesh") -> None: super().__init__() + self._mailbox: Mailbox = workers.client # Buffer for messages unrelated to debugging that are received while a # debugger session is active. self._non_debugger_pending_messages: deque[ @@ -220,6 +239,40 @@ def _initialize_env(worker_point: Point, proc_id: str) -> None: class MeshClient(Client): + def fetch( + self, + mesh: "DeviceMesh", + stream: "StreamRef", + shard, + preprocess_message, + args, + kwargs, + defs: Tuple["Tensor", ...], + uses: Tuple["Tensor", ...], + ) -> "OldFuture": # the OldFuture is a lie + sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True) + + ident = self.new_node(defs, uses, cast("OldFuture", sender)) + process = mesh._process(shard) + self.send( + process, + messages.SendValue( + ident, + None, + defs, + preprocess_message, + args, + kwargs, + stream, + ), + ) + # we have to ask for status updates + # from workers to be sure they have finished + # enough work to count this future as finished, + # and all potential errors have been reported + self._request_status() + return cast("OldFuture", receiver.recv()) + def shutdown( self, destroy_pg: bool = True, @@ -233,27 +286,47 @@ def shutdown( atexit.unregister(self._atexit) self._shutdown = True + sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True) + # ensure all pending work is finished. # all errors must be messaged back at this point - self.new_node_nocoalesce([], [], None, []) + seq = self.new_node_nocoalesce([], [], cast("OldFuture", sender), []) + self._mesh_controller.exit(seq) self._request_status() - - ttl = 60 - start_time = time.time() - end_time = start_time + ttl - while ttl > 0 and self.last_assigned_seq > self.last_processed_seq: - ttl = end_time - time.time() - self.handle_next_message(ttl) - if self._pending_shutdown_error: - raise self._pending_shutdown_error - - if ttl <= 0: - raise RuntimeError("shutdown timed out") - + receiver.recv().get(timeout=60) # we are not expecting anything more now, because we already # waited for the responses self.inner.drain_and_stop() + @property + def _mesh_controller(self) -> Controller: + return cast(Controller, self.inner) + + def new_node_nocoalesce( + self, + defs: Sequence["Tensor"], + uses: Sequence["Tensor"], + future: Optional["OldFuture"], + tracebacks: List[List[traceback.FrameSummary]], + ) -> Seq: + seq = self._next_seq() + for d in defs: + d._seq = seq + response_port = None + if future is not None: + # method annotation is a lie to make Client happy + port = cast("Port[Any]", future) + slice = NDSlice.new_row_major([]) + response_port = (port._port, slice) + self._mesh_controller.node(seq, defs, uses, response_port, tracebacks) + return seq + + def handle_next_message(self, timeout: Optional[float]) -> bool: + """ + Mesh controller message loop is handled by the tokio event loop. + """ + return False + def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh: # This argument to Controller @@ -269,3 +342,28 @@ def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh: ) dm.exit = lambda: client.shutdown() return dm + + +class RemoteException(Exception): + def __init__( + self, + worker_error_string: str, # this should really be an exception + stacktrace but + # worker code needs major refactor to make this possible + controller_frames: List[traceback.FrameSummary], + rank: int, + ): + self.worker_error_string = worker_error_string + self.controller_frames = controller_frames + self.rank = rank + + def __str__(self): + try: + controller_tb = "".join(traceback.format_list(self.controller_frames)) + return ( + f"A remote function has failed asynchronously on rank {self.rank}.\n" + f"Traceback of where the remote function was issued on controller (most recent call last):\n{controller_tb}" + f"Error as reported from worker:\n{self.worker_error_string}" + ) + except Exception: + traceback.print_exc() + return "" diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 2e468b3cc..ef65fbc20 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -17,6 +17,7 @@ import pytest import torch +from monarch import remote from monarch.actor_mesh import ( Accumulator, @@ -413,6 +414,14 @@ def test_tensor_engine() -> None: assert torch.allclose(torch.zeros(3, 4), r) assert torch.allclose(torch.zeros(3, 4), f) + @remote(propagate=lambda x: x) + def nope(x): + raise ValueError("nope") + + with pytest.raises(monarch.mesh_controller.RemoteException): + with dm.activate(): + monarch.inspect(nope(torch.zeros(3, 4))) + dm.exit()