diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index cbc4aa7e62d6..a442154a03bc 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1002,7 +1002,10 @@ def get_type(self): return self._get_optimizer_param('type') def get_mom(self): - return self._get_optimizer_param('betas') + if self.optimizer_name() in ['SGD', 'RMSprop']: + return self._get_optimizer_param('momentum') + else: + return self._get_optimizer_param('betas') def _report_progress(self, step): lr = self.get_lr()