From 0a30e13a98b1d9405ffa9e58efa34208266ccb41 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 22 Mar 2021 19:01:15 -0700 Subject: [PATCH 1/7] GatheredParameters can now handle a list of params --- .../runtime/zero/partition_parameters.py | 34 ++++++++++++------- docs/_tutorials/getting-started.md | 4 +-- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index e6cb9199899a..8ff5a83f05e8 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -850,13 +850,13 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): class GatheredParameters: - def __init__(self, param, modifier_rank=None, fwd_module=None, enabled=True): + def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True): """A context that collects a parameter that was partitioned via a :class:`deepspeed.zero.Init` context. The parameter is partitioned again upon exit. Args: - param (``torch.nn.Parameter``): The parameter to collect. + params (``torch.nn.Parameter``): The parameter to collect. modifier_rank (int, optional): If specified, this rank's parameter will be broadcasted after the context. This argument is required if ``param`` is modified all processes should have a consistent view of the data. Defaults @@ -903,35 +903,43 @@ def forward(self, input): if not enabled: return - # This is a no-op, just return. - if not is_zero_param(param): + if not isinstance(params, list): + params = [params] + + # enable if at least one is zero-param, otherwise a noop + if not any(is_zero_param(p) for p in params): self.enabled = False return - self.param = param + self.params = params self.src_rank = None if modifier_rank is not None: if self.param.ds_process_group == torch.distributed.group.WORLD: self.src_rank = modifier_rank else: # A group was specified; convert DP rank to global rank - self.src_rank = _get_global_rank(self.param.ds_process_group, + # XXX: is it safe to use 0th param? + self.src_rank = _get_global_rank(self.params[0].ds_process_group, modifier_rank) self.fwd_module = fwd_module if self.fwd_module is not None: # is a no-op if already registered - register_external_parameter(self.fwd_module, self.param) + for p in self.params: + register_external_parameter(self.fwd_module, p) def __enter__(self): if not self.enabled: return - self.param.all_gather() + for p in self.params: + p.all_gather() def __exit__(self, *exc): if not self.enabled: return - if self.src_rank is not None: - torch.distributed.broadcast(self.param, - self.src_rank, - group=self.param.ds_process_group) - self.param.partition(has_been_updated=self.src_rank is not None) + for p in self.params: + # XXX: can this be done on the list? + if self.src_rank is not None: + torch.distributed.broadcast(p, + self.src_rank, + group=self.param.ds_process_group) + p.partition(has_been_updated=self.src_rank is not None) diff --git a/docs/_tutorials/getting-started.md b/docs/_tutorials/getting-started.md index e12388aaf973..e9b9aa0e627e 100644 --- a/docs/_tutorials/getting-started.md +++ b/docs/_tutorials/getting-started.md @@ -265,8 +265,8 @@ local machine to discover the number of slots available. The `--include` and `--exclude` arguments work as normal, but the user should specify 'localhost' as the hostname. -Also note that `CUDA_VISIBLE_DEVICES` can't be used with DeepSpeed to control -which devices should be used. For example, to use only gpu1 of the current +Also note that `CUDA_VISIBLE_DEVICES` can't be used with DeepSpeed to control +which devices should be used. For example, to use only gpu1 of the current node, do: ```bash deepspeed --include localhost:1 ... From 5a9e789c877f129fe707cf64f0c8ec0b086cef71 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 22 Mar 2021 19:32:40 -0700 Subject: [PATCH 2/7] fix --- deepspeed/runtime/zero/partition_parameters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 8ff5a83f05e8..c9d8e8a7918b 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -914,7 +914,7 @@ def forward(self, input): self.params = params self.src_rank = None if modifier_rank is not None: - if self.param.ds_process_group == torch.distributed.group.WORLD: + if self.params[0].ds_process_group == torch.distributed.group.WORLD: self.src_rank = modifier_rank else: # A group was specified; convert DP rank to global rank @@ -941,5 +941,5 @@ def __exit__(self, *exc): if self.src_rank is not None: torch.distributed.broadcast(p, self.src_rank, - group=self.param.ds_process_group) + group=self.params[0].ds_process_group) p.partition(has_been_updated=self.src_rank is not None) From 3c574585ff9798a012e44a5bb47a84e07ad6601b Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 22 Mar 2021 20:45:51 -0700 Subject: [PATCH 3/7] add Pretrained model loading example --- .../runtime/zero/partition_parameters.py | 43 +++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index c9d8e8a7918b..354f61aa17bd 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -851,18 +851,19 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): class GatheredParameters: def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True): - """A context that collects a parameter that was partitioned via a - :class:`deepspeed.zero.Init` context. The parameter is partitioned + """A context that collects parameters that were partitioned via a + :class:`deepspeed.zero.Init` context. The parameters are partitioned again upon exit. Args: - params (``torch.nn.Parameter``): The parameter to collect. + params (``torch.nn.Parameter``): A single parameter or a list of parameters to collect. + It's assumed that all parameters are zero params. modifier_rank (int, optional): If specified, this rank's parameter will be - broadcasted after the context. This argument is required if ``param`` is - modified all processes should have a consistent view of the data. Defaults + broadcasted on exit from the context. This argument is required if ``params`` are + modified, so that all processes have a consistent view of the data. Defaults to ``None``. - fwd_module (``torch.nn.Module``, optional): If specified, ``param`` will be - registered as an external parameter of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`. + fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be + registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`. enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``. Examples @@ -897,6 +898,33 @@ def forward(self, input): fwd_module=self): y = self.layer2(x, self.layer1.weight) return y + + + #. Pretrained model loading + + .. code-block:: python + + with deepspeed.zero.Init(): + model = MyModel() + + state_dict = torch.load(model_path, map_location="cpu") + + def load(module: nn.Module, prefix=""): + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(state_dict, prefix) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model, prefix="") + + If this approach is not used, then the full model will first get copied to each GPU. For models + bigger than the memory of a single gpu this method is required. """ self.enabled = enabled @@ -918,7 +946,6 @@ def forward(self, input): self.src_rank = modifier_rank else: # A group was specified; convert DP rank to global rank - # XXX: is it safe to use 0th param? self.src_rank = _get_global_rank(self.params[0].ds_process_group, modifier_rank) self.fwd_module = fwd_module From a983ac33d50bee854531c123e74e1352d7945610 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 22 Mar 2021 21:44:50 -0700 Subject: [PATCH 4/7] optimizer - thanks @samyam --- deepspeed/runtime/zero/partition_parameters.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 354f61aa17bd..7f17ef4859f5 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -957,16 +957,18 @@ def load(module: nn.Module, prefix=""): def __enter__(self): if not self.enabled: return - for p in self.params: - p.all_gather() + self.params[0].all_gather(param_list=self.params) def __exit__(self, *exc): if not self.enabled: return - for p in self.params: - # XXX: can this be done on the list? - if self.src_rank is not None: + if self.src_rank is not None: + handles = [ torch.distributed.broadcast(p, self.src_rank, - group=self.params[0].ds_process_group) - p.partition(has_been_updated=self.src_rank is not None) + group=p.ds_process_group, + async_op=True) for p in self.params + ] + for h in handles: + h.wait() + self.params[0].partition(has_been_updated=True) From b834faed6850f1a323101c403a1d80faaab55354 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 22 Mar 2021 21:48:34 -0700 Subject: [PATCH 5/7] simplify --- .../runtime/zero/partition_parameters.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 7f17ef4859f5..260069b28b51 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -962,13 +962,15 @@ def __enter__(self): def __exit__(self, *exc): if not self.enabled: return - if self.src_rank is not None: - handles = [ - torch.distributed.broadcast(p, - self.src_rank, - group=p.ds_process_group, - async_op=True) for p in self.params - ] - for h in handles: - h.wait() - self.params[0].partition(has_been_updated=True) + if self.src_rank is None: + return + + handles = [ + torch.distributeds.broadcast(p, + self.src_rank, + group=p.ds_process_group, + async_op=True) for p in self.params + ] + for h in handles: + h.wait() + self.params[0].partition(has_been_updated=True) From 6dfcd1ddff01d2ae1f47db2304b14700c1e2b4cc Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 22 Mar 2021 21:51:39 -0700 Subject: [PATCH 6/7] fix --- deepspeed/runtime/zero/partition_parameters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 260069b28b51..b2f6c2d218ad 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -966,11 +966,11 @@ def __exit__(self, *exc): return handles = [ - torch.distributeds.broadcast(p, + torch.distributed.broadcast(p, self.src_rank, group=p.ds_process_group, async_op=True) for p in self.params ] for h in handles: h.wait() - self.params[0].partition(has_been_updated=True) + self.params[0].partition(param_list=self.params, has_been_updated=True) From d5ca20072cfe213ab4ec46c7d1525a95ed7f3efe Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 22 Mar 2021 22:11:27 -0700 Subject: [PATCH 7/7] style --- deepspeed/runtime/zero/partition_parameters.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index b2f6c2d218ad..e65253736984 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -967,9 +967,9 @@ def __exit__(self, *exc): handles = [ torch.distributed.broadcast(p, - self.src_rank, - group=p.ds_process_group, - async_op=True) for p in self.params + self.src_rank, + group=p.ds_process_group, + async_op=True) for p in self.params ] for h in handles: h.wait()