Skip to content

Commit

Permalink
[feat] pp 1/n (Lightning-AI#5016)
Browse files Browse the repository at this point in the history
* Added changes for RPC plugin

* Add missing kwargs

* Fix code format

* Loading refactors by introducing is_distributed var, fix optimizer step flow

* Add rpc guard

* Added docstrings and typing

* resolve comments

* Add additional rpc hook, refactor name of exit process hook for clarity

* remove annotation

* Modify behaviour to allow optional return, add test for rpc plugin

* resolve tests

* rename is_ddp_based

* update

* update for windows

* update

* resolve test

* code smell

* Revert back to init_ddp_connection for backwards compat

* Swap to explicit name for property

* Add missing speed parity increase for CI variability, fix call counts for child process

Co-authored-by: tchaton <thomas@grid.ai>
  • Loading branch information
SeanNaren and tchaton committed Dec 8, 2020
1 parent ddd3eda commit ee9b3fe
Show file tree
Hide file tree
Showing 23 changed files with 560 additions and 60 deletions.
1 change: 1 addition & 0 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderManualModel,
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
)


Expand Down
17 changes: 13 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional, Union
Expand All @@ -21,10 +20,8 @@
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict

if torch.distributed.is_available():
Expand Down Expand Up @@ -222,6 +219,18 @@ def __setstate__(self, d):
def on_save(self, checkpoint):
return checkpoint

@property
def rpc_enabled(self):
return self.ddp_plugin is not None and isinstance(self.ddp_plugin, RPCPlugin)

@property
def distributed_sampler_kwargs(self):
raise NotImplementedError

@property
def require_distributed_sampler(self):
raise NotImplementedError

@contextmanager
def block_ddp_plugin_sync_behaviour(self):
"""
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,7 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return tensor

@property
def require_distributed_sampler(self):
return False
37 changes: 34 additions & 3 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available

Expand Down Expand Up @@ -101,9 +102,11 @@ def set_world_ranks(self, process_idx):
def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)

def model_to_device(self, model, process_idx):
def init_device(self, process_idx):
self.trainer.root_gpu = process_idx
torch.cuda.set_device(self.trainer.root_gpu)

def model_to_device(self, model):
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
Expand Down Expand Up @@ -133,6 +136,9 @@ def ddp_train(self, process_idx, mp_queue, model):
# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# Initialize cuda device
self.init_device(process_idx)

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
Expand All @@ -143,6 +149,15 @@ def ddp_train(self, process_idx, mp_queue, model):
self.trainer.is_slurm_managing_tasks
)

if isinstance(self.ddp_plugin, RPCPlugin):
if not self.ddp_plugin.is_main_rpc_process:
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
self.ddp_plugin.exit_rpc_process()
if self.ddp_plugin.return_after_exit_rpc_process:
return
else:
self.ddp_plugin.on_main_rpc_connection(self.trainer)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

Expand All @@ -158,12 +173,14 @@ def ddp_train(self, process_idx, mp_queue, model):
model = self.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)
self.model_to_device(model)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.setup_optimizers(model)

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

Expand All @@ -189,7 +206,7 @@ def ddp_train(self, process_idx, mp_queue, model):
return results

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model
Expand Down Expand Up @@ -219,3 +236,17 @@ def sync_tensor(self,

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(
num_replicas=self.trainer.num_nodes,
rank=self.trainer.global_rank
)
if self.ddp_plugin is not None:
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
return distributed_sampler_kwargs

@property
def require_distributed_sampler(self):
return True
46 changes: 39 additions & 7 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -162,8 +163,11 @@ def _step(self, args):
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
if self.rpc_enabled:
# Allow RPC to handle barrier on main RPC processes
self.ddp_plugin.barrier()
elif torch_distrib.is_initialized():
torch_distrib.barrier(group=self.ddp_plugin.data_parallel_group)

def _check_can_spawn_children(self):
if self._has_spawned_children:
Expand All @@ -177,9 +181,11 @@ def set_world_ranks(self, process_idx):
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
def init_device(self, process_idx):
self.trainer.root_gpu = self.trainer.data_parallel_device_ids[self.trainer.local_rank]
torch.cuda.set_device(self.trainer.root_gpu)

def model_to_device(self, model):
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
Expand All @@ -192,12 +198,12 @@ def on_train_end(self):
def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
torch_distrib.barrier()
self.barrier('early_stopping')
should_stop = stop == self.trainer.world_size
return should_stop

def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)
return self.dist.broadcast(obj, group=self.ddp_plugin.data_parallel_group)

def ddp_train(self, process_idx, model):
"""
Expand Down Expand Up @@ -226,6 +232,9 @@ def ddp_train(self, process_idx, model):
# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# Initialize cuda device
self.init_device(process_idx)

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
Expand All @@ -236,6 +245,15 @@ def ddp_train(self, process_idx, model):
self.trainer.is_slurm_managing_tasks
)

if isinstance(self.ddp_plugin, RPCPlugin):
if not self.ddp_plugin.is_main_rpc_process:
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
self.ddp_plugin.exit_rpc_process()
if self.ddp_plugin.return_after_exit_rpc_process:
return
else:
self.ddp_plugin.on_main_rpc_connection(self.trainer)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

Expand All @@ -251,7 +269,7 @@ def ddp_train(self, process_idx, model):
model = self.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)
self.model_to_device(model)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand Down Expand Up @@ -284,7 +302,7 @@ def ddp_train(self, process_idx, model):
return results

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model
Expand Down Expand Up @@ -317,3 +335,17 @@ def sync_tensor(self,

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(
num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
rank=self.trainer.global_rank
)
if self.ddp_plugin is not None:
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
return distributed_sampler_kwargs

@property
def require_distributed_sampler(self):
return True
28 changes: 27 additions & 1 deletion pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import (
find_free_network_port,
Expand Down Expand Up @@ -107,6 +108,15 @@ def ddp_train(self, process_idx, mp_queue, model):
self.trainer.is_slurm_managing_tasks
)

if isinstance(self.ddp_plugin, RPCPlugin):
if not self.ddp_plugin.is_main_rpc_process:
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
self.ddp_plugin.exit_rpc_process()
if self.ddp_plugin.return_after_exit_rpc_process:
return
else:
self.ddp_plugin.on_main_rpc_connection(self.trainer)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

Expand All @@ -128,6 +138,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# allow for lr schedulers as well
self.setup_optimizers(model)

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

Expand Down Expand Up @@ -221,7 +233,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
mp_queue.put(results)

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model
Expand Down Expand Up @@ -251,3 +263,17 @@ def sync_tensor(self,

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(
num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
rank=self.trainer.global_rank
)
if self.ddp_plugin is not None:
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
return distributed_sampler_kwargs

@property
def require_distributed_sampler(self):
return True
34 changes: 31 additions & 3 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available

Expand Down Expand Up @@ -62,9 +63,11 @@ def set_world_ranks(self, process_idx):
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
def init_device(self, process_idx):
self.trainer.root_gpu = process_idx
torch.cuda.set_device(self.trainer.root_gpu)

def model_to_device(self, model):
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
Expand Down Expand Up @@ -136,6 +139,15 @@ def ddp_train(self, process_idx, model):
self.trainer.is_slurm_managing_tasks
)

if isinstance(self.ddp_plugin, RPCPlugin):
if not self.ddp_plugin.is_main_rpc_process:
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
self.ddp_plugin.exit_rpc_process()
if self.ddp_plugin.return_after_exit_rpc_process:
return
else:
self.ddp_plugin.on_main_rpc_connection(self.trainer)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

Expand All @@ -151,12 +163,14 @@ def ddp_train(self, process_idx, model):
model = self.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)
self.model_to_device(model)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.setup_optimizers(model)

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

Expand All @@ -183,7 +197,7 @@ def ddp_train(self, process_idx, model):
return results

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model
Expand Down Expand Up @@ -213,3 +227,17 @@ def sync_tensor(self,

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(
num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
rank=self.trainer.global_rank
)
if self.ddp_plugin is not None:
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
return distributed_sampler_kwargs

@property
def require_distributed_sampler(self):
return True
Loading

0 comments on commit ee9b3fe

Please sign in to comment.