diff --git a/monarch_hyperactor/src/supervision.rs b/monarch_hyperactor/src/supervision.rs index 5703ea79d..fb82e67be 100644 --- a/monarch_hyperactor/src/supervision.rs +++ b/monarch_hyperactor/src/supervision.rs @@ -10,17 +10,52 @@ use hyperactor::Bind; use hyperactor::Named; use hyperactor::Unbind; use hyperactor::supervision::ActorSupervisionEvent; -use pyo3::create_exception; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use serde::Deserialize; use serde::Serialize; -create_exception!( - monarch._rust_bindings.monarch_hyperactor.supervision, - SupervisionError, - PyRuntimeError -); +#[pyclass( + name = "SupervisionError", + module = "monarch._rust_bindings.monarch_hyperactor.supervision", + extends = PyRuntimeError +)] +#[derive(Clone, Debug)] +pub struct SupervisionError { + #[pyo3(set)] + pub endpoint: Option, + pub message: String, +} + +#[pymethods] +impl SupervisionError { + #[new] + #[pyo3(signature = (message, endpoint=None))] + fn new(message: String, endpoint: Option) -> Self { + SupervisionError { endpoint, message } + } + + #[staticmethod] + pub fn new_err(message: String) -> PyErr { + PyRuntimeError::new_err(message) + } + + fn __str__(&self) -> String { + if let Some(ep) = &self.endpoint { + format!("Endpoint call {} failed, {}", ep, self.message) + } else { + self.message.clone() + } + } + + fn __repr__(&self) -> String { + if let Some(ep) = &self.endpoint { + format!("SupervisionError(endpoint='{}', '{}')", ep, self.message) + } else { + format!("SupervisionError('{}')", self.message) + } + } +} #[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Bind, Unbind)] pub struct SupervisionFailureMessage { diff --git a/monarch_hyperactor/src/v1/actor_mesh.rs b/monarch_hyperactor/src/v1/actor_mesh.rs index 01fa7484b..366c5f670 100644 --- a/monarch_hyperactor/src/v1/actor_mesh.rs +++ b/monarch_hyperactor/src/v1/actor_mesh.rs @@ -711,12 +711,12 @@ impl ActorMeshProtocol for PythonActorMeshImpl { .unwrap_or_else(|e| e.into_inner()) { Unhealthy::StreamClosed => { - return Err(SupervisionError::new_err( + return Err(PyErr::new::( "actor mesh is stopped due to proc mesh shutdown".to_string(), )); } Unhealthy::Crashed(event) => { - return Err(SupervisionError::new_err(format!( + return Err(PyErr::new::(format!( "Actor {} is unhealthy with reason: {}", event.actor_id, event.actor_status ))); @@ -730,7 +730,7 @@ impl ActorMeshProtocol for PythonActorMeshImpl { .get(&rank) .map(|entry| entry.value().clone()) }) { - return Err(SupervisionError::new_err(format!( + return Err(PyErr::new::(format!( "Actor {} is unhealthy with reason: {}", event.actor_id, event.actor_status ))); diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/supervision.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/supervision.pyi index bedf161b6..7813342d5 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/supervision.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/supervision.pyi @@ -6,7 +6,7 @@ # pyre-unsafe -from typing import final +from typing import final, Optional @final class SupervisionError(RuntimeError): @@ -14,7 +14,7 @@ class SupervisionError(RuntimeError): Custom exception for supervision-related errors in monarch_hyperactor. """ - ... + endpoint: str | None # Settable attribute # TODO: Make this an exception subclass @final diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 1d7221f4c..60b0b4ffa 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -72,6 +72,7 @@ Region, Shape, ) +from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError from monarch._rust_bindings.monarch_hyperactor.v1.logging import log_endpoint_exception from monarch._rust_bindings.monarch_hyperactor.value_mesh import ( ValueMesh as HyValueMesh, @@ -488,6 +489,7 @@ class ActorEndpoint(Endpoint[P, R]): def __init__( self, actor_mesh: "ActorMeshProtocol", + mesh_name: str, shape: Shape, proc_mesh: "Optional[ProcMesh]", name: MethodSpecifier, @@ -497,6 +499,7 @@ def __init__( ) -> None: super().__init__(propagator) self._actor_mesh = actor_mesh + self._mesh_name = mesh_name self._name = name self._shape = shape self._proc_mesh = proc_mesh @@ -541,13 +544,25 @@ def _send( shape = self._shape return Extent(shape.labels, shape.ndslice.sizes) + def _full_name(self) -> str: + method_name = "unknown" + match self._name: + case MethodSpecifier.Init(): + method_name = "__init__" + case MethodSpecifier.ReturnsResponse(name=method_name): + pass + case MethodSpecifier.ExplicitPort(name=method_name): + pass + return f"{self._mesh_name}.{method_name}()" + def _port(self, once: bool = False) -> "Tuple[Port[R], PortReceiver[R]]": p, r = super()._port(once=once) instance = context().actor_instance._as_rust() monitor: Optional[Shared[Exception]] = self._actor_mesh.supervision_event( instance ) - r._set_monitor(monitor) + + r._attach_supervision(monitor, self._full_name()) return (p, r) def _rref(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> R: @@ -882,19 +897,33 @@ def __init__( mailbox: Mailbox, receiver: "PortReceiverBase", monitor: "Optional[Shared[Exception]]" = None, + endpoint: Optional[str] = None, ) -> None: self._mailbox: Mailbox = mailbox self._monitor = monitor self._receiver = receiver + self._endpoint = endpoint + + def _tag_supervision_error(self, error: Exception) -> None: + """Tag supervision error with endpoint name if available.""" + if self._endpoint is not None and isinstance(error, SupervisionError): + error.endpoint = self._endpoint async def _recv(self) -> R: awaitable = self._receiver.recv_task() if self._monitor is None: result = await awaitable else: - # type: ignore - result, i = await PythonTask.select_one([self._monitor.task(), awaitable]) + try: + result, i = await PythonTask.select_one( + # type: ignore + [self._monitor.task(), awaitable] + ) + except Exception as e: + self._tag_supervision_error(e) + raise e if i == 0: + self._tag_supervision_error(result) raise result return self._process(result) @@ -905,7 +934,9 @@ def _process(self, msg: PythonMessage) -> R: case PythonMessageKind.Result(): return payload case PythonMessageKind.Exception(): - raise cast(Exception, payload) + e = cast(Exception, payload) + self._tag_supervision_error(e) + raise e case _: raise ValueError(f"Unexpected message kind: {msg.kind}") @@ -913,10 +944,25 @@ def recv(self) -> "Future[R]": return Future(coro=self._recv()) def ranked(self) -> "RankedPortReceiver[R]": - return RankedPortReceiver[R](self._mailbox, self._receiver, self._monitor) + return RankedPortReceiver[R]( + self._mailbox, self._receiver, self._monitor, self._endpoint + ) + + def _attach_supervision( + self, monitor: "Optional[Shared[Exception]]", endpoint: str + ) -> None: + """ + Attach supervision monitoring to this port receiver. + + Enables the receiver to detect and report errors on any supervision events. - def _set_monitor(self, monitor: "Optional[Shared[Exception]]") -> None: + Args: + monitor: Shared exception monitor that signals supervision errors + from the actor mesh. None if supervision is not enabled. + endpoint: Full endpoint name + """ self._monitor = monitor + self._endpoint = endpoint class RankedPortReceiver(PortReceiver[Tuple[int, R]]): @@ -973,8 +1019,13 @@ async def handle( ) -> None: method_name = None MESSAGES_HANDLED.add(1) + + # Initialize method_name before try block so it's always defined + method_name = "unknown" + # response_port can be None. If so, then sending to port will drop the response, # and raise any exceptions to the caller. + try: _set_context(ctx) @@ -984,6 +1035,7 @@ async def handle( match method: case MethodSpecifier.Init(): + method_name = "__init__" ins = ctx.actor_instance ( Class, @@ -1000,7 +1052,7 @@ async def handle( self._maybe_exit_debugger() except Exception as e: self._saved_error = ActorError( - e, f"Remote actor {Class}.__init__ call failed." + e, f"Actor call {ins.name}.{method_name} failed." ) raise response_port.send(None) @@ -1024,7 +1076,7 @@ async def handle( # message delivery mechanism, or the framework accidentally # mixed the usage of cast and direct send. - error_message = f"Actor object is missing when executing method {method_name} on actor {ctx.actor_instance.actor_id}." + error_message = f'Actor object is missing when executing method "{method_name}" on actor {ctx.actor_instance.actor_id}.' if self._saved_error is not None: error_message += ( f" This is likely due to an earlier error: {self._saved_error}" @@ -1064,14 +1116,24 @@ async def handle( except Exception as e: log_endpoint_exception(e, method_name) self._post_mortem_debug(e.__traceback__) - response_port.exception(ActorError(e)) + response_port.exception( + ActorError( + e, + f"Actor call {ctx.actor_instance.name}.{method_name} failed.", + ) + ) except BaseException as e: self._post_mortem_debug(e.__traceback__) # A BaseException can be thrown in the case of a Rust panic. # In this case, we need a way to signal the panic to the Rust side. # See [Panics in async endpoints] try: - panic_flag.signal_panic(e) + panic_flag.signal_panic( + ActorError( + e, + f"Actor call {ctx.actor_instance.name}.{method_name} failed with BaseException.", + ) + ) except Exception: # The channel might be closed if the Rust side has already detected the error pass @@ -1245,11 +1307,15 @@ class ActorMesh(MeshTrait, Generic[T]): def __init__( self, Class: Type[T], + name: str, inner: "ActorMeshProtocol", shape: Shape, proc_mesh: "Optional[ProcMesh]", ) -> None: + # Class name of the actor. self.__name__: str = Class.__name__ + # The name user gives when spawning the mesh + self._mesh_name = name self._class: Type[T] = Class self._inner: "ActorMeshProtocol" = inner self._shape = shape @@ -1318,6 +1384,7 @@ def _endpoint( ) -> Any: return ActorEndpoint( self._inner, + self._mesh_name, self._shape, self._proc_mesh, name, @@ -1340,7 +1407,7 @@ def _create( *args: Any, **kwargs: Any, ) -> "ActorMesh[T]": - mesh = cls(Class, actor_mesh, shape, proc_mesh) + mesh = cls(Class, name, actor_mesh, shape, proc_mesh) # We don't start the supervision polling loop until the first call to # supervision_event, which needs an Instance. Initialize here so events @@ -1383,12 +1450,18 @@ def from_actor_id( Class: Type[T], actor_id: ActorId, ) -> "ActorMesh[T]": - return cls(Class, _SingletonActorAdapator(actor_id), singleton_shape, None) + return cls(Class, "", _SingletonActorAdapator(actor_id), singleton_shape, None) def __reduce_ex__( self, protocol: Any ) -> "Tuple[Type[ActorMesh[T]], Tuple[Any, ...]]": - return ActorMesh, (self._class, self._inner, self._shape, self._proc_mesh) + return ActorMesh, ( + self._class, + self._mesh_name, + self._inner, + self._shape, + self._proc_mesh, + ) @property def _ndslice(self) -> NDSlice: @@ -1400,7 +1473,7 @@ def _labels(self) -> Iterable[str]: def _new_with_shape(self, shape: Shape) -> "ActorMesh[T]": sliced = self._inner.new_with_region(shape.region) - return ActorMesh(self._class, sliced, shape, self._proc_mesh) + return ActorMesh(self._class, self._mesh_name, sliced, shape, self._proc_mesh) def __repr__(self) -> str: return f"ActorMesh(class={self._class}, shape={self._shape}), inner={type(self._inner)})" @@ -1423,7 +1496,7 @@ class ActorError(Exception): def __init__( self, - exception: Exception, + exception: BaseException, message: str = "A remote actor call has failed.", ) -> None: self.exception = exception diff --git a/python/tests/test_actor_error.py b/python/tests/test_actor_error.py index a1c9bdfa5..53d105123 100644 --- a/python/tests/test_actor_error.py +++ b/python/tests/test_actor_error.py @@ -725,18 +725,26 @@ async def test_supervision_with_sending_error() -> None: # The host mesh agent sends or the proc mesh agent sends might break. # Either case is an error that tells us that the send failed. - error_msg = ( - ".*Actor .* (is unhealthy with reason|exited because of the following reason)|" + error_msg_regx = ( + "Actor .* (is unhealthy with reason|exited because of the following reason)|" "actor mesh is stopped due to proc mesh shutdown" ) # send a large payload to trigger send timeout error + error_msg = ( + r"Endpoint call healthy\.check_with_payload\(\) failed, " + error_msg_regx + ) with pytest.raises(SupervisionError, match=error_msg): await actor_mesh.check_with_payload.call(payload="a" * 55000000) # new call should fail with check of health state of actor mesh + error_msg = r"Endpoint call healthy\.check\(\) failed, " + error_msg_regx with pytest.raises(SupervisionError, match=error_msg): await actor_mesh.check.call() + + error_msg = ( + r"Endpoint call healthy\.check_with_payload\(\) failed, " + error_msg_regx + ) with pytest.raises(SupervisionError, match=error_msg): await actor_mesh.check_with_payload.call(payload="a")