diff --git a/torchft/process_group.py b/torchft/process_group.py index 6badcfa2..964ad885 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -4,26 +4,33 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from abc import ABC import logging -from typing import Type, List, Optional, Callable, Tuple -from datetime import timedelta import threading +from abc import ABC +from datetime import timedelta +from typing import Callable, List, Optional, Tuple, Type -from torch.futures import Future +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch._C._distributed_c10d import ( + _register_process_group, + _unregister_process_group, +) from torch.distributed import ( - ProcessGroup as BaseProcessGroup, - Store, - TCPStore, - PrefixStore, BroadcastOptions, + DeviceMesh, + get_rank, + PrefixStore, + ProcessGroup as BaseProcessGroup, ProcessGroupGloo as BaseProcessGroupGloo, ProcessGroupNCCL as BaseProcessGroupNCCL, + Store, + TCPStore, ) -import torch.distributed as dist -from torch.distributed.distributed_c10d import Work -import torch -import torch.multiprocessing as mp +from torch.distributed.distributed_c10d import _world, Work + +from torch.futures import Future logger = logging.getLogger(__name__) @@ -62,6 +69,11 @@ def create_store(store_addr: str) -> Store: class ProcessGroup(BaseProcessGroup): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self._group_name = None + def configure(self, store_addr: str, rank: int, world_size: int) -> None: raise NotImplementedError("not implemented") @@ -90,6 +102,44 @@ def size(self) -> int: def getBackendName(self) -> str: raise NotImplementedError("not implemented") + def register(self, name: str) -> None: + """ + Registers the process group with the global registry. This enables usage + with things like functional_collectives which are compilable. + + This should only be called once. + + Args: + name: name must be a unique name for this process group + """ + + self._group_name = f"{self.getBackendName()}:{name}" + _register_process_group(self.group_name, self) + + # This is needed for DeviceMesh to work + # Resizable worlds don't fit well into DeviceMesh so we register a world + # size 1 PG. + _world.pg_map[self] = (None, None) + _world.pg_names[self] = self._group_name + _world.pg_to_tag[self] = self._group_name + _world.tags_to_pg.setdefault(self._group_name, []).append(self) + # these PGs can be resized so we lie about the rank mapping + _world.pg_group_ranks[self] = {get_rank(): 0} + + @property + def group_name(self) -> str: + if self._group_name is None: + raise ValueError("ProcessGroup name not set") + return self._group_name + + def unregister(self) -> None: + """ + Unregisters the process group with the global registry. + + Must be registered first. + """ + _unregister_process_group(self.group_name) + class ProcessGroupWrapper(ProcessGroup): PG_CLASS: Type[BaseProcessGroup] @@ -458,3 +508,32 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby): def getBackendName(self): return "torchft-baby-nccl" + + +def extend_device_mesh( + mesh: DeviceMesh, pg: ProcessGroup, name: str = "dp", dim: int = 0 +) -> DeviceMesh: + """ + This is a helper method to extend a traditional DeviceMesh with a torchft ProcessGroup for usage with DeviceMesh based APIs such as FSDPv2 with hybrid sharding. + + Resizable PGs aren't natively supported by DeviceMesh so we lie to + DeviceMesh and say the PG is world size 1. This is fine as long as any + numeric scaling is handled at the PG level. + + Args: + mesh: The DeviceMesh to extend + pg: The ProcessGroup to add to the mesh + name: The name of the new dimension + dim: The dimension to add the ProcessGroup to + """ + groups = mesh.get_all_groups() + groups.insert(dim, pg) + mesh_dim_names = list(mesh.mesh_dim_names) + mesh_dim_names.insert(dim, name) + + return DeviceMesh.from_group( + group=groups, + device_type=mesh.device_type, + mesh=mesh.mesh.unsqueeze(dim), + mesh_dim_names=mesh_dim_names, + ) diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 98ad6aee..14328c75 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -6,11 +6,17 @@ from unittest import TestCase, skipUnless from concurrent.futures import ThreadPoolExecutor +import os import torch from torch.distributed import TCPStore, ReduceOp import torch.distributed as dist from torch import nn +from torch._C._distributed_c10d import ( + _resolve_process_group, +) +from torch.distributed import _functional_collectives +from torch.distributed.device_mesh import init_device_mesh from torchft.process_group import ( ProcessGroupBabyGloo, @@ -19,6 +25,7 @@ ProcessGroupNCCL, ProcessGroupDummy, ProcessGroup, + extend_device_mesh, ) @@ -140,3 +147,44 @@ def run(rank: int) -> None: b_work.get_future().wait() torch.testing.assert_close(at.cpu(), bt.cpu()) + + def test_device_mesh(self) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(0) + os.environ["RANK"] = str(0) + os.environ["WORLD_SIZE"] = str(1) + + mesh_1d = init_device_mesh("cpu", mesh_shape=(1,), mesh_dim_names=("tp",)) + + store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + store_addr = f"localhost:{store.port}/prefix" + + pg = ProcessGroupGloo() + pg.register("test_device_mesh") + pg.configure(store_addr, 0, 1) + + mesh_2d = extend_device_mesh(mesh_1d, pg) + assert mesh_2d.ndim == 2 + + def test_functional_collectives(self) -> None: + store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + store_addr = f"localhost:{store.port}/prefix" + + pg = ProcessGroupGloo() + pg.configure(store_addr, 0, 1) + + pg.register("test_func_col") + + self.assertEqual(pg.group_name, "torchft-gloo:test_func_col") + + self.assertIs(_resolve_process_group(pg.group_name), pg) + + try: + t = torch.zeros(10) + _functional_collectives.all_reduce(t, "sum", pg).wait() + finally: + pg.unregister()