Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/monarch/_src/actor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# pyre-strict

"""
Monarch Actor API
Expand Down
62 changes: 33 additions & 29 deletions python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# pyre-strict

import abc
import collections
Expand Down Expand Up @@ -89,6 +89,8 @@
from monarch._src.actor.sync_state import fake_sync_state
from monarch._src.actor.telemetry import METER
from monarch._src.actor.tensor_engine_shim import actor_rref, actor_send
from opentelemetry.metrics import Counter
from opentelemetry.trace import Tracer
from typing_extensions import Self

if TYPE_CHECKING:
Expand All @@ -99,11 +101,9 @@
from monarch._src.actor.proc_mesh import _ControllerController, ProcMesh
from monarch._src.actor.telemetry import get_monarch_tracer

CallMethod = PythonMessageKind.CallMethod

logger: logging.Logger = logging.getLogger(__name__)

TRACER = get_monarch_tracer()
TRACER: Tracer = get_monarch_tracer()

Allocator = ProcessAllocator | LocalAllocator

Expand Down Expand Up @@ -164,9 +164,9 @@ def proc(self) -> "ProcMesh":

# this property is used to hold the handles to actors and processes launched by this actor
# in order to keep them alive until this actor exits.
_children: "Optional[List[ActorMesh | ProcMesh]]"
_children: "Optional[List[ActorMesh[Any] | ProcMesh]]"

def _add_child(self, child: "ActorMesh | ProcMesh") -> None:
def _add_child(self, child: "ActorMesh[Any] | ProcMesh") -> None:
if self._children is None:
self._children = [child]
else:
Expand Down Expand Up @@ -377,7 +377,7 @@ def stop(self, instance: HyInstance) -> "PythonTask[None]":
raise NotImplementedError("stop()")

def initialized(self) -> "PythonTask[None]":
async def empty():
async def empty() -> None:
pass

return PythonTask.from_coroutine(empty())
Expand All @@ -402,10 +402,10 @@ def __init__(
self._signature: inspect.Signature = inspect.signature(impl)
self._explicit_response_port = explicit_response_port

def _call_name(self) -> Any:
def _call_name(self) -> MethodSpecifier:
return self._name

def _check_arguments(self, args, kwargs):
def _check_arguments(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
if self._explicit_response_port:
self._signature.bind(None, None, *args, **kwargs)
else:
Expand All @@ -415,7 +415,7 @@ def _send(
self,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
port: "Optional[Port]" = None,
port: "Optional[Port[R]]" = None,
selection: Selection = "all",
) -> Extent:
"""
Expand Down Expand Up @@ -449,7 +449,7 @@ def _port(self, once: bool = False) -> "Tuple[Port[R], PortReceiver[R]]":
r._set_monitor(monitor)
return (p, r)

def _rref(self, args, kwargs):
def _rref(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> R:
self._check_arguments(args, kwargs)
refs, buffer = flatten((args, kwargs), _is_ref_or_mailbox)

Expand Down Expand Up @@ -479,7 +479,7 @@ def as_endpoint(
*,
propagate: Propagator = None,
explicit_response_port: bool = False,
):
) -> Any:
if not isinstance(not_an_endpoint, NotAnEndpoint):
raise ValueError("expected an method of a spawned actor")
kind = (
Expand Down Expand Up @@ -557,7 +557,7 @@ def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]":
remapped = [self._hy.get(pos[g]) for g in shape.ranks()]
return ValueMesh(shape, remapped)

def item(self, **kwargs) -> R:
def item(self, **kwargs: int) -> R:
"""
Get the value at the given coordinates.

Expand Down Expand Up @@ -621,10 +621,10 @@ def _ndslice(self) -> NDSlice:
def _labels(self) -> Iterable[str]:
return self._shape.labels

def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
return {"shape": self._shape, "values": self._hy.values()}

def __setstate__(self, state):
def __setstate__(self, state: Dict[str, Any]) -> None:
self._shape = state["shape"]
vals = state["values"]
self._hy = HyValueMesh(self._shape, vals)
Expand All @@ -634,7 +634,7 @@ def send(
endpoint: Endpoint[P, R],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
port: "Optional[Port]" = None,
port: "Optional[Port[R]]" = None,
selection: Selection = "all",
) -> None:
"""
Expand Down Expand Up @@ -690,15 +690,17 @@ def exception(self, obj: Exception) -> None:
PythonMessage(PythonMessageKind.Exception(self._rank), _pickle(obj)),
)

def __reduce__(self):
def __reduce__(self) -> Tuple[Any, Tuple[Any, ...]]:
"""
When Port is sent over the wire, we do not want to send the actor instance
from the current context. Instead, we want to reconstruct the Port with
the receiver's context, since that is where the message will be sent
from through this port.
"""

def _reconstruct_port(port_ref, rank):
def _reconstruct_port(
port_ref: PortRef | OncePortRef, rank: Optional[int]
) -> "Port[R]":
instance = context().actor_instance._as_rust()
return Port(port_ref, instance, rank)

Expand All @@ -714,7 +716,7 @@ class DroppingPort:
Makes sure any exception sent to it causes the actor to report an exception.
"""

def __init__(self):
def __init__(self) -> None:
pass

def send(self, obj: Any) -> None:
Expand Down Expand Up @@ -810,7 +812,7 @@ def recv(self) -> "Future[R]":
def ranked(self) -> "RankedPortReceiver[R]":
return RankedPortReceiver[R](self._mailbox, self._receiver, self._monitor)

def _set_monitor(self, monitor: "Optional[Shared[Exception]]"):
def _set_monitor(self, monitor: "Optional[Shared[Exception]]") -> None:
self._monitor = monitor


Expand All @@ -834,7 +836,7 @@ def _process(self, msg: PythonMessage) -> Tuple[int, R]:
# we need to signal to the consumer of the PythonTask object that the thread really isn't in an async context.
# We do this by blanking out the running event loop during the call to the synchronous actor function.

MESSAGES_HANDLED = METER.create_counter("py_mesages_handled")
MESSAGES_HANDLED: Counter = METER.create_counter("py_mesages_handled")


class _Actor:
Expand Down Expand Up @@ -968,15 +970,15 @@ async def handle(
pass
raise

def _maybe_exit_debugger(self, do_continue=True) -> None:
def _maybe_exit_debugger(self, do_continue: bool = True) -> None:
if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
if do_continue:
pdb_wrapper.clear_all_breaks()
pdb_wrapper.do_continue("")
pdb_wrapper.end_debug_session()
DebugContext.set(DebugContext())

def _post_mortem_debug(self, exc_tb) -> None:
def _post_mortem_debug(self, exc_tb: Any) -> None:
from monarch._src.actor.debugger.debug_controller import debug_controller

if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
Expand Down Expand Up @@ -1005,7 +1007,7 @@ def _handle_undeliverable_message(
else:
return False

def __supervise__(self, cx: Context, *args, **kwargs) -> object:
def __supervise__(self, cx: Context, *args: Any, **kwargs: Any) -> object:
_context.set(cx)
instance = self.instance
if instance is None:
Expand Down Expand Up @@ -1083,7 +1085,7 @@ def _new_with_shape(self, shape: Shape) -> Self:
)

@property
def initialized(self):
def initialized(self) -> Any:
raise NotImplementedError(
"actor implementations are not meshes, but we can't convince the typechecker of it..."
)
Expand Down Expand Up @@ -1164,9 +1166,9 @@ def _endpoint(
self,
name: MethodSpecifier,
impl: Callable[Concatenate[Any, P], Awaitable[R]],
propagator: Any,
propagator: Propagator,
explicit_response_port: bool,
):
) -> Any:
return ActorEndpoint(
self._inner,
self._shape,
Expand Down Expand Up @@ -1215,7 +1217,9 @@ def from_actor_id(
) -> "ActorMesh[T]":
return cls(Class, _SingletonActorAdapator(actor_id), singleton_shape, None)

def __reduce_ex__(self, protocol: ...) -> "Tuple[Type[ActorMesh], Tuple[Any, ...]]":
def __reduce_ex__(
self, protocol: Any
) -> "Tuple[Type[ActorMesh[T]], Tuple[Any, ...]]":
return ActorMesh, (self._class, self._inner, self._shape, self._proc_mesh)

@property
Expand Down Expand Up @@ -1269,7 +1273,7 @@ def __init__(
)
for s in actor_mesh_ref_tb
)
self.exception_formatted = "".join(actor_mesh_ref_tb)
self.exception_formatted: str = "".join(actor_mesh_ref_tb)
self.message = message

def __str__(self) -> str:
Expand Down
9 changes: 5 additions & 4 deletions python/monarch/_src/actor/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# pyre-strict


from pathlib import Path
Expand All @@ -28,11 +28,12 @@
CA = Union[bytes, Path, Literal["trust_all_connections"]]


def _as_python_task(s: str | Future[str]) -> PythonTask:
def _as_python_task(s: str | Future[str]) -> "PythonTask[str]":
if isinstance(s, str):
s_str: str = s

async def just():
return s
async def just() -> str:
return s_str

return PythonTask.from_coroutine(just())
else:
Expand Down
7 changes: 4 additions & 3 deletions python/monarch/_src/actor/bootstrap_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# pyre-strict

"""
This is the main function for the boostrapping a new process using a ProcessAllocator.
Expand All @@ -29,13 +29,14 @@
import monarch._rust_bindings # @manual # noqa: F401


async def main():
async def main() -> None:
from monarch._rust_bindings.monarch_hyperactor.bootstrap import bootstrap_main

# pyre-ignore[12]: bootstrap_main is async but imported from Rust bindings
await bootstrap_main()


def invoke_main():
def invoke_main() -> None:
# if this is invoked with the stdout piped somewhere, then print
# changes its buffering behavior. So we default to the standard
# behavior of std out as if it were a terminal.
Expand Down
2 changes: 1 addition & 1 deletion python/monarch/_src/actor/code_sync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# pyre-strict

from monarch._rust_bindings.monarch_extension.code_sync import ( # noqa: F401
CodeSyncMeshClient,
Expand Down
Loading