Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ def _configure_basic_optimizer(self, model_parameters):
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
)

if self.optimizer_name() in [ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False)
adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT)

Expand All @@ -1214,14 +1214,10 @@ def _configure_basic_optimizer(self, model_parameters):
optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters)
else:
if self.zero_use_cpu_optimizer():
if self.optimizer_name() == ADAGRAD_OPTIMIZER:
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters)
else:
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=effective_adam_w_mode)
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=effective_adam_w_mode)
else:
from deepspeed.ops.adam import FusedAdam

Expand All @@ -1231,6 +1227,12 @@ def _configure_basic_optimizer(self, model_parameters):
adam_w_mode=effective_adam_w_mode,
)

elif self.optimizer_name() == ADAGRAD_OPTIMIZER:
if self.zero_use_cpu_optimizer():
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters)
else:
optimizer = torch.optim.Adagrad(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb

Expand Down
12 changes: 11 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,10 @@ def initialize_optimizer_states(self):

timer_names = set()

# State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
# which do lazy initialization of the state at the first call to step.
is_adagrad = isinstance(self.optimizer, torch.optim.Adagrad)

if self.swap_optimizer:
self.optimizer_swapper.init_timers()

Expand Down Expand Up @@ -888,7 +892,9 @@ def initialize_optimizer_states(self):
else:
self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements)

self._optimizer_step(i)
# Initialize the optimizer states with the flattended fp32 partition.
if not is_adagrad:
self._optimizer_step(i)

if swappable_param_subgroup:
self._partitioned_params_swap_out(i)
Expand All @@ -900,6 +906,10 @@ def initialize_optimizer_states(self):
f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}',
force=False)

# Initialize the optimizer states with the flattended fp32 partition.
if is_adagrad:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add documentation for why we do this for only Adagrad, and not all optimizers.

self.optimizer = torch.optim.Adagrad(self.fp32_partitioned_groups_flat, **self.optimizer.defaults)

self.stop_timers([INIT_OPTIMIZER_TIMER])
self.log_timers(timer_names)

Expand Down
8 changes: 7 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,13 @@ def initialize_optimizer_states(self):
self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory(
single_grad_partition) if self.cpu_offload else single_grad_partition

self.optimizer.step()
# Initialize the optimizer states with the flattended fp32 partition.
# State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
# which do lazy initialization of the state at the first call to step.
if isinstance(self.optimizer, torch.optim.Adagrad):
Comment thread
tjruwase marked this conversation as resolved.
self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults)
else:
self.optimizer.step()

if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
Expand Down
5 changes: 4 additions & 1 deletion deepspeed/runtime/zero/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from deepspeed import comm as dist
from deepspeed.utils import logger
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
from deepspeed.ops.adam import FusedAdam
from deepspeed.utils.nvtx import instrument_w_nvtx
from deepspeed.accelerator import get_accelerator
Expand All @@ -35,7 +36,9 @@ class ZeRORuntimeException(Exception):
pass


ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam]
ZERO_SUPPORTED_OPTIMIZERS = [
torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad
]

# Add apex FusedAdam to supported list if apex is installed
try:
Expand Down