Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions torchft/device_mesh_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import io
import os
from concurrent.futures import ProcessPoolExecutor
from typing import cast
from unittest import TestCase
from unittest.mock import Mock

import torch
import torch.distributed as dist

from torchft.manager import Manager
from torchft.process_group import (
ManagedProcessGroup,
ProcessGroupGloo,
ft_init_device_mesh,
)


class DeviceMeshTest(TestCase):
@staticmethod
def _test_init_device_mesh(world_size: int, rank: int) -> None:
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(12346)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(4)

testcase = TestCase()

manager = Mock(spec=Manager)
manager._pg = ProcessGroupGloo()
# Even though we only have 4 workers, we can still initialize (2, 4) mesh.
# That's because the replicate group is NOT phystically created in the
# real mesh but is virtually added to the mesh via ManagedDeviceMesh.
device_mesh = ft_init_device_mesh(
device_type="cpu",
mesh_shape=(2, world_size),
mesh_dim_names=("dp_replicate", "dp_shard"),
replicate_dim=0,
manager=manager,
)

testcase.assertTrue(
isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup)
)
testcase.assertTrue(
not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup)
)
replicate_group = device_mesh.get_group("dp_replicate")
testcase.assertEqual(
cast(ManagedProcessGroup, replicate_group)._manager, manager
)
replicate_mesh = device_mesh["dp_replicate"]
testcase.assertEqual(replicate_mesh.get_group(), replicate_group)

flatten_mesh = device_mesh._flatten("dp")
manager.num_participants.return_value = 0
testcase.assertEqual(flatten_mesh.size(), world_size)
manager.num_participants.return_value = 1
testcase.assertEqual(flatten_mesh.size(), world_size)
manager.num_participants.return_value = 2
testcase.assertEqual(flatten_mesh.size(), world_size * 2)

testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank())

device_mesh.get_coordinate()
buffer = io.BytesIO()
torch.save(device_mesh, buffer)
buffer.seek(0)
torch.load(buffer, weights_only=False)

def test_init_device_mesh(self) -> None:
with ProcessPoolExecutor(max_workers=4) as executor:
futures = []
for i in range(4):
future = executor.submit(self._test_init_device_mesh, 4, i)
futures.append(future)
for f in futures:
f.result()
54 changes: 37 additions & 17 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import queue
import threading
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -861,6 +861,8 @@ def extend_device_mesh(


class ManagedDeviceMesh(DeviceMesh):
replicate_pg_singleton: Optional["ManagedProcessGroup"] = None

def __init__(
self,
mesh: Optional[DeviceMesh],
Expand Down Expand Up @@ -889,6 +891,16 @@ def __init__(
self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple()
self._thread_id: Optional[int] = None

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["replicate_pg"] = None
return state

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
assert self.replicate_pg_singleton is not None
self.replicate_pg = self.replicate_pg_singleton

def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
if isinstance(mesh_dim_names, str):
if mesh_dim_names == self.replicate_dim_name:
Expand All @@ -906,13 +918,16 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
return self.mesh[mesh_dim_names]
else:
assert isinstance(mesh_dim_names, tuple)
if self.replicate_dim_name in mesh_dim_names:
if self.replicate_dim_name not in mesh_dim_names:
assert self.mesh is not None
return self.mesh[mesh_dim_names]
else:
mesh_dim_names_wo_replicate = tuple(
n for n in mesh_dim_names if n != self.replicate_dim_name
)
assert self.mesh is not None
return ManagedDeviceMesh(
self.mesh[mesh_dim_names],
self.mesh[mesh_dim_names_wo_replicate],
mesh_dim_names,
self.replicate_pg,
mesh_dim_names.index(self.replicate_dim_name),
Expand Down Expand Up @@ -947,14 +962,18 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
return flatten_mesh

def size(self, mesh_dim: Optional[int] = None) -> int:
replicate_pg_size = self.replicate_pg.size()
# We have to lie to the users if there are zero particpants.
# This is possible during the initialization stage of training.
replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
if mesh_dim is None:
if self.mesh is None:
return self.replicate_pg.size()
return replicate_pg_size
else:
assert self.mesh is not None
return self.mesh.size() * self.replicate_pg.size()
return self.mesh.size() * replicate_pg_size
elif mesh_dim == self.replicate_dim:
return self.replicate_pg.size()
return replicate_pg_size
else:
assert self.mesh is not None
return self.mesh.size(self._real_mesh_dim(mesh_dim))
Expand Down Expand Up @@ -1004,7 +1023,16 @@ def get_coordinate(self) -> Optional[List[int]]:
dimensions of the mesh. If this rank is not part of the mesh, return None.
"""
assert self.mesh is not None
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
coordinate = (
self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
)
if not coordinate:
return coordinate

# We need to copy be cause we are going to modify the coordinate.
coordinate = coordinate.copy()
coordinate.insert(get_rank(self.replicate_pg), self.replicate_dim)
return coordinate

def get_all_groups(self) -> List[BaseProcessGroup]:
raise NotImplementedError
Expand Down Expand Up @@ -1066,19 +1094,11 @@ def ft_init_device_mesh(
mesh_dim_names=tuple(_mesh_dim_names),
)

if device_type == "cpu":
pg = ProcessGroupGloo()
elif device_type == "cuda":
pg = ProcessGroupNCCL()
else:
raise ValueError()

manager._pg = pg
replicate_pg = ManagedProcessGroup(manager)
# We have to use MultiProcessTestCase, otherwise c10d will complain
# the same backend has been registered.
replicate_pg.register(mesh_dim_names[replicate_dim])

ManagedDeviceMesh.replicate_pg_singleton = replicate_pg

return ManagedDeviceMesh(
mesh=mesh,
mesh_dim_names=mesh_dim_names,
Expand Down
48 changes: 1 addition & 47 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import io
import multiprocessing
import os
import unittest
Expand Down Expand Up @@ -369,50 +370,3 @@ def test_managed_process_group(self) -> None:
self.assertEqual(manager.report_error.call_count, 0)
self.assertEqual(manager.wrap_future.call_count, 1)
self.assertEqual(manager.wait_quorum.call_count, 1)


class DeviceMeshTest(TestCase):
@staticmethod
def _test_init_device_mesh(world_size: int, rank: int) -> None:
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(12346)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(4)

testcase = TestCase()

manager = Mock(spec=Manager)
# Even though we only have 4 workers, we can still initialize (2, 4) mesh.
# That's because the replicate group is NOT phystically created in the
# real mesh but is virtually added to the mesh via ManagedDeviceMesh.
device_mesh = ft_init_device_mesh(
device_type="cpu",
mesh_shape=(2, world_size),
mesh_dim_names=("dp_replicate", "dp_shard"),
replicate_dim=0,
manager=manager,
)

testcase.assertTrue(
isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup)
)
testcase.assertTrue(
not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup)
)
replicate_group = device_mesh.get_group("dp_replicate")
testcase.assertEqual(
cast(ManagedProcessGroup, replicate_group)._manager, manager
)
replicate_mesh = device_mesh["dp_replicate"]
testcase.assertEqual(replicate_mesh.get_group(), replicate_group)
flatten_mesh = device_mesh._flatten("dp")
manager.num_participants.return_value = 1
testcase.assertEqual(flatten_mesh.size(), world_size)
testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank())

def test_init_device_mesh(self) -> None:
with ProcessPoolExecutor(max_workers=4) as executor:
futures = []
for i in range(4):
future = executor.submit(self._test_init_device_mesh, 4, i)
futures.append(future)