diff --git a/Cargo.toml b/Cargo.toml index 93aaa71c1..9c6235299 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,6 @@ members = [ "nccl-sys", "torch-sys", ] + +[profile.release] +incremental = true diff --git a/monarch_extension/Cargo.toml b/monarch_extension/Cargo.toml index 9b138b341..ca5190970 100644 --- a/monarch_extension/Cargo.toml +++ b/monarch_extension/Cargo.toml @@ -15,6 +15,7 @@ crate-type = ["cdylib"] [dependencies] anyhow = "1.0.98" +async-trait = "0.1.86" bincode = "1.3.3" clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "wrap_help"] } controller = { version = "0.0.0", path = "../controller", optional = true } diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index 9fc141734..3cef61d81 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -9,293 +9,191 @@ use std::collections::BTreeMap; use std::collections::HashMap; use std::collections::HashSet; -use std::collections::VecDeque; -use std::iter::repeat_n; +use std::error::Error; +use std::fmt::Debug; +use std::fmt::Formatter; +use std::ops::DerefMut; use std::sync; use std::sync::Arc; use std::sync::atomic; use std::sync::atomic::AtomicUsize; +use async_trait::async_trait; +use hyperactor::Actor; +use hyperactor::ActorHandle; use hyperactor::ActorRef; -use hyperactor::data::Serialized; -use hyperactor_mesh::actor_mesh::ActorMesh; +use hyperactor::Context; +use hyperactor::HandleClient; +use hyperactor::Handler; +use hyperactor::Instance; +use hyperactor::PortRef; +use hyperactor::cap::CanSend; +use hyperactor::mailbox::MailboxSenderError; +use hyperactor_mesh::Mesh; +use hyperactor_mesh::ProcMesh; use hyperactor_mesh::actor_mesh::RootActorMesh; use hyperactor_mesh::shared_cell::SharedCell; +use hyperactor_mesh::shared_cell::SharedCellRef; +use monarch_hyperactor::actor::PythonMessage; +use monarch_hyperactor::mailbox::PyPortId; use monarch_hyperactor::ndslice::PySlice; -use monarch_hyperactor::proc::InstanceWrapper; -use monarch_hyperactor::proc::PyActorId; -use monarch_hyperactor::proc::PyProc; use monarch_hyperactor::proc_mesh::PyProcMesh; +use monarch_hyperactor::proc_mesh::TrackedProcMesh; use monarch_hyperactor::runtime::signal_safe_block_on; -use monarch_messages::client::Exception; use monarch_messages::controller::ControllerActor; use monarch_messages::controller::ControllerMessage; use monarch_messages::controller::Seq; -use monarch_messages::debugger::DebuggerAction; -use monarch_messages::debugger::DebuggerActor; -use monarch_messages::debugger::DebuggerMessage; +use monarch_messages::controller::WorkerError; use monarch_messages::worker::Ref; use monarch_messages::worker::WorkerMessage; use monarch_messages::worker::WorkerParams; use monarch_tensor_worker::AssignRankMessage; use monarch_tensor_worker::WorkerActor; use ndslice::Slice; -use pyo3::IntoPyObjectExt; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use tokio::sync::Mutex; use crate::convert::convert; +pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { + module.add_class::<_Controller>()?; + Ok(()) +} + +/// The rust-side implementation of monarch.mesh_controller.Controller +/// It exports the API that interacts with the controller actor (MeshControllerActor) #[pyclass( subclass, module = "monarch._rust_bindings.monarch_extension.mesh_controller" )] struct _Controller { - controller_instance: Arc>>, - workers: SharedCell>, - pending_messages: VecDeque, - history: History, + controller_handle: Arc>>, + all_ranks: Slice, } -impl _Controller { - fn add_responses( - &mut self, - py: Python<'_>, - responses: Vec<( - monarch_messages::controller::Seq, - Option>, - )>, - ) -> PyResult<()> { - for (seq, response) in responses { - let message = crate::client::WorkerResponse::new(seq, response); - self.pending_messages.push_back(message.into_py_any(py)?); - } - Ok(()) - } - fn fill_messages<'py>(&mut self, py: Python<'py>, timeout_msec: Option) -> PyResult<()> { - let instance = self.controller_instance.clone(); - let result = signal_safe_block_on(py, async move { - instance.lock().await.next_message(timeout_msec).await - })??; - result.map(|m| self.add_message(m)).transpose()?; - Ok(()) - } +static NEXT_ID: AtomicUsize = AtomicUsize::new(0); - fn add_message(&mut self, message: ControllerMessage) -> PyResult<()> { - Python::with_gil(|py| -> PyResult<()> { - match message { - ControllerMessage::DebuggerMessage { - debugger_actor_id, - action, - } => { - let dm = crate::client::DebuggerMessage::new(debugger_actor_id.into(), action)? - .into_py_any(py)?; - self.pending_messages.push_back(dm); - } - ControllerMessage::Status { - seq, - worker_actor_id, - controller: false, - } => { - let rank = worker_actor_id.rank(); - let responses = self.history.rank_completed(rank, seq); - self.add_responses(py, responses)?; - } - ControllerMessage::RemoteFunctionFailed { seq, error } => { - let responses = self - .history - .propagate_exception(seq, Exception::Error(seq, seq, error)); - self.add_responses(py, responses)?; - } - ControllerMessage::FetchResult { - seq, - value: Ok(value), - } => { - self.history.set_result(seq, value); - } - ControllerMessage::FetchResult { - seq, - value: Err(error), - } => { - let responses = self - .history - .propagate_exception(seq, Exception::Error(seq, seq, error)); - self.add_responses(py, responses)?; - } - message => { - panic!("unexpected message: {:?}", message); - } - }; - Ok(()) - }) - } - fn send_slice(&mut self, slice: Slice, message: WorkerMessage) -> PyResult<()> { - self.workers - .borrow() - .map_err(anyhow::Error::msg)? - .cast_slices(vec![slice], message) - .map_err(|err| PyErr::new::(err.to_string())) - // let shape = Shape::new( - // (0..slice.sizes().len()).map(|i| format!("d{i}")).collect(), - // slice, - // ) - // .unwrap(); - // println!("SENDING TO {:?} {:?}", &shape, &message); - // let worker_slice = SlicedActorMesh::new(&self.workers, shape); - // worker_slice - // .cast(ndslice::Selection::True, message) - // .map_err(|err| PyErr::new::(err.to_string())) - } +fn to_py_error(e: T) -> PyErr +where + T: Error, +{ + PyErr::new::(e.to_string()) } -static NEXT_ID: AtomicUsize = AtomicUsize::new(0); - #[pymethods] impl _Controller { #[new] fn new(py: Python, py_proc_mesh: &PyProcMesh) -> PyResult { - let proc_mesh = py_proc_mesh.try_inner()?; - let id = NEXT_ID.fetch_add(1, atomic::Ordering::Relaxed); - let controller_instance: InstanceWrapper = InstanceWrapper::new( - &PyProc::new_from_proc(proc_mesh.client_proc().clone()), - &format!("tensor_engine_controller_{}", id), - )?; - - let controller_actor_ref = - ActorRef::::attest(controller_instance.actor_id().clone()); - - let slice = proc_mesh.shape().slice(); + let proc_mesh: SharedCell = py_proc_mesh.inner.clone(); + let proc_mesh_ref = proc_mesh.borrow().unwrap(); + let shape = proc_mesh_ref.shape(); + let slice = shape.slice(); + let all_ranks = shape.slice().clone(); if !slice.is_contiguous() || slice.offset() != 0 { return Err(PyValueError::new_err( "NYI: proc mesh for workers must be contiguous and start at offset 0", )); } - let world_size = slice.len(); - let param = WorkerParams { - world_size, - // Rank assignment is consistent with proc indices. - rank: 0, - device_index: Some(0), - controller_actor: controller_actor_ref, - }; - - let py_proc_mesh = py_proc_mesh.try_inner()?; - let shape = py_proc_mesh.shape().clone(); - let workers: anyhow::Result>> = + let id = NEXT_ID.fetch_add(1, atomic::Ordering::Relaxed); + let controller_handle: Arc>> = signal_safe_block_on(py, async move { - let workers = py_proc_mesh - .spawn(&format!("tensor_engine_workers_{}", id), ¶m) + let controller_handle = proc_mesh + .borrow() + .unwrap() + .client_proc() + .spawn( + &format!("tensor_engine_controller_{}", id), + MeshControllerActorParams { proc_mesh, id }, + ) .await?; - //workers.cast(ndslice::Selection::True, )?; - workers - .borrow()? - .cast_slices(vec![shape.slice().clone()], AssignRankMessage::AssignRank())?; - Ok(workers) - })?; + let r: Result>>, anyhow::Error> = + Ok(Arc::new(Mutex::new(controller_handle))); + r + })??; + Ok(Self { - workers: workers?, - controller_instance: Arc::new(Mutex::new(controller_instance)), - pending_messages: VecDeque::new(), - history: History::new(world_size), + controller_handle, + all_ranks, }) } + #[pyo3(signature = (seq, defs, uses, response_port, tracebacks))] fn node<'py>( &mut self, seq: u64, defs: Bound<'py, PyAny>, uses: Bound<'py, PyAny>, + response_port: Option<(PyPortId, PySlice)>, + tracebacks: Py, ) -> PyResult<()> { - let failures = self.history.add_invocation( - seq.into(), - uses.try_iter()? + let response_port: Option = response_port.map(|(port, ranks)| PortInfo { + port: PortRef::attest(port.into()), + ranks: ranks.into(), + }); + let msg = ClientToControllerMessage::Node { + seq: seq.into(), + defs: defs + .try_iter()? .map(|x| Ref::from_py_object(&x?)) .collect::>>()?, - defs.try_iter()? + uses: uses + .try_iter()? .map(|x| Ref::from_py_object(&x?)) .collect::>>()?, - ); - self.add_responses(defs.py(), failures)?; - Ok(()) + tracebacks, + response_port, + }; + self.controller_handle + .blocking_lock() + .send(msg) + .map_err(to_py_error) } - fn drop_refs(&mut self, refs: Vec) { - self.history.drop_refs(refs); + fn drop_refs(&mut self, refs: Vec) -> PyResult<()> { + self.controller_handle + .blocking_lock() + .send(ClientToControllerMessage::DropRefs { refs }) + .map_err(to_py_error) + } + + fn sync_at_exit(&mut self, port: PyPortId) -> PyResult<()> { + self.controller_handle + .blocking_lock() + .send(ClientToControllerMessage::SyncAtExit { + port: PortRef::attest(port.into()), + }) + .map_err(to_py_error) } fn send<'py>(&mut self, ranks: Bound<'py, PyAny>, message: Bound<'py, PyAny>) -> PyResult<()> { - let message: WorkerMessage = convert(message)?; - if let Ok(slice) = ranks.extract::() { - self.send_slice(slice.into(), message)?; + let slices = if let Ok(slice) = ranks.extract::() { + vec![slice.into()] } else { let slices = ranks.extract::>()?; - for (slice, message) in slices.iter().zip(repeat_n(message, slices.len())) { - self.send_slice(slice.into(), message)?; - } + slices.iter().map(|x| x.into()).collect::>() }; - Ok(()) - } - - #[pyo3(signature = (*, timeout_msec = None))] - fn _get_next_message<'py>( - &mut self, - py: Python<'py>, - timeout_msec: Option, - ) -> PyResult> { - if self.pending_messages.is_empty() { - self.fill_messages(py, timeout_msec)?; - } - Ok(self.pending_messages.pop_front()) - } - - fn _debugger_attach(&mut self, pdb_actor: PyActorId) -> PyResult<()> { - let pdb_actor: ActorRef = ActorRef::attest(pdb_actor.into()); - pdb_actor - .send( - self.controller_instance.blocking_lock().mailbox(), - DebuggerMessage::Action { - action: DebuggerAction::Attach(), - }, - ) - .map_err(|err| PyErr::new::(err.to_string()))?; - Ok(()) - } - - fn _debugger_write(&mut self, pdb_actor: PyActorId, bytes: Vec) -> PyResult<()> { - let pdb_actor: ActorRef = ActorRef::attest(pdb_actor.into()); - pdb_actor - .send( - self.controller_instance.blocking_lock().mailbox(), - DebuggerMessage::Action { - action: DebuggerAction::Write { bytes }, - }, - ) - .map_err(|err| PyErr::new::(err.to_string()))?; - Ok(()) + let message: WorkerMessage = convert(message)?; + self.controller_handle + .blocking_lock() + .send(ClientToControllerMessage::Send { slices, message }) + .map_err(to_py_error) } - fn _drain_and_stop(&mut self, py: Python<'_>) -> PyResult<()> { - self.send_slice( - self.workers - .borrow() - .map_err(anyhow::Error::msg)? - .proc_mesh() - .shape() - .slice() - .clone(), - WorkerMessage::Exit { error: None }, - )?; - let instance = self.controller_instance.clone(); - let _ = signal_safe_block_on(py, async move { instance.lock().await.drain_and_stop() })??; - Ok(()) + fn _drain_and_stop(&mut self) -> PyResult<()> { + self.controller_handle + .blocking_lock() + .send(ClientToControllerMessage::Send { + slices: vec![self.all_ranks.clone()], + message: WorkerMessage::Exit { error: None }, + }) + .map_err(to_py_error)?; + self.controller_handle + .blocking_lock() + .drain_and_stop() + .map_err(to_py_error) } } -pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { - module.add_class::<_Controller>()?; - Ok(()) -} - /// An invocation tracks a discrete node in the graph of operations executed by /// the worker based on instructions from the client. /// It is useful for tracking the dependencies of an operation and propagating @@ -304,12 +202,26 @@ pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult #[derive(Debug)] enum Status { - Errored(Exception), - Complete(), + Errored { + exception: Arc, + }, + Complete {}, /// When incomplete this holds this list of users of this invocation, /// so a future error can be propagated to them., - Incomplete(HashMap>>), + Incomplete { + users: HashMap>>, + results: Vec, + }, +} + +impl Status { + fn incomplete() -> Status { + Self::Incomplete { + users: HashMap::new(), + results: vec![], + } + } } #[derive(Debug)] struct Invocation { @@ -320,85 +232,124 @@ struct Invocation { /// Result reported to a future if this invocation was a fetch /// Not all Invocations will be fetched so sometimes a Invocation will complete with /// both result and error == None - result: Option, + response_port: Option, + tracebacks: Py, } impl Invocation { - fn new(seq: Seq) -> Self { + fn new(seq: Seq, tracebacks: Py, response_port: Option) -> Self { Self { seq, - status: Status::Incomplete(HashMap::new()), - result: None, + status: Status::incomplete(), + response_port, + tracebacks, } } - fn add_user(&mut self, user: Arc>) { + fn add_user( + &mut self, + sender: &impl CanSend, + unreported_exception: &mut Option>, + user: Arc>, + ) -> Result<(), MailboxSenderError> { match &mut self.status { - Status::Complete() => {} - Status::Incomplete(users) => { + Status::Complete {} => {} + Status::Incomplete { users, .. } => { let seq = user.lock().unwrap().seq; users.insert(seq, user); } - Status::Errored(err) => { - user.lock().unwrap().set_exception(err.clone()); + Status::Errored { exception } => { + user.lock().unwrap().set_exception( + sender, + unreported_exception, + exception.clone(), + )?; } } + Ok(()) } /// Invocation results can only go from valid to failed, or be /// set if the invocation result is empty. - fn set_result(&mut self, result: Serialized) { - if self.result.is_none() { - self.result = Some(result); + fn set_result(&mut self, result: PythonMessage) { + match &mut self.status { + Status::Incomplete { results, .. } => { + results.push(result); + } + Status::Errored { .. } => {} + Status::Complete {} => { + panic!("setting result on a complete seq"); + } } } - fn succeed(&mut self) { - match self.status { - Status::Incomplete(_) => self.status = Status::Complete(), - _ => {} + fn complete(&mut self, sender: &impl CanSend) -> Result<(), MailboxSenderError> { + let old_status = std::mem::replace(&mut self.status, Status::Complete {}); + match old_status { + Status::Incomplete { results, .. } => match &self.response_port { + Some(PortInfo { port, ranks }) => { + assert!(ranks.len() == results.iter().len()); + for result in results.into_iter() { + port.send(sender, result)?; + } + } + None => {} + }, + _ => { + self.status = old_status; + } } + Ok(()) } - fn set_exception(&mut self, exception: Exception) -> Vec>> { - match exception { - Exception::Error(_, caused_by_new, error) => { - let err = Status::Errored(Exception::Error(self.seq, caused_by_new, error)); - match &self.status { - Status::Errored(Exception::Error(_, caused_by_current, _)) - if caused_by_new < *caused_by_current => - { - self.status = err; - } - Status::Incomplete(users) => { - let users = users.values().cloned().collect(); - self.status = err; - return users; + /// Changes the status of this invocation to an Errored. If this invocation was + /// Incomplete, it may have users that will also become errored. This function + /// will return those users so the error can be propagated. It does not autmoatically + /// propagate the error to avoid deep recursive invocations. + fn set_exception( + &mut self, + sender: &impl CanSend, + unreported_exception: &mut Option>, + exception: Arc, + ) -> Result<(), MailboxSenderError> { + let mut process = + |invocation: &mut Invocation, queue: &mut Vec>>| { + let err = Status::Errored { + exception: exception.clone(), + }; + let old_status = std::mem::replace(&mut invocation.status, err); + match old_status { + Status::Incomplete { users, .. } => { + match &invocation.response_port { + Some(PortInfo { port, ranks }) => { + *unreported_exception = None; + for rank in ranks.iter() { + let msg = exception.as_ref().clone().with_rank(rank); + port.send(sender, msg)?; + } + } + None => {} + }; + queue.extend(users.into_values()); } - Status::Complete() => { + Status::Complete {} => { panic!("Complete invocation getting an exception set") } - _ => {} + Status::Errored { .. } => invocation.status = old_status, } - } - Exception::Failure(_) => { - tracing::error!( - "system failures {:?} can never be assigned for an invocation", - exception - ); - } - } - vec![] - } - - fn msg_result(&self) -> Option> { - match &self.status { - Status::Complete() => self.result.clone().map(Ok), - Status::Errored(err) => Some(Err(err.clone())), - Status::Incomplete(_) => { - panic!("Incomplete invocation doesn't have a result yet") - } + Ok(()) + }; + let mut queue = vec![]; + let mut visited = HashSet::new(); + process(self, &mut queue)?; + while let Some(invocation) = queue.pop() { + let mut invocation = invocation.lock().unwrap(); + if !visited.insert(invocation.seq) { + continue; + }; + process(invocation.deref_mut(), &mut queue)?; } + Ok(()) } } @@ -423,6 +374,8 @@ struct History { // no new sequence numbers should be below this bound. use for // sanity checking. seq_lower_bound: Seq, + unreported_exception: Option>, + exit_port: Option>, } /// A vector that keeps track of the minimum value. @@ -473,6 +426,8 @@ impl History { invocation_for_ref: HashMap::new(), inflight_invocations: HashMap::new(), seq_lower_bound: 0.into(), + unreported_exception: None, + exit_port: None, } } @@ -490,10 +445,13 @@ impl History { /// Add an invocation to the history. pub fn add_invocation( &mut self, + sender: &impl CanSend, seq: Seq, uses: Vec, defs: Vec, - ) -> Vec<(Seq, Option>)> { + tracebacks: Py, + response_port: Option, + ) -> Result<(), MailboxSenderError> { assert!( seq >= self.seq_lower_bound, "nonmonotonic seq: {:?}; current lower bound: {:?}", @@ -501,76 +459,293 @@ impl History { self.seq_lower_bound, ); self.seq_lower_bound = seq; - let invocation = Arc::new(sync::Mutex::new(Invocation::new(seq))); + let invocation = Arc::new(sync::Mutex::new(Invocation::new( + seq, + tracebacks, + response_port, + ))); self.inflight_invocations.insert(seq, invocation.clone()); for ref use_ in uses { let producer = self.invocation_for_ref.get(use_).unwrap(); - producer.lock().unwrap().add_user(invocation.clone()); + producer.lock().unwrap().add_user( + sender, + &mut self.unreported_exception, + invocation.clone(), + )?; } for def in defs { self.invocation_for_ref.insert(def, invocation.clone()); } - let invocation = invocation.lock().unwrap(); - if matches!(invocation.status, Status::Errored(_)) { - vec![(seq, invocation.msg_result())] - } else { - vec![] - } + Ok(()) } /// Propagate worker error to the invocation with the given Seq. This will also propagate /// to all seqs that depend on this seq directly or indirectly. pub fn propagate_exception( &mut self, + sender: &impl CanSend, seq: Seq, - exception: Exception, - ) -> Vec<(Seq, Option>)> { - let mut results = Vec::new(); - let invocation = self.inflight_invocations.get(&seq).unwrap().clone(); + exception: WorkerError, + ) -> Result<(), MailboxSenderError> { + // TODO: supplement PythonMessage with the stack trace we have in invocation + let rank = exception.worker_actor_id.rank(); - let mut queue: Vec>> = vec![invocation]; - let mut visited = HashSet::new(); + let invocation = self.inflight_invocations.get(&seq).unwrap().clone(); - while let Some(invocation) = queue.pop() { - let mut invocation = invocation.lock().unwrap(); - if !visited.insert(invocation.seq) { - continue; - }; - queue.extend(invocation.set_exception(exception.clone())); - results.push((seq, invocation.msg_result())); + let python_message = Arc::new(Python::with_gil(|py| { + let traceback = invocation + .lock() + .unwrap() + .tracebacks + .bind(py) + .get_item(0) + .unwrap(); + let remote_exception = py + .import("monarch.mesh_controller") + .unwrap() + .getattr("RemoteException") + .unwrap(); + let pickle = py + .import("monarch.actor_mesh") + .unwrap() + .getattr("_pickle") + .unwrap(); + let exe = remote_exception + .call1((exception.backtrace, traceback, rank)) + .unwrap(); + let data: Vec = pickle.call1((exe,)).unwrap().extract().unwrap(); + PythonMessage::new_from_buf("exception".to_string(), data, None, Some(rank)) + })); + + let mut invocation = invocation.lock().unwrap(); + + if let Status::Incomplete { .. } = &invocation.status { + self.unreported_exception = Some(python_message.clone()); } - results + + invocation.set_exception( + sender, + &mut self.unreported_exception, + python_message.clone(), + )?; + + Ok(()) } /// Mark the given rank as completed up to but excluding the given Seq. This will also purge history for /// any Seqs that are no longer relevant (completed on all ranks). pub fn rank_completed( &mut self, + sender: &impl CanSend, rank: usize, seq: Seq, - ) -> Vec<(Seq, Option>)> { + ) -> Result<(), MailboxSenderError> { self.first_incomplete_seqs.set(rank, seq); let prev = self.min_incomplete_seq; self.min_incomplete_seq = self.first_incomplete_seqs.min(); - let mut results: Vec<(Seq, Option>)> = Vec::new(); for i in Seq::iter_between(prev, self.min_incomplete_seq) { - let invocation = self.inflight_invocations.remove(&i).unwrap(); - let mut invocation = invocation.lock().unwrap(); - - if matches!(invocation.status, Status::Errored(_)) { - // we already reported output early when it errored - continue; + if let Some(invocation) = self.inflight_invocations.remove(&i) { + let mut invocation = invocation.lock().unwrap(); + invocation.complete(sender)?; + } + } + if let Some(port) = &self.exit_port { + if self.min_incomplete_seq >= self.seq_lower_bound { + let result = match &self.unreported_exception { + Some(exception) => exception.as_ref().clone(), + None => { + // the byte string is just a Python None + PythonMessage::new("result".to_string(), b"\x80\x04N.", None, None) + } + }; + port.send(sender, result)?; + self.exit_port = None; } - invocation.succeed(); - results.push((i, invocation.msg_result())); } - results + Ok(()) } - pub fn set_result(&mut self, seq: Seq, result: Serialized) { + pub fn set_result(&mut self, seq: Seq, result: PythonMessage) { let invocation = self.inflight_invocations.get(&seq).unwrap(); invocation.lock().unwrap().set_result(result); } + + fn report_exit(&mut self, port: PortRef) { + self.exit_port = Some(port); + } +} + +#[derive(Debug)] +struct PortInfo { + port: PortRef, + // the slice of ranks expected to respond + // to the port. used for error reporting. + ranks: Slice, +} + +#[derive(Debug, Handler, HandleClient)] +enum ClientToControllerMessage { + Send { + slices: Vec, + message: WorkerMessage, + }, + Node { + seq: Seq, + defs: Vec, + uses: Vec, + tracebacks: Py, + response_port: Option, + }, + DropRefs { + refs: Vec, + }, + SyncAtExit { + port: PortRef, + }, +} + +struct MeshControllerActor { + proc_mesh: SharedCell, + workers: Option>>, + history: History, + id: usize, +} + +impl MeshControllerActor { + fn workers(&self) -> SharedCellRef> { + self.workers.as_ref().unwrap().borrow().unwrap() + } +} + +impl Debug for MeshControllerActor { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MeshControllerActor").finish() + } +} + +struct MeshControllerActorParams { + proc_mesh: SharedCell, + id: usize, +} + +#[async_trait] +impl Actor for MeshControllerActor { + type Params = MeshControllerActorParams; + async fn new( + MeshControllerActorParams { proc_mesh, id }: Self::Params, + ) -> Result { + let world_size = proc_mesh.borrow().unwrap().shape().slice().len(); + Ok(MeshControllerActor { + proc_mesh: proc_mesh.clone(), + workers: None, + history: History::new(world_size), + id, + }) + } + async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { + let controller_actor_ref: ActorRef = this.bind(); + let proc_mesh = self.proc_mesh.borrow().unwrap(); + let slice = proc_mesh.shape().slice(); + let world_size = slice.len(); + let param = WorkerParams { + world_size, + // Rank assignment is consistent with proc indices. + rank: 0, + device_index: Some(0), + controller_actor: controller_actor_ref, + }; + + let workers = proc_mesh + .spawn(&format!("tensor_engine_workers_{}", self.id), ¶m) + .await?; + workers + .borrow() + .unwrap() + .cast_slices(vec![slice.clone()], AssignRankMessage::AssignRank())?; + self.workers = Some(workers); + Ok(()) + } +} + +#[async_trait] +impl Handler for MeshControllerActor { + async fn handle( + &mut self, + this: &Context, + message: ControllerMessage, + ) -> anyhow::Result<()> { + match message { + ControllerMessage::DebuggerMessage { + debugger_actor_id, + action, + } => { + let dm = crate::client::DebuggerMessage::new(debugger_actor_id.into(), action)?; + panic!("NYI: debugger message handling"); + } + ControllerMessage::Status { + seq, + worker_actor_id, + controller: false, + } => { + let rank = worker_actor_id.rank(); + self.history.rank_completed(this, rank, seq)?; + } + ControllerMessage::FetchResult { + seq, + value: Ok(value), + } => { + let msg: PythonMessage = value.deserialized().unwrap(); + self.history.set_result(seq, msg); + } + ControllerMessage::RemoteFunctionFailed { seq, error } => { + self.history.propagate_exception(this, seq, error)?; + } + message => { + panic!("unexpected message: {:?}", message); + } + }; + Ok(()) + } +} + +#[async_trait] +impl Handler for MeshControllerActor { + async fn handle( + &mut self, + this: &Context, + message: ClientToControllerMessage, + ) -> anyhow::Result<()> { + match message { + ClientToControllerMessage::Send { slices, message } => { + self.workers().cast_slices(slices, message)?; + } + ClientToControllerMessage::Node { + seq, + defs, + uses, + tracebacks, + response_port, + } => { + self.history + .add_invocation(this, seq, uses, defs, tracebacks, response_port)?; + } + ClientToControllerMessage::DropRefs { refs } => { + self.history.drop_refs(refs); + } + ClientToControllerMessage::SyncAtExit { port } => { + let all_ranks = vec![self.workers().shape().slice().clone()]; + self.workers().cast_slices( + all_ranks, + WorkerMessage::RequestStatus { + seq: self.history.seq_lower_bound, + controller: false, + }, + )?; + self.history.report_exit(port); + } + } + Ok(()) + } } diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 2fa461bc9..6fd7c51f1 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -180,6 +180,28 @@ pub struct PythonMessage { rank: Option, } +impl PythonMessage { + pub fn with_rank(self, rank: usize) -> PythonMessage { + PythonMessage { + rank: Some(rank), + ..self + } + } + pub fn new_from_buf( + method: String, + message: Vec, + response_port: Option, + rank: Option, + ) -> Self { + Self { + method, + message: message.into(), + response_port, + rank, + } + } +} + impl std::fmt::Debug for PythonMessage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PythonMessage") @@ -208,18 +230,13 @@ impl Bind for PythonMessage { impl PythonMessage { #[new] #[pyo3(signature = (method, message, response_port, rank))] - fn new( + pub fn new( method: String, message: &[u8], response_port: Option, rank: Option, ) -> Self { - Self { - method, - message: ByteBuf::from(message), - response_port, - rank, - } + Self::new_from_buf(method, message.into(), response_port, rank) } #[getter] diff --git a/monarch_hyperactor/src/mailbox.rs b/monarch_hyperactor/src/mailbox.rs index 3247de3ec..e281c2207 100644 --- a/monarch_hyperactor/src/mailbox.rs +++ b/monarch_hyperactor/src/mailbox.rs @@ -329,6 +329,11 @@ impl PythonPortRef { fn __repr__(&self) -> String { self.inner.to_string() } + + #[getter] + fn port_id(&self) -> PyResult { + Ok(self.inner.port_id().clone().into()) + } } impl From> for PythonPortRef { @@ -470,6 +475,11 @@ impl PythonOncePortRef { .as_ref() .map_or("OncePortRef is already used".to_string(), |r| r.to_string()) } + + #[getter] + fn port_id(&self) -> PyResult { + Ok(self.inner.as_ref().unwrap().port_id().clone().into()) + } } impl From> for PythonOncePortRef { @@ -520,7 +530,7 @@ impl PythonOncePortReceiver { FromPyObject, IntoPyObject )] -pub(super) enum EitherPortRef { +pub enum EitherPortRef { Unbounded(PythonPortRef), Once(PythonOncePortRef), } diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index d0c52b2fc..c457e130a 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -111,7 +111,7 @@ impl TrackedProcMesh { module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh" )] pub struct PyProcMesh { - inner: SharedCell, + pub inner: SharedCell, keepalive: Keepalive, proc_events: SharedCell>, stop_monitor_sender: mpsc::Sender, diff --git a/monarch_messages/src/worker.rs b/monarch_messages/src/worker.rs index c3236b876..bc9199709 100644 --- a/monarch_messages/src/worker.rs +++ b/monarch_messages/src/worker.rs @@ -234,18 +234,32 @@ impl FunctionPath { } pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult> { - let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| { + let (start, rest) = self.path.split_once(".").with_context(|| { format!( "invalid function path {}: paths must be fully qualified", self.path ) })?; - let module = PyModule::import(py, module_fqn)?; - let mut function = module.getattr(function_name)?; - if function.hasattr("_remote_impl")? { - function = function.getattr("_remote_impl")?; + if start == "torch" { + let mut cur = py.import("torch")?.into_any(); + for p in rest.split(".") { + cur = cur.getattr(p)?; + } + Ok(cur) + } else { + let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| { + format!( + "invalid function path {}: paths must be fully qualified", + self.path + ) + })?; + let module = PyModule::import(py, module_fqn)?; + let mut function = module.getattr(function_name)?; + if function.hasattr("_remote_impl")? { + function = function.getattr("_remote_impl")?; + } + Ok(function.downcast_into()?) } - Ok(function.downcast_into()?) } } diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index 6b42beff5..004feb32b 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -186,6 +186,7 @@ pub struct WorkerActor { send_recv_comms: HashMap<(StreamRef, StreamRef), Arc>>, recordings: HashMap, defining_recording: Option, + respond_with_python_message: bool, } impl WorkerActor { @@ -251,6 +252,7 @@ impl Actor for WorkerActor { send_recv_comms: HashMap::new(), recordings: HashMap::new(), defining_recording: None, + respond_with_python_message: false, }) } @@ -266,6 +268,7 @@ impl Handler for WorkerActor { ) -> anyhow::Result<()> { let (rank, shape) = this.cast_info()?; self.rank = rank; + self.respond_with_python_message = true; Python::with_gil(|py| { let mesh_controller = py.import("monarch.mesh_controller").unwrap(); let shape: PyShape = shape.into(); @@ -449,6 +452,7 @@ impl WorkerMessageHandler for WorkerActor { id: result, device: self.device, controller_actor: self.controller_actor.clone(), + respond_with_python_message: self.respond_with_python_message, }, ) .await?; diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index 55077f30a..4861ce792 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -36,6 +36,7 @@ use hyperactor::mailbox::Mailbox; use hyperactor::mailbox::OncePortHandle; use hyperactor::mailbox::PortReceiver; use hyperactor::proc::Proc; +use monarch_hyperactor::actor::PythonMessage; use monarch_messages::controller::ControllerMessageClient; use monarch_messages::controller::Seq; use monarch_messages::controller::WorkerError; @@ -424,6 +425,7 @@ pub struct StreamActor { remote_process_groups: HashMap, recordings: HashMap, active_recording: Option, + respond_with_python_message: bool, } /// Parameters for creating a [`Stream`]. @@ -440,6 +442,7 @@ pub struct StreamParams { pub device: Option, /// Actor ref of the controller that created this stream. pub controller_actor: ActorRef, + pub respond_with_python_message: bool, } #[async_trait] @@ -453,6 +456,7 @@ impl Actor for StreamActor { device, controller_actor, creation_mode, + respond_with_python_message, }: Self::Params, ) -> Result { Ok(Self { @@ -467,6 +471,7 @@ impl Actor for StreamActor { remote_process_groups: HashMap::new(), recordings: HashMap::new(), active_recording: None, + respond_with_python_message, }) } @@ -710,10 +715,11 @@ impl StreamActor { }) } - fn call_python_fn( + fn call_python_fn<'py>( &mut self, + py: Python<'py>, this: &Instance, - function: ResolvableFunction, + function: Option, args: Vec, kwargs: HashMap, mutates: &[Ref], @@ -722,149 +728,176 @@ impl StreamActor { Ref, (DeviceMesh, Vec, Arc>), >, - ) -> Result, CallFunctionError> { - Python::with_gil(|py| { - let function = function.resolve(py).map_err(|e| { - CallFunctionError::InvalidRemoteFunction(format!( - "failed to resolve function {}: {}", - function, - SerializablePyErr::from(py, &e) - )) - })?; - - let remote_process_groups = remote_process_groups - .into_iter() - .map(|(gref, (mesh, dims, comm))| { - let group = match self.remote_process_groups.entry(gref) { - Entry::Occupied(ent) => ent.get().clone_ref(py), - Entry::Vacant(ent) => { - // We need to run `init_process_group` before any - // remote process groups can get created. - torch_sys::backend::ensure_init_process_group( - py, - self.world_size, - self.rank, - )?; - - // Create a backend object to wrap the comm and use - // it to create a new torch group. - let ranks = mesh.get_ranks_for_dim_slice(&dims)?; - let group_size = ranks.len(); - let backend = CommBackend::new( - comm, - Mailbox::new_detached(this.self_id().clone()), - self.rank, - group_size, - self.world_size, - ); - ent.insert(torch_sys::backend::new_group(py, ranks, backend)?.unbind()) - .clone_ref(py) - } - }; - PyResult::Ok((gref, group)) + ) -> Result, CallFunctionError> { + let function = function + .map(|function| { + function.resolve(py).map_err(|e| { + CallFunctionError::InvalidRemoteFunction(format!( + "failed to resolve function {}: {}", + function, + SerializablePyErr::from(py, &e) + )) }) - .collect::, _>>() - .map_err(SerializablePyErr::from_fn(py))?; - - // SAFETY: We will be making an unchecked clone of each tensor to pass to to - // C++, so we need to hold a borrow of each input tensor for the duration of - // this function. - let mut multiborrow = MultiBorrow::new(); - - let resolve = |val: WireValue| { - val.into_py_object() - .map_err(|e| { - CallFunctionError::UnsupportedArgType( - format!("{:?}", function), - format!("{:?}", e), - ) - })? - .unpickle(py) - .map_err(SerializablePyErr::from_fn(py))? - .extract::>() - .map_err(SerializablePyErr::from_fn(py))? - .try_into_map(|obj| { - Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) { - if let Some(mesh) = device_meshes.get(&ref_) { - PyArg::DeviceMesh(mesh) - } else if let Some(pg) = remote_process_groups.get(&ref_) { - PyArg::PyObject(pg.clone_ref(py)) - } else { - let rval = self.ref_to_rvalue(&ref_)?; - PyArg::RValue(rval) - } - } else { - PyArg::PyObject(obj) - }) - }) - }; + }) + .transpose()?; - // Resolve refs - let py_args: Vec> = args - .into_iter() - .map(resolve) - .collect::>()?; - let py_kwargs: HashMap<_, PyTree> = kwargs - .into_iter() - .map(|(k, object)| Ok((k, resolve(object)?))) - .collect::>()?; - - // Add a shared-borrow for each rvalue reference. - py_args - .iter() - .chain(py_kwargs.values()) - .flat_map(|o| o.iter()) - .for_each(|arg| { - if let PyArg::RValue(rval) = arg { - multiborrow.add(rval, BorrowType::Shared); + let remote_process_groups = remote_process_groups + .into_iter() + .map(|(gref, (mesh, dims, comm))| { + let group = match self.remote_process_groups.entry(gref) { + Entry::Occupied(ent) => ent.get().clone_ref(py), + Entry::Vacant(ent) => { + // We need to run `init_process_group` before any + // remote process groups can get created. + torch_sys::backend::ensure_init_process_group( + py, + self.world_size, + self.rank, + )?; + + // Create a backend object to wrap the comm and use + // it to create a new torch group. + let ranks = mesh.get_ranks_for_dim_slice(&dims)?; + let group_size = ranks.len(); + let backend = CommBackend::new( + comm, + Mailbox::new_detached(this.self_id().clone()), + self.rank, + group_size, + self.world_size, + ); + ent.insert(torch_sys::backend::new_group(py, ranks, backend)?.unbind()) + .clone_ref(py) } - }); - - // Add mutable borrows for params we're mutating. - let mutates: Vec<_> = mutates - .iter() - .map(|r| self.ref_to_rvalue(r)) - .collect::>()?; - mutates - .iter() - .for_each(|rval| multiborrow.add(rval, BorrowType::Mutable)); + }; + PyResult::Ok((gref, group)) + }) + .collect::, _>>() + .map_err(SerializablePyErr::from_fn(py))?; + + // SAFETY: We will be making an unchecked clone of each tensor to pass to to + // C++, so we need to hold a borrow of each input tensor for the duration of + // this function. + let mut multiborrow = MultiBorrow::new(); + + let resolve = |val: WireValue| { + val.into_py_object() + .map_err(|e| { + CallFunctionError::UnsupportedArgType( + format!("{:?}", function), + format!("{:?}", e), + ) + })? + .unpickle(py) + .map_err(SerializablePyErr::from_fn(py))? + .extract::>() + .map_err(SerializablePyErr::from_fn(py))? + .try_into_map(|obj| { + Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) { + if let Some(mesh) = device_meshes.get(&ref_) { + PyArg::DeviceMesh(mesh) + } else if let Some(pg) = remote_process_groups.get(&ref_) { + PyArg::PyObject(pg.clone_ref(py)) + } else { + let rval = self.ref_to_rvalue(&ref_)?; + PyArg::RValue(rval) + } + } else { + PyArg::PyObject(obj) + }) + }) + }; - // Execute the borrow. - let _borrow = multiborrow.borrow()?; + // Resolve refs + let py_args: Vec> = args + .into_iter() + .map(resolve) + .collect::>()?; + let py_kwargs: HashMap<_, PyTree> = kwargs + .into_iter() + .map(|(k, object)| Ok((k, resolve(object)?))) + .collect::>()?; + + // Add a shared-borrow for each rvalue reference. + py_args + .iter() + .chain(py_kwargs.values()) + .flat_map(|o| o.iter()) + .for_each(|arg| { + if let PyArg::RValue(rval) = arg { + multiborrow.add(rval, BorrowType::Shared); + } + }); - // Call function. - // Use custom subscriber to route Worker messages to stdout. - let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish(); - let result: Bound<'_, PyAny> = - tracing::subscriber::with_default(scoped_subscriber, || { + // Add mutable borrows for params we're mutating. + let mutates: Vec<_> = mutates + .iter() + .map(|r| self.ref_to_rvalue(r)) + .collect::>()?; + mutates + .iter() + .for_each(|rval| multiborrow.add(rval, BorrowType::Mutable)); + + // Execute the borrow. + let _borrow = multiborrow.borrow()?; + + // Call function. + // Use custom subscriber to route Worker messages to stdout. + let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish(); + let result: Bound<'_, PyAny> = + tracing::subscriber::with_default(scoped_subscriber, || { + // SAFETY: The borrows above guard the unchecked clones done by + // `rvalue_to_ivalue`. This may result in multiple mutable + // references to tensor data, but the Python side is responsible + // for making sure that is safe + // TODO(agallagher): The args/kwargs conversion traits generate + // the appropriate types here, but they get casted to `PyAny`. + // It'd be nice to make `TryToPyObjectUnsafe` take a template + // arg for the converted py object to avoid this downcast. + let args = unsafe { py_args.try_to_object_unsafe(py) } + .map_err(SerializablePyErr::from_fn(py))?; + // SAFETY: above + let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) } + .map_err(SerializablePyErr::from_fn(py))?; + + if let Some(function) = function { function - .call( - // SAFETY: The borrows above guard the unchecked clones done by - // `rvalue_to_ivalue`. This may result in multiple mutable - // references to tensor data, but the Python side is responsible - // for making sure that is safe - // TODO(agallagher): The args/kwargs conversion traits generate - // the appropriate types here, but they get casted to `PyAny`. - // It'd be nice to make `TryToPyObjectUnsafe` take a template - // arg for the converted py object to avoid this downcast. - unsafe { py_args.try_to_object_unsafe(py) } - .map_err(SerializablePyErr::from_fn(py))?, - Some( - // SAFETY: Same. - &unsafe { py_kwargs.try_to_object_unsafe(py) } - .map_err(SerializablePyErr::from_fn(py))?, - ), - ) + .call(args, Some(kwargs)) .map_err(SerializablePyErr::from_fn(py)) - })?; + } else { + Ok(args.get_item(0).unwrap()) + } + })?; + Ok(result) + } - // Parse the python result as an `Object`, which should preserve the - // original Python object structure, while providing access to the - // leaves as `RValue`s. + fn call_python_fn_pytree( + &mut self, + this: &Instance, + function: ResolvableFunction, + args: Vec, + kwargs: HashMap, + mutates: &[Ref], + device_meshes: HashMap, + remote_process_groups: HashMap< + Ref, + (DeviceMesh, Vec, Arc>), + >, + ) -> Result, CallFunctionError> { + Python::with_gil(|py| { + let result = self.call_python_fn( + py, + this, + Some(function), + args, + kwargs, + mutates, + device_meshes, + remote_process_groups, + )?; Ok(PyTree::::extract_bound(&result).map_err(SerializablePyErr::from_fn(py))?) }) } - /// Retrieve `ref_` or create a fake value with the provided factory if it /// is an error. We use this for collective calls, where even if there was /// an upstream failure, we still have participate in the collective to @@ -919,6 +952,79 @@ impl StreamActor { } Ok(None) } + async fn send_value_python_message( + &mut self, + this: &Instance, + seq: Seq, + worker_actor_id: ActorId, + mutates: Vec, + function: Option, + args: Vec, + kwargs: HashMap, + device_meshes: HashMap, + ) -> Result<()> { + let result = Python::with_gil(|py| { + let result = tokio::task::block_in_place(|| { + self.call_python_fn( + py, + this, + function, + args, + kwargs, + &mutates, + device_meshes, + HashMap::new(), + ) + }); + result + .map_err(|err| { + let err = Arc::new(err); + for ref_ in mutates { + self.env.insert(ref_, Err(err.clone())); + } + let err = err.unwrap_dependent_error().unwrap_or(err); + WorkerError { + backtrace: format!("{:?}", err), + worker_actor_id: worker_actor_id.clone(), + } + }) + .and_then(|result| -> Result { + let pickle = py + .import("monarch.actor_mesh") + .unwrap() + .getattr("_pickle") + .unwrap(); + let data: Vec = pickle + .call1((result,)) + .map_err(|pyerr| WorkerError { + backtrace: SerializablePyErr::from(py, &pyerr).to_string(), + worker_actor_id: worker_actor_id.clone(), + })? + .extract() + .unwrap(); + Ok(PythonMessage::new_from_buf( + "result".to_string(), + data, + None, + Some(worker_actor_id.rank()), + )) + }) + }); + match result { + Ok(value) => { + let ser = Serialized::serialize(&value).unwrap(); + self.controller_actor + .fetch_result(this, seq, Ok(ser)) + .await?; + } + Err(e) => { + self.controller_actor + .remote_function_failed(this, seq, e) + .await?; + } + } + Ok(()) + } } #[async_trait] @@ -955,7 +1061,7 @@ impl StreamMessageHandler for StreamActor { // Use block-in-place to allow nested callbacks to re-enter the // runtime to run async code. tokio::task::block_in_place(|| { - self.call_python_fn( + self.call_python_fn_pytree( this, params.function, params.args, @@ -1428,6 +1534,20 @@ impl StreamMessageHandler for StreamActor { device_meshes: HashMap, pipe: Option>, ) -> Result<()> { + if self.respond_with_python_message && pipe.is_none() { + return self + .send_value_python_message( + this, + seq, + worker_actor_id, + mutates, + function, + args, + kwargs, + device_meshes, + ) + .await; + } let result = if let Some(function) = function { // If a function was provided, use that to resolve the value. match function.as_torch_op() { @@ -1457,7 +1577,7 @@ impl StreamMessageHandler for StreamActor { // Use block-in-place to allow nested callbacks to re-enter the // runtime to run async code. _ => tokio::task::block_in_place(|| { - self.call_python_fn( + self.call_python_fn_pytree( this, function, args, @@ -2016,6 +2136,7 @@ mod tests { id: 0.into(), device: Some(CudaDevice::new(0.into())), controller_actor: controller_actor.clone(), + respond_with_python_message: false, }, ) .await?; @@ -2200,6 +2321,7 @@ mod tests { id: 0.into(), device: None, controller_actor: controller_ref, + respond_with_python_message: false, }; let mut actor = StreamActor::new(param).await.unwrap(); @@ -3039,6 +3161,7 @@ mod tests { id: 1.into(), device: Some(CudaDevice::new(0.into())), controller_actor: test_setup.controller_actor.clone(), + respond_with_python_message: false, }, ) .await?; @@ -3656,6 +3779,7 @@ mod tests { id: 1.into(), device: Some(CudaDevice::new(1.into())), controller_actor: test_setup.controller_actor.clone(), + respond_with_python_message: false, }, ) .await?; diff --git a/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi b/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi index f50b7da28..dbdba70e7 100644 --- a/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi @@ -4,9 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, NamedTuple, Sequence, Union +from traceback import FrameSummary +from typing import List, NamedTuple, Sequence, Tuple, Union from monarch._rust_bindings.monarch_extension import client +from monarch._rust_bindings.monarch_hyperactor.mailbox import PortId from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh @@ -15,7 +17,12 @@ from monarch._rust_bindings.monarch_hyperactor.shape import Slice as NDSlice class _Controller: def __init__(self) -> None: ... def node( - self, seq: int, defs: Sequence[object], uses: Sequence[object] + self, + seq: int, + defs: Sequence[object], + uses: Sequence[object], + port: Tuple[PortId, NDSlice] | None, + tracebacks: List[List[FrameSummary]], ) -> None: ... def drop_refs(self, refs: Sequence[object]) -> None: ... def send( @@ -23,11 +30,15 @@ class _Controller: ranks: Union[NDSlice, List[NDSlice]], msg: NamedTuple, ) -> None: ... - def _get_next_message( - self, *, timeout_msec: int | None = None - ) -> client.WorkerResponse | client.DebuggerMessage | None: ... def _debugger_attach(self, debugger_actor_id: ActorId) -> None: ... def _debugger_write(self, debugger_actor_id: ActorId, data: bytes) -> None: ... def _drain_and_stop( self, ) -> List[client.LogMessage | client.WorkerResponse | client.DebuggerMessage]: ... + def sync_at_exit(self, port: PortId) -> None: + """ + Controller waits until all nodes that were added are complete, then replies on the + given port. The port will get an exception if there was a known error that was not reported + to any future. + """ + ... diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi index 072137363..35743679d 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi @@ -70,6 +70,9 @@ class PortRef: def send(self, mailbox: Mailbox, message: PythonMessage) -> None: """Send a single message to the port's receiver.""" ... + + @property + def port_id(self) -> PortId: ... def __repr__(self) -> str: ... @final @@ -119,6 +122,9 @@ class OncePortRef: def send(self, mailbox: Mailbox, message: PythonMessage) -> None: """Send a single message to the port's receiver.""" ... + + @property + def port_id(self) -> PortId: ... def __repr__(self) -> str: ... @final diff --git a/python/monarch/_testing.py b/python/monarch/_testing.py index 72c1b3e4e..67bd02f87 100644 --- a/python/monarch/_testing.py +++ b/python/monarch/_testing.py @@ -228,3 +228,4 @@ def exit( class BackendType: PY = "py" RS = "rs" + MESH = "mesh" diff --git a/python/monarch/actor_mesh.py b/python/monarch/actor_mesh.py index 613fda8b0..a7f00cffb 100644 --- a/python/monarch/actor_mesh.py +++ b/python/monarch/actor_mesh.py @@ -10,12 +10,13 @@ import contextvars import functools import inspect - +import io import itertools import logging import random import sys import traceback +from contextlib import contextmanager from dataclasses import dataclass from traceback import extract_tb, StackSummary @@ -31,6 +32,7 @@ Iterable, List, Literal, + NamedTuple, Optional, ParamSpec, Tuple, @@ -40,6 +42,8 @@ ) import monarch + +import torch from monarch import ActorFuture as Future from monarch._rust_bindings.hyperactor_extension.telemetry import enter_span, exit_span @@ -410,19 +414,44 @@ def send(self, method: str, obj: R) -> None: ) +R = TypeVar("R") + +T = TypeVar("T") + +if TYPE_CHECKING: + # Python <= 3.10 cannot inherit from Generic[R] and NamedTuple at the same time. + # we only need it for type checking though, so copypasta it until 3.11. + class PortTuple(NamedTuple, Generic[R]): + sender: "Port[R]" + receiver: "PortReceiver[R]" + + @staticmethod + def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]": + handle, receiver = mailbox.open_once_port() if once else mailbox.open_port() + port_ref = handle.bind() + return PortTuple( + Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver) + ) +else: + + class PortTuple(NamedTuple): + sender: "Port[Any]" + receiver: "PortReceiver[Any]" + + @staticmethod + def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]": + handle, receiver = mailbox.open_once_port() if once else mailbox.open_port() + port_ref = handle.bind() + return PortTuple( + Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver) + ) + + # advance lower-level API for sending messages. This is intentially # not part of the Endpoint API because they way it accepts arguments # and handles concerns is different. -def port( - endpoint: Endpoint[P, R], once: bool = False -) -> Tuple["Port[R]", "PortReceiver[R]"]: - handle, receiver = ( - endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port() - ) - port_ref: PortRef | OncePortRef = handle.bind() - return Port(port_ref, endpoint._mailbox, rank=None), PortReceiver( - endpoint._mailbox, receiver - ) +def port(endpoint: Endpoint[P, R], once: bool = False) -> "PortTuple[R]": + return PortTuple.create(endpoint._mailbox, once) def ranked_port( @@ -599,10 +628,25 @@ def _pickle(obj: object) -> bytes: return msg +@contextmanager +def _load_tensors_on_cpu(): + # Ensure that any tensors load from CPU via monkeypatching how Storages are + # loaded. + old = torch.storage._load_from_bytes + try: + torch.storage._load_from_bytes = lambda b: torch.load( + io.BytesIO(b), map_location="cpu", weights_only=False + ) + yield + finally: + torch.storage._load_from_bytes = old + + def _unpickle(data: bytes, mailbox: Mailbox) -> Any: - # regardless of the mailboxes of the remote objects - # they all become the local mailbox. - return unflatten(data, itertools.repeat(mailbox)) + with _load_tensors_on_cpu(): + # regardless of the mailboxes of the remote objects + # they all become the local mailbox. + return unflatten(data, itertools.repeat(mailbox)) class Actor(MeshTrait): diff --git a/python/monarch/common/pickle_flatten.py b/python/monarch/common/pickle_flatten.py index 8033387d1..557b5399d 100644 --- a/python/monarch/common/pickle_flatten.py +++ b/python/monarch/common/pickle_flatten.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Iterable, List, Tuple import cloudpickle +import torch class _Pickler(cloudpickle.Pickler): @@ -44,5 +45,6 @@ def flatten(obj: Any, filter: Callable[[Any], bool]) -> Tuple[List[Any], bytes]: def unflatten(data: bytes, values: Iterable[Any]) -> Any: - up = _Unpickler(data, values) - return up.load() + with torch.utils._python_dispatch._disable_current_modes(): + up = _Unpickler(data, values) + return up.load() diff --git a/python/monarch/common/remote.py b/python/monarch/common/remote.py index bdd0c1d23..01f55e804 100644 --- a/python/monarch/common/remote.py +++ b/python/monarch/common/remote.py @@ -8,7 +8,6 @@ import functools import logging -import warnings from logging import Logger from typing import ( @@ -29,7 +28,7 @@ import torch -from monarch.common import _coalescing, device_mesh, messages, stream +from monarch.common import _coalescing, device_mesh, stream if TYPE_CHECKING: from monarch.common.client import Client @@ -62,8 +61,6 @@ R = TypeVar("R") T = TypeVar("T") -Propagator = Callable | Literal["mocked", "cached", "inspect"] | None - class Remote(Generic[P, R]): def __init__(self, impl: Any, propagator_arg: Propagator): @@ -180,15 +177,9 @@ def _call_on_shard_and_fetch( client: "Client" = mesh.client if _coalescing.is_active(client): raise NotImplementedError("NYI: fetching results during a coalescing block") + stream_ref = stream._active._to_ref(client) return client.fetch( - mesh, - stream._active._to_ref(client), - shard, - preprocess_message, - args, - kwargs, - mutates, - dtensors, + mesh, stream_ref, shard, preprocess_message, args, kwargs, mutates, dtensors ) diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index de485bd94..622ddd762 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -7,23 +7,39 @@ import atexit import logging import os -import time import traceback from collections import deque from logging import Logger -from typing import List, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import ( + Any, + cast, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) import torch.utils._python_dispatch -from monarch import NDSlice -from monarch._rust_bindings.monarch_extension import client, debugger +from monarch._rust_bindings.monarch_extension import client from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension WorldState, ) from monarch._rust_bindings.monarch_extension.mesh_controller import _Controller +from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) +from monarch.actor_mesh import Port, PortTuple +from monarch.common import messages +from monarch.common.controller_api import TController +from monarch.common.invocation import Seq +from monarch.common.shape import NDSlice +from monarch.common.stream import StreamRef +from monarch.common.tensor import Tensor if TYPE_CHECKING: from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ( @@ -33,14 +49,12 @@ from monarch._rust_bindings.monarch_hyperactor.shape import Point -from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction from monarch.common.client import Client from monarch.common.controller_api import LogMessage, MessageResult -from monarch.common.device_mesh import DeviceMesh, no_mesh +from monarch.common.device_mesh import DeviceMesh +from monarch.common.future import Future as OldFuture from monarch.common.invocation import DeviceException, RemoteException -from monarch.controller.debugger import read as debugger_read, write as debugger_write from monarch.rust_local_mesh import _get_worker_exec_info -from pyre_extensions import none_throws logger: Logger = logging.getLogger(__name__) @@ -48,6 +62,7 @@ class Controller(_Controller): def __init__(self, workers: "HyProcMesh") -> None: super().__init__() + self._mailbox: Mailbox = workers.client # Buffer for messages unrelated to debugging that are received while a # debugger session is active. self._non_debugger_pending_messages: deque[ @@ -58,19 +73,9 @@ def __init__(self, workers: "HyProcMesh") -> None: def next_message( self, timeout: Optional[float] ) -> Optional[LogMessage | MessageResult]: - if self._non_debugger_pending_messages: - msg = self._non_debugger_pending_messages.popleft() - else: - msg = self._get_next_message(timeout_msec=int((timeout or 0.0) * 1000.0)) - if msg is None: - return None - - if isinstance(msg, client.WorkerResponse): - return _worker_response_to_result(msg) - elif isinstance(msg, client.LogMessage): - return LogMessage(msg.level, msg.message) - elif isinstance(msg, client.DebuggerMessage): - self._run_debugger_loop(msg) + raise RuntimeError( + "internal error: tensor engine does not produce futures that call next_message" + ) def send( self, @@ -86,56 +91,6 @@ def drain_and_stop( self._drain_and_stop() return [] - def _run_debugger_loop(self, message: client.DebuggerMessage) -> None: - if not isinstance(message.action, DebuggerAction.Paused): - raise RuntimeError( - f"Unexpected debugger message {message} when no debugger session is running" - ) - - self._pending_debugger_sessions.append(message.debugger_actor_id) - while self._pending_debugger_sessions: - debugger_actor_id = self._pending_debugger_sessions.popleft() - rank = debugger_actor_id.rank - proc_id = debugger_actor_id.proc_id - debugger_write( - f"pdb attached to proc {proc_id} with rank {rank}, debugger actor {debugger_actor_id} \n" - ) - - self._debugger_attach(debugger_actor_id) - while True: - # TODO: Add appropriate timeout. - msg = self._get_next_message(timeout_msec=None) - - if not isinstance(msg, client.DebuggerMessage): - self._non_debugger_pending_messages.append(msg) - continue - - if msg.debugger_actor_id != debugger_actor_id: - if isinstance(msg.action, DebuggerAction.Paused): - self._pending_debugger_sessions.append(msg.debugger_actor_id) - continue - else: - raise RuntimeError( - f"unexpected debugger message {msg} from rank {msg.debugger_actor_id.rank} " - f"when debugging rank {debugger_actor_id.rank}" - ) - - action = msg.action - if isinstance(action, DebuggerAction.Detach): - break - elif isinstance(action, DebuggerAction.Read): - self._debugger_write( - debugger_actor_id, debugger_read(action.requested_size) - ) - elif isinstance(action, DebuggerAction.Write): - debugger_write( - debugger.get_bytes_from_write_action(action).decode() - ) - else: - raise RuntimeError( - f"unexpected debugger message {msg} when debugging rank {debugger_actor_id.rank}" - ) - def worker_world_state(self) -> WorldState: raise NotImplementedError("worker world state") @@ -145,54 +100,6 @@ def stop_mesh(self): pass -# TODO: Handling conversion of the response can move to a separate module over time -# especially as we have structured error messages. -def _worker_response_to_result(result: client.WorkerResponse) -> MessageResult: - if not result.is_exception(): - # The result of the message needs to be unwrapped on a real device. - # Staying as a fake tensor will fail the tensor deserialization. - with no_mesh.activate(): - return MessageResult(result.seq, result.result(), None) - exc = none_throws(result.exception()) - if isinstance(exc, client.Error): - worker_frames = [ - traceback.FrameSummary("", None, frame) - for frame in exc.backtrace.split("\\n") - ] - return MessageResult( - seq=result.seq, - result=None, - error=RemoteException( - seq=exc.caused_by_seq, - exception=RuntimeError(exc.backtrace), - controller_frame_index=0, # TODO: T225205291 fix this once we have recording support in rust - controller_frames=None, - worker_frames=worker_frames, - source_actor_id=exc.actor_id, - message=f"Remote function in {exc.actor_id} errored.", - ), - ) - elif isinstance(exc, client.Failure): - frames = [ - traceback.FrameSummary("", None, frame) - for frame in exc.backtrace.split("\n") - ] - reason = f"Actor {exc.actor_id} crashed on {exc.address}, check the host log for details" - logger.error(reason) - return MessageResult( - seq=0, # seq is not consumed for DeviceException; it will be directly thrown by the client - result=None, - error=DeviceException( - exception=RuntimeError(reason), - frames=frames, - source_actor_id=exc.actor_id, - message=reason, - ), - ) - else: - raise RuntimeError(f"Unknown exception type: {type(exc)}") - - def _initialize_env(worker_point: Point, proc_id: str) -> None: worker_rank = worker_point.rank try: @@ -219,6 +126,40 @@ def _initialize_env(worker_point: Point, proc_id: str) -> None: class MeshClient(Client): + def fetch( + self, + mesh: "DeviceMesh", + stream: "StreamRef", + shard, + preprocess_message, + args, + kwargs, + defs: Tuple["Tensor", ...], + uses: Tuple["Tensor", ...], + ) -> "OldFuture": # the OldFuture is a lie + sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True) + + ident = self.new_node(defs, uses, cast("OldFuture", sender)) + process = mesh._process(shard) + self.send( + process, + messages.SendValue( + ident, + None, + defs, + preprocess_message, + args, + kwargs, + stream, + ), + ) + # we have to ask for status updates + # from workers to be sure they have finished + # enough work to count this future as finished, + # and all potential errors have been reported + self._request_status() + return cast("OldFuture", receiver.recv()) + def shutdown( self, destroy_pg: bool = True, @@ -232,27 +173,42 @@ def shutdown( atexit.unregister(self._atexit) self._shutdown = True - # ensure all pending work is finished. - # all errors must be messaged back at this point - self.new_node_nocoalesce([], [], None, []) - self._request_status() - - ttl = 60 - start_time = time.time() - end_time = start_time + ttl - while ttl > 0 and self.last_assigned_seq > self.last_processed_seq: - ttl = end_time - time.time() - self.handle_next_message(ttl) - if self._pending_shutdown_error: - raise self._pending_shutdown_error - - if ttl <= 0: - raise RuntimeError("shutdown timed out") - + sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True) + self._mesh_controller.sync_at_exit(sender._port_ref.port_id) + receiver.recv().get(timeout=60) # we are not expecting anything more now, because we already # waited for the responses self.inner.drain_and_stop() + @property + def _mesh_controller(self) -> Controller: + return cast(Controller, self.inner) + + def new_node_nocoalesce( + self, + defs: Sequence["Tensor"], + uses: Sequence["Tensor"], + future: Optional["OldFuture"], + tracebacks: List[List[traceback.FrameSummary]], + ) -> Seq: + seq = self._next_seq() + for d in defs: + d._seq = seq + response_port = None + if future is not None: + # method annotation is a lie to make Client happy + port = cast("Port[Any]", future) + slice = NDSlice.new_row_major([]) + response_port = (port._port_ref.port_id, slice) + self._mesh_controller.node(seq, defs, uses, response_port, tracebacks) + return seq + + def handle_next_message(self, timeout: Optional[float]) -> bool: + """ + Mesh controller message loop is handled by the tokio event loop. + """ + return False + def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh: # This argument to Controller @@ -260,7 +216,7 @@ def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh: # report the proc ID instead of the rank it currently does. gpus = proc_mesh.sizes.get("gpus", 1) backend_ctrl = Controller(proc_mesh._proc_mesh) - client = MeshClient(backend_ctrl, proc_mesh.size(), gpus) + client = MeshClient(cast("TController", backend_ctrl), proc_mesh.size(), gpus) dm = DeviceMesh( client, NDSlice.new_row_major(list(proc_mesh.sizes.values())), @@ -268,3 +224,28 @@ def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh: ) dm.exit = lambda: client.shutdown() return dm + + +class RemoteException(Exception): + def __init__( + self, + worker_error_string: str, # this should really be an exception + stacktrace but + # worker code needs major refactor to make this possible + controller_frames: List[traceback.FrameSummary], + rank: int, + ): + self.worker_error_string = worker_error_string + self.controller_frames = controller_frames + self.rank = rank + + def __str__(self): + try: + controller_tb = "".join(traceback.format_list(self.controller_frames)) + return ( + f"A remote function has failed asynchronously on rank {self.rank}.\n" + f"Traceback of where the remote function was issued on controller (most recent call last):\n{controller_tb}" + f"Error as reported from worker!!!!!!!:\n{self.worker_error_string}" + ) + except Exception: + traceback.print_exc() + return "" diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index ff86453a7..b01248227 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -158,7 +158,8 @@ async def test_proc_mesh_rdma(): x = await client_gpu.get_buffer.call_one() buffer_gpu = x.view(torch.float32).view(10, 10) assert torch.sum(buffer_gpu) == 0 - assert buffer_gpu.device.type == "cuda" + # copying a tensor across hosts moves it to CPU + assert buffer_gpu.device.type == "cpu" # Modify server state again await server.update.call_one() diff --git a/python/tests/test_remote_functions.py b/python/tests/test_remote_functions.py index 058b8dfab..7b1b33983 100644 --- a/python/tests/test_remote_functions.py +++ b/python/tests/test_remote_functions.py @@ -25,9 +25,10 @@ Pipe, remote, remote_generator, - RemoteException, + RemoteException as OldRemoteException, Stream, ) + from monarch._testing import BackendType, TestingContext from monarch.builtins.log import log_remote from monarch.builtins.random import set_manual_seed_remote @@ -35,6 +36,7 @@ from monarch.common import remote as remote_module from monarch.common.device_mesh import DeviceMesh from monarch.common.remote import Remote +from monarch.mesh_controller import RemoteException as NewRemoteException from monarch.opaque_module import OpaqueModule from monarch.opaque_object import opaque_method, OpaqueObject @@ -57,6 +59,8 @@ from monarch_supervisor.logging import fix_exception_lines from torch.distributed import ReduceOp +RemoteException = (NewRemoteException, OldRemoteException) + def custom_excepthook(exc_type, exc_value, exc_traceback): tb_lines = fix_exception_lines( @@ -326,7 +330,7 @@ def test_eager_remote_function_failed(self, backend_type): _ = fetch_shard(a).result(timeout=40) def test_set_device_inside_udf_fails_with_explanation(self, backend_type): - if backend_type == BackendType.PY: + if backend_type != BackendType.RS: pytest.skip("Python support not planned for this test") with self.local_device_mesh(2, 2, backend_type): t = set_device_udf(2) diff --git a/python/tests/test_tensor_engine.py b/python/tests/test_tensor_engine.py index 36d69cdf2..6098c32c7 100644 --- a/python/tests/test_tensor_engine.py +++ b/python/tests/test_tensor_engine.py @@ -7,6 +7,7 @@ import monarch import pytest import torch +from monarch import remote from monarch.mesh_controller import spawn_tensor_engine from monarch.proc_mesh import proc_mesh @@ -32,6 +33,14 @@ def test_tensor_engine() -> None: assert torch.allclose(torch.zeros(3, 4), r) assert torch.allclose(torch.zeros(3, 4), f) + @remote(propagate=lambda x: x) + def nope(x): + raise ValueError("nope") + + with pytest.raises(monarch.mesh_controller.RemoteException): + with dm.activate(): + monarch.inspect(nope(torch.zeros(3, 4))) + dm.exit()