diff --git a/python/monarch/common/device_mesh.py b/python/monarch/common/device_mesh.py index 96e76816d..da02da055 100644 --- a/python/monarch/common/device_mesh.py +++ b/python/monarch/common/device_mesh.py @@ -244,24 +244,24 @@ def __call__(self, **kwargs) -> "DeviceMesh": def rotate(self, **kwargs: Dict[str, int]): raise NotImplementedError() - def rank(self, dims: Union[str, Sequence[str]]) -> int: + def rank(self, dims: Union[str, Sequence[str]]) -> torch.Tensor: self.define_remotely() if isinstance(dims, str): if dims not in self.names: raise KeyError(f"{self} does not have dimension {repr(dims)}") return _remote( - "monarch.worker.worker._rank", + _rank, propagate=lambda _self, _dims: torch.full((), 0, dtype=torch.long), )(self, dims) - combined_rank = 0 + combined_rank: Any = 0 for dim in dims: combined_rank *= self.size(dim) combined_rank += self.rank(dim) return combined_rank @property - def ranks(self) -> dict[str, int]: + def ranks(self) -> dict[str, torch.Tensor]: return {dim: self.rank(dim) for dim in self.names} def process_idx(self): @@ -334,6 +334,10 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return _remote(func, propagate=func)(*args, **kwargs) +def _rank(mesh, dim): + return torch.full((), mesh.dims[dim].rank, dtype=torch.long) + + @contextmanager def _dispatch(): global _dispatch_enabled @@ -401,7 +405,7 @@ def _to_mesh(tensor: Union["Tensor", "MeshSliceTensor"]) -> "Tensor": def slice_mesh( tensors: Any, - **kwargs: Dict[str, Union[int, slice]], + **kwargs: Union[int, slice], ) -> Any: """ Performs the slice_mesh operation for each tensor in tensors. diff --git a/python/monarch/common/shape.py b/python/monarch/common/shape.py index 6b65b28a7..41afd8672 100644 --- a/python/monarch/common/shape.py +++ b/python/monarch/common/shape.py @@ -44,6 +44,9 @@ def _ndslice(self) -> NDSlice: ... @abstractmethod def _labels(self) -> Tuple[str, ...]: ... + # mesh trait guarentees that its own calls to _new_with_shape + # will only ever select a shape that is a subspace of the + # current _ndslice. @abstractmethod def _new_with_shape(self, shape: Shape) -> Self: ... diff --git a/python/monarch/common/tensor.py b/python/monarch/common/tensor.py index 1a0c5ffad..c379f4af8 100644 --- a/python/monarch/common/tensor.py +++ b/python/monarch/common/tensor.py @@ -7,17 +7,20 @@ # pyre-unsafe import itertools import traceback +import typing import warnings from collections import defaultdict from typing import ( Any, Callable, + cast, Dict, Iterable, List, Literal, NamedTuple, Optional, + runtime_checkable, Sequence, TYPE_CHECKING, TypeVar, @@ -35,7 +38,8 @@ from .borrows import StorageAliases if TYPE_CHECKING: - from .device_mesh import DeviceMesh + from monarch.common.device_mesh import DeviceMesh + from .fake import fake_call from .function import Propagator, ResolvableFunction from .invocation import Invocation @@ -52,6 +56,12 @@ T = TypeVar("T") +@runtime_checkable +class HasDeviceMesh(typing.Protocol): + @property + def _device_mesh(self) -> "DeviceMesh": ... + + class DropLocation(NamedTuple): tensor_id: int traceback: List[traceback.FrameSummary] @@ -167,7 +177,11 @@ def _use(self): self._on_first_use(self) self._on_first_use = None - def to_mesh(self, mesh: "DeviceMesh", stream: Optional["Stream"] = None): + def to_mesh( + self, + mesh: Union["DeviceMesh", "HasDeviceMesh"], + stream: Optional["Stream"] = None, + ): """ Move data between one device mesh and another. Sizes of named dimensions must match. If mesh has dimensions that self.mesh does not, it will broadcast to those dimensions. @@ -177,6 +191,8 @@ def to_mesh(self, mesh: "DeviceMesh", stream: Optional["Stream"] = None): t.slice_mesh(batch=0).to_mesh(t.mesh) """ + if isinstance(mesh, HasDeviceMesh): + mesh = mesh._device_mesh return MeshSliceTensor(self, self.mesh).to_mesh(mesh, stream) def reduce_( @@ -344,7 +360,7 @@ def reduce( ) return r - def slice_mesh(self, **kwargs: Dict[str, Union[int, slice]]) -> "MeshSliceTensor": + def slice_mesh(self, **kwargs: Union[int, slice]) -> "MeshSliceTensor": # technically a slice of a device mesh and a device mesh are not same thing # because a device mesh also has caches for doing collectives. # but this is an easy way to create a MeshSliceTensor until we optimize @@ -368,8 +384,13 @@ def __init__(self, tensor: "Tensor", slicing: "DeviceMesh"): self.slicing = slicing def to_mesh( - self, mesh: "DeviceMesh", stream: Optional["Stream"] = None + self, + mesh: Union["DeviceMesh", "HasDeviceMesh"], + stream: Optional["Stream"] = None, ) -> "Tensor": + if isinstance(mesh, HasDeviceMesh): + mesh = mesh._device_mesh + if stream is None: stream = self.tensor.stream diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 1335c88ef..499f45a59 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -11,7 +11,7 @@ import traceback from collections import deque from logging import Logger -from typing import List, NamedTuple, Optional, Union +from typing import List, NamedTuple, Optional, TYPE_CHECKING, Union import torch.utils._python_dispatch @@ -24,7 +24,13 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) -from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh + +if TYPE_CHECKING: + from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ( + ProcMesh as HyProcMesh, + ) + from monarch.proc_mesh import ProcMesh + from monarch._rust_bindings.monarch_hyperactor.shape import Point from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction @@ -33,7 +39,6 @@ from monarch.common.device_mesh import DeviceMesh, no_mesh from monarch.common.invocation import DeviceException, RemoteException from monarch.controller.debugger import read as debugger_read, write as debugger_write -from monarch.proc_mesh import ProcMesh from monarch.rust_local_mesh import _get_worker_exec_info from pyre_extensions import none_throws @@ -41,7 +46,7 @@ class Controller(_Controller): - def __init__(self, workers: HyProcMesh) -> None: + def __init__(self, workers: "HyProcMesh") -> None: super().__init__() # Buffer for messages unrelated to debugging that are received while a # debugger session is active. @@ -250,7 +255,7 @@ def shutdown( self.inner.drain_and_stop() -def spawn_tensor_engine(proc_mesh: ProcMesh) -> DeviceMesh: +def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh: # This argument to Controller # is currently only used for debug printing. It should be fixed to # report the proc ID instead of the rank it currently does. diff --git a/python/monarch/proc_mesh.py b/python/monarch/proc_mesh.py index 065f680b5..8819e2adf 100644 --- a/python/monarch/proc_mesh.py +++ b/python/monarch/proc_mesh.py @@ -7,8 +7,22 @@ # pyre-strict import sys +from contextlib import AbstractContextManager + +from typing import ( + Any, + cast, + Dict, + List, + Optional, + Sequence, + Type, + TYPE_CHECKING, + TypeVar, +) -from typing import Any, cast, List, Optional, Type, TypeVar +if TYPE_CHECKING: + import torch import monarch from monarch import ActorFuture as Future @@ -24,7 +38,9 @@ from monarch.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef from monarch.common._device_utils import _local_device_count +from monarch.common.device_mesh import DeviceMesh from monarch.common.shape import MeshTrait +from monarch.mesh_controller import spawn_tensor_engine from monarch.rdma import RDMAManager T = TypeVar("T") @@ -45,25 +61,43 @@ def _allocate_blocking(alloc: Alloc) -> "ProcMesh": class ProcMesh(MeshTrait): - def __init__(self, hy_proc_mesh: HyProcMesh) -> None: + def __init__( + self, + hy_proc_mesh: HyProcMesh, + _mock_shape: Optional[Shape] = None, + _device_mesh: Optional[DeviceMesh] = None, + ) -> None: self._proc_mesh = hy_proc_mesh + self._mock_shape: Optional[Shape] = _mock_shape self._mailbox: Mailbox = self._proc_mesh.client - self._rdma_manager: RDMAManager = self._spawn_blocking( - "rdma_manager", RDMAManager - ) + self._rdma_manager: Optional[RDMAManager] = None + self._maybe_device_mesh: Optional[DeviceMesh] = _device_mesh + if _mock_shape is None: + self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager) + + @property + def _shape(self) -> Shape: + return self._proc_mesh.shape if self._mock_shape is None else self._mock_shape @property def _ndslice(self) -> Slice: - return self._proc_mesh.shape.ndslice + return self._shape.ndslice @property def _labels(self) -> List[str]: - return self._proc_mesh.shape.labels + return self._shape.labels def _new_with_shape(self, shape: Shape) -> "ProcMesh": - raise NotImplementedError("ProcMesh slicing is not implemeted yet.") + device_mesh = ( + None + if self._device_mesh is None + else self._device_mesh._new_with_shape(shape) + ) + return ProcMesh(self._proc_mesh, _mock_shape=shape, _device_mesh=device_mesh) def spawn(self, name: str, Class: Type[T], *args: Any, **kwargs: Any) -> Future[T]: + if self._mock_shape is not None: + raise NotImplementedError("NYI: spawn on slice of a proc mesh.") return Future( lambda: self._spawn_nonblocking(name, Class, *args, **kwargs), lambda: self._spawn_blocking(name, Class, *args, **kwargs), @@ -120,6 +154,26 @@ async def _spawn_nonblocking( service._create(args, kwargs) return cast(T, service) + @property + def _device_mesh(self) -> "DeviceMesh": + if self._maybe_device_mesh is None: + if self._mock_shape is not None: + raise NotImplementedError( + "NYI: activating a proc mesh must first happen on the root proc_mesh until we fix spawning on submeshes." + ) + self._maybe_device_mesh = spawn_tensor_engine(self) + return self._maybe_device_mesh + + # pyre-ignore + def activate(self) -> AbstractContextManager: + return self._device_mesh.activate() + + def rank_tensor(self, dim: str | Sequence[str]) -> "torch.Tensor": + return self._device_mesh.rank(dim) + + def rank_tensors(self) -> Dict[str, "torch.Tensor"]: + return self._device_mesh.ranks + async def local_proc_mesh_nonblocking( *, gpus: Optional[int] = None, hosts: int = 1 diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 45de81e79..b7da6df05 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -391,10 +391,13 @@ def check(module, path): check(bindings, "monarch._rust_bindings") -@pytest.mark.skipif( +two_gpu = pytest.mark.skipif( torch.cuda.device_count() < 2, reason="Not enough GPUs, this test requires at least 2 GPUs", ) + + +@two_gpu def test_tensor_engine() -> None: pm = proc_mesh(gpus=2).get() @@ -591,3 +594,20 @@ async def test_actor_tls() -> None: assert 2 == await am.get.call_one() # assert 4 == await am.get_async.call_one() + + +@two_gpu +def test_proc_mesh_tensor_engine() -> None: + pm = proc_mesh(gpus=2).get() + with pm.activate(): + f = 10 * pm.rank_tensor("gpus").cuda() + a = monarch.inspect(f, hosts=0, gpus=0) + b = monarch.inspect(f, hosts=0, gpus=1) + + one = pm.slice(gpus=1) + with one.activate(): + sliced_b = monarch.slice_mesh(f, gpus=1).to_mesh(one) + c = monarch.inspect(sliced_b * 10) + assert a == 0 + assert b == 10 + assert c == 100