diff --git a/torchft/_test/diloco_trainer.py b/torchft/_test/diloco_trainer.py index 1a581f65..812604e1 100644 --- a/torchft/_test/diloco_trainer.py +++ b/torchft/_test/diloco_trainer.py @@ -1,15 +1,13 @@ import copy import logging import os -from contextlib import ExitStack from datetime import timedelta -from typing import Any, cast, Dict, List +from typing import Any, Dict import torch from torch import nn -from torch.distributed.tensor import DTensor +from torch.distributed.tensor import DeviceMesh, DTensor -from torchft.device_mesh import ft_init_device_mesh, ManagedDeviceMesh from torchft.local_sgd import DiLoCo from torchft.manager import Manager from torchft.manager_integ_test import MyModel, Runner @@ -113,7 +111,7 @@ def __init__( self.manager: Manager = self.setup_manager() - self.ft_device_mesh: None | ManagedDeviceMesh = None + self.device_mesh: None | DeviceMesh = None self.setup_distributed() self.criterion: nn.CrossEntropyLoss = nn.CrossEntropyLoss() @@ -197,12 +195,9 @@ def setup_distributed(self) -> None: os.environ["WORLD_SIZE"] = str(self.runner.world_size) os.environ["RANK"] = str(self.rank) - self.ft_device_mesh = ft_init_device_mesh( - device_type=self.device.type, - mesh_shape=(self.runner.world_size, 1), - mesh_dim_names=("replicate", "none"), - replicate_dim=0, - manager=self.manager, + self.device_mesh = DeviceMesh( + self.device.type, + torch.arange(self.runner.world_size), ) # Convert model parameters to DTensor @@ -211,7 +206,7 @@ def setup_distributed(self) -> None: for param in layer.parameters(): param = DTensor.from_local( param, - device_mesh=self.ft_device_mesh, + device_mesh=self.device_mesh, ) def load_state_dict(self, state_dict: Dict[str, Dict[str, object]]) -> None: diff --git a/torchft/device_mesh.py b/torchft/device_mesh.py deleted file mode 100644 index bb255743..00000000 --- a/torchft/device_mesh.py +++ /dev/null @@ -1,340 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, Optional, TYPE_CHECKING, Union - -import torch -from torch._C._distributed_c10d import Backend as C10dBackend -from torch.distributed import ( - DeviceMesh, - get_rank, - init_device_mesh, - ProcessGroup as BaseProcessGroup, -) -from torch.distributed._mesh_layout import _MeshLayout -from torch.distributed.tensor.device_mesh import _mesh_resources - -from torchft.manager import Manager - -if TYPE_CHECKING: - from torchft.process_group import ManagedProcessGroup, ProcessGroup - - -def extend_device_mesh( - mesh: DeviceMesh, pg: ProcessGroup, name: str = "dp", dim: int = 0 -) -> DeviceMesh: - """ - This is a helper method to extend a traditional DeviceMesh with a torchft ProcessGroup for usage with DeviceMesh based APIs such as FSDPv2 with hybrid sharding. - - Resizable PGs aren't natively supported by DeviceMesh so we lie to - DeviceMesh and say the PG is world size 1. This is fine as long as any - numeric scaling is handled at the PG level. - - Args: - mesh: The DeviceMesh to extend - pg: The ProcessGroup to add to the mesh - name: The name of the new dimension - dim: The dimension to add the ProcessGroup to - """ - groups = mesh.get_all_groups() - groups.insert(dim, pg) - mesh_dim_names = list(mesh.mesh_dim_names or []) - mesh_dim_names.insert(dim, name) - - return DeviceMesh.from_group( - group=groups, - device_type=mesh.device_type, - mesh=mesh.mesh.unsqueeze(dim), - mesh_dim_names=tuple(mesh_dim_names), - ) - - -class ManagedDeviceMesh(DeviceMesh): - replicate_pg_singleton: Optional["ManagedProcessGroup"] = None - - def __init__( - self, - mesh: Optional[DeviceMesh], - mesh_dim_names: tuple[str, ...], - replicate_pg: ManagedProcessGroup, - replicate_dim: int, - parent: Optional["ManagedDeviceMesh"], - ) -> None: - if mesh is None and parent is None: - raise ValueError( - "ManagedDeviceMesh doesn't support both mesh and parent are None." - ) - self._mesh = mesh - self._mesh_dim_names = mesh_dim_names - self.replicate_pg = replicate_pg - self.replicate_dim = replicate_dim - self.replicate_dim_name: str = mesh_dim_names[replicate_dim] - self.parent = parent - self.flatten_meshes: Dict[str, DeviceMesh] = {} - self._flatten_mapping: Dict[str, "DeviceMesh"] = {} - self._device_type: str - if mesh is not None: - self._device_type = mesh.device_type - self._layout: _MeshLayout = mesh._layout - else: - assert parent is not None - self._device_type = parent.device_type - self._layout: _MeshLayout = parent._layout - self._flatten_mesh_list: tuple[DeviceMesh, ...] = tuple() - self._thread_id: Optional[int] = None - self._hash: Optional[int] = None - - def __getstate__(self) -> Dict[str, Any]: - state = self.__dict__.copy() - state["replicate_pg"] = None - return state - - def __setstate__(self, state: Dict[str, Any]) -> None: - self.__dict__.update(state) - assert self.replicate_pg_singleton is not None - self.replicate_pg = self.replicate_pg_singleton - - def __getitem__(self, mesh_dim_names: Union[str, tuple[str, ...]]) -> DeviceMesh: - if isinstance(mesh_dim_names, str): - if mesh_dim_names == self.replicate_dim_name: - res_submesh = ManagedDeviceMesh( - mesh=None, - mesh_dim_names=(mesh_dim_names,), - replicate_pg=self.replicate_pg, - replicate_dim=0, - parent=self, - ) - elif mesh_dim_names in self.flatten_meshes: - res_submesh = self.flatten_meshes[mesh_dim_names] - else: - assert self._mesh is not None - res_submesh = self._mesh[mesh_dim_names] - else: - assert isinstance(mesh_dim_names, tuple) - if self.replicate_dim_name not in mesh_dim_names: - assert self._mesh is not None - res_submesh = self._mesh[mesh_dim_names] - else: - mesh_dim_names_wo_replicate = tuple( - n for n in mesh_dim_names if n != self.replicate_dim_name - ) - assert self._mesh is not None - res_submesh = ManagedDeviceMesh( - self._mesh[mesh_dim_names_wo_replicate], - mesh_dim_names, - self.replicate_pg, - mesh_dim_names.index(self.replicate_dim_name), - parent=self, - ) - - # TODO: find a better way to do this that doesn't depend on device mesh - # internals - root = _mesh_resources.get_root_mesh(self) - res_submesh._root_mesh = root - - return res_submesh - - def _real_mesh_dim(self, mesh_dim: int) -> int: - return mesh_dim - 1 if mesh_dim > self.replicate_dim else mesh_dim - - def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup: - if isinstance(mesh_dim, str): - dim = self._mesh_dim_names.index(mesh_dim) - else: - dim = 0 if mesh_dim is None else int(mesh_dim) - - if mesh_dim is None: - return self.replicate_pg - elif dim == self.replicate_dim: - return self.replicate_pg - else: - assert self._mesh is not None - return self._mesh.get_group(self._real_mesh_dim(dim)) - - def _flatten( - self, - mesh_dim_name: Optional[str] = None, - backend_override: Union[ - None, str, C10dBackend.Options, tuple[str, C10dBackend.Options] - ] = None, - ) -> "DeviceMesh": - flatten_mesh = _FlattenDeviceMesh(self) - if mesh_dim_name is None: - raise ValueError("ManagedDeviceMesh._flatten requires `mesh_dim_name`") - if self.parent is None: - self.flatten_meshes[mesh_dim_name] = flatten_mesh - else: - self.parent.flatten_meshes[mesh_dim_name] = flatten_mesh - return flatten_mesh - - def size(self, mesh_dim: Optional[int] = None) -> int: - replicate_pg_size = self.replicate_pg.size() - # We have to lie to the users if there are zero particpants. - # This is possible during the initialization stage of training. - replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size - if mesh_dim is None: - if self._mesh is None: - return replicate_pg_size - else: - assert self._mesh is not None - return self._mesh.size() * replicate_pg_size - elif mesh_dim == self.replicate_dim: - return replicate_pg_size - else: - assert self._mesh is not None - return self._mesh.size(self._real_mesh_dim(mesh_dim)) - - @property - def ndim(self) -> int: - assert self._mesh is not None - return self._mesh.ndim + 1 - - @property - def shape(self) -> tuple[int, ...]: - assert self._mesh is not None - ret: list[int] = list(self._mesh.shape) - ret.insert(self.replicate_dim, self.replicate_pg.size()) - return tuple(ret) - - def get_rank(self) -> int: - assert self._mesh is not None - return self._mesh.get_rank() - - def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: - if isinstance(mesh_dim, str): - dim = self._mesh_dim_names.index(mesh_dim) - else: - dim = 0 if mesh_dim is None else int(mesh_dim) - - if mesh_dim is None: - if self._mesh is None: - return get_rank(self.replicate_pg) - - assert self.replicate_dim == 0, "replicate_dim must be the first one" - assert self._mesh is not None - other_dim_size = self._mesh.size() - assert self._mesh is not None - other_dim_rank = self._mesh.get_local_rank() - replicate_pg_rank = get_rank(self.replicate_pg) - return other_dim_size * replicate_pg_rank + other_dim_rank - elif dim == self.replicate_dim: - return get_rank(self.replicate_pg) - else: - assert self._mesh is not None - return self._mesh.get_local_rank(self._real_mesh_dim(dim)) - - def get_coordinate(self) -> Optional[list[int]]: - """ - Return the relative indices of this rank relative to all - dimensions of the mesh. If this rank is not part of the mesh, return None. - """ - assert self._mesh is not None - coordinate = ( - self._mesh._coordinate_on_dim if self._mesh._coordinate_on_dim else None - ) - if not coordinate: - return coordinate - - # We need to copy be cause we are going to modify the coordinate. - coordinate = coordinate.copy() - coordinate.insert(get_rank(self.replicate_pg), self.replicate_dim) - return coordinate - - def get_all_groups(self) -> list[BaseProcessGroup]: - raise NotImplementedError - - def __repr__(self) -> str: - return f"ManagedDeviceMesh(mesh={self._mesh})" - - def __hash__(self) -> int: - # lazily compute hash - if not self._hash: - self._hash = hash( - ( - self._mesh, - self._mesh_dim_names, - self.replicate_pg, - self.replicate_dim, - self.replicate_dim_name, - self.parent, - self._device_type, - ) - ) - return self._hash - - -class _FlattenDeviceMesh(DeviceMesh): - def __init__(self, managed_mesh: ManagedDeviceMesh) -> None: - self.managed_mesh = managed_mesh - - def __getitem__(self, mesh_dim_names: Union[str, tuple[str, ...]]) -> DeviceMesh: - raise NotImplementedError - - def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGroup: - raise NotImplementedError - - def _flatten( - self, - mesh_dim_name: Optional[str] = None, - backend_override: Union[ - None, str, C10dBackend.Options, tuple[str, C10dBackend.Options] - ] = None, - ) -> "DeviceMesh": - raise NotImplementedError - - def size(self, mesh_dim: Optional[int] = None) -> int: - assert mesh_dim is None - return self.managed_mesh.size() - - @property - def ndim(self) -> int: - raise NotImplementedError - - @property - def shape(self) -> tuple[int, ...]: - raise NotImplementedError - - def get_rank(self) -> int: - raise NotImplementedError - - def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: - assert mesh_dim is None - return self.managed_mesh.get_local_rank() - - def get_all_groups(self) -> list[BaseProcessGroup]: - raise NotImplementedError - - -def ft_init_device_mesh( - *, - device_type: str, - mesh_shape: Union[tuple[int, ...], list[int]], - mesh_dim_names: Union[tuple[str, ...], list[str]], - replicate_dim: int, - manager: "Manager", -) -> "ManagedDeviceMesh": - # We need to mislead DeviceMesh into thinking that replicate_dim has only - # 1 rank. - _mesh_shape = list(mesh_shape) - _mesh_shape.pop(replicate_dim) - _mesh_dim_names = list(mesh_dim_names) - _mesh_dim_names.pop(replicate_dim) - mesh = init_device_mesh( - device_type, - mesh_shape=tuple(_mesh_shape), - mesh_dim_names=tuple(_mesh_dim_names), - ) - - from torchft.process_group import ManagedProcessGroup - - replicate_pg = ManagedProcessGroup(manager) - replicate_pg.register(mesh_dim_names[replicate_dim]) - - ManagedDeviceMesh.replicate_pg_singleton = replicate_pg - - return ManagedDeviceMesh( - mesh=mesh, - mesh_dim_names=tuple(mesh_dim_names), - replicate_pg=replicate_pg, - replicate_dim=replicate_dim, - parent=None, - ) diff --git a/torchft/device_mesh_test.py b/torchft/device_mesh_test.py deleted file mode 100644 index 757cab5e..00000000 --- a/torchft/device_mesh_test.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import io -import os -from concurrent.futures import ProcessPoolExecutor -from typing import cast -from unittest import TestCase -from unittest.mock import Mock - -import torch -import torch.distributed as dist - -from torchft.manager import Manager -from torchft.process_group import ( - ft_init_device_mesh, - ManagedProcessGroup, - ProcessGroupGloo, -) - - -class DeviceMeshTest(TestCase): - @staticmethod - def _test_init_device_mesh(world_size: int, rank: int) -> None: - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(12346) - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(4) - - testcase = TestCase() - - manager = Mock(spec=Manager) - manager._pg = ProcessGroupGloo() - # Even though we only have 4 workers, we can still initialize (2, 4) mesh. - # That's because the replicate group is NOT phystically created in the - # real mesh but is virtually added to the mesh via ManagedDeviceMesh. - device_mesh = ft_init_device_mesh( - device_type="cpu", - mesh_shape=(2, world_size), - mesh_dim_names=("dp_replicate", "dp_shard"), - replicate_dim=0, - manager=manager, - ) - - testcase.assertTrue( - isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup) - ) - testcase.assertTrue( - not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup) - ) - replicate_group = device_mesh.get_group("dp_replicate") - testcase.assertEqual( - cast(ManagedProcessGroup, replicate_group)._manager, manager - ) - replicate_mesh = device_mesh["dp_replicate"] - testcase.assertEqual(replicate_mesh.get_group(), replicate_group) - - flatten_mesh = device_mesh._flatten("dp") - manager.num_participants.return_value = 0 - testcase.assertEqual(flatten_mesh.size(), world_size) - manager.num_participants.return_value = 1 - testcase.assertEqual(flatten_mesh.size(), world_size) - manager.num_participants.return_value = 2 - testcase.assertEqual(flatten_mesh.size(), world_size * 2) - - testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank()) - - device_mesh.get_coordinate() - buffer = io.BytesIO() - torch.save(device_mesh, buffer) - buffer.seek(0) - torch.load(buffer, weights_only=False) - - def test_init_device_mesh(self) -> None: - if dist.is_initialized(): - dist.destroy_process_group() - - with ProcessPoolExecutor(max_workers=4) as executor: - futures = [] - for i in range(4): - future = executor.submit(self._test_init_device_mesh, 4, i) - futures.append(future) - for f in futures: - f.result() - - def test_repr_hash(self) -> None: - if dist.is_initialized(): - dist.destroy_process_group() - - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(12346) - os.environ["RANK"] = str(0) - os.environ["WORLD_SIZE"] = str(1) - - manager = Mock(spec=Manager) - manager._pg = ProcessGroupGloo() - - for container in [tuple, list]: - device_mesh = ft_init_device_mesh( - device_type="cpu", - mesh_shape=container((1, 1)), - mesh_dim_names=container((f"dp_replicate_{container}", "dp_shard")), - replicate_dim=0, - manager=manager, - ) - - self.assertIsInstance(repr(device_mesh), str) - self.assertIsInstance(str(device_mesh), str) - self.assertEqual(hash(device_mesh), hash(device_mesh)) - self.assertIsInstance(hash(device_mesh), int) - - dist.destroy_process_group() diff --git a/torchft/diloco_regression_test.py b/torchft/diloco_regression_test.py index 714f87ed..3d2236a5 100644 --- a/torchft/diloco_regression_test.py +++ b/torchft/diloco_regression_test.py @@ -18,7 +18,6 @@ from torchft._test.diloco_trainer import DiLoCoTrainer, MultiModel from torchft._torchft import LighthouseServer -from torchft.device_mesh import ft_init_device_mesh from torchft.local_sgd import DiLoCo from torchft.manager import Manager from torchft.manager_integ_test import EventInjector, EventInjectorEvent, Runner diff --git a/torchft/fsdp_test.py b/torchft/fsdp_test.py index f7a5c2ec..fd672f73 100644 --- a/torchft/fsdp_test.py +++ b/torchft/fsdp_test.py @@ -8,42 +8,18 @@ import os import unittest from concurrent.futures import ProcessPoolExecutor -from typing import Any, Dict, Tuple from unittest.mock import Mock import torch import torch.distributed as dist from torch import nn -from torch._C._distributed_c10d import ( - _resolve_process_group, - AllgatherOptions, - AllreduceOptions, - BroadcastOptions, - ReduceOp, -) -from torch.distributed import ( - _functional_collectives, - get_world_size, - ReduceOp, - TCPStore, - Work, -) -from torch.distributed._composable.fsdp import fully_shard -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - parallelize_module, - PrepareModuleInput, - RowwiseParallel, - SequenceParallel, -) +from torch._C._distributed_c10d import ReduceOp +from torch.distributed._composable.fsdp import FSDPModule, fully_shard +from torch.distributed.tensor import init_device_mesh +from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module from torchft.manager import Manager -from torchft.process_group import ( - ft_init_device_mesh, - ManagedProcessGroup, - ProcessGroupGloo, -) +from torchft.process_group import ProcessGroupGloo class FSDPTest(unittest.TestCase): @@ -67,19 +43,24 @@ def _test_fsdp( os.environ["WORLD_SIZE"] = str(group_size) manager = Mock(spec=Manager) - manager._pg = ProcessGroupGloo() - device_mesh = ft_init_device_mesh( + pg: ProcessGroupGloo = Mock(spec=ProcessGroupGloo) + device_mesh = init_device_mesh( device_type="cuda", - mesh_shape=(dp_replicate, dp_shard, tp), - mesh_dim_names=("dp_replicate", "dp_shard", "tp"), - replicate_dim=0, - manager=manager, + mesh_shape=(dp_shard, tp), + mesh_dim_names=("dp_shard", "tp"), ) manager.num_participants.return_value = 1 model = nn.Linear(128, 128).cuda() batch = torch.randn(4, 128).cuda() - fsdp_mesh = device_mesh["dp_replicate", "dp_shard"] + fsdp_mesh = device_mesh["dp_shard"] + + def all_reduce_hook(output: torch.Tensor) -> None: + dist.all_reduce(output, group=pg, op=ReduceOp.AVG) + + def apply_set_all_reduce_hook(m: nn.Module) -> None: + assert isinstance(m, FSDPModule) + m.set_all_reduce_hook(all_reduce_hook) if tp > 1: tp_mesh = device_mesh["tp"] @@ -89,6 +70,7 @@ def _test_fsdp( ColwiseParallel(), ) shard_model = fully_shard(model, mesh=fsdp_mesh) + shard_model.apply(apply_set_all_reduce_hook) shard_model(batch).mean().backward() # pyre-ignore[56]: Pyre was not able to infer the type of argument diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index cef8d43b..25df3622 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -20,7 +20,6 @@ from torchft._test.diloco_trainer import DiLoCoTrainer, MultiMyModel from torchft._torchft import LighthouseServer -from torchft.device_mesh import ft_init_device_mesh from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager from torchft.manager_integ_test import ( diff --git a/torchft/manager.py b/torchft/manager.py index f3aaff75..5e3e721f 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -446,7 +446,7 @@ def allreduce( # on the Future @torch.profiler.record_function("torchft::manager::allreduce::callback") def callback( - fut: torch.futures.Future[list[torch.Tensor]], + fut: torch.futures.Future[torch.Tensor], ) -> torch.Tensor: nonlocal tensor if reduce_op == ReduceOp.AVG: @@ -455,6 +455,7 @@ def callback( managed_work = _ManagedWork(self, work, tensor) fut = managed_work.get_future() + fut = cast(torch.futures.Future[torch.Tensor], fut) fut = fut.then(callback) return managed_work diff --git a/torchft/process_group.py b/torchft/process_group.py index c462928e..eed40ec7 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -45,7 +45,6 @@ # pyre-fixme[21]: no attribute ProcessGroupGloo from torch.distributed import ( - DeviceMesh, PrefixStore, ProcessGroup as BaseProcessGroup, ProcessGroupGloo as BaseProcessGroupGloo, @@ -67,7 +66,6 @@ from torch.utils._pytree import tree_any # We import these for backwards compatibility -from torchft.device_mesh import * # noqa: F401 from torchft.futures import context_timeout, stream_timeout from torchft.multiprocessing import _MonitoredPipe from torchft.utils import get_stream_context, record_event, synchronize diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 468e0816..2f323ad5 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -34,14 +34,10 @@ ReduceOp, TCPStore, ) -from torch.distributed.device_mesh import init_device_mesh - from torchft.manager import Manager from torchft.process_group import ( _ErrorSwallowingWork, ErrorSwallowingProcessGroupWrapper, - extend_device_mesh, - ft_init_device_mesh, ManagedProcessGroup, ProcessGroup, ProcessGroupBabyGloo, @@ -729,29 +725,6 @@ def test_dummy(self) -> None: m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg) m(torch.rand(2, 3)) - def test_device_mesh(self) -> None: - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(0) - os.environ["RANK"] = str(0) - os.environ["WORLD_SIZE"] = str(1) - - mesh_1d = init_device_mesh("cpu", mesh_shape=(1,), mesh_dim_names=("tp",)) - - store = TCPStore( - host_name="localhost", port=0, is_master=True, wait_for_workers=False - ) - store_addr = f"localhost:{store.port}/prefix" - - pg = ProcessGroupGloo() - pg.register("test_device_mesh") - pg.configure(store_addr, "0", 0, 1) - - mesh_2d = extend_device_mesh(mesh_1d, pg) - mesh_2d.get_group("dp") - assert mesh_2d.ndim == 2 - - pg.unregister() - def test_functional_collectives(self) -> None: dummy_init_pg() @@ -813,53 +786,6 @@ def test_managed_process_group(self) -> None: self.assertEqual(manager.allreduce.call_count, 2) -class DeviceMeshTest(TestCase): - @staticmethod - def _test_init_device_mesh(world_size: int, rank: int) -> None: - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(12346) - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(4) - - testcase = TestCase() - - manager = Mock(spec=Manager) - # Even though we only have 4 workers, we can still initialize (2, 4) mesh. - # That's because the replicate group is NOT phystically created in the - # real mesh but is virtually added to the mesh via ManagedDeviceMesh. - device_mesh = ft_init_device_mesh( - device_type="cpu", - mesh_shape=(2, world_size), - mesh_dim_names=("dp_replicate", "dp_shard"), - replicate_dim=0, - manager=manager, - ) - - testcase.assertTrue( - isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup) - ) - testcase.assertTrue( - not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup) - ) - replicate_group = device_mesh.get_group("dp_replicate") - testcase.assertEqual( - cast(ManagedProcessGroup, replicate_group)._manager, manager - ) - replicate_mesh = device_mesh["dp_replicate"] - testcase.assertEqual(replicate_mesh.get_group(), replicate_group) - flatten_mesh = device_mesh._flatten("dp") - manager.num_participants.return_value = 1 - testcase.assertEqual(flatten_mesh.size(), world_size) - testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank()) - - def test_init_device_mesh(self) -> None: - with ProcessPoolExecutor(max_workers=4) as executor: - futures = [] - for i in range(4): - future = executor.submit(self._test_init_device_mesh, 4, i) - futures.append(future) - - class MultiPgBaseTest(TestCase): """ A base test that creates N processes (via ThreadPoolExecutor) sharing