diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 15a5871ca643b..1240710674c59 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -161,6 +161,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): gpus=2, accelerator='ddp_spawn', model_cls=SeedTrainLoaderManualModel, + max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers ) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c5b744c3384ec..11af9e4d8f91e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from contextlib import contextmanager from enum import Enum from typing import Any, Optional, Union @@ -21,10 +20,8 @@ from torch.optim import Optimizer from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.utilities import AMPType +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict if torch.distributed.is_available(): @@ -222,6 +219,18 @@ def __setstate__(self, d): def on_save(self, checkpoint): return checkpoint + @property + def rpc_enabled(self): + return self.ddp_plugin is not None and isinstance(self.ddp_plugin, RPCPlugin) + + @property + def distributed_sampler_kwargs(self): + raise NotImplementedError + + @property + def require_distributed_sampler(self): + raise NotImplementedError + @contextmanager def block_ddp_plugin_sync_behaviour(self): """ diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index fe0ab59fb554f..9113331ef0a7d 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -90,3 +90,7 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return tensor + + @property + def require_distributed_sampler(self): + return False diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index f43866881cabb..f47b389faf436 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -23,6 +23,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import Result from pytorch_lightning.distributed.dist import LightningDistributed +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available @@ -101,9 +102,11 @@ def set_world_ranks(self, process_idx): def broadcast(self, obj, src=0): return self.dist.broadcast(obj) - def model_to_device(self, model, process_idx): + def init_device(self, process_idx): self.trainer.root_gpu = process_idx torch.cuda.set_device(self.trainer.root_gpu) + + def model_to_device(self, model): model.cuda(self.trainer.root_gpu) def get_device_ids(self): @@ -133,6 +136,9 @@ def ddp_train(self, process_idx, mp_queue, model): # set warning rank rank_zero_only.rank = self.trainer.global_rank + # Initialize cuda device + self.init_device(process_idx) + # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table @@ -143,6 +149,15 @@ def ddp_train(self, process_idx, mp_queue, model): self.trainer.is_slurm_managing_tasks ) + if isinstance(self.ddp_plugin, RPCPlugin): + if not self.ddp_plugin.is_main_rpc_process: + self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) + self.ddp_plugin.exit_rpc_process() + if self.ddp_plugin.return_after_exit_rpc_process: + return + else: + self.ddp_plugin.on_main_rpc_connection(self.trainer) + # call setup after the ddp process has connected self.trainer.call_setup_hook(model) @@ -158,12 +173,14 @@ def ddp_train(self, process_idx, mp_queue, model): model = self.configure_sync_batchnorm(model) # move the model to the correct device - self.model_to_device(model, process_idx) + self.model_to_device(model) # CHOOSE OPTIMIZER # allow for lr schedulers as well self.setup_optimizers(model) + self.ddp_plugin.on_after_setup_optimizers(self.trainer) + # set model properties before going into wrapper self.trainer.model_connector.copy_trainer_model_properties(model) @@ -189,7 +206,7 @@ def ddp_train(self, process_idx, mp_queue, model): return results def configure_ddp( - self, model: LightningModule, device_ids: List[int] + self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: model = self.ddp_plugin.configure_ddp(model, device_ids) return model @@ -219,3 +236,17 @@ def sync_tensor(self, def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict( + num_replicas=self.trainer.num_nodes, + rank=self.trainer.global_rank + ) + if self.ddp_plugin is not None: + distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) + return distributed_sampler_kwargs + + @property + def require_distributed_sampler(self): + return True diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 4c2754f4049bd..d3d4c1fa1b766 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -27,6 +27,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -162,8 +163,11 @@ def _step(self, args): return output def barrier(self, name: Optional[str] = None): - if torch_distrib.is_initialized(): - torch_distrib.barrier() + if self.rpc_enabled: + # Allow RPC to handle barrier on main RPC processes + self.ddp_plugin.barrier() + elif torch_distrib.is_initialized(): + torch_distrib.barrier(group=self.ddp_plugin.data_parallel_group) def _check_can_spawn_children(self): if self._has_spawned_children: @@ -177,9 +181,11 @@ def set_world_ranks(self, process_idx): self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - def model_to_device(self, model, process_idx): + def init_device(self, process_idx): self.trainer.root_gpu = self.trainer.data_parallel_device_ids[self.trainer.local_rank] torch.cuda.set_device(self.trainer.root_gpu) + + def model_to_device(self, model): model.cuda(self.trainer.root_gpu) def get_device_ids(self): @@ -192,12 +198,12 @@ def on_train_end(self): def early_stopping_should_stop(self, pl_module): stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM) - torch_distrib.barrier() + self.barrier('early_stopping') should_stop = stop == self.trainer.world_size return should_stop def broadcast(self, obj, src=0): - return self.dist.broadcast(obj) + return self.dist.broadcast(obj, group=self.ddp_plugin.data_parallel_group) def ddp_train(self, process_idx, model): """ @@ -226,6 +232,9 @@ def ddp_train(self, process_idx, model): # set warning rank rank_zero_only.rank = self.trainer.global_rank + # Initialize cuda device + self.init_device(process_idx) + # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table @@ -236,6 +245,15 @@ def ddp_train(self, process_idx, model): self.trainer.is_slurm_managing_tasks ) + if isinstance(self.ddp_plugin, RPCPlugin): + if not self.ddp_plugin.is_main_rpc_process: + self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) + self.ddp_plugin.exit_rpc_process() + if self.ddp_plugin.return_after_exit_rpc_process: + return + else: + self.ddp_plugin.on_main_rpc_connection(self.trainer) + # call setup after the ddp process has connected self.trainer.call_setup_hook(model) @@ -251,7 +269,7 @@ def ddp_train(self, process_idx, model): model = self.configure_sync_batchnorm(model) # move the model to the correct device - self.model_to_device(model, process_idx) + self.model_to_device(model) # CHOOSE OPTIMIZER # allow for lr schedulers as well @@ -284,7 +302,7 @@ def ddp_train(self, process_idx, model): return results def configure_ddp( - self, model: LightningModule, device_ids: List[int] + self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: model = self.ddp_plugin.configure_ddp(model, device_ids) return model @@ -317,3 +335,17 @@ def sync_tensor(self, def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict( + num_replicas=self.trainer.num_nodes * self.trainer.num_processes, + rank=self.trainer.global_rank + ) + if self.ddp_plugin is not None: + distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) + return distributed_sampler_kwargs + + @property + def require_distributed_sampler(self): + return True diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 982da2f53216b..50bd1b7ab9051 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -24,6 +24,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType from pytorch_lightning.utilities.distributed import ( find_free_network_port, @@ -107,6 +108,15 @@ def ddp_train(self, process_idx, mp_queue, model): self.trainer.is_slurm_managing_tasks ) + if isinstance(self.ddp_plugin, RPCPlugin): + if not self.ddp_plugin.is_main_rpc_process: + self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) + self.ddp_plugin.exit_rpc_process() + if self.ddp_plugin.return_after_exit_rpc_process: + return + else: + self.ddp_plugin.on_main_rpc_connection(self.trainer) + # call setup after the ddp process has connected self.trainer.call_setup_hook(model) @@ -128,6 +138,8 @@ def ddp_train(self, process_idx, mp_queue, model): # allow for lr schedulers as well self.setup_optimizers(model) + self.ddp_plugin.on_after_setup_optimizers(self.trainer) + # set model properties before going into wrapper self.trainer.model_connector.copy_trainer_model_properties(model) @@ -221,7 +233,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): mp_queue.put(results) def configure_ddp( - self, model: LightningModule, device_ids: List[int] + self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: model = self.ddp_plugin.configure_ddp(model, device_ids) return model @@ -251,3 +263,17 @@ def sync_tensor(self, def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict( + num_replicas=self.trainer.num_nodes * self.trainer.num_processes, + rank=self.trainer.global_rank + ) + if self.ddp_plugin is not None: + distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) + return distributed_sampler_kwargs + + @property + def require_distributed_sampler(self): + return True diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 28817c6845f5b..50267afa525dc 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -23,6 +23,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available @@ -62,9 +63,11 @@ def set_world_ranks(self, process_idx): self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - def model_to_device(self, model, process_idx): + def init_device(self, process_idx): self.trainer.root_gpu = process_idx torch.cuda.set_device(self.trainer.root_gpu) + + def model_to_device(self, model): model.cuda(self.trainer.root_gpu) def get_device_ids(self): @@ -136,6 +139,15 @@ def ddp_train(self, process_idx, model): self.trainer.is_slurm_managing_tasks ) + if isinstance(self.ddp_plugin, RPCPlugin): + if not self.ddp_plugin.is_main_rpc_process: + self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) + self.ddp_plugin.exit_rpc_process() + if self.ddp_plugin.return_after_exit_rpc_process: + return + else: + self.ddp_plugin.on_main_rpc_connection(self.trainer) + # call setup after the ddp process has connected self.trainer.call_setup_hook(model) @@ -151,12 +163,14 @@ def ddp_train(self, process_idx, model): model = self.configure_sync_batchnorm(model) # move the model to the correct device - self.model_to_device(model, process_idx) + self.model_to_device(model) # CHOOSE OPTIMIZER # allow for lr schedulers as well self.setup_optimizers(model) + self.ddp_plugin.on_after_setup_optimizers(self.trainer) + # set model properties before going into wrapper self.trainer.model_connector.copy_trainer_model_properties(model) @@ -183,7 +197,7 @@ def ddp_train(self, process_idx, model): return results def configure_ddp( - self, model: LightningModule, device_ids: List[int] + self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: model = self.ddp_plugin.configure_ddp(model, device_ids) return model @@ -213,3 +227,17 @@ def sync_tensor(self, def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict( + num_replicas=self.trainer.num_nodes * self.trainer.num_processes, + rank=self.trainer.global_rank + ) + if self.ddp_plugin is not None: + distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) + return distributed_sampler_kwargs + + @property + def require_distributed_sampler(self): + return True diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index a06d0b82d6d15..7db2d5a309d9c 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -25,6 +25,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed import LightningDistributed +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -109,6 +110,9 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # set warning rank rank_zero_only.rank = self.trainer.global_rank + # Initialize cuda device + self.init_device(process_idx, is_master) + # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table @@ -119,6 +123,15 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 self.trainer.is_slurm_managing_tasks ) + if isinstance(self.ddp_plugin, RPCPlugin): + if not self.ddp_plugin.is_main_rpc_process: + self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) + self.ddp_plugin.exit_rpc_process() + if self.ddp_plugin.return_after_exit_rpc_process: + return + else: + self.ddp_plugin.on_main_rpc_connection(self.trainer) + # call setup after the ddp process has connected self.trainer.call_setup_hook(model) @@ -134,12 +147,14 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 model = self.configure_sync_batchnorm(model) # move the model to the correct device - self.model_to_device(model, process_idx, is_master) + self.model_to_device(model) # CHOOSE OPTIMIZER # allow for lr schedulers as well self.setup_optimizers(model) + self.ddp_plugin.on_after_setup_optimizers(self.trainer) + # set model properties before going into wrapper self.trainer.model_connector.copy_trainer_model_properties(model) @@ -174,10 +189,12 @@ def set_world_ranks(self, process_idx): self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - def model_to_device(self, model, process_idx, is_master): + def init_device(self, process_idx, is_master): gpu_idx = self.trainer.data_parallel_device_ids[self.trainer.local_rank] self.trainer.root_gpu = gpu_idx torch.cuda.set_device(self.trainer.root_gpu) + + def model_to_device(self, model): model.cuda(self.trainer.root_gpu) def get_device_ids(self): @@ -248,7 +265,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): mp_queue.put(last_path) def configure_ddp( - self, model: LightningModule, device_ids: List[int] + self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: model = self.ddp_plugin.configure_ddp(model, device_ids) return model @@ -278,3 +295,17 @@ def sync_tensor(self, def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict( + num_replicas=self.trainer.num_nodes * self.trainer.num_processes, + rank=self.trainer.global_rank + ) + if self.ddp_plugin is not None: + distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) + return distributed_sampler_kwargs + + @property + def require_distributed_sampler(self): + return True diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 4b4e1eac8a66c..a7f3c260e682c 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -181,3 +181,7 @@ def get_reference_model(self, model) -> LightningModule: if isinstance(model, LightningDataParallel): return model.module return model + + @property + def require_distributed_sampler(self): + return False diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index b12d275c8ac26..abc065cd39ed4 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -129,3 +129,7 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return tensor + + @property + def require_distributed_sampler(self): + return False diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 460f5a83d2582..93983369f17a9 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -18,7 +18,7 @@ from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE +from pytorch_lightning.utilities import HOROVOD_AVAILABLE, AMPType from pytorch_lightning.utilities.distributed import rank_zero_only if HOROVOD_AVAILABLE: @@ -206,3 +206,11 @@ def sync_tensor(self, # sync all processes before reduction hvd.join() return hvd.allreduce(tensor, op=reduce_op) + + @property + def distributed_sampler_kwargs(self): + return dict(num_replicas=hvd.size(), rank=hvd.rank()) + + @property + def require_distributed_sampler(self): + return True diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index cd6b99fa64eef..a7752e42a96cf 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -364,3 +364,11 @@ def on_save(self, checkpoint): https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors """ return move_data_to_device(checkpoint, torch.device("cpu")) + + @property + def distributed_sampler_kwargs(self): + return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + + @property + def require_distributed_sampler(self): + return True diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index eb669736ada3a..1354f7f5056b3 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -33,6 +33,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -548,7 +549,13 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath) ) last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}") - self._save_model(last_filepath, trainer, pl_module) + accelerator_backend = trainer.accelerator_backend + + if accelerator_backend is not None and accelerator_backend.rpc_enabled: + # RPCPlugin manages saving all model states + accelerator_backend.ddp_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module) + else: + self._save_model(last_filepath, trainer, pl_module) if ( self.last_model_path and self.last_model_path != last_filepath diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index f07f467810c09..dc63231ba6ccb 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -113,6 +113,17 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n with trainer.profiler.profile(profiler_name): optimizer.step(closure=closure, *args, **kwargs) + accelerator_backend = trainer.accelerator_backend + if accelerator_backend is not None and accelerator_backend.rpc_enabled: + if accelerator_backend.ddp_plugin.is_main_rpc_process: + # Initialize optimizer step on main process + accelerator_backend.ddp_plugin.worker_optimizer_step( + model=model, + opt_idx=self._optimizer_idx, + *args, + **kwargs + ) + trainer.train_loop.on_before_zero_grad(self) model.optimizer_zero_grad( diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index 37706523c8fdd..429121f71feeb 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -12,10 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -import torch from typing import Any + +import torch from torch import distributed as torch_distrib +from pytorch_lightning.utilities import GROUP_AVAILABLE + +WORLD = None +if GROUP_AVAILABLE: + from torch.distributed import group + WORLD = group.WORLD + class LightningDistributed: @@ -23,27 +31,32 @@ def __init__(self, rank=None, device=None): self.rank = rank self.device = device - def broadcast(self, obj: Any): + def broadcast(self, obj: Any, group=WORLD): if self.rank == 0: - self._emit(obj) + self._emit(obj, group) else: - obj = self._receive() + obj = self._receive(group) return obj - def _emit(self, obj): + def _broadcast(self, tensor, src=0, group=WORLD): + if group is None: + return torch_distrib.broadcast(tensor, src=src) + return torch_distrib.broadcast(tensor, src=0, group=group) + + def _emit(self, obj: Any, group=WORLD): buffer = io.BytesIO() torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) length_tensor = torch.tensor([len(data)]).long().to(self.device) - length_tensor = torch_distrib.broadcast(length_tensor, src=0) + length_tensor = self._broadcast(length_tensor, src=0, group=group) data_tensor = torch.ByteTensor(data).to(self.device) - data_tensor = torch_distrib.broadcast(data_tensor, src=0) + data_tensor = self._broadcast(data_tensor, src=0, group=group) - def _receive(self): + def _receive(self, group=WORLD): length_tensor = torch.tensor([0]).long().to(self.device) - torch_distrib.broadcast(length_tensor, src=0) + self._broadcast(length_tensor, src=0, group=group) data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) - torch_distrib.broadcast(data_tensor, src=0) + self._broadcast(data_tensor, src=0, group=group) buffer = io.BytesIO(data_tensor.cpu().numpy()) obj = torch.load(buffer) return obj diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 7e481dfade421..281074cb37813 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -1,6 +1,6 @@ import os from contextlib import contextmanager -from typing import Any, Dict, List, Union, Optional +from typing import Any, Dict, List, Optional, Union import torch.distributed as torch_distrib from torch.optim import Optimizer @@ -112,6 +112,12 @@ def on_before_forward(self, model, *args): def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() + def on_after_setup_optimizers(self, trainer): + """ + Called after optimizers have been set-up. This is useful for doing any configuration options in RPC, or + state sharding. + """ + def get_model_from_plugin( self, model: Union[LightningDistributedDataParallel, LightningModule] @@ -148,3 +154,15 @@ def on_before_manual_backward(self, model: LightningDistributedDataParallel, out def on_after_manual_backward(self, model: LightningDistributedDataParallel): model.reducer_reset_hooks() + + def distributed_sampler_kwargs(self, distributed_sampler_kwargs): + return distributed_sampler_kwargs + + @property + def data_parallel_group(self): + """ + Return the group that this process exists in. By default, this is the world size. + Useful for when additional parallel groups have been created, to select certain processes. + Returns: The ProcessGroup this process exists in. + """ + return torch_distrib.group.WORLD diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py new file mode 100644 index 0000000000000..776ac17c3d4eb --- /dev/null +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -0,0 +1,118 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Optional + +import torch + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin +from pytorch_lightning.utilities import RPC_AVAILABLE + +if RPC_AVAILABLE: + from torch.distributed import rpc + + +class RPCPlugin(DDPPlugin): + """ + Backbone for RPC Plugins built on top of DDP. + RPC introduces different communication behaviour than DDP. Unlike DDP, processes potentially are not + required to run the same code as the main process. + This leads to edge cases where logic needs to be re-defined. This class contains special cases + that need to be addressed when using RPC communication when building custom RPC Plugins. + """ + + def __init__(self, **kwargs): + self.rpc_initialized = False + super().__init__(**kwargs) + + def init_rpc_connection(self, + global_rank: int, + world_size: int) -> None: + os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000') + rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) + self.rpc_initialized = True + + def rpc_save_model(self, + save_model_fn, + last_filepath, + trainer, + pl_module) -> None: + """ + Override to save model to disk. + This is required as the main process will be required to handle aggregating model states from RPC processes. + Args: + save_model_fn: The saving function to save final model. + last_filepath: The filepath to save the model to. + trainer: The trainer object. + pl_module: The LightningModule. + """ + raise NotImplementedError + + def on_main_rpc_connection(self, trainer) -> None: + """ + Called when main rpc connection has been established. + Args: + trainer: The trainer object. + """ + raise NotImplementedError + + def on_accelerator_exit_rpc_process(self, trainer) -> None: + """ + Called to exit RPC process within the accelerator, that is being managed by main process. + Args: + trainer: The trainer object. + """ + self.exit_rpc_process() + + def exit_rpc_process(self): + if self.rpc_initialized: + torch.distributed.rpc.shutdown() + self.rpc_initialized = False + + @property + def return_after_exit_rpc_process(self) -> bool: + """ + Override to decide whether to skip train/test function after shutdown completed. + Usually RPC shutdown is a join/exit function, afterwards we want to exit the process. + Returns: Whether to return after rpc exit. + """ + raise NotImplementedError + + def worker_optimizer_step(self, + model: LightningModule, + opt_idx: int, + *args, + **kwargs) -> None: + """ + Called when optimizer step is run on the main process. Used to signal any RPC workers to run optimizer step. + Args: + model: The LightningModule. + opt_idx: The idx of the optimizer to carry out step on. + """ + raise NotImplementedError + + @property + def is_main_rpc_process(self) -> bool: + """ + Override to add logic to determine current process is main RPC process. + """ + raise NotImplementedError + + def barrier(self, name: Optional[str] = None) -> None: + """ + Override to define distributed sync communication. This needs to be handled differently due to + the RPC connection managing certain processes at the same time. + """ + raise NotImplementedError diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 4a7b14d0b1fe9..3bb444622cebc 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -16,25 +16,19 @@ import platform from abc import ABC from copy import deepcopy -from typing import Union, List, Tuple, Callable, Optional, Iterable +from typing import Callable, Iterable, List, Optional, Tuple, Union from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE, HOROVOD_AVAILABLE +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_utils import is_overridden -if TPU_AVAILABLE: - import torch_xla.core.xla_model as xm - -if HOROVOD_AVAILABLE: - import horovod.torch as hvd - class TrainerDataLoadingMixin(ABC): @@ -100,8 +94,7 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: if not is_dataloader or is_iterable_ds: return dataloader - is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu - need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler) + need_dist_sampler = self.require_distributed_sampler and not isinstance(dataloader.sampler, DistributedSampler) if self.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( @@ -131,20 +124,7 @@ def replace_sampler(self, dataloader, sampler): return dataloader def _get_distributed_sampler(self, dataloader, shuffle): - if self.use_tpu: - kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - elif self.use_horovod: - kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) - else: - world_size = { - "ddp": self.num_nodes * self.num_processes, - "ddp_spawn": self.num_nodes * self.num_processes, - "ddp2": self.num_nodes, - "ddp_cpu": self.num_processes * self.num_nodes - } - assert self.distributed_backend is not None - kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank) - + kwargs = self.distributed_sampler_kwargs kwargs['shuffle'] = shuffle and not self.overfit_batches sampler = DistributedSampler(dataloader.dataset, **kwargs) return sampler diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index f315bf9df819c..355bbad3a037e 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -27,10 +27,16 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import argparse_utils +from pytorch_lightning.utilities import HOROVOD_AVAILABLE, TPU_AVAILABLE, argparse_utils, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.model_utils import is_overridden +if TPU_AVAILABLE: + import torch_xla.core.xla_model as xm + +if HOROVOD_AVAILABLE: + import horovod.torch as hvd + class TrainerProperties(ABC): @@ -242,6 +248,35 @@ def __setstate__(self, d): # wrap optimizers in enable_pl_optimzer is True self.convert_to_lightning_optimizers() + @property + def require_distributed_sampler(self): + if self.accelerator_backend is not None: + return self.accelerator_backend.require_distributed_sampler + return self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu + + @property + def distributed_sampler_kwargs(self): + if self.accelerator_backend is not None: + return self.accelerator_backend.distributed_sampler_kwargs + + if self.use_tpu: + kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + + elif self.use_horovod: + kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) + + else: + world_size = { + "ddp": self.num_nodes * self.num_processes, + "ddp_spawn": self.num_nodes * self.num_processes, + "ddp2": self.num_nodes, + "ddp_cpu": self.num_processes * self.num_nodes + } + assert self.distributed_backend is not None + kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank) + + return kwargs + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 1e2eeea9f456c..3a04a325905a9 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -51,6 +51,8 @@ def _module_available(module_path: str) -> bool: TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') +RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc') +GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group') FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps diff --git a/tests/backends/test_accelerator_connector.py b/tests/backends/test_accelerator_connector.py index 7eeada3d5ddd1..551de95c7e480 100644 --- a/tests/backends/test_accelerator_connector.py +++ b/tests/backends/test_accelerator_connector.py @@ -288,6 +288,7 @@ def test_accelerator_choice_ddp_cpu_custom_cluster(tmpdir): """ Test that we choose the custom cluster even when SLURM or TE flags are around """ + class CustomCluster(ClusterEnvironment): def master_address(self): return 'asdf' @@ -322,7 +323,11 @@ def on_fit_start(self, trainer, pl_module): @mock.patch('torch.cuda.device_count', return_value=0) def test_custom_accelerator(tmpdir): class Accel(Accelerator): - def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True) -> None: + def init_ddp_connection( + self, + global_rank: int, + world_size: int, + is_slurm_managing_tasks: bool = True) -> None: pass class CB(Callback): diff --git a/tests/plugins/test_rpc_plugin.py b/tests/plugins/test_rpc_plugin.py new file mode 100644 index 0000000000000..7411fe9774334 --- /dev/null +++ b/tests/plugins/test_rpc_plugin.py @@ -0,0 +1,124 @@ +import os +from typing import Optional +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin +from pytorch_lightning.utilities import RPC_AVAILABLE +from tests.base.boring_model import BoringModel + + +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +) +@pytest.mark.skipif(not RPC_AVAILABLE, reason="RPC is not available") +def test_rpc_choice(tmpdir, ddp_backend, gpus, num_processes): + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.ddp_plugin, RPCPlugin) + raise RuntimeError('finished plugin check') + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + gpus=gpus, + num_processes=num_processes, + distributed_backend=ddp_backend, + callbacks=[CB()], + plugins=[RPCPlugin()] + ) + + with pytest.raises(RuntimeError, match='finished plugin check'): + trainer.fit(model) + + +class CustomRPCPlugin(RPCPlugin): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.rpc_save_model_count = 0 + self.on_main_rpc_connect_count = 0 + self.worker_optimizer_step_count = 0 + self.is_main_rpc_process_count = 0 + self.on_exit_rpc_process_count = 0 + self.return_after_exit_rpc_process_count = 0 + + def on_accelerator_exit_rpc_process(self, trainer) -> None: + self.on_exit_rpc_process_count += 1 + + def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None: + self.rpc_save_model_count += 1 + + def on_main_rpc_connection(self, trainer) -> None: + self.on_main_rpc_connect_count += 1 + + def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None: + self.worker_optimizer_step_count += 1 + + @property + def is_main_rpc_process(self) -> bool: + self.is_main_rpc_process_count += 1 + return torch.distributed.get_rank() == 0 + + @property + def return_after_exit_rpc_process(self) -> bool: + self.return_after_exit_rpc_process_count += 1 + return False + + def barrier(self, name: Optional[str] = None) -> None: + return + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not RPC_AVAILABLE, reason="RPC is not available") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +def test_rpc_function_calls_ddp(tmpdir): + model = BoringModel() + plugin = CustomRPCPlugin() + max_epochs = 2 + limit_train_batches = 2 + trainer = Trainer( + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=max_epochs, + gpus=2, + distributed_backend='ddp', + plugins=[plugin] + ) + + trainer.fit(model) + if trainer.global_rank == 0: # Main process + assert plugin.rpc_save_model_count == max_epochs + assert plugin.on_main_rpc_connect_count == 1 + assert plugin.worker_optimizer_step_count == max_epochs * limit_train_batches + # Call once at init, and at optim step + assert plugin.is_main_rpc_process_count == 1 + plugin.worker_optimizer_step_count + assert plugin.on_exit_rpc_process_count == 0 + else: # Worker process + assert plugin.rpc_save_model_count == max_epochs + assert plugin.on_main_rpc_connect_count == 0 + # Never signaled by worker, only by main process + assert plugin.worker_optimizer_step_count == 0 + # Call once at init, and at optim step + assert plugin.is_main_rpc_process_count == 1 + (max_epochs * limit_train_batches) + # Called at init + assert plugin.on_exit_rpc_process_count == 1 diff --git a/tests/special_tests.sh b/tests/special_tests.sh index a87e380dbe275..7ea0f77ca2971 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -15,3 +15,4 @@ export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp +python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp \ No newline at end of file