Skip to content

[docs] critical API documentation is missing #776

@stas00

Description

@stas00

It looks like things kind of work, except not quite right, and there are a lot of subtle nuances that are so hard to know about when integrating DeepSpeed. I think all these should be made loud and clear - and perhaps a simple full example of a training loop would help, including showing commented out code where the original training code is removed to do it the DeepSpeed-way.

As I am trying to figure out how to make gradient_accumulation_steps work correctly I'm finding all kinds of things I have missed when integrating DeepSpeed into HF Trainer. I will post them here as I find such things:

  1. engine's backward returns loss, which it modifies under gradient_accumulation_steps > 1 but this is undocumented.

  2. Also it's not documented that the "client" shouldn't scale loss by gradient_accumulation_steps since Deepspeed does it in backward.

  3. the fact that lr_scheduler.step happens inside engine's step is not documented in the API

  4. the "client" must not skip engine.step() when gradient_accumulation_steps > 1, and since this is an integration of many methods this leads to a complicated brittle code:

                if self.deepspeed:
                    self.deepspeed.step()

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    steps_in_epoch <= self.args.gradient_accumulation_steps
                    and (step + 1) == steps_in_epoch
                ):
                    # Gradient clipping
                    if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0 and not self.deepspeed:
                        # deepspeed does its own clipping
                        if self.use_amp:
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)
                        [...]
                        else:
                            # Revert to normal clipping otherwise, handling Apex or full precision
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
                                self.args.max_grad_norm,
                            )

                    # Optimizer step
                    if self.deepspeed:
                        pass # called outside the loop
                    [...]
                    else:
                        self.optimizer.step()

                    if not self.deepspeed:
                        self.lr_scheduler.step()

                    model.zero_grad()
                    [...]

After fixing the above 4 I managed to get the same weights and loss with bs=8/grad_accum=1 and bs=4/grad_accum=2. Yay!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions