Skip to content

Commit

Permalink
Don't check overflow for bf16 data type (#4512)
Browse files Browse the repository at this point in the history
Always check for fp16
bf16 dynamic range is similar to fp32. don't check overflow by default.

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
BacharL and tjruwase committed Oct 27, 2023
1 parent 8f168c2 commit 244040c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Expand Up @@ -1844,7 +1844,8 @@ def _overflow_clean_up(self, prev_scale):
def _overflow_check_and_loss_scale_update(self):

# First compute norm for all group so we know if there is overflow
self.check_overflow()
if self.dtype == torch.float16:
self.check_overflow()

#loss scaling related computation
prev_scale = self.loss_scale
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Expand Up @@ -1705,7 +1705,8 @@ def step(self, closure=None):
see_memory_usage(f"In step before checking overflow")

# First compute norm for all group so we know if there is overflow
self.check_overflow()
if self.dtype == torch.float16:
self.check_overflow()

prev_scale = self.loss_scale
self._update_scale(self.overflow)
Expand Down

0 comments on commit 244040c

Please sign in to comment.