Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 91 additions & 12 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,33 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC
import logging
from typing import Type, List, Optional, Callable, Tuple
from datetime import timedelta
import threading
from abc import ABC
from datetime import timedelta
from typing import Callable, List, Optional, Tuple, Type

from torch.futures import Future
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch._C._distributed_c10d import (
_register_process_group,
_unregister_process_group,
)
from torch.distributed import (
ProcessGroup as BaseProcessGroup,
Store,
TCPStore,
PrefixStore,
BroadcastOptions,
DeviceMesh,
get_rank,
PrefixStore,
ProcessGroup as BaseProcessGroup,
ProcessGroupGloo as BaseProcessGroupGloo,
ProcessGroupNCCL as BaseProcessGroupNCCL,
Store,
TCPStore,
)
import torch.distributed as dist
from torch.distributed.distributed_c10d import Work
import torch
import torch.multiprocessing as mp
from torch.distributed.distributed_c10d import _world, Work

from torch.futures import Future

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,6 +69,11 @@ def create_store(store_addr: str) -> Store:


class ProcessGroup(BaseProcessGroup):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self._group_name = None

def configure(self, store_addr: str, rank: int, world_size: int) -> None:
raise NotImplementedError("not implemented")

Expand Down Expand Up @@ -90,6 +102,44 @@ def size(self) -> int:
def getBackendName(self) -> str:
raise NotImplementedError("not implemented")

def register(self, name: str) -> None:
"""
Registers the process group with the global registry. This enables usage
with things like functional_collectives which are compilable.

This should only be called once.

Args:
name: name must be a unique name for this process group
"""

self._group_name = f"{self.getBackendName()}:{name}"
_register_process_group(self.group_name, self)

# This is needed for DeviceMesh to work
# Resizable worlds don't fit well into DeviceMesh so we register a world
# size 1 PG.
_world.pg_map[self] = (None, None)
_world.pg_names[self] = self._group_name
_world.pg_to_tag[self] = self._group_name
_world.tags_to_pg.setdefault(self._group_name, []).append(self)
# these PGs can be resized so we lie about the rank mapping
_world.pg_group_ranks[self] = {get_rank(): 0}

@property
def group_name(self) -> str:
if self._group_name is None:
raise ValueError("ProcessGroup name not set")
return self._group_name

def unregister(self) -> None:
"""
Unregisters the process group with the global registry.

Must be registered first.
"""
_unregister_process_group(self.group_name)


class ProcessGroupWrapper(ProcessGroup):
PG_CLASS: Type[BaseProcessGroup]
Expand Down Expand Up @@ -458,3 +508,32 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):

def getBackendName(self):
return "torchft-baby-nccl"


def extend_device_mesh(
mesh: DeviceMesh, pg: ProcessGroup, name: str = "dp", dim: int = 0
) -> DeviceMesh:
"""
This is a helper method to extend a traditional DeviceMesh with a torchft ProcessGroup for usage with DeviceMesh based APIs such as FSDPv2 with hybrid sharding.

Resizable PGs aren't natively supported by DeviceMesh so we lie to
DeviceMesh and say the PG is world size 1. This is fine as long as any
numeric scaling is handled at the PG level.

Args:
mesh: The DeviceMesh to extend
pg: The ProcessGroup to add to the mesh
name: The name of the new dimension
dim: The dimension to add the ProcessGroup to
"""
groups = mesh.get_all_groups()
groups.insert(dim, pg)
mesh_dim_names = list(mesh.mesh_dim_names)
mesh_dim_names.insert(dim, name)

return DeviceMesh.from_group(
group=groups,
device_type=mesh.device_type,
mesh=mesh.mesh.unsqueeze(dim),
mesh_dim_names=mesh_dim_names,
)
48 changes: 48 additions & 0 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@

from unittest import TestCase, skipUnless
from concurrent.futures import ThreadPoolExecutor
import os

import torch
from torch.distributed import TCPStore, ReduceOp
import torch.distributed as dist
from torch import nn
from torch._C._distributed_c10d import (
_resolve_process_group,
)
from torch.distributed import _functional_collectives
from torch.distributed.device_mesh import init_device_mesh

from torchft.process_group import (
ProcessGroupBabyGloo,
Expand All @@ -19,6 +25,7 @@
ProcessGroupNCCL,
ProcessGroupDummy,
ProcessGroup,
extend_device_mesh,
)


Expand Down Expand Up @@ -140,3 +147,44 @@ def run(rank: int) -> None:
b_work.get_future().wait()

torch.testing.assert_close(at.cpu(), bt.cpu())

def test_device_mesh(self) -> None:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(0)
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)

mesh_1d = init_device_mesh("cpu", mesh_shape=(1,), mesh_dim_names=("tp",))

store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)
store_addr = f"localhost:{store.port}/prefix"

pg = ProcessGroupGloo()
pg.register("test_device_mesh")
pg.configure(store_addr, 0, 1)

mesh_2d = extend_device_mesh(mesh_1d, pg)
assert mesh_2d.ndim == 2

def test_functional_collectives(self) -> None:
store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)
store_addr = f"localhost:{store.port}/prefix"

pg = ProcessGroupGloo()
pg.configure(store_addr, 0, 1)

pg.register("test_func_col")

self.assertEqual(pg.group_name, "torchft-gloo:test_func_col")

self.assertIs(_resolve_process_group(pg.group_name), pg)

try:
t = torch.zeros(10)
_functional_collectives.all_reduce(t, "sum", pg).wait()
finally:
pg.unregister()
Loading