diff --git a/torchft/manager.py b/torchft/manager.py index 174418d7..8a7d056c 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -197,6 +197,9 @@ def step(self) -> None: if not self._use_async_quorum: self._quorum_future.result() + # eagerly apply pending state_dict so we can run the forwards pass + self._apply_pending_state_dict() + # we are forcing healing at the beginning so we're in a good state # and don't need to zero_grad self._healing = False @@ -236,14 +239,27 @@ def _async_quorum(self) -> None: primary_client = ManagerClient(address, timeout=self._timeout) checkpoint_server_address = primary_client.checkpoint_address(self._rank) - state_dict = CheckpointServer.load_from_address(checkpoint_server_address) - self._load_state_dict(state_dict["user"]) - self.load_state_dict(state_dict["torchft"]) + self._state_dict = CheckpointServer.load_from_address( + checkpoint_server_address + ) + self.load_state_dict(self._state_dict["torchft"]) + # we apply the user state dict only when safe from the main thread # This isn't strictly needed as loading the state_dict above should # restore the correct step but it makes writing tests simpler. self._step = max_step + def _apply_pending_state_dict(self) -> None: + assert self._healing, "must be in healing state" + + # synchronize on future + self._quorum_future.result() + + assert self._state_dict is not None, "checkpoint was not staged" + + self._load_state_dict(self._state_dict["user"]) + self._state_dict = None + def should_commit(self) -> bool: for work in self._pending_work: # check at the beginning of since .wait() may trigger errors @@ -256,6 +272,10 @@ def should_commit(self) -> bool: self._pending_work = [] + # apply state_dict if healing + if self._healing: + self._apply_pending_state_dict() + enough_replicas = self._participating_replicas >= self._min_replica_size local_should_commit = enough_replicas and not self._errored should_commit = self._client.should_commit( diff --git a/torchft/process_group.py b/torchft/process_group.py index 494c687f..6badcfa2 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -8,6 +8,7 @@ import logging from typing import Type, List, Optional, Callable, Tuple from datetime import timedelta +import threading from torch.futures import Future from torch.distributed import ( @@ -26,6 +27,11 @@ logger = logging.getLogger(__name__) +# TODO: use non strings which are cheaper +_QUEUE_CLOSE = "queue_close" +_FUTURE_RESULT = "fut_result" +_FUTURE_EXCEPTION = "fut_exception" + def _get(queue: mp.Queue, timeout) -> object: v = queue.get(timeout=timeout) @@ -208,9 +214,17 @@ def getBackendName(self): class BabyWork(Work): - def __init__(self, tx: mp.Queue, rx: mp.Queue, op_id: int, timeout: float): + def __init__( + self, + pg: "ProcessGroupBaby", + tx: mp.Queue, + rx: mp.Queue, + op_id: int, + timeout: float, + ): super().__init__() + self._pg = pg self._tx = tx self._rx = rx self._op_id = op_id @@ -221,6 +235,9 @@ def wait(self) -> bool: assert _get(self._rx, self._timeout) == self._op_id return True + def get_future(self) -> Future: + return self._pg._get_future(self._op_id) + class BabyWorkNCCL(BabyWork): def wait(self) -> bool: @@ -255,6 +272,8 @@ def __init__(self, timeout: float = 60.0) -> None: self._p = None self._tx = None self._rx = None + self._future_queue = None + self._future_thread = None self._timeout = timeout @@ -264,20 +283,46 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: self._world_size = world_size + if self._tx is not None: + self._tx.close() + if self._rx is not None: + self._rx.close() + if self._future_queue is not None: + self._future_queue.put(_QUEUE_CLOSE) + self._future_queue.close() + ctx = mp.get_context("spawn") self._tx = ctx.Queue() self._rx = ctx.Queue() + # futures need thread to fire callbacks + self._future_queue = ctx.Queue() + # this lock needs to be held when manipulating _futures + self._futures_lock = threading.Lock() + self._futures = {} + self._future_thread = threading.Thread( + target=self._future_handler, + args=(self._future_queue,), + daemon=True, + ) + self._future_thread.start() + self._p = ctx.Process( target=self._worker, - args=(store_addr, rank, world_size, self._tx, self._rx), + args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue), daemon=True, ) self._p.start() @classmethod def _worker( - cls, store_addr: str, rank: int, world_size: int, rx: mp.Queue, tx: mp.Queue + cls, + store_addr: str, + rank: int, + world_size: int, + rx: mp.Queue, + tx: mp.Queue, + future_queue: mp.Queue, ) -> None: try: store = create_store(store_addr) @@ -291,8 +336,9 @@ def _worker( op = rx.get() cmd = op[0] if cmd == "func": - func, args, kwargs = op[1:] - work[next_op_id] = getattr(pg, func)(*args, **kwargs) + func_name, args, kwargs = op[1:] + fn = getattr(pg, func_name) + work[next_op_id] = fn(*args, **kwargs) tx.put(next_op_id) next_op_id += 1 elif cmd == "wait": @@ -300,6 +346,18 @@ def _worker( work[op_id].wait() del work[op_id] tx.put(op_id) + elif cmd == "future": + op_id = op[1] + + def callback(fut: Future): + try: + fut.wait() + future_queue.put((op_id, _FUTURE_RESULT, None)) + except Exception as e: + future_queue.put((op_id, _FUTURE_EXCEPTION, e)) + + work[op_id].get_future().add_done_callback(callback) + tx.put(op_id) elif cmd == "synchronize": # CUDA only, use events instead of waiting on CPU op_id = op[1] @@ -322,12 +380,41 @@ def _worker( logger.exception("worker errored") tx.put(e) + def _future_handler(self, future_queue: mp.Queue) -> None: + try: + while True: + cmd = future_queue.get() + if cmd == _QUEUE_CLOSE: + break + op_id, mode, data = cmd + with self._futures_lock: + fut = self._futures[op_id] + del self._futures[op_id] + if mode == _FUTURE_RESULT: + fut.set_result(data) + elif mode == _FUTURE_EXCEPTION: + fut.set_exception(data) + else: + raise ValueError(f"unknown mode {mode}") + except Exception as e: + logger.exception(f"got unexpected error in future handler: {e}") + + def _get_future(self, op_id: int) -> Future: + with self._futures_lock: + fut = Future() + self._futures[op_id] = fut + self._tx.put(("future", op_id), timeout=self._timeout) + + assert _get(self._rx, self._timeout) == op_id + # TODO: return correct tensor instead of None + return fut + def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: self._tx.put(("func", func, args, kwargs), timeout=self._timeout) op_id = _get(self._rx, self._timeout) assert isinstance(op_id, int), f"invalid return {op_id}" return self.WORK_CLASS( - tx=self._tx, rx=self._rx, op_id=op_id, timeout=self._timeout + pg=self, tx=self._tx, rx=self._rx, op_id=op_id, timeout=self._timeout ) def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: @@ -366,7 +453,7 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby): tensors may leak in the current PyTorch implementation. TODO fix """ - PG_CLASS = BaseProcessGroupGloo + PG_CLASS = BaseProcessGroupNCCL WORK_CLASS = BabyWorkNCCL def getBackendName(self): diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index f741b7f5..5bed565f 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from unittest import TestCase, skipUnless +from concurrent.futures import ThreadPoolExecutor import torch from torch.distributed import TCPStore, ReduceOp @@ -37,6 +38,7 @@ def test_gloo(self) -> None: a_work = pg.allreduce([at], ReduceOp.SUM) a_work.wait() + a_work.get_future().wait() m = nn.Linear(3, 4) m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg) @@ -58,6 +60,7 @@ def test_nccl(self) -> None: at = torch.tensor([2], device=device) a_work = pg.allreduce([at], ReduceOp.SUM) a_work.wait() + a_work.get_future().wait() m = nn.Linear(3, 4).to(device) m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg) @@ -95,7 +98,9 @@ def test_baby_gloo(self) -> None: b_work = b.allreduce([bt], ReduceOp.SUM) a_work.wait() - b_work.wait() + fut = b_work.get_future() + + fut.wait() torch.testing.assert_close(at, bt) @@ -113,23 +118,25 @@ def test_baby_nccl(self) -> None: store_addr = f"localhost:{store.port}/prefix" - device = "cuda" + def run(rank: int) -> None: + a = ProcessGroupBabyNCCL() + a.configure(store_addr, rank, 2) - a = ProcessGroupBabyNCCL() - b = ProcessGroupBabyNCCL() + self.assertEqual(a.size(), 2) - a.configure(store_addr, 0, 2) - b.configure(store_addr, 1, 2) + at = torch.tensor([rank + 1], device=f"cuda:{rank}") - self.assertEqual(a.size(), 2) + a_work = a.allreduce([at], ReduceOp.SUM) + return at, a_work - at = torch.tensor([1], device=device) - bt = torch.tensor([2], device=device) + with ThreadPoolExecutor(max_workers=2) as executor: + a_fut = executor.submit(run, 0) + b_fut = executor.submit(run, 1) - a_work = a.allreduce([at], ReduceOp.SUM) - b_work = b.allreduce([bt], ReduceOp.SUM) + at, a_work = a_fut.result() + bt, b_work = b_fut.result() a_work.wait() - b_work.wait() + b_work.get_future().wait() - torch.testing.assert_close(at, bt) + torch.testing.assert_close(at.cpu(), bt.cpu()) diff --git a/train_ddp.py b/train_ddp.py index f3994f7f..f211f862 100644 --- a/train_ddp.py +++ b/train_ddp.py @@ -18,6 +18,7 @@ from torchft import ( Manager, ProcessGroupGloo, + ProcessGroupBabyNCCL, DistributedDataParallel, Optimizer, DistributedSampler, @@ -25,112 +26,116 @@ logging.basicConfig(level=logging.INFO) -device = "cpu" -transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] -) -trainset = torchvision.datasets.CIFAR10( - root="./cifar", train=True, download=True, transform=transform -) - -# This shards the training set across all ranks and replica groups. We manage -# the dataloaders on a per replica group basis with the assumption that the -# majority of groups will be available so few batches will be dropped. -sampler = DistributedSampler( - trainset, - replica_group=int(os.environ.get("REPLICA_GROUP_ID", 0)), - num_replica_groups=int(os.environ.get("NUM_REPLICA_GROUPS", 2)), - rank=0, - # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. - num_replicas=1, -) - -# This uses the torchdata StatefulDataLoader to be able to checkpoint and -# restore the per worker dataloader position. -trainloader = StatefulDataLoader(trainset, batch_size=2, shuffle=True, num_workers=2) - - -def load_state_dict(state_dict): - m.load_state_dict(state_dict["model"]) - optimizer.load_state_dict(state_dict["optim"]) - - -def state_dict(): - return { - "model": m.state_dict(), - "optim": optimizer.state_dict(), - } - - -manager = Manager( - pg=ProcessGroupGloo(), - min_replica_size=2, - load_state_dict=load_state_dict, - state_dict=state_dict, -) - - -class Net(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = torch.flatten(x, 1) # flatten all dimensions except batch - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - -m = Net().to(device) -m = DistributedDataParallel(manager, m) -optimizer = Optimizer(manager, optim.AdamW(m.parameters())) -criterion = nn.CrossEntropyLoss() - -print(m) - -# You can use an epoch based training but with faults it's easier to use step -# based training. -while True: - for i, (inputs, labels) in enumerate(trainloader): - inputs = inputs.to(device) - labels = labels.to(device) - - # must be called at the beginning of each train loop - # Quorum computation is triggered here but only needed in the backwards pass. - optimizer.zero_grad() - - out = m(inputs) - loss = criterion(out, labels) - - # Gradient allreduce overlaps with the backwards pass. - loss.backward() - - # must be called at the end of the train loop - # This may not actually step the optimizer if an error occured during grad allreduce. - optimizer.step() - - if manager.current_step() % 100 == 0: - print(f"[{manager.current_step()}] loss = {loss.item()}") - - # TODO (by the user): periodically checkpoint model, optim, manager and dataloader - - # You typically want to checkpoint dataloader frequently (every step?) to - # avoid repeated batches as it's replica group specific. - - # Model, optim and manager checkpoints can be done more infrequently as - # they're shared across all groups and will load from existing replicas as - # long as not every worker goes down. - - if manager.current_step() >= 10000: - # complete training - exit() +def main() -> None: + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + trainset = torchvision.datasets.CIFAR10( + root="./cifar", train=True, download=True, transform=transform + ) + + # This shards the training set across all ranks and replica groups. We manage + # the dataloaders on a per replica group basis with the assumption that the + # majority of groups will be available so few batches will be dropped. + sampler = DistributedSampler( + trainset, + replica_group=int(os.environ.get("REPLICA_GROUP_ID", 0)), + num_replica_groups=int(os.environ.get("NUM_REPLICA_GROUPS", 2)), + rank=0, + # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. + num_replicas=1, + ) + + # This uses the torchdata StatefulDataLoader to be able to checkpoint and + # restore the per worker dataloader position. + trainloader = StatefulDataLoader( + trainset, batch_size=2, shuffle=True, num_workers=2 + ) + + def load_state_dict(state_dict): + m.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optim"]) + + def state_dict(): + return { + "model": m.state_dict(), + "optim": optimizer.state_dict(), + } + + device = "cuda" if torch.cuda.is_available() else "cpu" + pg = ProcessGroupBabyNCCL() if torch.cuda.is_available() else ProcessGroupGloo() + + manager = Manager( + pg=pg, + min_replica_size=2, + load_state_dict=load_state_dict, + state_dict=state_dict, + ) + + class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + m = Net().to(device) + m = DistributedDataParallel(manager, m) + optimizer = Optimizer(manager, optim.AdamW(m.parameters())) + criterion = nn.CrossEntropyLoss() + + print(m) + + # You can use an epoch based training but with faults it's easier to use step + # based training. + while True: + for i, (inputs, labels) in enumerate(trainloader): + inputs = inputs.to(device) + labels = labels.to(device) + + # must be called at the beginning of each train loop + # Quorum computation is triggered here but only needed in the backwards pass. + optimizer.zero_grad() + + out = m(inputs) + loss = criterion(out, labels) + + # Gradient allreduce overlaps with the backwards pass. + loss.backward() + + # must be called at the end of the train loop + # This may not actually step the optimizer if an error occured during grad allreduce. + optimizer.step() + + if manager.current_step() % 100 == 0: + print(f"[{manager.current_step()}] loss = {loss.item()}") + + # TODO (by the user): periodically checkpoint model, optim, manager and dataloader + + # You typically want to checkpoint dataloader frequently (every step?) to + # avoid repeated batches as it's replica group specific. + + # Model, optim and manager checkpoints can be done more infrequently as + # they're shared across all groups and will load from existing replicas as + # long as not every worker goes down. + + if manager.current_step() >= 10000: + # complete training + exit() + + +if __name__ == "__main__": + main()