Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions monarch_hyperactor/src/supervision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,52 @@ use hyperactor::Bind;
use hyperactor::Named;
use hyperactor::Unbind;
use hyperactor::supervision::ActorSupervisionEvent;
use pyo3::create_exception;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use serde::Deserialize;
use serde::Serialize;

create_exception!(
monarch._rust_bindings.monarch_hyperactor.supervision,
SupervisionError,
PyRuntimeError
);
#[pyclass(
name = "SupervisionError",
module = "monarch._rust_bindings.monarch_hyperactor.supervision",
extends = PyRuntimeError
)]
#[derive(Clone, Debug)]
pub struct SupervisionError {
#[pyo3(set)]
pub endpoint: Option<String>,
pub message: String,
}

#[pymethods]
impl SupervisionError {
#[new]
#[pyo3(signature = (message, endpoint=None))]
fn new(message: String, endpoint: Option<String>) -> Self {
SupervisionError { endpoint, message }
}

#[staticmethod]
pub fn new_err(message: String) -> PyErr {
PyRuntimeError::new_err(message)
}

fn __str__(&self) -> String {
if let Some(ep) = &self.endpoint {
format!("Endpoint call {} failed, {}", ep, self.message)
} else {
self.message.clone()
}
}

fn __repr__(&self) -> String {
if let Some(ep) = &self.endpoint {
format!("SupervisionError(endpoint='{}', '{}')", ep, self.message)
} else {
format!("SupervisionError('{}')", self.message)
}
}
}

#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Bind, Unbind)]
pub struct SupervisionFailureMessage {
Expand Down
6 changes: 3 additions & 3 deletions monarch_hyperactor/src/v1/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,12 @@ impl ActorMeshProtocol for PythonActorMeshImpl {
.unwrap_or_else(|e| e.into_inner())
{
Unhealthy::StreamClosed => {
return Err(SupervisionError::new_err(
return Err(PyErr::new::<SupervisionError, _>(
"actor mesh is stopped due to proc mesh shutdown".to_string(),
));
}
Unhealthy::Crashed(event) => {
return Err(SupervisionError::new_err(format!(
return Err(PyErr::new::<SupervisionError, _>(format!(
"Actor {} is unhealthy with reason: {}",
event.actor_id, event.actor_status
)));
Expand All @@ -730,7 +730,7 @@ impl ActorMeshProtocol for PythonActorMeshImpl {
.get(&rank)
.map(|entry| entry.value().clone())
}) {
return Err(SupervisionError::new_err(format!(
return Err(PyErr::new::<SupervisionError, _>(format!(
"Actor {} is unhealthy with reason: {}",
event.actor_id, event.actor_status
)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

# pyre-unsafe

from typing import final
from typing import final, Optional

@final
class SupervisionError(RuntimeError):
"""
Custom exception for supervision-related errors in monarch_hyperactor.
"""

...
endpoint: str | None # Settable attribute

# TODO: Make this an exception subclass
@final
Expand Down
103 changes: 88 additions & 15 deletions python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
Region,
Shape,
)
from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError
from monarch._rust_bindings.monarch_hyperactor.v1.logging import log_endpoint_exception
from monarch._rust_bindings.monarch_hyperactor.value_mesh import (
ValueMesh as HyValueMesh,
Expand Down Expand Up @@ -488,6 +489,7 @@ class ActorEndpoint(Endpoint[P, R]):
def __init__(
self,
actor_mesh: "ActorMeshProtocol",
mesh_name: str,
shape: Shape,
proc_mesh: "Optional[ProcMesh]",
name: MethodSpecifier,
Expand All @@ -497,6 +499,7 @@ def __init__(
) -> None:
super().__init__(propagator)
self._actor_mesh = actor_mesh
self._mesh_name = mesh_name
self._name = name
self._shape = shape
self._proc_mesh = proc_mesh
Expand Down Expand Up @@ -541,13 +544,25 @@ def _send(
shape = self._shape
return Extent(shape.labels, shape.ndslice.sizes)

def _full_name(self) -> str:
method_name = "unknown"
match self._name:
case MethodSpecifier.Init():
method_name = "__init__"
case MethodSpecifier.ReturnsResponse(name=method_name):
pass
case MethodSpecifier.ExplicitPort(name=method_name):
pass
return f"{self._mesh_name}.{method_name}()"

def _port(self, once: bool = False) -> "Tuple[Port[R], PortReceiver[R]]":
p, r = super()._port(once=once)
instance = context().actor_instance._as_rust()
monitor: Optional[Shared[Exception]] = self._actor_mesh.supervision_event(
instance
)
r._set_monitor(monitor)

r._attach_supervision(monitor, self._full_name())
return (p, r)

def _rref(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> R:
Expand Down Expand Up @@ -882,19 +897,33 @@ def __init__(
mailbox: Mailbox,
receiver: "PortReceiverBase",
monitor: "Optional[Shared[Exception]]" = None,
endpoint: Optional[str] = None,
) -> None:
self._mailbox: Mailbox = mailbox
self._monitor = monitor
self._receiver = receiver
self._endpoint = endpoint

def _tag_supervision_error(self, error: Exception) -> None:
"""Tag supervision error with endpoint name if available."""
if self._endpoint is not None and isinstance(error, SupervisionError):
error.endpoint = self._endpoint

async def _recv(self) -> R:
awaitable = self._receiver.recv_task()
if self._monitor is None:
result = await awaitable
else:
# type: ignore
result, i = await PythonTask.select_one([self._monitor.task(), awaitable])
try:
result, i = await PythonTask.select_one(
# type: ignore
[self._monitor.task(), awaitable]
)
except Exception as e:
self._tag_supervision_error(e)
raise e
if i == 0:
self._tag_supervision_error(result)
raise result
return self._process(result)

Expand All @@ -905,18 +934,35 @@ def _process(self, msg: PythonMessage) -> R:
case PythonMessageKind.Result():
return payload
case PythonMessageKind.Exception():
raise cast(Exception, payload)
e = cast(Exception, payload)
self._tag_supervision_error(e)
raise e
case _:
raise ValueError(f"Unexpected message kind: {msg.kind}")

def recv(self) -> "Future[R]":
return Future(coro=self._recv())

def ranked(self) -> "RankedPortReceiver[R]":
return RankedPortReceiver[R](self._mailbox, self._receiver, self._monitor)
return RankedPortReceiver[R](
self._mailbox, self._receiver, self._monitor, self._endpoint
)

def _attach_supervision(
self, monitor: "Optional[Shared[Exception]]", endpoint: str
) -> None:
"""
Attach supervision monitoring to this port receiver.

Enables the receiver to detect and report errors on any supervision events.

def _set_monitor(self, monitor: "Optional[Shared[Exception]]") -> None:
Args:
monitor: Shared exception monitor that signals supervision errors
from the actor mesh. None if supervision is not enabled.
endpoint: Full endpoint name
"""
self._monitor = monitor
self._endpoint = endpoint


class RankedPortReceiver(PortReceiver[Tuple[int, R]]):
Expand Down Expand Up @@ -973,8 +1019,13 @@ async def handle(
) -> None:
method_name = None
MESSAGES_HANDLED.add(1)

# Initialize method_name before try block so it's always defined
method_name = "unknown"

# response_port can be None. If so, then sending to port will drop the response,
# and raise any exceptions to the caller.

try:
_set_context(ctx)

Expand All @@ -984,6 +1035,7 @@ async def handle(

match method:
case MethodSpecifier.Init():
method_name = "__init__"
ins = ctx.actor_instance
(
Class,
Expand All @@ -1000,7 +1052,7 @@ async def handle(
self._maybe_exit_debugger()
except Exception as e:
self._saved_error = ActorError(
e, f"Remote actor {Class}.__init__ call failed."
e, f"Actor call {ins.name}.{method_name} failed."
)
raise
response_port.send(None)
Expand All @@ -1024,7 +1076,7 @@ async def handle(
# message delivery mechanism, or the framework accidentally
# mixed the usage of cast and direct send.

error_message = f"Actor object is missing when executing method {method_name} on actor {ctx.actor_instance.actor_id}."
error_message = f'Actor object is missing when executing method "{method_name}" on actor {ctx.actor_instance.actor_id}.'
if self._saved_error is not None:
error_message += (
f" This is likely due to an earlier error: {self._saved_error}"
Expand Down Expand Up @@ -1064,14 +1116,24 @@ async def handle(
except Exception as e:
log_endpoint_exception(e, method_name)
self._post_mortem_debug(e.__traceback__)
response_port.exception(ActorError(e))
response_port.exception(
ActorError(
e,
f"Actor call {ctx.actor_instance.name}.{method_name} failed.",
)
)
except BaseException as e:
self._post_mortem_debug(e.__traceback__)
# A BaseException can be thrown in the case of a Rust panic.
# In this case, we need a way to signal the panic to the Rust side.
# See [Panics in async endpoints]
try:
panic_flag.signal_panic(e)
panic_flag.signal_panic(
ActorError(
e,
f"Actor call {ctx.actor_instance.name}.{method_name} failed with BaseException.",
)
)
except Exception:
# The channel might be closed if the Rust side has already detected the error
pass
Expand Down Expand Up @@ -1245,11 +1307,15 @@ class ActorMesh(MeshTrait, Generic[T]):
def __init__(
self,
Class: Type[T],
name: str,
inner: "ActorMeshProtocol",
shape: Shape,
proc_mesh: "Optional[ProcMesh]",
) -> None:
# Class name of the actor.
self.__name__: str = Class.__name__
# The name user gives when spawning the mesh
self._mesh_name = name
self._class: Type[T] = Class
self._inner: "ActorMeshProtocol" = inner
self._shape = shape
Expand Down Expand Up @@ -1318,6 +1384,7 @@ def _endpoint(
) -> Any:
return ActorEndpoint(
self._inner,
self._mesh_name,
self._shape,
self._proc_mesh,
name,
Expand All @@ -1340,7 +1407,7 @@ def _create(
*args: Any,
**kwargs: Any,
) -> "ActorMesh[T]":
mesh = cls(Class, actor_mesh, shape, proc_mesh)
mesh = cls(Class, name, actor_mesh, shape, proc_mesh)

# We don't start the supervision polling loop until the first call to
# supervision_event, which needs an Instance. Initialize here so events
Expand Down Expand Up @@ -1383,12 +1450,18 @@ def from_actor_id(
Class: Type[T],
actor_id: ActorId,
) -> "ActorMesh[T]":
return cls(Class, _SingletonActorAdapator(actor_id), singleton_shape, None)
return cls(Class, "", _SingletonActorAdapator(actor_id), singleton_shape, None)

def __reduce_ex__(
self, protocol: Any
) -> "Tuple[Type[ActorMesh[T]], Tuple[Any, ...]]":
return ActorMesh, (self._class, self._inner, self._shape, self._proc_mesh)
return ActorMesh, (
self._class,
self._mesh_name,
self._inner,
self._shape,
self._proc_mesh,
)

@property
def _ndslice(self) -> NDSlice:
Expand All @@ -1400,7 +1473,7 @@ def _labels(self) -> Iterable[str]:

def _new_with_shape(self, shape: Shape) -> "ActorMesh[T]":
sliced = self._inner.new_with_region(shape.region)
return ActorMesh(self._class, sliced, shape, self._proc_mesh)
return ActorMesh(self._class, self._mesh_name, sliced, shape, self._proc_mesh)

def __repr__(self) -> str:
return f"ActorMesh(class={self._class}, shape={self._shape}), inner={type(self._inner)})"
Expand All @@ -1423,7 +1496,7 @@ class ActorError(Exception):

def __init__(
self,
exception: Exception,
exception: BaseException,
message: str = "A remote actor call has failed.",
) -> None:
self.exception = exception
Expand Down
12 changes: 10 additions & 2 deletions python/tests/test_actor_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,18 +725,26 @@ async def test_supervision_with_sending_error() -> None:

# The host mesh agent sends or the proc mesh agent sends might break.
# Either case is an error that tells us that the send failed.
error_msg = (
".*Actor .* (is unhealthy with reason|exited because of the following reason)|"
error_msg_regx = (
"Actor .* (is unhealthy with reason|exited because of the following reason)|"
"actor mesh is stopped due to proc mesh shutdown"
)

# send a large payload to trigger send timeout error
error_msg = (
r"Endpoint call healthy\.check_with_payload\(\) failed, " + error_msg_regx
)
with pytest.raises(SupervisionError, match=error_msg):
await actor_mesh.check_with_payload.call(payload="a" * 55000000)

# new call should fail with check of health state of actor mesh
error_msg = r"Endpoint call healthy\.check\(\) failed, " + error_msg_regx
with pytest.raises(SupervisionError, match=error_msg):
await actor_mesh.check.call()

error_msg = (
r"Endpoint call healthy\.check_with_payload\(\) failed, " + error_msg_regx
)
with pytest.raises(SupervisionError, match=error_msg):
await actor_mesh.check_with_payload.call(payload="a")

Expand Down
Loading