diff --git a/torchft/device_mesh_test.py b/torchft/device_mesh_test.py new file mode 100644 index 00000000..ee78c6d5 --- /dev/null +++ b/torchft/device_mesh_test.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import io +import os +from concurrent.futures import ProcessPoolExecutor +from typing import cast +from unittest import TestCase +from unittest.mock import Mock + +import torch +import torch.distributed as dist + +from torchft.manager import Manager +from torchft.process_group import ( + ManagedProcessGroup, + ProcessGroupGloo, + ft_init_device_mesh, +) + + +class DeviceMeshTest(TestCase): + @staticmethod + def _test_init_device_mesh(world_size: int, rank: int) -> None: + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(12346) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(4) + + testcase = TestCase() + + manager = Mock(spec=Manager) + manager._pg = ProcessGroupGloo() + # Even though we only have 4 workers, we can still initialize (2, 4) mesh. + # That's because the replicate group is NOT phystically created in the + # real mesh but is virtually added to the mesh via ManagedDeviceMesh. + device_mesh = ft_init_device_mesh( + device_type="cpu", + mesh_shape=(2, world_size), + mesh_dim_names=("dp_replicate", "dp_shard"), + replicate_dim=0, + manager=manager, + ) + + testcase.assertTrue( + isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup) + ) + testcase.assertTrue( + not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup) + ) + replicate_group = device_mesh.get_group("dp_replicate") + testcase.assertEqual( + cast(ManagedProcessGroup, replicate_group)._manager, manager + ) + replicate_mesh = device_mesh["dp_replicate"] + testcase.assertEqual(replicate_mesh.get_group(), replicate_group) + + flatten_mesh = device_mesh._flatten("dp") + manager.num_participants.return_value = 0 + testcase.assertEqual(flatten_mesh.size(), world_size) + manager.num_participants.return_value = 1 + testcase.assertEqual(flatten_mesh.size(), world_size) + manager.num_participants.return_value = 2 + testcase.assertEqual(flatten_mesh.size(), world_size * 2) + + testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank()) + + device_mesh.get_coordinate() + buffer = io.BytesIO() + torch.save(device_mesh, buffer) + buffer.seek(0) + torch.load(buffer, weights_only=False) + + def test_init_device_mesh(self) -> None: + with ProcessPoolExecutor(max_workers=4) as executor: + futures = [] + for i in range(4): + future = executor.submit(self._test_init_device_mesh, 4, i) + futures.append(future) + for f in futures: + f.result() diff --git a/torchft/process_group.py b/torchft/process_group.py index d1d2cbe6..7f1b7e34 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -20,7 +20,7 @@ import queue import threading from datetime import timedelta -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.distributed as dist @@ -861,6 +861,8 @@ def extend_device_mesh( class ManagedDeviceMesh(DeviceMesh): + replicate_pg_singleton: Optional["ManagedProcessGroup"] = None + def __init__( self, mesh: Optional[DeviceMesh], @@ -889,6 +891,16 @@ def __init__( self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple() self._thread_id: Optional[int] = None + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + state["replicate_pg"] = None + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + assert self.replicate_pg_singleton is not None + self.replicate_pg = self.replicate_pg_singleton + def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh: if isinstance(mesh_dim_names, str): if mesh_dim_names == self.replicate_dim_name: @@ -906,13 +918,16 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh return self.mesh[mesh_dim_names] else: assert isinstance(mesh_dim_names, tuple) - if self.replicate_dim_name in mesh_dim_names: + if self.replicate_dim_name not in mesh_dim_names: assert self.mesh is not None return self.mesh[mesh_dim_names] else: + mesh_dim_names_wo_replicate = tuple( + n for n in mesh_dim_names if n != self.replicate_dim_name + ) assert self.mesh is not None return ManagedDeviceMesh( - self.mesh[mesh_dim_names], + self.mesh[mesh_dim_names_wo_replicate], mesh_dim_names, self.replicate_pg, mesh_dim_names.index(self.replicate_dim_name), @@ -947,14 +962,18 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh": return flatten_mesh def size(self, mesh_dim: Optional[int] = None) -> int: + replicate_pg_size = self.replicate_pg.size() + # We have to lie to the users if there are zero particpants. + # This is possible during the initialization stage of training. + replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size if mesh_dim is None: if self.mesh is None: - return self.replicate_pg.size() + return replicate_pg_size else: assert self.mesh is not None - return self.mesh.size() * self.replicate_pg.size() + return self.mesh.size() * replicate_pg_size elif mesh_dim == self.replicate_dim: - return self.replicate_pg.size() + return replicate_pg_size else: assert self.mesh is not None return self.mesh.size(self._real_mesh_dim(mesh_dim)) @@ -1004,7 +1023,16 @@ def get_coordinate(self) -> Optional[List[int]]: dimensions of the mesh. If this rank is not part of the mesh, return None. """ assert self.mesh is not None - return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None + coordinate = ( + self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None + ) + if not coordinate: + return coordinate + + # We need to copy be cause we are going to modify the coordinate. + coordinate = coordinate.copy() + coordinate.insert(get_rank(self.replicate_pg), self.replicate_dim) + return coordinate def get_all_groups(self) -> List[BaseProcessGroup]: raise NotImplementedError @@ -1066,19 +1094,11 @@ def ft_init_device_mesh( mesh_dim_names=tuple(_mesh_dim_names), ) - if device_type == "cpu": - pg = ProcessGroupGloo() - elif device_type == "cuda": - pg = ProcessGroupNCCL() - else: - raise ValueError() - - manager._pg = pg replicate_pg = ManagedProcessGroup(manager) - # We have to use MultiProcessTestCase, otherwise c10d will complain - # the same backend has been registered. replicate_pg.register(mesh_dim_names[replicate_dim]) + ManagedDeviceMesh.replicate_pg_singleton = replicate_pg + return ManagedDeviceMesh( mesh=mesh, mesh_dim_names=mesh_dim_names, diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index d24f838f..75a3e537 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import io import multiprocessing import os import unittest @@ -369,50 +370,3 @@ def test_managed_process_group(self) -> None: self.assertEqual(manager.report_error.call_count, 0) self.assertEqual(manager.wrap_future.call_count, 1) self.assertEqual(manager.wait_quorum.call_count, 1) - - -class DeviceMeshTest(TestCase): - @staticmethod - def _test_init_device_mesh(world_size: int, rank: int) -> None: - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(12346) - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(4) - - testcase = TestCase() - - manager = Mock(spec=Manager) - # Even though we only have 4 workers, we can still initialize (2, 4) mesh. - # That's because the replicate group is NOT phystically created in the - # real mesh but is virtually added to the mesh via ManagedDeviceMesh. - device_mesh = ft_init_device_mesh( - device_type="cpu", - mesh_shape=(2, world_size), - mesh_dim_names=("dp_replicate", "dp_shard"), - replicate_dim=0, - manager=manager, - ) - - testcase.assertTrue( - isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup) - ) - testcase.assertTrue( - not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup) - ) - replicate_group = device_mesh.get_group("dp_replicate") - testcase.assertEqual( - cast(ManagedProcessGroup, replicate_group)._manager, manager - ) - replicate_mesh = device_mesh["dp_replicate"] - testcase.assertEqual(replicate_mesh.get_group(), replicate_group) - flatten_mesh = device_mesh._flatten("dp") - manager.num_participants.return_value = 1 - testcase.assertEqual(flatten_mesh.size(), world_size) - testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank()) - - def test_init_device_mesh(self) -> None: - with ProcessPoolExecutor(max_workers=4) as executor: - futures = [] - for i in range(4): - future = executor.submit(self._test_init_device_mesh, 4, i) - futures.append(future)