Skip to content

Commit

Permalink
Automatic registration rebased (microsoft#164)
Browse files Browse the repository at this point in the history
* set adamw_mode default true (follows FusedAdam and < 0.3.11 logic) (microsoft#844)

* less scary overflow notice (microsoft#833)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* Add optimizers and schedules to RTD and updated the corresponding part in the website (microsoft#799)

* add optimizers and schedules to rtd

* update ds website and fix links

* add optimizers and schedules to rtd

* update ds website and fix links

* add flops profiler to rtd

* fix

Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>

* small tweaks (microsoft#839)

* Control ZeRO wall clock timers (microsoft#849)

* Control ZeRO wall clock timers

* Disable more ZeRO3 debug prints

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* [WarmupDecayLR] fix log(0) & 1/log(1) bugs (microsoft#772)

* fix log(0) & 1/log(1) bugs

* simplify

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>

* bump to v0.3.12

* Bug fix: Remove client optimizer param_group list item that does not have 'params' (microsoft#827)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* [doc] pipeline doc typos/improvements (microsoft#659)

Admin merging for pure-doc PR that does not trigger build.

* Samyamr/inference hook fix (microsoft#851)

* Fix mis-aligned-grad

When a parameter is not divisible by world size, the partitioned gradients are mis-aligned due to incorrect padding handling. This PR should fix for that.

* Formatting fix

* Adding static_scale test back for Z3, and also changing hidden size to be not divisile by world_size

* also removing alignment from flat fp16 buffers

* Testing for hidden dim alignment

* inference hook fix

* Update stage3.py

* formatting

* [bug-fix] move params to gpu if offload params is turned off

Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* ZeRO Stage 2: Clear reduced gradients (microsoft#856)

* Ensure gradients of other partitions are cleared after reduction

* Remove redundant code

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* Squash stage3 v1 (microsoft#146)

Co-authored-by: Samyam <samyamr@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: Shaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: eltonzheng <eltonz@microsoft.com>

* formatting fix (microsoft#150)

* stage3 bugfix (API) update and simplified FP16 Z3 tests (microsoft#151)

* fp16 Z3 API update and bugfix

* revert debug change

* docs

* filling in allocation docs

* better assumption docs

* doc progress

* config json

* major docs edits

* auto registration works for accessed cases

* working on small models.

* debugging large-model discovery?

* fix discovery to first forward pass?

* return obj ext param

* support None parameters in auto-discovery

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: eltonzheng <eltonz@microsoft.com>
  • Loading branch information
8 people committed Mar 19, 2021
1 parent 4985b2f commit 840856d
Show file tree
Hide file tree
Showing 27 changed files with 797 additions and 179 deletions.
2 changes: 1 addition & 1 deletion deepspeed/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def main(args=None):
# encode world info as base64 to make it easier to pass via command line
world_info_base64 = encode_world_info(active_resources)

multi_node_exec = len(active_resources) > 1
multi_node_exec = True # len(active_resources) > 1

if multi_node_exec and not shutil.which('pdsh'):
raise RuntimeError("pdsh is not installed, unable to proceed")
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/ops/adam/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self,

self.opt_id = DeepSpeedCPUAdam.optimizer_id
DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1

self.adam_w_mode = adamw_mode
self.ds_opt_adam = CPUAdamBuilder().load()

self.ds_opt_adam.create_adam(self.opt_id,
Expand Down
6 changes: 0 additions & 6 deletions deepspeed/profiling/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@

class DeepSpeedFlopsProfilerConfig(DeepSpeedConfigObject):
def __init__(self, param_dict):
"""
docstring
"""
super(DeepSpeedFlopsProfilerConfig, self).__init__()

self.enabled = None
Expand All @@ -27,9 +24,6 @@ def __init__(self, param_dict):
self._initialize(flops_profiler_dict)

def _initialize(self, flops_profiler_dict):
"""
docstring
"""
self.enabled = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_ENABLED,
FLOPS_PROFILER_ENABLED_DEFAULT)
Expand Down
56 changes: 55 additions & 1 deletion deepspeed/profiling/flops_profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,34 @@ class FlopsProfiler(object):
"""Measures the latency, number of estimated floating point operations and parameters of each module in a PyTorch model.
The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how latency, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated latency, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input.
The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package.
When using DeepSpeed for model training, the flops profiler can be configured in the deepspeed_config file and no user code change is required.
If using the profiler as a standalone package, one imports the flops_profiler package and use the APIs.
Here is an example for usage in a typical training workflow:
.. code-block:: python
model = Model()
prof = FlopsProfiler(model)
for step, batch in enumerate(data_loader):
if step == profile_step:
prof.start_profile()
loss = model(batch)
if step == profile_step:
flops = prof.get_total_flops(as_string=True)
params = prof.get_total_params(as_string=True)
prof.print_model_profile(profile_step=profile_step)
prof.end_profile()
loss.backward()
optimizer.step()
To profile a trained model in inference, use the `get_model_profile` API.
Args:
object (torch.nn.Module): The PyTorch model to profile.
Expand Down Expand Up @@ -118,6 +146,9 @@ def get_total_flops(self, as_string=False):
Args:
as_string (bool, optional): whether to output the flops as string. Defaults to False.
Returns:
The number of multiply-accumulate operations of the model forward pass.
"""
total_flops = get_module_flops(self.model)
return macs_to_string(total_flops) if as_string else total_flops
Expand All @@ -127,6 +158,9 @@ def get_total_duration(self, as_string=False):
Args:
as_string (bool, optional): whether to output the duration as string. Defaults to False.
Returns:
The latency of the model forward pass.
"""
total_duration = self.model.__duration__
return duration_to_string(total_duration) if as_string else total_duration
Expand All @@ -136,6 +170,9 @@ def get_total_params(self, as_string=False):
Args:
as_string (bool, optional): whether to output the parameters as string. Defaults to False.
Returns:
The number of parameters in the model.
"""
return params_to_string(
self.model.__params__) if as_string else self.model.__params__
Expand All @@ -146,6 +183,12 @@ def print_model_profile(self,
top_modules=3,
detailed=True):
"""Prints the model graph with the measured profile attached to each module.
Args:
profile_step (int, optional): The global training step at which to profile. Note that warm up steps are needed for accurate time measurement.
module_depth (int, optional): The depth of the model at which to print the aggregated module information. When set to -1, it prints information on the innermost modules (with the maximum depth).
top_modules (int, optional): Limits the aggregated profile output to the number of top modules specified.
detailed (bool, optional): Whether to print the detailed model profile.
"""

total_flops = self.get_total_flops()
Expand Down Expand Up @@ -219,7 +262,7 @@ def del_extra_repr(module):
"\n------------------------------ Detailed Profile ------------------------------"
)
print(
"Each module profile is listed after its name in the follwing order: \nnumber of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency)."
"Each module profile is listed after its name in the following order: \nnumber of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency)."
)
print(
"Note: \n1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.\n2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.\n"
Expand Down Expand Up @@ -749,6 +792,14 @@ def get_model_profile(
):
"""Returns the total MACs and parameters of a model.
Example:
.. code-block:: python
model = torchvision.models.alexnet()
batch_size = 256
macs, params = get_model_profile(model=model, input_res= (batch_size, 3, 224, 224)))
Args:
model ([torch.nn.Module]): the PyTorch model to be profiled.
input_res (list): input shape or input to the input_constructor
Expand All @@ -760,6 +811,9 @@ def get_model_profile(
warm_up (int, optional): the number of warm-up steps before measuring the latency of each module. Defaults to 1.
as_string (bool, optional): whether to print the output as string. Defaults to True.
ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None.
Returns:
The number of multiply-accumulate operations (MACs) and parameters in the model.
"""
assert type(input_res) is tuple
assert len(input_res) >= 1
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
# extra optimizer parameters for adam/adamw
TORCH_ADAM_PARAM = "torch_adam"

# default to adamw logic for adam/adamw optimizers unless user explictly opts out
ADAM_W_MODE = "adam_w_mode"
ADAM_W_MODE_DEFAULT = True


class DeepSpeedConfigError(Exception):
pass
Expand Down
62 changes: 42 additions & 20 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, \
TORCH_ADAM_PARAM
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT

from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
Expand Down Expand Up @@ -591,6 +591,12 @@ def _configure_distributed_model(self, model):
def _configure_optimizer(self, client_optimizer, model_parameters):

if client_optimizer is not None:
client_optimizer.param_groups[:] = [
pg for pg in client_optimizer.param_groups if len(pg["params"]) != 0
]
logger.info(
"Removing param_group that has no 'params'in the client Optimizer")

basic_optimizer = client_optimizer
if self.global_rank == 0:
logger.info('Using client Optimizer as basic optimizer')
Expand Down Expand Up @@ -646,26 +652,30 @@ def _configure_basic_optimizer(self, model_parameters):

if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False)
adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER
# zero-offload torch-adam adam_w_mode optimizer
# T|F T T torch.optim.AdamW
# T|F T F torch.optim.Adam
# T F T|F DeepSpeedCPUAdam(adam_w_mode)
# F F T|F FusedAdam(adam_w_mode)
adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT)

# Optimizer name of Adam forces AdamW logic unless adam_w_mode is explictly set
effective_adam_w_mode = self.optimizer_name(
) == ADAMW_OPTIMIZER or adam_w_mode

if torch_adam:
if adam_w_mode:
optimizer = torch.optim.AdamW(model_parameters,
**optimizer_parameters)
else:
if not effective_adam_w_mode:
optimizer = torch.optim.Adam(model_parameters,
**optimizer_parameters)
elif self.zero_cpu_offload():
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=adam_w_mode)
else:
optimizer = torch.optim.AdamW(model_parameters,
**optimizer_parameters)
else:
optimizer_parameters['adam_w_mode'] = adam_w_mode
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
if self.zero_cpu_offload():
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=effective_adam_w_mode)
else:
from deepspeed.ops.adam import FusedAdam
optimizer = FusedAdam(model_parameters,
**optimizer_parameters,
adam_w_mode=effective_adam_w_mode)

elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb
Expand Down Expand Up @@ -724,6 +734,7 @@ def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage()
log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), ranks=[0])
assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true"
timers = self.timers if self.wall_clock_breakdown() else None

if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
Expand All @@ -742,7 +753,7 @@ def _configure_zero_optimizer(self, optimizer):
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer,
timers=self.timers,
timers=timers,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
Expand All @@ -764,7 +775,7 @@ def _configure_zero_optimizer(self, optimizer):
optimizer = FP16_DeepSpeedZeroOptimizer_Stage3(
self.module,
optimizer,
timers=self.timers,
timers=timers,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
Expand Down Expand Up @@ -892,6 +903,13 @@ def forward(self, *inputs, **kwargs):
if self.module.training and self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state())

if self.zero_optimization_partition_weights():
# Enable automated discovery of external parameters by indicating that
# we are in a forward pass.
for module in self.module.modules():
module._parameters._in_forward = True
pass

if self.wall_clock_breakdown():
self.timers('forward_microstep').start()
self.timers('forward').start()
Expand All @@ -900,11 +918,15 @@ def forward(self, *inputs, **kwargs):
self.tput_timer.start()
loss = self.module(*inputs, **kwargs)

# Reset the ZeRO-3 state if we are only doing forward-passes (ie evaluation).
if self.zero_optimization_partition_weights():
# Reset the ZeRO-3 state if we are only doing forward-passes (ie evaluation).
if not torch._C.is_grad_enabled():
self.optimizer.param_coordinator.reset_step()

# Disable automated discovery of external parameters
for module in self.module.modules():
module._parameters._in_forward = False

if self.wall_clock_breakdown():
self.timers('forward').stop()
self.timers('forward_microstep').stop()
Expand Down
8 changes: 4 additions & 4 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def step_fused_adam(self, closure=None):

if self.overflow:
if self.verbose:
logger.info("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(
prev_scale,
self.cur_scale))
logger.info(
"[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
return self.overflow
combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
norm_groups,
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/fp16/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def update_scale(self, overflow):
optimizer.step()
# Otherwise, don't do anything -- ie, skip iteration
else:
print('OVERFLOW!')
print('fp16 dynamic loss scale overflow!')
# Update loss scale for next iteration
loss_scaler.update_scale(has_overflow)
Expand Down
16 changes: 8 additions & 8 deletions deepspeed/runtime/fp16/unfused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ def step_fused_lamb(self, closure=None):
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
logger.info("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(
prev_scale,
self.cur_scale))
logger.info(
"[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
return self.overflow

combined_scale = self.unscale_and_clip_grads(norm_groups, apply_scale=False)
Expand All @@ -165,10 +165,10 @@ def step(self, closure=None):
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
logger.info("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(
prev_scale,
self.cur_scale))
logger.info(
"[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
return self.overflow

norm_groups = []
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,8 +706,8 @@ def __init__(self,
self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr")
self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr")
self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)]
self.warmup_num_steps = warmup_num_steps
self.inverse_log_warm_up = 1.0 / math.log(warmup_num_steps)
self.warmup_num_steps = max(2, warmup_num_steps)
self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
self.last_batch_iteration = last_batch_iteration

def get_lr(self):
Expand Down
Loading

0 comments on commit 840856d

Please sign in to comment.