diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index 8a2799a4..d591d0d2 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -2,24 +2,26 @@ import logging import threading import time +import traceback from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import ExitStack, contextmanager from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Dict, Generator, List, Protocol, Set, Tuple +from typing import Any, Dict, Generator, List, Optional, Protocol, Set, Tuple, TypeVar from unittest import TestCase import torch import torch.distributed as dist from parameterized import parameterized from torch import nn, optim +from torch._dynamo.utils import timed from torchft._torchft import LighthouseServer from torchft.ddp import DistributedDataParallel from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager from torchft.optim import OptimizerWrapper -from torchft.process_group import ProcessGroupGloo +from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo logger: logging.Logger = logging.getLogger(__name__) @@ -69,10 +71,14 @@ def check(self, rank: int, step: int) -> None: raise InjectedFailure(f"injected failure {rank=} {step=}") -class TrainLoop(Protocol): +# R for an arbitrary return type +R = TypeVar("R", covariant=True) + + +class TrainLoop(Protocol[R]): def __call__( self, rank: int, store_port: int, device: torch.device, runner: "Runner" - ) -> Dict[str, Dict[str, object]]: ... + ) -> R: ... @dataclass @@ -81,7 +87,7 @@ class Runner: num_replicas: int lighthouse_address: str failure_injector: FailureInjector - train_loop: TrainLoop + train_loop: TrainLoop[object] use_cuda: bool = False world_size: int = 1 @@ -89,7 +95,7 @@ class Runner: manager_args: Dict[str, object] = field(default_factory=dict) train_loop_args: Dict[str, Any] = field(default_factory=dict) - def _replica_main(self) -> List[Dict[str, Dict[str, object]]]: + def _replica_main(self) -> List[object]: store = dist.TCPStore( host_name="localhost", port=0, @@ -131,7 +137,7 @@ def _replica_main(self) -> List[Dict[str, Dict[str, object]]]: return [fut.result() for fut in futures] - def run_replica(self) -> List[Dict[str, Dict[str, object]]]: + def run_replica(self) -> List[object]: for i in range(self.attempts): try: print( @@ -391,3 +397,92 @@ def test_quorum_timeout(self) -> None: "status: Cancelled, message.*Timeout expired", ): manager.should_commit(timeout=timedelta(seconds=0.01)) + + @parameterized.expand( + [ + (True,), # Test with CUDA + (False,), # Test without CUDA (CPU) + ] + ) + def test_manager_allreduce(self, use_cuda: bool) -> None: + # Skip the test if use_cuda is True and there are not enough GPUs + if use_cuda and torch.cuda.device_count() < 2: + self.skipTest("Not enough GPUs for CUDA test") + + # manager supports allreduce but we found an issue where the future callback is getting called + # before the allreduce is complete. This test is to ensure that the callback has stream synchronization + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=2, + ) + num_replicas = 2 + futures = [] + + with ThreadPoolExecutor(max_workers=num_replicas) as executor: + for replica_id in range(num_replicas): + failure_injector = FailureInjector() + runner = Runner( + replica_id=replica_id, + num_replicas=num_replicas, + lighthouse_address=lighthouse.address(), + failure_injector=failure_injector, + train_loop=all_reduce_callback, + use_cuda=use_cuda, + ) + futures.append(executor.submit(runner.run_replica)) + + results = [] + for fut in as_completed(futures): + try: + results.append(fut.result()[0]) + except Exception as e: + print(e, flush=True) + traceback.print_exc() + raise + + lighthouse.shutdown() + + print(results) + r0, r1 = results + torch.testing.assert_close(r0, r1, check_device=False) + + +def all_reduce_callback( + rank: int, + store_port: int, + device: torch.device, + runner: Runner, +) -> Optional[torch.Tensor]: + with ExitStack() as stack: + print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") + + if device.type == "cuda": + pg = ProcessGroupBabyNCCL() + else: + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=2, + use_async_quorum=False, + load_state_dict=lambda x: None, + state_dict=lambda: None, + replica_id=str(runner.replica_id), + store_addr="localhost", + store_port=store_port, + rank=rank, + world_size=runner.world_size, + lighthouse_addr=runner.lighthouse_address, + port=19530 + runner.replica_id, + timeout=timedelta(seconds=10), + quorum_timeout=timedelta(seconds=10), + # pyre-fixme[6]: Incompatible parameter type + **runner.manager_args, + ) + stack.callback(lambda: manager.shutdown(wait=False)) + + manager.start_quorum() + t1 = torch.ones((1, 3), device=device) + fut = manager.allreduce(t1) + fut.wait() + return t1 + return None diff --git a/torchft/process_group.py b/torchft/process_group.py index 2d008deb..00a0c4a1 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -1093,10 +1093,12 @@ def _worker( args = _PickleSafeOptions.unsafe_args(args) fn = getattr(pg, func_name) + work[op_id] = _OpMetadata( work=fn(*args, **kwargs), stream=stream, ) + elif cmd == "wait": op_id, timeout = cast(tuple[int, timedelta], op[1:]) @@ -1126,15 +1128,29 @@ def _worker( del work[op_id] elif cmd == "future": op_id: int = cast(int, op[1]) + metadata: _OpMetadata = work[op_id] - def callback(fut: Future[object]) -> None: + def callback(fut: Future[object], metadata: _OpMetadata) -> None: try: - fut.wait() - future_pipe.send((op_id, _FUTURE_RESULT, None)) + # create an event after the collective has been issued + # to wait on this before we call "future" + with metadata.set_stream(): + fut.wait() + event = ( + torch.cuda.current_stream().record_event( + torch.cuda.Event(interprocess=True) + ) + if metadata.stream is not None + else None + ) + + future_pipe.send((op_id, _FUTURE_RESULT, None, event)) except Exception as e: - future_pipe.send((op_id, _FUTURE_EXCEPTION, e)) + future_pipe.send((op_id, _FUTURE_EXCEPTION, e, None)) - work[op_id].work.get_future().add_done_callback(callback) + metadata.work.get_future().add_done_callback( + lambda fut: callback(fut, metadata) + ) elif cmd == "num_active_work": req_pipe.send(len(work)) else: @@ -1153,11 +1169,15 @@ def _future_handler(self, future_pipe: _MonitoredPipe) -> None: except TimeoutError: continue - op_id, mode, data = cast(Tuple[int, str, object], cmd) + op_id, mode, data, event = cast( + Tuple[int, str, object, Optional[torch.cuda.Event]], cmd + ) with self._futures_lock: fut = self._futures[op_id] del self._futures[op_id] if mode == _FUTURE_RESULT: + if event is not None: + event.wait() fut.set_result(data) elif mode == _FUTURE_EXCEPTION: fut.set_exception(data)