Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions monarch_hyperactor/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ pub struct PythonMessage {
method: String,
message: ByteBuf,
response_port: Option<PortId>,
rank_in_response: bool,
rank: Option<usize>,
}

impl std::fmt::Debug for PythonMessage {
Expand Down Expand Up @@ -209,18 +209,18 @@ impl Bind for PythonMessage {
#[pymethods]
impl PythonMessage {
#[new]
#[pyo3(signature = (method, message, response_port = None, rank_in_response = false))]
#[pyo3(signature = (method, message, response_port, rank))]
fn new(
method: String,
message: Vec<u8>,
response_port: Option<crate::mailbox::PyPortId>,
rank_in_response: bool,
rank: Option<usize>,
) -> Self {
Self {
method,
message: ByteBuf::from(message),
response_port: response_port.map(Into::into),
rank_in_response,
rank,
}
}

Expand All @@ -240,8 +240,8 @@ impl PythonMessage {
}

#[getter]
fn rank_in_response(&self) -> bool {
self.rank_in_response
fn rank(&self) -> Option<usize> {
return self.rank;
}
}

Expand Down Expand Up @@ -521,7 +521,7 @@ mod tests {
method: "test".to_string(),
message: ByteBuf::from(vec![1, 2, 3]),
response_port: Some(id!(world[0].client[0][123])),
rank_in_response: false,
rank: None,
};
{
let unbound = message.clone().unbind().unwrap();
Expand Down
10 changes: 5 additions & 5 deletions python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import abc

from typing import final, List, Protocol
from typing import final, List, Optional, Protocol

from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox, PortId
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId, Proc, Serialized
Expand Down Expand Up @@ -103,8 +103,8 @@ class PythonMessage:
self,
method: str,
message: bytes,
response_port: PortId | None = None,
rank_in_response: bool = False,
response_port: Optional[PortId],
rank: int | None,
) -> None: ...
@property
def method(self) -> str:
Expand All @@ -122,8 +122,8 @@ class PythonMessage:
...

@property
def rank_in_response(self) -> bool:
"""Whether or not to include the rank of the handling actor in the response."""
def rank(self) -> Optional[int]:
"""If this message is a response, the rank of the actor in the original broadcast that send the request."""
...

@final
Expand Down
59 changes: 35 additions & 24 deletions python/monarch/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,11 @@ def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
return self.choose(*args, **kwargs)

def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
p: PortId
r: PortReceiver[R]
p, r = port(self)
p: Port[R]
r: RankedPortReceiver[R]
p, r = ranked_port(self)
# pyre-ignore
send(self, args, kwargs, port=p, rank_in_response=True)
send(self, args, kwargs, port=p)

async def process() -> ValueMesh[R]:
results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
Expand Down Expand Up @@ -375,9 +375,8 @@ def send(
endpoint: Endpoint[P, R],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
port: "Optional[PortId]" = None,
port: "Optional[Port]" = None,
selection: Selection = "all",
rank_in_response: bool = False,
) -> None:
"""
Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
Expand All @@ -386,7 +385,10 @@ def send(
"""
endpoint._signature.bind(None, *args, **kwargs)
message = PythonMessage(
endpoint._name, _pickle((args, kwargs)), port, rank_in_response
endpoint._name,
_pickle((args, kwargs)),
None if port is None else port._port,
None,
)
endpoint._actor_mesh.cast(message, selection)

Expand All @@ -408,18 +410,16 @@ def endpoint(
return EndpointProperty(method)


class Port:
def __init__(self, port: PortId, mailbox: Mailbox, rank_in_response: bool) -> None:
class Port(Generic[R]):
def __init__(self, port: PortId, mailbox: Mailbox, rank: Optional[int]) -> None:
self._port = port
self._mailbox = mailbox
self._rank_in_response = rank_in_response
self._rank = rank

def send(self, method: str, obj: object) -> None:
if self._rank_in_response:
obj = (MonarchContext.get().point.rank, obj)
def send(self, method: str, obj: R) -> None:
self._mailbox.post(
self._port,
PythonMessage(method, _pickle(obj), None),
PythonMessage(method, _pickle(obj), None, self._rank),
)


Expand All @@ -428,12 +428,21 @@ def send(self, method: str, obj: object) -> None:
# and handles concerns is different.
def port(
endpoint: Endpoint[P, R], once: bool = False
) -> Tuple["PortId", "PortReceiver[R]"]:
) -> Tuple["Port[R]", "PortReceiver[R]"]:
handle, receiver = (
endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port()
)
port_id: PortId = handle.bind()
return port_id, PortReceiver(endpoint._mailbox, receiver)
return Port(port_id, endpoint._mailbox, rank=None), PortReceiver(
endpoint._mailbox, receiver
)


def ranked_port(
endpoint: Endpoint[P, R], once: bool = False
) -> Tuple["Port[R]", "RankedPortReceiver[R]"]:
p, receiver = port(endpoint, once)
return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver)


class PortReceiver(Generic[R]):
Expand All @@ -458,18 +467,20 @@ def _process(self, msg: PythonMessage) -> R:
return payload
else:
assert msg.method == "exception"
if isinstance(payload, tuple):
# If the payload is a tuple, it's because we requested the rank
# to be included in the response; just ignore it.
raise payload[1]
else:
# pyre-ignore
raise payload
# pyre-ignore
raise payload

def recv(self) -> "Future[R]":
return Future(lambda: self._recv(), self._blocking_recv)


class RankedPortReceiver(PortReceiver[Tuple[int, R]]):
def _process(self, msg: PythonMessage) -> Tuple[int, R]:
if msg.rank is None:
raise ValueError("RankedPort receiver got a message without a rank")
return msg.rank, super()._process(msg)


singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[]))


Expand All @@ -493,7 +504,7 @@ def handle_cast(
panic_flag: PanicFlag,
) -> Optional[Coroutine[Any, Any, Any]]:
port = (
Port(message.response_port, mailbox, message.rank_in_response)
Port(message.response_port, mailbox, rank)
if message.response_port
else None
)
Expand Down