Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions python/monarch/common/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,24 +244,24 @@ def __call__(self, **kwargs) -> "DeviceMesh":
def rotate(self, **kwargs: Dict[str, int]):
raise NotImplementedError()

def rank(self, dims: Union[str, Sequence[str]]) -> int:
def rank(self, dims: Union[str, Sequence[str]]) -> torch.Tensor:
self.define_remotely()
if isinstance(dims, str):
if dims not in self.names:
raise KeyError(f"{self} does not have dimension {repr(dims)}")
return _remote(
"monarch.worker.worker._rank",
_rank,
propagate=lambda _self, _dims: torch.full((), 0, dtype=torch.long),
)(self, dims)

combined_rank = 0
combined_rank: Any = 0
for dim in dims:
combined_rank *= self.size(dim)
combined_rank += self.rank(dim)
return combined_rank

@property
def ranks(self) -> dict[str, int]:
def ranks(self) -> dict[str, torch.Tensor]:
return {dim: self.rank(dim) for dim in self.names}

def process_idx(self):
Expand Down Expand Up @@ -334,6 +334,10 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return _remote(func, propagate=func)(*args, **kwargs)


def _rank(mesh, dim):
return torch.full((), mesh.dims[dim].rank, dtype=torch.long)


@contextmanager
def _dispatch():
global _dispatch_enabled
Expand Down Expand Up @@ -401,7 +405,7 @@ def _to_mesh(tensor: Union["Tensor", "MeshSliceTensor"]) -> "Tensor":

def slice_mesh(
tensors: Any,
**kwargs: Dict[str, Union[int, slice]],
**kwargs: Union[int, slice],
) -> Any:
"""
Performs the slice_mesh operation for each tensor in tensors.
Expand Down
3 changes: 3 additions & 0 deletions python/monarch/common/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def _ndslice(self) -> NDSlice: ...
@abstractmethod
def _labels(self) -> Tuple[str, ...]: ...

# mesh trait guarentees that its own calls to _new_with_shape
# will only ever select a shape that is a subspace of the
# current _ndslice.
@abstractmethod
def _new_with_shape(self, shape: Shape) -> Self: ...

Expand Down
29 changes: 25 additions & 4 deletions python/monarch/common/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@
# pyre-unsafe
import itertools
import traceback
import typing
import warnings
from collections import defaultdict
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Literal,
NamedTuple,
Optional,
runtime_checkable,
Sequence,
TYPE_CHECKING,
TypeVar,
Expand All @@ -35,7 +38,8 @@
from .borrows import StorageAliases

if TYPE_CHECKING:
from .device_mesh import DeviceMesh
from monarch.common.device_mesh import DeviceMesh

from .fake import fake_call
from .function import Propagator, ResolvableFunction
from .invocation import Invocation
Expand All @@ -52,6 +56,12 @@
T = TypeVar("T")


@runtime_checkable
class HasDeviceMesh(typing.Protocol):
@property
def _device_mesh(self) -> "DeviceMesh": ...


class DropLocation(NamedTuple):
tensor_id: int
traceback: List[traceback.FrameSummary]
Expand Down Expand Up @@ -167,7 +177,11 @@ def _use(self):
self._on_first_use(self)
self._on_first_use = None

def to_mesh(self, mesh: "DeviceMesh", stream: Optional["Stream"] = None):
def to_mesh(
self,
mesh: Union["DeviceMesh", "HasDeviceMesh"],
stream: Optional["Stream"] = None,
):
"""
Move data between one device mesh and another. Sizes of named dimensions must match.
If mesh has dimensions that self.mesh does not, it will broadcast to those dimensions.
Expand All @@ -177,6 +191,8 @@ def to_mesh(self, mesh: "DeviceMesh", stream: Optional["Stream"] = None):
t.slice_mesh(batch=0).to_mesh(t.mesh)

"""
if isinstance(mesh, HasDeviceMesh):
mesh = mesh._device_mesh
return MeshSliceTensor(self, self.mesh).to_mesh(mesh, stream)

def reduce_(
Expand Down Expand Up @@ -344,7 +360,7 @@ def reduce(
)
return r

def slice_mesh(self, **kwargs: Dict[str, Union[int, slice]]) -> "MeshSliceTensor":
def slice_mesh(self, **kwargs: Union[int, slice]) -> "MeshSliceTensor":
# technically a slice of a device mesh and a device mesh are not same thing
# because a device mesh also has caches for doing collectives.
# but this is an easy way to create a MeshSliceTensor until we optimize
Expand All @@ -368,8 +384,13 @@ def __init__(self, tensor: "Tensor", slicing: "DeviceMesh"):
self.slicing = slicing

def to_mesh(
self, mesh: "DeviceMesh", stream: Optional["Stream"] = None
self,
mesh: Union["DeviceMesh", "HasDeviceMesh"],
stream: Optional["Stream"] = None,
) -> "Tensor":
if isinstance(mesh, HasDeviceMesh):
mesh = mesh._device_mesh

if stream is None:
stream = self.tensor.stream

Expand Down
15 changes: 10 additions & 5 deletions python/monarch/mesh_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import traceback
from collections import deque
from logging import Logger
from typing import List, NamedTuple, Optional, Union
from typing import List, NamedTuple, Optional, TYPE_CHECKING, Union

import torch.utils._python_dispatch

Expand All @@ -24,7 +24,13 @@
from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
ActorId,
)
from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh

if TYPE_CHECKING:
from monarch._rust_bindings.monarch_hyperactor.proc_mesh import (
ProcMesh as HyProcMesh,
)
from monarch.proc_mesh import ProcMesh

from monarch._rust_bindings.monarch_hyperactor.shape import Point

from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
Expand All @@ -33,15 +39,14 @@
from monarch.common.device_mesh import DeviceMesh, no_mesh
from monarch.common.invocation import DeviceException, RemoteException
from monarch.controller.debugger import read as debugger_read, write as debugger_write
from monarch.proc_mesh import ProcMesh
from monarch.rust_local_mesh import _get_worker_exec_info
from pyre_extensions import none_throws

logger: Logger = logging.getLogger(__name__)


class Controller(_Controller):
def __init__(self, workers: HyProcMesh) -> None:
def __init__(self, workers: "HyProcMesh") -> None:
super().__init__()
# Buffer for messages unrelated to debugging that are received while a
# debugger session is active.
Expand Down Expand Up @@ -250,7 +255,7 @@ def shutdown(
self.inner.drain_and_stop()


def spawn_tensor_engine(proc_mesh: ProcMesh) -> DeviceMesh:
def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh:
# This argument to Controller
# is currently only used for debug printing. It should be fixed to
# report the proc ID instead of the rank it currently does.
Expand Down
70 changes: 62 additions & 8 deletions python/monarch/proc_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,22 @@
# pyre-strict

import sys
from contextlib import AbstractContextManager

from typing import (
Any,
cast,
Dict,
List,
Optional,
Sequence,
Type,
TYPE_CHECKING,
TypeVar,
)

from typing import Any, cast, List, Optional, Type, TypeVar
if TYPE_CHECKING:
import torch

import monarch
from monarch import ActorFuture as Future
Expand All @@ -24,7 +38,9 @@
from monarch.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef

from monarch.common._device_utils import _local_device_count
from monarch.common.device_mesh import DeviceMesh
from monarch.common.shape import MeshTrait
from monarch.mesh_controller import spawn_tensor_engine
from monarch.rdma import RDMAManager

T = TypeVar("T")
Expand All @@ -45,25 +61,43 @@ def _allocate_blocking(alloc: Alloc) -> "ProcMesh":


class ProcMesh(MeshTrait):
def __init__(self, hy_proc_mesh: HyProcMesh) -> None:
def __init__(
self,
hy_proc_mesh: HyProcMesh,
_mock_shape: Optional[Shape] = None,
_device_mesh: Optional[DeviceMesh] = None,
) -> None:
self._proc_mesh = hy_proc_mesh
self._mock_shape: Optional[Shape] = _mock_shape
self._mailbox: Mailbox = self._proc_mesh.client
self._rdma_manager: RDMAManager = self._spawn_blocking(
"rdma_manager", RDMAManager
)
self._rdma_manager: Optional[RDMAManager] = None
self._maybe_device_mesh: Optional[DeviceMesh] = _device_mesh
if _mock_shape is None:
self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager)

@property
def _shape(self) -> Shape:
return self._proc_mesh.shape if self._mock_shape is None else self._mock_shape

@property
def _ndslice(self) -> Slice:
return self._proc_mesh.shape.ndslice
return self._shape.ndslice

@property
def _labels(self) -> List[str]:
return self._proc_mesh.shape.labels
return self._shape.labels

def _new_with_shape(self, shape: Shape) -> "ProcMesh":
raise NotImplementedError("ProcMesh slicing is not implemeted yet.")
device_mesh = (
None
if self._device_mesh is None
else self._device_mesh._new_with_shape(shape)
)
return ProcMesh(self._proc_mesh, _mock_shape=shape, _device_mesh=device_mesh)

def spawn(self, name: str, Class: Type[T], *args: Any, **kwargs: Any) -> Future[T]:
if self._mock_shape is not None:
raise NotImplementedError("NYI: spawn on slice of a proc mesh.")
return Future(
lambda: self._spawn_nonblocking(name, Class, *args, **kwargs),
lambda: self._spawn_blocking(name, Class, *args, **kwargs),
Expand Down Expand Up @@ -120,6 +154,26 @@ async def _spawn_nonblocking(
service._create(args, kwargs)
return cast(T, service)

@property
def _device_mesh(self) -> "DeviceMesh":
if self._maybe_device_mesh is None:
if self._mock_shape is not None:
raise NotImplementedError(
"NYI: activating a proc mesh must first happen on the root proc_mesh until we fix spawning on submeshes."
)
self._maybe_device_mesh = spawn_tensor_engine(self)
return self._maybe_device_mesh

# pyre-ignore
def activate(self) -> AbstractContextManager:
return self._device_mesh.activate()

def rank_tensor(self, dim: str | Sequence[str]) -> "torch.Tensor":
return self._device_mesh.rank(dim)

def rank_tensors(self) -> Dict[str, "torch.Tensor"]:
return self._device_mesh.ranks


async def local_proc_mesh_nonblocking(
*, gpus: Optional[int] = None, hosts: int = 1
Expand Down
22 changes: 21 additions & 1 deletion python/tests/test_python_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,13 @@ def check(module, path):
check(bindings, "monarch._rust_bindings")


@pytest.mark.skipif(
two_gpu = pytest.mark.skipif(
torch.cuda.device_count() < 2,
reason="Not enough GPUs, this test requires at least 2 GPUs",
)


@two_gpu
def test_tensor_engine() -> None:
pm = proc_mesh(gpus=2).get()

Expand Down Expand Up @@ -591,3 +594,20 @@ async def test_actor_tls() -> None:

assert 2 == await am.get.call_one()
# assert 4 == await am.get_async.call_one()


@two_gpu
def test_proc_mesh_tensor_engine() -> None:
pm = proc_mesh(gpus=2).get()
with pm.activate():
f = 10 * pm.rank_tensor("gpus").cuda()
a = monarch.inspect(f, hosts=0, gpus=0)
b = monarch.inspect(f, hosts=0, gpus=1)

one = pm.slice(gpus=1)
with one.activate():
sliced_b = monarch.slice_mesh(f, gpus=1).to_mesh(one)
c = monarch.inspect(sliced_b * 10)
assert a == 0
assert b == 10
assert c == 100