Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BF16_Optimizer: add support for bf16 grad acc #4713

Merged
merged 3 commits into from Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 15 additions & 3 deletions deepspeed/runtime/bf16_optimizer.py
Expand Up @@ -37,14 +37,19 @@ def __init__(self,
norm_type=2,
allgather_bucket_size=5000000000,
dp_process_group=None,
timers=None):
timers=None,
grad_acc_dtype=None):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
self.timers = timers
self.optimizer = init_optimizer
self.param_names = param_names
self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)

assert grad_acc_dtype in [torch.float32, torch.bfloat16
], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}"
self.grad_acc_dtype = grad_acc_dtype

self.clip_grad = clip_grad
self.norm_type = norm_type
self.mpu = mpu
Expand Down Expand Up @@ -119,7 +124,8 @@ def _setup_for_real_optimizer(self):
num_elem_list = [t.numel() for t in self.bf16_groups[i]]

# create fp32 gradients
self.fp32_groups_gradients_flat.append(torch.zeros_like(self.bf16_groups_flat[i], dtype=torch.float32))
self.fp32_groups_gradients_flat.append(
torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype))

# track individual fp32 gradients for entire model
fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],
Expand Down Expand Up @@ -204,10 +210,16 @@ def initialize_optimizer_states(self):
"""
for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
self.fp32_groups_gradient_flat_partition):
param_partition.grad = grad_partition
# In case of grad acc dtype different than FP32, need to cast to high precision.
param_partition.grad = grad_partition.to(
param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition

self.optimizer.step()

if self.grad_acc_dtype is not torch.float32:
for param_partition in self.fp32_groups_flat_partition:
param_partition.grad = None

self.clear_hp_grads()

def _split_flat_tensor(self, flat_tensor, num_elem_list):
Expand Down
15 changes: 11 additions & 4 deletions deepspeed/runtime/engine.py
Expand Up @@ -1180,9 +1180,15 @@ def _do_optimizer_sanity_check(self, basic_optimizer):
# data type checks
elif model_dtype == grad_accum_dtype:
if model_dtype == torch.bfloat16:
raise NotImplementedError(
"Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation"
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
)
if self.pipeline_parallelism:
logger.warning(
"**** BF16 gradient accumulation is not safe numerically with large number of accumulation steps, proceed with caution *****"
)
return BFLOAT16
else:
raise NotImplementedError(
"Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation"
)
if model_dtype == torch.float16:
return FP16
# else optimizer_wrapper = None
Expand Down Expand Up @@ -1444,7 +1450,8 @@ def _configure_bf16_optimizer(self, optimizer):
clip_grad=clip_grad,
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.seq_data_parallel_group,
timers=timers)
timers=timers,
grad_acc_dtype=self.get_data_types()[1])

return optimizer

Expand Down