From 095067611e2d2607985d06e2d66869281822e257 Mon Sep 17 00:00:00 2001 From: Robert Rusch Date: Thu, 30 Oct 2025 09:08:09 -0700 Subject: [PATCH] Migrator monarch/python/monarch/_src/actor to pyre-strict Summary: pyre had a couple of gaps for how we handled endpoints that interfered with full turning typing on that are mostly addressed with pyrefly. This is part 1 of a series of diffs to turn on pyre-strict, leaning someone generously on Any's for now, which then will be followed up by enabling pyrefly. Differential Revision: D85356082 --- python/monarch/_src/actor/__init__.py | 2 +- python/monarch/_src/actor/actor_mesh.py | 62 ++++++------ python/monarch/_src/actor/bootstrap.py | 9 +- python/monarch/_src/actor/bootstrap_main.py | 7 +- .../monarch/_src/actor/code_sync/__init__.py | 2 +- .../_src/actor/code_sync/auto_reload.py | 75 +++++++++------ .../monarch/_src/actor/debugger/__init__.py | 2 +- .../monarch/_src/actor/debugger/breakpoint.py | 2 +- .../_src/actor/debugger/debug_command.py | 15 +-- .../_src/actor/debugger/debug_controller.py | 12 +-- .../monarch/_src/actor/debugger/debug_io.py | 8 +- .../_src/actor/debugger/debug_session.py | 39 +++++--- .../_src/actor/debugger/pdb_wrapper.py | 47 +++++----- python/monarch/_src/actor/device_utils.py | 4 +- python/monarch/_src/actor/endpoint.py | 94 ++++++++++++------- python/monarch/_src/actor/event_loop.py | 6 +- python/monarch/_src/actor/future.py | 18 ++-- python/monarch/_src/actor/host_mesh.py | 6 +- python/monarch/_src/actor/pickle.py | 45 +++++---- .../_src/actor/python_extension_methods.py | 4 +- python/monarch/_src/actor/shape.py | 12 +-- python/monarch/_src/actor/source_loader.py | 4 +- python/monarch/_src/actor/sync_state.py | 5 +- .../monarch/_src/actor/tensor_engine_shim.py | 50 ++++++++-- python/monarch/_src/actor/v1/__init__.py | 4 +- 25 files changed, 318 insertions(+), 216 deletions(-) diff --git a/python/monarch/_src/actor/__init__.py b/python/monarch/_src/actor/__init__.py index e785d856e..4de33298f 100644 --- a/python/monarch/_src/actor/__init__.py +++ b/python/monarch/_src/actor/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict """ Monarch Actor API diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 938d19773..3ebe39c13 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import abc import collections @@ -89,6 +89,8 @@ from monarch._src.actor.sync_state import fake_sync_state from monarch._src.actor.telemetry import METER from monarch._src.actor.tensor_engine_shim import actor_rref, actor_send +from opentelemetry.metrics import Counter +from opentelemetry.trace import Tracer from typing_extensions import Self if TYPE_CHECKING: @@ -99,11 +101,9 @@ from monarch._src.actor.proc_mesh import _ControllerController, ProcMesh from monarch._src.actor.telemetry import get_monarch_tracer -CallMethod = PythonMessageKind.CallMethod - logger: logging.Logger = logging.getLogger(__name__) -TRACER = get_monarch_tracer() +TRACER: Tracer = get_monarch_tracer() Allocator = ProcessAllocator | LocalAllocator @@ -164,9 +164,9 @@ def proc(self) -> "ProcMesh": # this property is used to hold the handles to actors and processes launched by this actor # in order to keep them alive until this actor exits. - _children: "Optional[List[ActorMesh | ProcMesh]]" + _children: "Optional[List[ActorMesh[Any] | ProcMesh]]" - def _add_child(self, child: "ActorMesh | ProcMesh") -> None: + def _add_child(self, child: "ActorMesh[Any] | ProcMesh") -> None: if self._children is None: self._children = [child] else: @@ -377,7 +377,7 @@ def stop(self, instance: HyInstance) -> "PythonTask[None]": raise NotImplementedError("stop()") def initialized(self) -> "PythonTask[None]": - async def empty(): + async def empty() -> None: pass return PythonTask.from_coroutine(empty()) @@ -402,10 +402,10 @@ def __init__( self._signature: inspect.Signature = inspect.signature(impl) self._explicit_response_port = explicit_response_port - def _call_name(self) -> Any: + def _call_name(self) -> MethodSpecifier: return self._name - def _check_arguments(self, args, kwargs): + def _check_arguments(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: if self._explicit_response_port: self._signature.bind(None, None, *args, **kwargs) else: @@ -415,7 +415,7 @@ def _send( self, args: Tuple[Any, ...], kwargs: Dict[str, Any], - port: "Optional[Port]" = None, + port: "Optional[Port[R]]" = None, selection: Selection = "all", ) -> Extent: """ @@ -449,7 +449,7 @@ def _port(self, once: bool = False) -> "Tuple[Port[R], PortReceiver[R]]": r._set_monitor(monitor) return (p, r) - def _rref(self, args, kwargs): + def _rref(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> R: self._check_arguments(args, kwargs) refs, buffer = flatten((args, kwargs), _is_ref_or_mailbox) @@ -479,7 +479,7 @@ def as_endpoint( *, propagate: Propagator = None, explicit_response_port: bool = False, -): +) -> Any: if not isinstance(not_an_endpoint, NotAnEndpoint): raise ValueError("expected an method of a spawned actor") kind = ( @@ -557,7 +557,7 @@ def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]": remapped = [self._hy.get(pos[g]) for g in shape.ranks()] return ValueMesh(shape, remapped) - def item(self, **kwargs) -> R: + def item(self, **kwargs: int) -> R: """ Get the value at the given coordinates. @@ -621,10 +621,10 @@ def _ndslice(self) -> NDSlice: def _labels(self) -> Iterable[str]: return self._shape.labels - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return {"shape": self._shape, "values": self._hy.values()} - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: self._shape = state["shape"] vals = state["values"] self._hy = HyValueMesh(self._shape, vals) @@ -634,7 +634,7 @@ def send( endpoint: Endpoint[P, R], args: Tuple[Any, ...], kwargs: Dict[str, Any], - port: "Optional[Port]" = None, + port: "Optional[Port[R]]" = None, selection: Selection = "all", ) -> None: """ @@ -690,7 +690,7 @@ def exception(self, obj: Exception) -> None: PythonMessage(PythonMessageKind.Exception(self._rank), _pickle(obj)), ) - def __reduce__(self): + def __reduce__(self) -> Tuple[Any, Tuple[Any, ...]]: """ When Port is sent over the wire, we do not want to send the actor instance from the current context. Instead, we want to reconstruct the Port with @@ -698,7 +698,9 @@ def __reduce__(self): from through this port. """ - def _reconstruct_port(port_ref, rank): + def _reconstruct_port( + port_ref: PortRef | OncePortRef, rank: Optional[int] + ) -> "Port[R]": instance = context().actor_instance._as_rust() return Port(port_ref, instance, rank) @@ -714,7 +716,7 @@ class DroppingPort: Makes sure any exception sent to it causes the actor to report an exception. """ - def __init__(self): + def __init__(self) -> None: pass def send(self, obj: Any) -> None: @@ -810,7 +812,7 @@ def recv(self) -> "Future[R]": def ranked(self) -> "RankedPortReceiver[R]": return RankedPortReceiver[R](self._mailbox, self._receiver, self._monitor) - def _set_monitor(self, monitor: "Optional[Shared[Exception]]"): + def _set_monitor(self, monitor: "Optional[Shared[Exception]]") -> None: self._monitor = monitor @@ -834,7 +836,7 @@ def _process(self, msg: PythonMessage) -> Tuple[int, R]: # we need to signal to the consumer of the PythonTask object that the thread really isn't in an async context. # We do this by blanking out the running event loop during the call to the synchronous actor function. -MESSAGES_HANDLED = METER.create_counter("py_mesages_handled") +MESSAGES_HANDLED: Counter = METER.create_counter("py_mesages_handled") class _Actor: @@ -968,7 +970,7 @@ async def handle( pass raise - def _maybe_exit_debugger(self, do_continue=True) -> None: + def _maybe_exit_debugger(self, do_continue: bool = True) -> None: if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None: if do_continue: pdb_wrapper.clear_all_breaks() @@ -976,7 +978,7 @@ def _maybe_exit_debugger(self, do_continue=True) -> None: pdb_wrapper.end_debug_session() DebugContext.set(DebugContext()) - def _post_mortem_debug(self, exc_tb) -> None: + def _post_mortem_debug(self, exc_tb: Any) -> None: from monarch._src.actor.debugger.debug_controller import debug_controller if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None: @@ -1005,7 +1007,7 @@ def _handle_undeliverable_message( else: return False - def __supervise__(self, cx: Context, *args, **kwargs) -> object: + def __supervise__(self, cx: Context, *args: Any, **kwargs: Any) -> object: _context.set(cx) instance = self.instance if instance is None: @@ -1083,7 +1085,7 @@ def _new_with_shape(self, shape: Shape) -> Self: ) @property - def initialized(self): + def initialized(self) -> Any: raise NotImplementedError( "actor implementations are not meshes, but we can't convince the typechecker of it..." ) @@ -1164,9 +1166,9 @@ def _endpoint( self, name: MethodSpecifier, impl: Callable[Concatenate[Any, P], Awaitable[R]], - propagator: Any, + propagator: Propagator, explicit_response_port: bool, - ): + ) -> Any: return ActorEndpoint( self._inner, self._shape, @@ -1215,7 +1217,9 @@ def from_actor_id( ) -> "ActorMesh[T]": return cls(Class, _SingletonActorAdapator(actor_id), singleton_shape, None) - def __reduce_ex__(self, protocol: ...) -> "Tuple[Type[ActorMesh], Tuple[Any, ...]]": + def __reduce_ex__( + self, protocol: Any + ) -> "Tuple[Type[ActorMesh[T]], Tuple[Any, ...]]": return ActorMesh, (self._class, self._inner, self._shape, self._proc_mesh) @property @@ -1269,7 +1273,7 @@ def __init__( ) for s in actor_mesh_ref_tb ) - self.exception_formatted = "".join(actor_mesh_ref_tb) + self.exception_formatted: str = "".join(actor_mesh_ref_tb) self.message = message def __str__(self) -> str: diff --git a/python/monarch/_src/actor/bootstrap.py b/python/monarch/_src/actor/bootstrap.py index 0d5f5005f..7d2631854 100644 --- a/python/monarch/_src/actor/bootstrap.py +++ b/python/monarch/_src/actor/bootstrap.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict from pathlib import Path @@ -28,11 +28,12 @@ CA = Union[bytes, Path, Literal["trust_all_connections"]] -def _as_python_task(s: str | Future[str]) -> PythonTask: +def _as_python_task(s: str | Future[str]) -> "PythonTask[str]": if isinstance(s, str): + s_str: str = s - async def just(): - return s + async def just() -> str: + return s_str return PythonTask.from_coroutine(just()) else: diff --git a/python/monarch/_src/actor/bootstrap_main.py b/python/monarch/_src/actor/bootstrap_main.py index fa8d3e093..a8c41ba0d 100644 --- a/python/monarch/_src/actor/bootstrap_main.py +++ b/python/monarch/_src/actor/bootstrap_main.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict """ This is the main function for the boostrapping a new process using a ProcessAllocator. @@ -29,13 +29,14 @@ import monarch._rust_bindings # @manual # noqa: F401 -async def main(): +async def main() -> None: from monarch._rust_bindings.monarch_hyperactor.bootstrap import bootstrap_main + # pyre-ignore[12]: bootstrap_main is async but imported from Rust bindings await bootstrap_main() -def invoke_main(): +def invoke_main() -> None: # if this is invoked with the stdout piped somewhere, then print # changes its buffering behavior. So we default to the standard # behavior of std out as if it were a terminal. diff --git a/python/monarch/_src/actor/code_sync/__init__.py b/python/monarch/_src/actor/code_sync/__init__.py index 1b3fcbf84..a20ac98f3 100644 --- a/python/monarch/_src/actor/code_sync/__init__.py +++ b/python/monarch/_src/actor/code_sync/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict from monarch._rust_bindings.monarch_extension.code_sync import ( # noqa: F401 CodeSyncMeshClient, diff --git a/python/monarch/_src/actor/code_sync/auto_reload.py b/python/monarch/_src/actor/code_sync/auto_reload.py index 675f36cb8..97f7a8f99 100644 --- a/python/monarch/_src/actor/code_sync/auto_reload.py +++ b/python/monarch/_src/actor/code_sync/auto_reload.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import contextlib import dataclasses @@ -16,33 +16,40 @@ import sys import threading from pathlib import Path -from types import ModuleType -from typing import Dict, List, Optional, Tuple +from types import ModuleType, TracebackType +from typing import Any, Callable, Dict, List, Optional, Tuple, Type from monarch._src.actor.actor_mesh import Actor from monarch._src.actor.endpoint import endpoint -class SysAuditHookGuard(contextlib.AbstractContextManager): +class SysAuditHookGuard(contextlib.AbstractContextManager["SysAuditHookGuard", None]): """ A guard (and context manager), which will unregister an import hook when closed or deleted. """ - def __init__(self, hooks, idx): - self._hooks = hooks - self._idx = idx + def __init__( + self, hooks: Dict[int, Callable[[str, tuple[Any, ...]], None]], idx: int + ) -> None: + self._hooks: Dict[int, Callable[[str, tuple[Any, ...]], None]] = hooks + self._idx: int = idx - def close(self): + def close(self) -> None: self._hooks.pop(self._idx, None) - def __enter__(self): + def __enter__(self) -> "SysAuditHookGuard": return self - def __exit__(self, *args): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self.close() - def __del__(self): + def __del__(self) -> None: self.close() @@ -55,29 +62,30 @@ class SysAuditHookMultiplexer: removal. """ - def __init__(self): - self._idx = itertools.count() - self._hooks = {} + def __init__(self) -> None: + self._idx: itertools.count[int] = itertools.count() + self._hooks: Dict[int, Callable[[str, tuple[Any, ...]], None]] = {} - def _callback(self, event, args): + def _callback(self, event: str, args: tuple[Any, ...]) -> None: for hook in self._hooks.values(): hook(event, args) - def add(self, hook) -> SysAuditHookGuard: + def add(self, hook: Callable[[str, tuple[Any, ...]], None]) -> SysAuditHookGuard: idx = next(self._idx) self._hooks[idx] = hook return SysAuditHookGuard(self._hooks, idx) _instance_lock = threading.Lock() - _instance = None + _instance: Optional["SysAuditHookMultiplexer"] = None @classmethod - def singleton(cls): + def singleton(cls) -> "SysAuditHookMultiplexer": if cls._instance is None: with cls._instance_lock: if cls._instance is None: cls._instance = SysAuditHookMultiplexer() sys.addaudithook(cls._instance._callback) + assert cls._instance is not None return cls._instance @@ -93,12 +101,12 @@ class SysAuditImportHook: imported. """ - def __init__(self, callback): - self._callback = callback + def __init__(self, callback: Callable[[str, ModuleType], None]) -> None: + self._callback: Callable[[str, ModuleType], None] = callback self._state = ThreadLocalState() @classmethod - def install(cls, callback) -> SysAuditHookGuard: + def install(cls, callback: Callable[[str, ModuleType], None]) -> SysAuditHookGuard: return SysAuditHookMultiplexer.singleton().add(SysAuditImportHook(callback)) def _py_filename(self, filename: Path) -> Path: @@ -106,7 +114,7 @@ def _py_filename(self, filename: Path) -> Path: return filename.with_suffix(".py") return filename - def __call__(self, event, args): + def __call__(self, event: str, args: tuple[Any, ...]) -> None: if event == "import": # While `filename` is specific as an argument to the import event, it's # almost always `None`, so we need to wait for a subsequent exec event @@ -123,13 +131,14 @@ def __call__(self, event, args): module = sys.modules.get(module_name) if module is None: return - if getattr(module, "__file__", None) is None: + module_file = getattr(module, "__file__", None) + if module_file is None: return (code_obj,) = args if code_obj.co_filename is None: return # code objects store the original source name, not the pyc - if self._py_filename(Path(module.__file__)) != Path(code_obj.co_filename): + if self._py_filename(Path(module_file)) != Path(code_obj.co_filename): return self._callback(module_name, module) @@ -150,12 +159,14 @@ class AutoReloader: Track changes to modules and reloads them when they change. """ - def __init__(self, reload=importlib.reload): - self._reload = reload + def __init__( + self, reload: Callable[[ModuleType], ModuleType] = importlib.reload + ) -> None: + self._reload: Callable[[ModuleType], ModuleType] = reload self._tracked_modules: Dict[str, Tuple[Path, Fingerprint]] = {} self._track_all_imported() - def _maybe_track_module(self, name: str, module: ModuleType): + def _maybe_track_module(self, name: str, module: ModuleType) -> None: filename = getattr(module, "__file__", None) if filename is None: return @@ -181,13 +192,13 @@ def _maybe_track_module(self, name: str, module: ModuleType): Fingerprint.for_path(filename), ) - def _track_all_imported(self): + def _track_all_imported(self) -> None: for name, module in sys.modules.items(): if module is None: continue self._maybe_track_module(name, module) - def import_callback(self, name: str, module: ModuleType): + def import_callback(self, name: str, module: ModuleType) -> None: """ Callback for when a module has been imported. """ @@ -215,9 +226,11 @@ def reload_changes(self) -> List[str]: class AutoReloadActor(Actor): - def __init__(self): + def __init__(self) -> None: self._reloader = AutoReloader() - self._hook_guard = SysAuditImportHook.install(self._reloader.import_callback) + self._hook_guard: SysAuditHookGuard = SysAuditImportHook.install( + self._reloader.import_callback + ) @endpoint async def reload(self) -> None: diff --git a/python/monarch/_src/actor/debugger/__init__.py b/python/monarch/_src/actor/debugger/__init__.py index 6ac1a72bd..581f84e46 100644 --- a/python/monarch/_src/actor/debugger/__init__.py +++ b/python/monarch/_src/actor/debugger/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict diff --git a/python/monarch/_src/actor/debugger/breakpoint.py b/python/monarch/_src/actor/debugger/breakpoint.py index c50f84e9a..127d36f2d 100644 --- a/python/monarch/_src/actor/debugger/breakpoint.py +++ b/python/monarch/_src/actor/debugger/breakpoint.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import inspect from monarch._src.actor.actor_mesh import context, DebugContext diff --git a/python/monarch/_src/actor/debugger/debug_command.py b/python/monarch/_src/actor/debugger/debug_command.py index da5e34905..e19f9b21a 100644 --- a/python/monarch/_src/actor/debugger/debug_command.py +++ b/python/monarch/_src/actor/debugger/debug_command.py @@ -4,21 +4,24 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import sys from dataclasses import dataclass -from typing import cast, Dict, List, Tuple, Union +from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union + +if TYPE_CHECKING: + from lark import Lark, Transformer from monarch._src.actor.debugger.debug_io import DebugIO RanksType = Union[int, List[int], range, Dict[str, Union[range, List[int], int]]] -_debug_input_parser = None +_debug_input_parser: "Optional[Lark]" = None # Wrap the parser in a function so that jobs don't have to import lark # unless they want to use the debugger. -def _get_debug_input_parser(): +def _get_debug_input_parser() -> "Lark": global _debug_input_parser if _debug_input_parser is None: from lark import Lark @@ -55,12 +58,12 @@ def _get_debug_input_parser(): return _debug_input_parser -_debug_input_transformer = None +_debug_input_transformer: "Optional[Transformer]" = None # Wrap the transformer in a function so that jobs don't have to import lark # unless they want to use the debugger. -def _get_debug_input_transformer(): +def _get_debug_input_transformer() -> "Transformer": global _debug_input_transformer if _debug_input_transformer is None: from lark import Transformer diff --git a/python/monarch/_src/actor/debugger/debug_controller.py b/python/monarch/_src/actor/debugger/debug_controller.py index 245a70d80..d157f3061 100644 --- a/python/monarch/_src/actor/debugger/debug_controller.py +++ b/python/monarch/_src/actor/debugger/debug_controller.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import asyncio import functools from typing import Dict, List, Optional, Tuple @@ -54,10 +54,10 @@ class DebugController(Actor): def __init__(self) -> None: self.sessions = DebugSessions() self._task_lock = asyncio.Lock() - self._task: asyncio.Task | None = None + self._task: Optional[asyncio.Task[None]] = None self._debug_io: DebugIO = DebugStdIO() - self._server = asyncio.Future() - self._server_task = asyncio.create_task(self._serve()) + self._server: asyncio.Future[asyncio.Server] = asyncio.Future() + self._server_task: asyncio.Task[None] = asyncio.create_task(self._serve()) async def _serve(self) -> None: try: @@ -98,12 +98,12 @@ async def _handle_client( self._task = asyncio.create_task(self._enter()) @endpoint - async def wait_pending_session(self): + async def wait_pending_session(self) -> None: while len(self.sessions) == 0: await asyncio.sleep(1) @endpoint - async def list(self, print_output=True) -> List[DebugSessionInfo]: + async def list(self, print_output: bool = True) -> List[DebugSessionInfo]: session_info = sorted(self.sessions.info()) if print_output: await self._debug_io.output( diff --git a/python/monarch/_src/actor/debugger/debug_io.py b/python/monarch/_src/actor/debugger/debug_io.py index 26452e02e..c0d87fba1 100644 --- a/python/monarch/_src/actor/debugger/debug_io.py +++ b/python/monarch/_src/actor/debugger/debug_io.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import asyncio import sys from abc import abstractmethod @@ -34,12 +34,14 @@ async def quit(self) -> None: class DebugIOError(RuntimeError): - def __init__(self): + def __init__(self) -> None: super().__init__("Error encountered during debugger I/O operation.") class DebugCliIO(DebugIO): - def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + def __init__( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: self._reader = reader self._writer = writer diff --git a/python/monarch/_src/actor/debugger/debug_session.py b/python/monarch/_src/actor/debugger/debug_session.py index 2d7a7e2af..0688860ba 100644 --- a/python/monarch/_src/actor/debugger/debug_session.py +++ b/python/monarch/_src/actor/debugger/debug_session.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import asyncio from dataclasses import dataclass from typing import Dict, Generator, List, Optional, Tuple @@ -24,7 +24,7 @@ class DebugSessionInfo: function: str | None lineno: int | None - def __lt__(self, other): + def __lt__(self, other: "DebugSessionInfo") -> bool: if self.actor_name < other.actor_name: return True elif self.actor_name == other.actor_name: @@ -38,20 +38,27 @@ class DebugSession: def __init__( self, rank: int, coords: Dict[str, int], hostname: str, actor_name: str - ): + ) -> None: self.rank = rank self.coords = coords self.hostname = hostname self.actor_name = actor_name self._active = False - self._message_queue = asyncio.Queue() - self._task = None - self._pending_send_to_actor = asyncio.Queue() - self._outputs_since_last_input = [] - self._function_lineno = None + self._message_queue: asyncio.Queue[str | Tuple[str, DebuggerWrite]] = ( + asyncio.Queue() + ) + self._task: Optional[asyncio.Task[None]] = None + self._pending_send_to_actor: asyncio.Queue[bytes] = asyncio.Queue() + self._outputs_since_last_input: List[DebuggerWrite] = [] + self._function_lineno: Optional[Tuple[str, int]] = None self._need_read = False - async def _event_loop(self, debug_io: DebugIO, line=None, suppress_output=False): + async def _event_loop( + self, + debug_io: DebugIO, + line: Optional[str] = None, + suppress_output: bool = False, + ) -> None: if not suppress_output: # If the user had previously attached to this debug session, # then it would have printed various messages from the @@ -105,6 +112,7 @@ async def _event_loop(self, debug_io: DebugIO, line=None, suppress_output=False) raise elif message[0] == "write": output = message[1] + assert isinstance(output, DebuggerWrite) # If the user sees this output but then detaches from the session, # its useful to store all outputs since the last input so that # they can be printed again when the user re-attaches. @@ -117,7 +125,7 @@ async def _event_loop(self, debug_io: DebugIO, line=None, suppress_output=False) f"Detaching from debug session for {self.actor_name} {self.rank} ({self.hostname})\n" ) - def get_info(self): + def get_info(self) -> DebugSessionInfo: function = lineno = None if self._function_lineno is not None: function, lineno = self._function_lineno @@ -125,7 +133,12 @@ def get_info(self): self.actor_name, self.rank, self.coords, self.hostname, function, lineno ) - async def attach(self, debug_io: DebugIO, line=None, suppress_output=False): + async def attach( + self, + debug_io: DebugIO, + line: Optional[str] = None, + suppress_output: bool = False, + ) -> None: self._active = True if not suppress_output: await debug_io.output( @@ -141,7 +154,7 @@ async def attach(self, debug_io: DebugIO, line=None, suppress_output=False): ) self._active = False - async def detach(self): + async def detach(self) -> None: if self._active: await self._message_queue.put("detach") @@ -159,7 +172,7 @@ async def debugger_write(self, write: DebuggerWrite) -> None: class DebugSessions: - def __init__(self): + def __init__(self) -> None: self._sessions: Dict[str, Dict[int, DebugSession]] = {} def insert(self, session: DebugSession) -> None: diff --git a/python/monarch/_src/actor/debugger/pdb_wrapper.py b/python/monarch/_src/actor/debugger/pdb_wrapper.py index 82ae559a1..3590e5a3b 100644 --- a/python/monarch/_src/actor/debugger/pdb_wrapper.py +++ b/python/monarch/_src/actor/debugger/pdb_wrapper.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import bdb import inspect import io @@ -15,8 +15,9 @@ import sys from contextlib import contextmanager from dataclasses import dataclass +from types import FrameType, TracebackType -from typing import Dict, TYPE_CHECKING +from typing import Any, Dict, Generator, Optional, TYPE_CHECKING from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._src.actor.sync_state import fake_sync_state @@ -33,7 +34,7 @@ class DebuggerWrite: @contextmanager -def _debug_controller_request_ctx(): +def _debug_controller_request_ctx() -> Generator[None, None, None]: try: with fake_sync_state(): yield @@ -49,7 +50,7 @@ def __init__( actor_id: ActorId, controller: "DebugController", header: str | None = None, - ): + ) -> None: self.rank = rank self.coords = coords self.header = header @@ -59,7 +60,7 @@ def __init__( super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self)) self._first = True - def set_trace(self, frame=None): + def set_trace(self, frame: Optional[FrameType] = None) -> None: with _debug_controller_request_ctx(): self.controller.debugger_session_start.call_one( self.rank, @@ -71,7 +72,7 @@ def set_trace(self, frame=None): self.message(self.header) super().set_trace(frame) - def do_clear(self, arg): + def do_clear(self, arg: str) -> None: if not arg: # Sending `clear` without any argument specified will # request confirmation from the user using the `input` function, @@ -82,31 +83,31 @@ def do_clear(self, arg): else: super().do_clear(arg) - def lookupmodule(self, filename): - filename = super().lookupmodule(filename) + def lookupmodule(self, filename: str) -> Optional[str]: + result = super().lookupmodule(filename) if ( - filename is not None - and not os.path.exists(filename) - and filename not in linecache.cache + result is not None + and not os.path.exists(result) + and result not in linecache.cache ): from monarch._src.actor.actor_mesh import ActorError from monarch._src.actor.source_loader import load_remote_source try: with fake_sync_state(): - source = load_remote_source(filename) + source = load_remote_source(result) if source: - linecache.cache[filename] = ( + linecache.cache[result] = ( len(source), None, source.splitlines(keepends=True), - filename, + result, ) except ActorError as e: self.error(f"Failed querying root client host for source code: {e}") - return filename + return result - def end_debug_session(self): + def end_debug_session(self) -> None: with _debug_controller_request_ctx(): self.controller.debugger_session_end.call_one( self.actor_id.actor_name, self.rank @@ -117,7 +118,7 @@ def end_debug_session(self): self.stdin = sys.stdin self.stdout = sys.stdout - def post_mortem(self, exc_tb): + def post_mortem(self, exc_tb: TracebackType) -> None: self._first = False # See builtin implementation of pdb.post_mortem() for reference. self.reset() @@ -125,10 +126,10 @@ def post_mortem(self, exc_tb): class ReadWrapper(io.RawIOBase): - def __init__(self, session: "PdbWrapper"): + def __init__(self, session: "PdbWrapper") -> None: self.session = session - def readinto(self, b): + def readinto(self, b: Any) -> int: with _debug_controller_request_ctx(): response = self.session.controller.debugger_read.call_one( self.session.actor_id.actor_name, self.session.rank, len(b) @@ -147,18 +148,18 @@ def readable(self) -> bool: return True @classmethod - def create(cls, session: "PdbWrapper"): + def create(cls, session: "PdbWrapper") -> io.TextIOWrapper: return io.TextIOWrapper(io.BufferedReader(cls(session))) class WriteWrapper: - def __init__(self, session: "PdbWrapper"): + def __init__(self, session: "PdbWrapper") -> None: self.session = session def writable(self) -> bool: return True - def write(self, s: str): + def write(self, s: str) -> None: function = None lineno = None if self.session.curframe is not None: @@ -177,5 +178,5 @@ def write(self, s: str): ), ).get() - def flush(self): + def flush(self) -> None: pass diff --git a/python/monarch/_src/actor/device_utils.py b/python/monarch/_src/actor/device_utils.py index 36f1d4331..c3ccffb62 100644 --- a/python/monarch/_src/actor/device_utils.py +++ b/python/monarch/_src/actor/device_utils.py @@ -4,14 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import os import re from pathlib import Path -def _local_device_count(): +def _local_device_count() -> int: if "CUDA_VISIBLE_DEVICES" in os.environ: return len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) dev_path = Path("/dev") diff --git a/python/monarch/_src/actor/endpoint.py b/python/monarch/_src/actor/endpoint.py index 90de2478d..ddec09759 100644 --- a/python/monarch/_src/actor/endpoint.py +++ b/python/monarch/_src/actor/endpoint.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import functools from abc import ABC, abstractmethod @@ -26,16 +26,20 @@ Tuple, TYPE_CHECKING, TypeVar, + Union, ) +from monarch._rust_bindings.monarch_hyperactor.actor import MethodSpecifier from monarch._rust_bindings.monarch_hyperactor.shape import Extent from monarch._src.actor.future import Future from monarch._src.actor.telemetry import METER from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call +from opentelemetry.metrics import Histogram + # Histogram for measuring endpoint call latency -endpoint_call_latency_histogram = METER.create_histogram( +endpoint_call_latency_histogram: Histogram = METER.create_histogram( name="endpoint_call_latency.us", description="Latency of endpoint call operations in microseconds", ) @@ -53,20 +57,20 @@ Selection = Literal["all", "choose"] -Propagator = Any +Propagator = Union[None, Literal["cached", "inspect", "mocked"], Callable[..., Any]] class Endpoint(ABC, Generic[P, R]): def __init__(self, propagator: Propagator) -> None: self._propagator_arg = propagator - self._cache: Optional[dict] = None + self._cache: Optional[Dict[Any, Any]] = None @abstractmethod def _send( self, args: Tuple[Any, ...], kwargs: Dict[str, Any], - port: "Optional[Port]" = None, + port: "Optional[Port[R]]" = None, selection: Selection = "all", ) -> Extent: """ @@ -86,13 +90,15 @@ def _port(self, once: bool = False) -> "Tuple[Port[R], PortReceiver[R]]": return Channel[R].open(once) @abstractmethod - def _call_name(self) -> Any: + def _call_name(self) -> MethodSpecifier: """ Something to use in InputChecker to represent calling this thingy. """ pass - def _supervise(self, r: "HyPortReceiver | HyOncePortReceiver") -> Any: + def _supervise( + self, r: "HyPortReceiver | HyOncePortReceiver" + ) -> "HyPortReceiver | HyOncePortReceiver": return r # the following are all 'adverbs' or different ways to handle the @@ -108,13 +114,13 @@ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: """ p, r = self._port(once=True) - # pyre-ignore + # pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any] self._send(args, kwargs, port=p, selection="choose") return r.recv() def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: p, r = self._port(once=True) - # pyre-ignore + # pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any] extent = self._send(args, kwargs, port=p, selection="choose") if extent.nelements != 1: raise ValueError( @@ -123,19 +129,20 @@ def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: return r.recv() def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]": - from monarch._src.actor.actor_mesh import ValueMesh + from monarch._src.actor.actor_mesh import RankedPortReceiver, ValueMesh - start_time = datetime.now() + start_time: datetime = datetime.now() p, unranked = self._port() - r = unranked.ranked() - # pyre-ignore - extent = self._send(args, kwargs, port=p) + r: RankedPortReceiver[R] = unranked.ranked() + # pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any] + extent: Extent = self._send(args, kwargs, port=p) method_specifier = self._call_name() if hasattr(method_specifier, "name"): - method_name = method_specifier.name + # pyre-ignore[16]: MethodSpecifier subclasses ReturnsResponse and ExplicitPort have .name + method_name: str = method_specifier.name else: - method_name = "unknown" + method_name: str = "unknown" async def process() -> "ValueMesh[R]": from monarch._rust_bindings.monarch_hyperactor.shape import Shape @@ -172,11 +179,12 @@ def stream( This enables processing results from multiple actors incrementally as they become available. Returns an async generator of response values. """ - p, r = self._port() - # type: ignore - extent = self._send(args, kwargs, port=p) + p, r_port = self._port() + # pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any] + extent: Extent = self._send(args, kwargs, port=p) + r: "PortReceiver[R]" = r_port - def _stream(): + def _stream() -> Generator[Future[R], None, None]: for _ in range(extent.nelements): yield r.recv() @@ -192,16 +200,23 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None: """ from monarch._src.actor.actor_mesh import send - # pyre-ignore + # pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any] send(self, args, kwargs) @abstractmethod - def _rref(self, args, kwargs) -> Any: ... + def _rref(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> R: ... def rref(self, *args: P.args, **kwargs: P.kwargs) -> R: + # pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any] return self._rref(args, kwargs) - def _propagate(self, args, kwargs, fake_args, fake_kwargs): + def _propagate( + self, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + fake_args: Tuple[Any, ...], + fake_kwargs: Dict[str, Any], + ) -> Any: if self._propagator_arg is None or self._propagator_arg == "cached": if self._cache is None: self._cache = {} @@ -218,12 +233,24 @@ def _propagate(self, args, kwargs, fake_args, fake_kwargs): else: return fake_call(self._propagator_arg, *fake_args, **fake_kwargs) - def _fetch_propagate(self, args, kwargs, fake_args, fake_kwargs): + def _fetch_propagate( + self, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + fake_args: Tuple[Any, ...], + fake_kwargs: Dict[str, Any], + ) -> Any: if self._propagator_arg is None: return # no propgator provided, so we just assume no mutations return self._propagate(args, kwargs, fake_args, fake_kwargs) - def _pipe_propagate(self, args, kwargs, fake_args, fake_kwargs): + def _pipe_propagate( + self, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + fake_args: Tuple[Any, ...], + fake_kwargs: Dict[str, Any], + ) -> Any: if not callable(self._propagator_arg): raise ValueError("Must specify explicit callable for pipe") return self._propagate(args, kwargs, fake_args, fake_kwargs) @@ -260,7 +287,7 @@ def __init__( self._explicit_response_port = explicit_response_port self._instrument = instrument - def __get__(self, instance, owner) -> Endpoint[P, R]: + def __get__(self, instance: Any, owner: Any) -> Endpoint[P, R]: # this is a total lie, but we have to actually # recognize this was defined as an endpoint, # and also lookup the method @@ -274,11 +301,11 @@ class NotAnEndpoint: and to provide the oppurtunity for someone to do endpoint(x.foo) on something that wasn't marked as an endpoint. """ - def __init__(self, ref: "ActorMesh", name: str): + def __init__(self, ref: "ActorMesh[Any]", name: str) -> None: self._ref = ref self._name = name - def __call__(self, *args, **kwargs) -> None: + def __call__(self, *args: Any, **kwargs: Any) -> None: raise RuntimeError( f"Actor {self._ref._class}.{self._name} is not annotated as an endpoint. To call it as one, add a @endpoint decorator to it, or directly wrap it in one as_endpoint(obj.method).call(...)" ) @@ -291,12 +318,13 @@ class EndpointIfy: def __call__( self, function: Callable[Concatenate[Any, P], Awaitable[R]] ) -> Endpoint[P, R]: ... + @overload def __call__( self, function: Callable[Concatenate[Any, P], R] ) -> Endpoint[P, R]: ... - def __call__(self, function: Any): + def __call__(self, function: Any) -> Any: pass @@ -312,7 +340,7 @@ def __call__( self, function: Callable[Concatenate[Any, "Port[R]", P], None] ) -> Endpoint[P, R]: ... - def __call__(self, function: Any): + def __call__(self, function: Any) -> Any: pass @@ -375,12 +403,12 @@ def endpoint( def endpoint( - method=None, + method: Any = None, *, - propagate=None, + propagate: Any = None, explicit_response_port: bool = False, instrument: bool = False, -): +) -> Any: if method is None: return functools.partial( endpoint, diff --git a/python/monarch/_src/actor/event_loop.py b/python/monarch/_src/actor/event_loop.py index c0a2fbe0d..9b1c0b6b7 100644 --- a/python/monarch/_src/actor/event_loop.py +++ b/python/monarch/_src/actor/event_loop.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict """ Module for managing the event loop used by Monarch Python actors. @@ -18,7 +18,7 @@ from pyre_extensions import none_throws -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) _event_loop: Optional[asyncio.AbstractEventLoop] = None _lock = threading.Lock() @@ -35,7 +35,7 @@ def _initialize_event_loop() -> None: return # Create a new thread that will host our event loop - def event_loop_thread(): + def event_loop_thread() -> None: """Target function for the event loop thread.""" global _event_loop, _ready try: diff --git a/python/monarch/_src/actor/future.py b/python/monarch/_src/actor/future.py index e38fa32fb..6f4afadaa 100644 --- a/python/monarch/_src/actor/future.py +++ b/python/monarch/_src/actor/future.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import asyncio from typing import ( @@ -27,7 +27,7 @@ R = TypeVar("R") -async def _aincomplete(impl, self): +async def _aincomplete(impl: Any, self: Any) -> Any: try: return self._set_result(await impl()) except Exception as e: @@ -64,7 +64,7 @@ async def _aincomplete(impl, self): class _Unawaited(NamedTuple): - coro: PythonTask + coro: PythonTask[Any] class _Complete(NamedTuple): @@ -80,7 +80,7 @@ class _Asyncio(NamedTuple): class _Tokio(NamedTuple): - shared: Shared + shared: Shared[Any] _Status = _Unawaited | _Complete | _Exception | _Asyncio | _Tokio @@ -98,7 +98,7 @@ class Future(Generic[R]): """ - def __init__(self, *, coro: "Coroutine[Any, Any, R] | PythonTask[R]"): + def __init__(self, *, coro: "Coroutine[Any, Any, R] | PythonTask[R]") -> None: self._status: _Status = _Unawaited( coro if isinstance(coro, PythonTask) else PythonTask.from_coroutine(coro) ) @@ -138,15 +138,15 @@ def __await__(self) -> Generator[Any, Any, R]: fut = loop.create_future() self._status = _Asyncio(fut) - def set_result(fut, value): + def set_result(fut: asyncio.Future[R], value: R) -> None: if not fut.cancelled(): fut.set_result(value) - def set_exception(fut, e): + def set_exception(fut: asyncio.Future[R], e: Exception) -> None: if not fut.cancelled(): fut.set_exception(e) - async def mark_complete(): + async def mark_complete() -> None: try: func, value = set_result, await coro except Exception as e: @@ -192,7 +192,7 @@ async def mark_complete(): def result(self, timeout: Optional[float] = None) -> R: return self.get(timeout) - def exception(self, timeout: Optional[float] = None): + def exception(self, timeout: Optional[float] = None) -> Optional[Exception]: try: self.get(timeout) return None diff --git a/python/monarch/_src/actor/host_mesh.py b/python/monarch/_src/actor/host_mesh.py index 57ee4d209..9943c5bb2 100644 --- a/python/monarch/_src/actor/host_mesh.py +++ b/python/monarch/_src/actor/host_mesh.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import warnings from math import prod @@ -80,7 +80,7 @@ def __init__( shape: Shape, allocator: AllocateMixin, alloc_constraints: Optional[AllocConstraints] = None, - ): + ) -> None: warnings.warn( ( "DEPRECATION WARNING: using a deprecated version of HostMesh. This is going be removed imminently. " @@ -186,7 +186,7 @@ def fake_in_process_host_v0() -> "HostMeshV0": return HostMeshV0(Shape.unity(), LocalAllocator()) -def hosts_from_config_v0(name: str): +def hosts_from_config_v0(name: str) -> HostMeshV0: """ Get the host mesh 'name' from the monarch configuration for the project. diff --git a/python/monarch/_src/actor/pickle.py b/python/monarch/_src/actor/pickle.py index 23e6779b8..e430522e2 100644 --- a/python/monarch/_src/actor/pickle.py +++ b/python/monarch/_src/actor/pickle.py @@ -4,22 +4,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import collections.abc as abc import io import os import pickle import sys +import types from collections import ChainMap from contextlib import ExitStack -from typing import Any, Callable, Iterable, List, Tuple +from typing import Any, Callable, Iterable, List, Sequence, Tuple import cloudpickle from monarch._rust_bindings.monarch_hyperactor.buffers import Buffer, FrozenBuffer -def maybe_torch(): +def maybe_torch() -> types.ModuleType | None: """ Returns the torch module if it has been loaded, otherwise None. """ @@ -37,7 +38,7 @@ def maybe_torch(): return None -_orig_function_getstate = cloudpickle.cloudpickle._function_getstate +_orig_function_getstate: Any = cloudpickle.cloudpickle._function_getstate # To ensure that the debugger and tracebacks work on remote hosts @@ -47,7 +48,7 @@ def maybe_torch(): # to load the source code for the function, it will use the RemoteImportLoader # to retrieve the source code from the root client, where it *ostensibly* # exists. -def _function_getstate(func): +def _function_getstate(func: Any) -> Any: from monarch._src.actor.source_loader import RemoteImportLoader state, slotstate = _orig_function_getstate(func) @@ -60,10 +61,11 @@ def _function_getstate(func): cloudpickle.cloudpickle._function_getstate = _function_getstate -def _load_from_bytes(b): +def _load_from_bytes(b: bytes | Buffer) -> object: import torch # if we haven't loaded it # we have to now + # pyre-ignore[16]: dynamic torch load causing problems return torch.load( io.BytesIO(b) if isinstance(b, bytes) else b, map_location="cpu", @@ -71,29 +73,32 @@ def _load_from_bytes(b): ) -def _torch_storage(obj): +def _torch_storage(obj: Any) -> Any: import torch # we only get here if torch is already imported b = io.BytesIO() + # pyre-ignore[16]: dynamic torch load causing problems torch.save(obj, b, _use_new_zipfile_serialization=False) return (_load_from_bytes, (b.getvalue(),)) class _Pickler(cloudpickle.Pickler): _torch_initialized = False - _dispatch_table = {} + _dispatch_table: dict[Any, Any] = {} - dispatch_table = ChainMap(_dispatch_table, cloudpickle.Pickler.dispatch_table) + dispatch_table: ChainMap[Any, Any] = ChainMap( + _dispatch_table, cloudpickle.Pickler.dispatch_table + ) - def __init__(self, filter, f: Buffer | io.BytesIO): + def __init__(self, filter: Callable[[Any], bool], f: Buffer | io.BytesIO) -> None: self.f = f super().__init__(self.f) - self._filter = filter - self._saved = [] + self._filter: Callable[[Any], bool] = filter + self._saved: List[Any] = [] _Pickler._init_torch_dispatch() @classmethod - def _init_torch_dispatch(cls): + def _init_torch_dispatch(cls) -> None: # already initialized if cls._torch_initialized: return @@ -108,7 +113,7 @@ def _init_torch_dispatch(cls): cls._dispatch_table[key] = _torch_storage cls._torch_initialized = True - def persistent_id(self, obj): + def persistent_id(self, obj: Any) -> int | None: if not self._filter(obj): return None self._saved.append(obj) @@ -116,18 +121,18 @@ def persistent_id(self, obj): class _Unpickler(pickle.Unpickler): - def __init__(self, data: bytes | FrozenBuffer, sequence: Iterable[Any]): + def __init__(self, data: bytes | FrozenBuffer, sequence: Iterable[Any]) -> None: if isinstance(data, FrozenBuffer): super().__init__(data) else: super().__init__(io.BytesIO(data)) - self._iter = iter(sequence) - self._values = [] + self._iter: abc.Iterator[Any] = iter(sequence) + self._values: List[Any] = [] - def persistent_load(self, id): - while id >= len(self._values): + def persistent_load(self, pid: Any) -> Any: + while pid >= len(self._values): self._values.append(next(self._iter)) - return self._values[id] + return self._values[pid] def flatten(obj: Any, filter: Callable[[Any], bool]) -> Tuple[List[Any], FrozenBuffer]: diff --git a/python/monarch/_src/actor/python_extension_methods.py b/python/monarch/_src/actor/python_extension_methods.py index 27fd746e4..f3ff7d920 100644 --- a/python/monarch/_src/actor/python_extension_methods.py +++ b/python/monarch/_src/actor/python_extension_methods.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import importlib @@ -15,7 +15,7 @@ class PatchRustClass: - def __init__(self, rust_class: Type): + def __init__(self, rust_class: Type[T]) -> None: self.rust_class = rust_class def __call__(self, python_class: Type[T]) -> Type[T]: diff --git a/python/monarch/_src/actor/shape.py b/python/monarch/_src/actor/shape.py index df75d7ecc..a1216d61f 100644 --- a/python/monarch/_src/actor/shape.py +++ b/python/monarch/_src/actor/shape.py @@ -4,13 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import itertools import operator from abc import ABC, abstractmethod -from typing import Dict, Generator, Sequence, Tuple, Union +from typing import Any, Dict, Generator, Sequence, Tuple, Union from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Shape, Slice @@ -38,7 +38,7 @@ class ShapeExt: functionality.""" @staticmethod - def slice(shape: Shape, **kwargs) -> Shape: + def slice(shape: Shape, **kwargs: Any) -> Shape: """Select along named dimensions. Integer values remove dimensions, slice objects keep dimensions but restrict them. @@ -78,7 +78,7 @@ def _labels(self) -> Tuple[str, ...]: ... @abstractmethod def _new_with_shape(self, shape: Shape) -> Self: ... - def slice(self, **kwargs) -> Self: + def slice(self, **kwargs: Any) -> Self: """Select along named dimensions. Integer values remove dimensions, slice objects keep dimensions but restrict them. @@ -87,7 +87,7 @@ def slice(self, **kwargs) -> Self: shape = Shape(list(self._labels), self._ndslice) return self._new_with_shape(ShapeExt.slice(shape, **kwargs)) - def split(self, **kwargs) -> Self: + def split(self, **kwargs: Any) -> Self: """ Returns a new device mesh with some dimensions of this mesh split. For instance, this call splits the host dimension into dp and pp dimensions, @@ -195,7 +195,7 @@ def flatten(self, name: str) -> Self: ) ) - def rename(self, **kwargs) -> Self: + def rename(self, **kwargs: Any) -> Self: """ Returns a new device mesh with some of dimensions renamed. Dimensions not mentioned are retained: diff --git a/python/monarch/_src/actor/source_loader.py b/python/monarch/_src/actor/source_loader.py index 0765ad27c..b6a956a5b 100644 --- a/python/monarch/_src/actor/source_loader.py +++ b/python/monarch/_src/actor/source_loader.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import functools import importlib import importlib.abc @@ -35,7 +35,7 @@ def load_remote_source(filename: str) -> str: class RemoteImportLoader(importlib.abc.Loader): - def __init__(self, filename: str): + def __init__(self, filename: str) -> None: self._filename = filename def get_source(self, _module_name: str) -> str: diff --git a/python/monarch/_src/actor/sync_state.py b/python/monarch/_src/actor/sync_state.py index 721eaa8a7..6cf814c10 100644 --- a/python/monarch/_src/actor/sync_state.py +++ b/python/monarch/_src/actor/sync_state.py @@ -4,14 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import asyncio from contextlib import contextmanager +from typing import Any, Generator @contextmanager -def fake_sync_state(): +def fake_sync_state() -> Generator[None, None, None]: prev_loop = asyncio.events._get_running_loop() asyncio._set_running_loop(None) try: diff --git a/python/monarch/_src/actor/tensor_engine_shim.py b/python/monarch/_src/actor/tensor_engine_shim.py index 71117f740..6e41bf39c 100644 --- a/python/monarch/_src/actor/tensor_engine_shim.py +++ b/python/monarch/_src/actor/tensor_engine_shim.py @@ -4,11 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import importlib from functools import partial -from typing import Any, Optional, Sequence, TYPE_CHECKING +from typing import ( + Any, + Callable, + Optional, + overload, + ParamSpec, + Sequence, + TYPE_CHECKING, + TypeVar, +) """ This file provides a type annoated shim for using tensor engine functions @@ -24,17 +33,34 @@ from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer +P = ParamSpec("P") +F = TypeVar("F", bound=Callable[..., Any]) -def shim(fn=None, *, module=None): + +@overload +def shim(fn: F, *, module: Optional[str] = None) -> F: ... + + +@overload +def shim( + fn: None = None, *, module: Optional[str] = None +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ... + + +def shim( + fn: Optional[Callable[..., Any]] = None, *, module: Optional[str] = None +) -> Any: if fn is None: return partial(shim, module=module) - impl = None - name = fn.__name__ + impl: Optional[Callable[..., Any]] = None + name: str = fn.__name__ - def wrap(*args, **kwargs): + def wrap(*args: Any, **kwargs: Any) -> Any: nonlocal impl if impl is None: + # TODO: See if there's a reasonable way to assert that the module name is not none + # pyre-ignore Incompatible parameter type [6]: In call `importlib.import_module`, for 1st positional argument, expected `str` but got `Optional[str]` impl = getattr(importlib.import_module(module), name) return impl(*args, **kwargs) @@ -43,7 +69,7 @@ def wrap(*args, **kwargs): @shim(module="monarch.mesh_controller") def actor_send( - endpoint: "ActorEndpoint", + endpoint: "ActorEndpoint[..., ...]", args_kwargs_tuple: bytes, refs: "Sequence[Any]", port: "Optional[Port[Any]]", @@ -52,12 +78,16 @@ def actor_send( @shim(module="monarch.mesh_controller") -def actor_rref(endpoint, args_kwargs_tuple: FrozenBuffer, refs: Sequence[Any]): ... +def actor_rref( + endpoint: Any, + args_kwargs_tuple: FrozenBuffer, + refs: Sequence[Any], +) -> Any: ... @shim(module="monarch.common.remote") -def _cached_propagation(_cache, rfunction, args, kwargs) -> Any: ... +def _cached_propagation(_cache: Any, rfunction: Any, args: Any, kwargs: Any) -> Any: ... @shim(module="monarch.common.fake") -def fake_call(fn, *args, **kwargs): ... +def fake_call(fn: Any, *args: Any, **kwargs: Any) -> Any: ... diff --git a/python/monarch/_src/actor/v1/__init__.py b/python/monarch/_src/actor/v1/__init__.py index 3d8dea9c3..c85c1bdf7 100644 --- a/python/monarch/_src/actor/v1/__init__.py +++ b/python/monarch/_src/actor/v1/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import os -enabled = os.environ.get("MONARCH_V0_WORKAROUND_DO_NOT_USE", "0") != "1" +enabled: bool = os.environ.get("MONARCH_V0_WORKAROUND_DO_NOT_USE", "0") != "1"