diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index ec3cf822..37602f77 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -10,7 +10,7 @@ """ import logging from types import TracebackType -from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type import torch from torch import nn, optim @@ -59,8 +59,6 @@ def __init__( model: nn.Module, optimizer: optim.Optimizer, sync_every: int, - backup_device: Optional[torch.device] = None, - pin_memory: bool = True, ) -> None: """ Args: @@ -78,21 +76,8 @@ def __init__( self._local_step = 0 self._sync_every = sync_every assert sync_every >= 1, "sync_every must be greater than or equal to 1" - device = backup_device or torch.device("cpu") - self._backup_parameters: Dict[str, torch.Tensor] = {} - for name, p in self._model.named_parameters(): - t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device) - if ( - pin_memory - and t.device == torch.device("cpu") - and torch.cuda.is_available() - ): - t = t.pin_memory() - self._backup_parameters[name] = t self._hooks: List[RemovableHandle] = [] - # Need to copy the parameters to the host to be safe if we are on the first step. - self._save_parameters() def __enter__(self) -> "LocalSGD": # Add optimizer hook which increments the local step counter and syncs if necessary @@ -108,9 +93,6 @@ def __exit__( traceback: Optional[TracebackType], ) -> bool: # Handle any cleanup or error handling here - if exc_type is not None: - # If an exception occurred, restore parameters - self._restore_parameters() # Clean up hooks for hook in self._hooks: hook.remove() @@ -118,20 +100,8 @@ def __exit__( return False # Propagate exceptions - def _save_parameters(self) -> None: - with torch.no_grad(): - # TODO: consider running copy on a separate stream - for name, p in self._model.named_parameters(): - self._backup_parameters[name].copy_(p.data, non_blocking=True) - - def _restore_parameters(self) -> None: - with torch.no_grad(): - # TODO: consider running copy on a separate stream - for name, p in self._model.named_parameters(): - p.data.copy_(self._backup_parameters[name], non_blocking=False) - def _step_post_hook( - self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object] + self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any] ) -> None: """ This hook is registered on the optimizer and is called after the optimizer step. @@ -151,30 +121,31 @@ def sync(self) -> None: def _perform_sync(self) -> None: """ Performs the synchronization of the model weights across the manager. - This method is intended to be overridden by subclasses to implement custom - synchronization logic. """ - self._average() + averaged_parameters = self._average() if self._manager.should_commit(): - self._save_parameters() - else: - # commit failed, restore from the backup parameters - self._restore_parameters() - - def _average(self) -> None: - # TODO: do we need to broadcast buffers like DDP does? + # Update the model parameters with the averaged values + for param, avg_param in zip(self._model.parameters(), averaged_parameters): + param.data.copy_(avg_param) + def _average(self) -> list[torch.Tensor]: + """ + Averages the model parameters across the manager and returns the averaged parameters. + """ works = [] - + averaged_parameters = [] for p in self._model.parameters(): - # TODO: bucketize parameters - works.append(self._manager.allreduce(p.data.detach())) - + # Create a new tensor to store the averaged parameter + p.data.grad = None + avg_param = p.data.clone() + works.append(self._manager.allreduce(avg_param)) + averaged_parameters.append(avg_param) for work in works: work.wait() + return averaged_parameters -class DiLoCo(LocalSGD): +class DiLoCo: """ DiLoCo is a subclass of LocalSGD that overrides the synchronization mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights). @@ -197,27 +168,96 @@ def __init__( "Using DiLoCo require synchronous quorum to be enabled. " "Ensure that the manager is initialized with use_async_quorum=False" ) - super().__init__( - manager, model, inner_optimizer, sync_every, backup_device, pin_memory - ) + super().__init__() + self._manager = manager + self._model = model + self._local_optimizer = inner_optimizer + self._local_step = 0 + self._sync_every = sync_every + assert sync_every >= 1, "sync_every must be greater than or equal to 1" + self._backup_device = backup_device + self._pin_memory = pin_memory + + self._hooks: List[RemovableHandle] = [] self._outer_optimizer = outer_optimizer + self.original_parameters: Dict[str, torch.Tensor] = {} + for name, p in self._model.named_parameters(): + t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=self._backup_device) + if ( + self._pin_memory + and t.device == torch.device("cpu") + and torch.cuda.is_available() + ): + t = t.pin_memory() + self.original_parameters[name] = t + + # Need to copy the parameters to the host to be safe if we are on the first step. + self._save_parameters() + + def _save_parameters(self) -> None: + with torch.no_grad(): + # TODO: consider running copy on a separate stream + for name, p in self._model.named_parameters(): + self.original_parameters[name].copy_(p.data, non_blocking=True) + + def _restore_parameters(self) -> None: + with torch.no_grad(): + # TODO: consider running copy on a separate stream + for name, p in self._model.named_parameters(): + p.data.copy_(self.original_parameters[name], non_blocking=False) + + def __enter__(self) -> "DiLoCo": + # Add optimizer hook which increments the local step counter and syncs if necessary + self._hooks.append( + self._local_optimizer.register_step_post_hook(self._step_post_hook) + ) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> bool: + # Handle any cleanup or error handling here + # Clean up hooks + for hook in self._hooks: + hook.remove() + self._hooks.clear() + + return False # Propagate exceptions + + def _step_post_hook( + self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any] + ) -> None: + """ + This hook is registered on the optimizer and is called after the optimizer step. + """ + self._local_step += 1 + if self._local_step >= self._sync_every: + self.sync() + + def sync(self) -> None: + """ + Synchronizes and averages the model weights across the manager. + """ + self._manager.start_quorum() + self._perform_sync() + self._local_step = 0 def _perform_sync(self) -> None: """ Overrides the sync method to calculate the pseugradient, average them across the manager group, and step using the outer optimizer. """ - # Set the .grad field of each parameter to its pseudogradient for name, p in self._model.named_parameters(): - assert name in self._backup_parameters - pseudogradient = p.data - self._backup_parameters[name] + pseudogradient = p.data - self.original_parameters[name] p.grad = pseudogradient self._average_grads() # Restore the parameters back to the previous state self._restore_parameters() - if self._manager.should_commit(): # Use the outer optimizer to update the model parameters self._outer_optimizer.step() diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index 55ca5c34..e39031e5 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -1,18 +1,22 @@ import copy import logging +import re +import traceback from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import ExitStack +from datetime import timedelta from typing import Any, Dict from unittest import TestCase import torch +from parameterized import parameterized from torch import nn, optim from torchft._torchft import LighthouseServer from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager from torchft.manager_integ_test import FailureInjector, MyModel, Runner -from torchft.process_group import ProcessGroupGloo +from torchft.process_group import ProcessGroupGloo, ProcessGroupNCCL logger: logging.Logger = logging.getLogger(__name__) @@ -20,6 +24,7 @@ def local_sgd_train_loop( rank: int, store_port: int, + device: torch.device, runner: Runner, ) -> Dict[str, Dict[str, object]]: with ExitStack() as stack: @@ -49,24 +54,24 @@ def state_dict() -> Dict[str, Dict[str, object]]: world_size=runner.world_size, lighthouse_addr=runner.lighthouse_address, port=19530 + runner.replica_id, + timeout=timedelta(seconds=10), # pyre-fixme[6]: Incompatible parameter type **runner.manager_args, ) stack.callback(lambda: manager.shutdown(wait=False)) - m: nn.Module = MyModel() + m: nn.Module = MyModel().to(device) optimizer: optim.Optimizer = optim.Adam(m.parameters()) criterion = nn.CrossEntropyLoss() - with LocalSGD(manager, m, optimizer, sync_every=2): + with LocalSGD(manager, m, optimizer, sync_every=2) as local_sgd: while True: - inputs = torch.rand(2, 3) - labels = torch.randint(4, (2,)) + inputs = torch.rand(2, 3).to(device) + labels = torch.randint(4, (2,)).to(device) optimizer.zero_grad() out = m(inputs) loss = criterion(out, labels) - loss.backward() optimizer.step() @@ -78,18 +83,21 @@ def state_dict() -> Dict[str, Dict[str, object]]: # return state_dict so we can check consistency return state_dict() + return {} def diloco_train_loop( rank: int, store_port: int, + device: torch.device, runner: Runner, ) -> Dict[str, Dict[str, object]]: with ExitStack() as stack: # Declare the model and optimizers - m: nn.Module = MyModel() + m: nn.Module = MyModel(2, 3) model_state_dict: Dict[str, Any] = runner.train_loop_args["model_state_dict"] m.load_state_dict(model_state_dict) + m.to(device) # Setup optimizers inner_optimizer: optim.Optimizer = torch.optim.AdamW( @@ -102,15 +110,14 @@ def diloco_train_loop( # pyre-ignore[53] def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: m.load_state_dict(state_dict["model"]) - # TODO: make this cleaner so we don't have to save this - diloco._backup_parameters = state_dict["backup_params"] + diloco.original_parameters = state_dict["original_params"] inner_optimizer.load_state_dict(state_dict["inner_optim"]) outer_optimizer.load_state_dict(state_dict["outer_optim"]) def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] return { "model": m.state_dict(), - "backup_params": copy.deepcopy(diloco._backup_parameters), + "original_params": diloco.original_parameters, "inner_optim": inner_optimizer.state_dict(), "outer_optim": outer_optimizer.state_dict(), } @@ -131,6 +138,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] world_size=runner.world_size, lighthouse_addr=runner.lighthouse_address, port=19530 + runner.replica_id, + timeout=timedelta(seconds=10), # pyre-fixme[6]: Incompatible parameter type **runner.manager_args, ) @@ -139,20 +147,25 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] criterion = nn.CrossEntropyLoss() all_state_dicts = {} with DiLoCo( - manager, m, inner_optimizer, outer_optimizer, sync_every=2 + manager, + m, + inner_optimizer, + outer_optimizer, + backup_device=device, + sync_every=2, ) as diloco: while True: - inputs = torch.rand(2, 3) - labels = torch.randint(4, (2,)) + batch_size = 1 + inputs = m.get_rand_inputs(batch_size).to(device) + labels = m.get_rand_labels(batch_size).to(device) out = m(inputs) loss = criterion(out, labels) inner_optimizer.zero_grad() loss.backward() + all_state_dicts[str(manager.current_step())] = state_dict() inner_optimizer.step() - manager_step_str = str(manager.current_step()) - all_state_dicts[manager_step_str] = state_dict() # after 4 model updates then break if manager.current_step() >= 4: @@ -162,10 +175,16 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] # return state_dict so we can check consistency return all_state_dicts + return {} -class ManagerIntegTest(TestCase): - def test_local_sgd_recovery(self) -> None: +class LocalSGDIntegTest(TestCase): + @parameterized.expand( + [ + (False,), + ] + ) + def test_local_sgd_recovery(self, use_cuda: bool) -> None: lighthouse = LighthouseServer( bind="[::]:0", min_replicas=2, @@ -184,9 +203,11 @@ def test_local_sgd_recovery(self) -> None: ): runner = Runner( replica_id=replica_id, + num_replicas=num_replicas, lighthouse_address=lighthouse.address(), failure_injector=failure_injector, train_loop=local_sgd_train_loop, + use_cuda=use_cuda, manager_args={ "use_async_quorum": False, }, @@ -208,54 +229,65 @@ def test_local_sgd_recovery(self) -> None: # LocalSGD only guarantees that the model is consistent across # replicas but uses separate optimizer states. torch.testing.assert_close( - state_dict[0]["model"], state_dicts[0][0]["model"] + state_dict[0]["model"], state_dicts[0][0]["model"], check_device=False ) self.assertEqual(failure_injectors[1].count, 1) - def test_diloco_healthy(self) -> None: - lighthouse = LighthouseServer( - bind="[::]:0", - min_replicas=2, - ) + @parameterized.expand( + [ + (False,), + ] + ) + def test_diloco_healthy(self, use_cuda: bool) -> None: + lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2) num_replicas = 2 futures = [] torch.manual_seed(42) # Initialize the model so we can pass in the state_dict - m: nn.Module = MyModel() + m: nn.Module = MyModel(2, 3) with ThreadPoolExecutor(max_workers=num_replicas) as executor: for replica_id in range(num_replicas): failure_injector = FailureInjector() runner = Runner( replica_id=replica_id, + num_replicas=num_replicas, lighthouse_address=lighthouse.address(), failure_injector=failure_injector, train_loop=diloco_train_loop, + use_cuda=use_cuda, train_loop_args={ "model_state_dict": m.state_dict(), }, ) futures.append(executor.submit(runner.run_replica)) - state_dicts = [] - - for fut in as_completed(futures): - state_dicts.append(fut.result()[0]) + state_dicts = [] + for fut in as_completed(futures): + try: + state_dicts.append(fut.result()[0]) + except Exception as e: + print(e, flush=True) + traceback.print_exc() + raise lighthouse.shutdown() - for replica_group in state_dicts: - for step, state_dict in replica_group.items(): - # inner optimizer will be different, outer optimizer and model should be the same - torch.testing.assert_close( - state_dict["backup_params"], - state_dicts[0][str(step)]["backup_params"], - ) - torch.testing.assert_close( - state_dict["outer_optim"], state_dicts[0][str(step)]["outer_optim"] - ) + rep0, rep1 = state_dicts + for step, state_dict in rep1.items(): + # inner optimizer will be different, outer optimizer and model should be the same + torch.testing.assert_close( + state_dict["model"], + rep0[step]["model"], + check_device=False, + ) + torch.testing.assert_close( + state_dict["outer_optim"], + rep0[step]["outer_optim"], + check_device=False, + ) def test_diloco_recovery(self) -> None: lighthouse = LighthouseServer( @@ -272,7 +304,7 @@ def test_diloco_recovery(self) -> None: torch.manual_seed(42) # Initialize the model so we can pass in the state_dict - m: nn.Module = MyModel() + m: nn.Module = MyModel(2, 3) with ThreadPoolExecutor(max_workers=num_replicas) as executor: for replica_id, failure_injector in zip( @@ -280,6 +312,7 @@ def test_diloco_recovery(self) -> None: ): runner = Runner( replica_id=replica_id, + num_replicas=num_replicas, lighthouse_address=lighthouse.address(), failure_injector=failure_injector, train_loop=diloco_train_loop, @@ -299,18 +332,19 @@ def test_diloco_recovery(self) -> None: raise lighthouse.shutdown() - for replica_group in state_dicts: - for step, state_dict in replica_group.items(): - str_step = str(step) - if str_step in state_dicts[0]: - # inner optimizer will be different, outer optimizer and model should be the same - torch.testing.assert_close( - state_dict["backup_params"], - state_dicts[0][str_step]["backup_params"], - ) - torch.testing.assert_close( - state_dict["outer_optim"], - state_dicts[0][str_step]["outer_optim"], - ) + rep0, rep1 = state_dicts + + for step in rep0.keys(): + # Inner optimizer will be different, outer optimizer and model should be the same + torch.testing.assert_close( + rep1[step]["model"], + rep0[step]["model"], + check_device=False, + ) + torch.testing.assert_close( + rep1[step]["outer_optim"], + rep0[step]["outer_optim"], + check_device=False, + ) self.assertEqual(failure_injectors[1].count, 1) diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py index 05f88b7a..d26b316b 100644 --- a/torchft/local_sgd_test.py +++ b/torchft/local_sgd_test.py @@ -45,9 +45,6 @@ def test_local_sgd_healthy(self) -> None: manager = create_autospec(Manager) with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd: self.assertEqual(local_sgd._local_step, 0) - torch.testing.assert_close( - local_sgd._backup_parameters, _params_dict(model) - ) inp = torch.rand(2, 3) loss = model(inp).mean() loss.backward() @@ -62,9 +59,6 @@ def test_local_sgd_healthy(self) -> None: manager.should_commit.return_value = True self.assertEqual(local_sgd._local_step, 0) - torch.testing.assert_close( - local_sgd._backup_parameters, _params_dict(model) - ) self.assertEqual(manager.should_commit.call_count, 1) self.assertEqual(manager.allreduce.call_count, 4) @@ -74,11 +68,7 @@ def test_local_sgd_recovery(self) -> None: manager = create_autospec(Manager) with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd: - torch.testing.assert_close( - local_sgd._backup_parameters, _params_dict(model) - ) og_state_dict = _copy_state_dict(model.state_dict()) - print(og_state_dict) inp = torch.rand(2, 3) @@ -95,11 +85,6 @@ def test_local_sgd_recovery(self) -> None: ) self.assertEqual(local_sgd._local_step, 1) - local_sgd._restore_parameters() - torch.testing.assert_close( - local_sgd._backup_parameters, _params_dict(model) - ) - class DiLoCoTest(TestCase): def test_diloco_healthy(self) -> None: @@ -123,7 +108,7 @@ def test_diloco_healthy(self) -> None: self.assertEqual(initial_outer_opt_state["state"], {}) self.assertEqual(diloco._local_step, 0) - torch.testing.assert_close(diloco._backup_parameters, _params_dict(model)) + torch.testing.assert_close(diloco.original_parameters, _params_dict(model)) inp = torch.rand(2, 3) loss = model(inp).mean() loss.backward() @@ -138,7 +123,7 @@ def test_diloco_healthy(self) -> None: manager.should_commit.return_value = True self.assertEqual(diloco._local_step, 0) - torch.testing.assert_close(diloco._backup_parameters, _params_dict(model)) + torch.testing.assert_close(diloco.original_parameters, _params_dict(model)) self.assertEqual(manager.should_commit.call_count, 1) self.assertEqual(manager.allreduce.call_count, parameter_count) diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index 0f69a7b6..8a2799a4 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -25,16 +25,24 @@ class MyModel(nn.Module): - def __init__(self) -> None: + def __init__(self, in_dim: int = 3, out_dim: int = 4) -> None: super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim self.model = nn.Sequential( - nn.Linear(3, 4), + nn.Linear(in_dim, out_dim), nn.Sigmoid(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) + def get_rand_inputs(self, batch_size: int) -> torch.Tensor: + return torch.rand(batch_size, self.in_dim) + + def get_rand_labels(self, batch_size: int) -> torch.Tensor: + return torch.randint(3, (batch_size,)) + class InjectedFailure(Exception): pass @@ -63,17 +71,19 @@ def check(self, rank: int, step: int) -> None: class TrainLoop(Protocol): def __call__( - self, rank: int, store_port: int, runner: "Runner" + self, rank: int, store_port: int, device: torch.device, runner: "Runner" ) -> Dict[str, Dict[str, object]]: ... @dataclass class Runner: replica_id: int + num_replicas: int lighthouse_address: str failure_injector: FailureInjector train_loop: TrainLoop + use_cuda: bool = False world_size: int = 1 attempts: int = 3 manager_args: Dict[str, object] = field(default_factory=dict) @@ -92,11 +102,22 @@ def _replica_main(self) -> List[Dict[str, Dict[str, object]]]: ) as executor: futures = [] for rank in range(self.world_size): + if self.use_cuda: + num_cuda_devices = torch.cuda.device_count() + assert num_cuda_devices >= self.num_replicas + device_index = ( + num_cuda_devices // self.num_replicas + ) * self.replica_id + rank + device = torch.device(f"cuda:{device_index}") + else: + device = torch.device("cpu") + futures.append( executor.submit( self.train_loop, rank=rank, store_port=store.port, + device=device, runner=self, ) ) @@ -129,6 +150,7 @@ def run_replica(self) -> List[Dict[str, Dict[str, object]]]: def ddp_train_loop( rank: int, store_port: int, + device: torch.device, runner: Runner, ) -> Dict[str, Dict[str, object]]: with ExitStack() as stack: @@ -213,6 +235,7 @@ def test_ddp_healthy(self) -> None: failure_injector = FailureInjector() runner = Runner( replica_id=replica_id, + num_replicas=num_replicas, lighthouse_address=lighthouse.address(), failure_injector=failure_injector, train_loop=ddp_train_loop, @@ -260,6 +283,7 @@ def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: ): runner = Runner( replica_id=replica_id, + num_replicas=num_replicas, lighthouse_address=lighthouse.address(), failure_injector=failure_injector, manager_args={ @@ -301,6 +325,7 @@ def test_ddp_recovery_multi_rank(self) -> None: ): runner = Runner( replica_id=replica_id, + num_replicas=num_replicas, lighthouse_address=lighthouse.address(), failure_injector=failure_injector, world_size=world_size,