diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index ff7b36fbb..769f50e3e 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -172,7 +172,7 @@ pub struct PythonMessage { method: String, message: ByteBuf, response_port: Option, - rank_in_response: bool, + rank: Option, } impl std::fmt::Debug for PythonMessage { @@ -209,18 +209,18 @@ impl Bind for PythonMessage { #[pymethods] impl PythonMessage { #[new] - #[pyo3(signature = (method, message, response_port = None, rank_in_response = false))] + #[pyo3(signature = (method, message, response_port, rank))] fn new( method: String, message: Vec, response_port: Option, - rank_in_response: bool, + rank: Option, ) -> Self { Self { method, message: ByteBuf::from(message), response_port: response_port.map(Into::into), - rank_in_response, + rank, } } @@ -240,8 +240,8 @@ impl PythonMessage { } #[getter] - fn rank_in_response(&self) -> bool { - self.rank_in_response + fn rank(&self) -> Option { + return self.rank; } } @@ -521,7 +521,7 @@ mod tests { method: "test".to_string(), message: ByteBuf::from(vec![1, 2, 3]), response_port: Some(id!(world[0].client[0][123])), - rank_in_response: false, + rank: None, }; { let unbound = message.clone().unbind().unwrap(); diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi index a80a79618..376dd0425 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi @@ -8,7 +8,7 @@ import abc -from typing import final, List, Protocol +from typing import final, List, Optional, Protocol from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox, PortId from monarch._rust_bindings.monarch_hyperactor.proc import ActorId, Proc, Serialized @@ -103,8 +103,8 @@ class PythonMessage: self, method: str, message: bytes, - response_port: PortId | None = None, - rank_in_response: bool = False, + response_port: Optional[PortId], + rank: int | None, ) -> None: ... @property def method(self) -> str: @@ -122,8 +122,8 @@ class PythonMessage: ... @property - def rank_in_response(self) -> bool: - """Whether or not to include the rank of the handling actor in the response.""" + def rank(self) -> Optional[int]: + """If this message is a response, the rank of the actor in the original broadcast that send the request.""" ... @final diff --git a/python/monarch/actor_mesh.py b/python/monarch/actor_mesh.py index 278e81d9a..445ab68d6 100644 --- a/python/monarch/actor_mesh.py +++ b/python/monarch/actor_mesh.py @@ -276,11 +276,11 @@ def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: return self.choose(*args, **kwargs) def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]": - p: PortId - r: PortReceiver[R] - p, r = port(self) + p: Port[R] + r: RankedPortReceiver[R] + p, r = ranked_port(self) # pyre-ignore - send(self, args, kwargs, port=p, rank_in_response=True) + send(self, args, kwargs, port=p) async def process() -> ValueMesh[R]: results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9] @@ -375,9 +375,8 @@ def send( endpoint: Endpoint[P, R], args: Tuple[Any, ...], kwargs: Dict[str, Any], - port: "Optional[PortId]" = None, + port: "Optional[Port]" = None, selection: Selection = "all", - rank_in_response: bool = False, ) -> None: """ Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh. @@ -386,7 +385,10 @@ def send( """ endpoint._signature.bind(None, *args, **kwargs) message = PythonMessage( - endpoint._name, _pickle((args, kwargs)), port, rank_in_response + endpoint._name, + _pickle((args, kwargs)), + None if port is None else port._port, + None, ) endpoint._actor_mesh.cast(message, selection) @@ -408,18 +410,16 @@ def endpoint( return EndpointProperty(method) -class Port: - def __init__(self, port: PortId, mailbox: Mailbox, rank_in_response: bool) -> None: +class Port(Generic[R]): + def __init__(self, port: PortId, mailbox: Mailbox, rank: Optional[int]) -> None: self._port = port self._mailbox = mailbox - self._rank_in_response = rank_in_response + self._rank = rank - def send(self, method: str, obj: object) -> None: - if self._rank_in_response: - obj = (MonarchContext.get().point.rank, obj) + def send(self, method: str, obj: R) -> None: self._mailbox.post( self._port, - PythonMessage(method, _pickle(obj), None), + PythonMessage(method, _pickle(obj), None, self._rank), ) @@ -428,12 +428,21 @@ def send(self, method: str, obj: object) -> None: # and handles concerns is different. def port( endpoint: Endpoint[P, R], once: bool = False -) -> Tuple["PortId", "PortReceiver[R]"]: +) -> Tuple["Port[R]", "PortReceiver[R]"]: handle, receiver = ( endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port() ) port_id: PortId = handle.bind() - return port_id, PortReceiver(endpoint._mailbox, receiver) + return Port(port_id, endpoint._mailbox, rank=None), PortReceiver( + endpoint._mailbox, receiver + ) + + +def ranked_port( + endpoint: Endpoint[P, R], once: bool = False +) -> Tuple["Port[R]", "RankedPortReceiver[R]"]: + p, receiver = port(endpoint, once) + return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver) class PortReceiver(Generic[R]): @@ -458,18 +467,20 @@ def _process(self, msg: PythonMessage) -> R: return payload else: assert msg.method == "exception" - if isinstance(payload, tuple): - # If the payload is a tuple, it's because we requested the rank - # to be included in the response; just ignore it. - raise payload[1] - else: - # pyre-ignore - raise payload + # pyre-ignore + raise payload def recv(self) -> "Future[R]": return Future(lambda: self._recv(), self._blocking_recv) +class RankedPortReceiver(PortReceiver[Tuple[int, R]]): + def _process(self, msg: PythonMessage) -> Tuple[int, R]: + if msg.rank is None: + raise ValueError("RankedPort receiver got a message without a rank") + return msg.rank, super()._process(msg) + + singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[])) @@ -493,7 +504,7 @@ def handle_cast( panic_flag: PanicFlag, ) -> Optional[Coroutine[Any, Any, Any]]: port = ( - Port(message.response_port, mailbox, message.rank_in_response) + Port(message.response_port, mailbox, rank) if message.response_port else None )