From 730804c12be64f75fbeb3354803c689645353f50 Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Fri, 10 Mar 2023 21:13:28 +0000 Subject: [PATCH 1/4] Adding torch.optim.Adagrad --- deepspeed/runtime/engine.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 13c669b6de90..a1ff6ab0877d 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1324,7 +1324,7 @@ def _configure_basic_optimizer(self, model_parameters): "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" ) - if self.optimizer_name() in [ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER]: + if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]: torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False) adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT) @@ -1341,15 +1341,10 @@ def _configure_basic_optimizer(self, model_parameters): **optimizer_parameters) else: if self.zero_use_cpu_optimizer(): - if self.optimizer_name() == ADAGRAD_OPTIMIZER: - from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad - optimizer = DeepSpeedCPUAdagrad(model_parameters, - **optimizer_parameters) - else: - from deepspeed.ops.adam import DeepSpeedCPUAdam - optimizer = DeepSpeedCPUAdam(model_parameters, - **optimizer_parameters, - adamw_mode=effective_adam_w_mode) + from deepspeed.ops.adam import DeepSpeedCPUAdam + optimizer = DeepSpeedCPUAdam(model_parameters, + **optimizer_parameters, + adamw_mode=effective_adam_w_mode) else: from deepspeed.ops.adam import FusedAdam @@ -1359,6 +1354,12 @@ def _configure_basic_optimizer(self, model_parameters): adam_w_mode=effective_adam_w_mode, ) + elif self.optimizer_name() == ADAGRAD_OPTIMIZER: + if self.zero_use_cpu_optimizer(): + from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad + optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters) + else: + optimizer = torch.optim.Adagrad(model_parameters, **optimizer_parameters) elif self.optimizer_name() == LAMB_OPTIMIZER: from deepspeed.ops.lamb import FusedLamb From 291fa3de4674f8a26fa23460568c35d77cec5c47 Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Wed, 19 Apr 2023 21:50:38 +0000 Subject: [PATCH 2/4] adding adagrad for zero 1 2 --- deepspeed/runtime/zero/stage_1_and_2.py | 5 ++++- deepspeed/runtime/zero/utils.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 0e7a6115b091..370e9ca3f95e 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -611,7 +611,10 @@ def initialize_optimizer_states(self): self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory( single_grad_partition) if self.cpu_offload else single_grad_partition - self.optimizer.step() + if isinstance(self.optimizer, torch.optim.Adagrad): + self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults) + else: + self.optimizer.step() if not self.cpu_offload: for group in self.single_partition_of_fp32_groups: diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 81a301c8d782..e86bab7f4080 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -35,7 +35,7 @@ class ZeRORuntimeException(Exception): pass -ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam] +ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad] # Add apex FusedAdam to supported list if apex is installed try: From afa608702c4943c9c293e7876d73f0de5b376d8a Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Thu, 27 Apr 2023 23:54:41 +0000 Subject: [PATCH 3/4] Adding Adagrad support to zero 3. --- deepspeed/runtime/zero/stage3.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 969b7ecdf675..4923049172b9 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -858,6 +858,7 @@ def initialize_optimizer_states(self): gradient_buffer = torch.zeros(int(largest_numel), dtype=gradient_dtype, device=self.device) timer_names = set() + is_adagrad = isinstance(self.optimizer, torch.optim.Adagrad) if self.swap_optimizer: self.optimizer_swapper.init_timers() @@ -888,7 +889,8 @@ def initialize_optimizer_states(self): else: self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements) - self._optimizer_step(i) + if not is_adagrad: + self._optimizer_step(i) if swappable_param_subgroup: self._partitioned_params_swap_out(i) @@ -900,6 +902,9 @@ def initialize_optimizer_states(self): f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', force=False) + if is_adagrad: + self.optimizer = torch.optim.Adagrad(self.fp32_partitioned_groups_flat, **self.optimizer.defaults) + self.stop_timers([INIT_OPTIMIZER_TIMER]) self.log_timers(timer_names) From f80d7911849ec9d4d9619a09cf3f2f3a0c65173a Mon Sep 17 00:00:00 2001 From: Joe Mayer Date: Tue, 2 May 2023 17:13:35 +0000 Subject: [PATCH 4/4] Adding documentation and DeepSpeedCPUAdagrad to list. --- deepspeed/runtime/zero/stage3.py | 5 +++++ deepspeed/runtime/zero/stage_1_and_2.py | 3 +++ deepspeed/runtime/zero/utils.py | 5 ++++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 4923049172b9..4f873327b9dc 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -858,6 +858,9 @@ def initialize_optimizer_states(self): gradient_buffer = torch.zeros(int(largest_numel), dtype=gradient_dtype, device=self.device) timer_names = set() + + # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers + # which do lazy initialization of the state at the first call to step. is_adagrad = isinstance(self.optimizer, torch.optim.Adagrad) if self.swap_optimizer: @@ -889,6 +892,7 @@ def initialize_optimizer_states(self): else: self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements) + # Initialize the optimizer states with the flattended fp32 partition. if not is_adagrad: self._optimizer_step(i) @@ -902,6 +906,7 @@ def initialize_optimizer_states(self): f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', force=False) + # Initialize the optimizer states with the flattended fp32 partition. if is_adagrad: self.optimizer = torch.optim.Adagrad(self.fp32_partitioned_groups_flat, **self.optimizer.defaults) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 1fa4d0da7171..f5ab3982b4d6 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -611,6 +611,9 @@ def initialize_optimizer_states(self): self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory( single_grad_partition) if self.cpu_offload else single_grad_partition + # Initialize the optimizer states with the flattended fp32 partition. + # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers + # which do lazy initialization of the state at the first call to step. if isinstance(self.optimizer, torch.optim.Adagrad): self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults) else: diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index e86bab7f4080..0250796f793d 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -10,6 +10,7 @@ from deepspeed import comm as dist from deepspeed.utils import logger from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad from deepspeed.ops.adam import FusedAdam from deepspeed.utils.nvtx import instrument_w_nvtx from deepspeed.accelerator import get_accelerator @@ -35,7 +36,9 @@ class ZeRORuntimeException(Exception): pass -ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad] +ZERO_SUPPORTED_OPTIMIZERS = [ + torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad +] # Add apex FusedAdam to supported list if apex is installed try: