-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Description
It looks like things kind of work, except not quite right, and there are a lot of subtle nuances that are so hard to know about when integrating DeepSpeed. I think all these should be made loud and clear - and perhaps a simple full example of a training loop would help, including showing commented out code where the original training code is removed to do it the DeepSpeed-way.
As I am trying to figure out how to make gradient_accumulation_steps work correctly I'm finding all kinds of things I have missed when integrating DeepSpeed into HF Trainer. I will post them here as I find such things:
-
engine's
backwardreturnsloss, which it modifies undergradient_accumulation_steps > 1but this is undocumented.- neither in the API docstring https://github.com/microsoft/DeepSpeed/blob/e60e92eb0a06673748c4cb63fbcf713ddd12fc22/deepspeed/runtime/engine.py#L852-L858
- nor main docs: https://www.deepspeed.ai/getting-started/#training
-
Also it's not documented that the "client" shouldn't scale loss by
gradient_accumulation_stepssince Deepspeed does it inbackward. -
the fact that
lr_scheduler.stephappens inside engine'sstepis not documented in the API- https://github.com/microsoft/DeepSpeed/blob/e60e92eb0a06673748c4cb63fbcf713ddd12fc22/deepspeed/runtime/engine.py#L993-L996
- but is documented in https://www.deepspeed.ai/getting-started/#training
- it might be a good idea to also add an explicit - make sure to remove
lr_scheduler.step()from your code if using DeepSpeed's scheduler.
-
the "client" must not skip
engine.step()whengradient_accumulation_steps > 1, and since this is an integration of many methods this leads to a complicated brittle code:
if self.deepspeed:
self.deepspeed.step()
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= self.args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
# Gradient clipping
if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0 and not self.deepspeed:
# deepspeed does its own clipping
if self.use_amp:
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
[...]
else:
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
self.args.max_grad_norm,
)
# Optimizer step
if self.deepspeed:
pass # called outside the loop
[...]
else:
self.optimizer.step()
if not self.deepspeed:
self.lr_scheduler.step()
model.zero_grad()
[...]
After fixing the above 4 I managed to get the same weights and loss with bs=8/grad_accum=1 and bs=4/grad_accum=2. Yay!