Skip to content

why does ds have a clip_grad_norm that's an alias to torch.nn.utils.clip_grad_norm_? #611

@stas00

Description

@stas00

Is this a left-over from some older times? So it results in some weird code where you have a special method but it doesn't do anything special and is an alias to torch.nn.utils.clip_grad_norm_ instead.

As transformers is integrating various other engines, and some of those (e.g. fairscale) actually do have clip_grad_norm with a sig:

clip_grad_norm(self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0) -> torch.Tensor

ds having this method which has a different signature is awkward, so I had to code it as:

[...]
                        if hasattr(self.optimizer, "clip_grad_norm") and not self.args.deepspeed:
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
                            # deepspeed has clip_grad_norm aliased to torch.nn.utils.clip_grad_norm_
                            self.optimizer.clip_grad_norm(self.args.max_grad_norm)
                        else:
                            # Revert to normal clipping otherwise, handling Apex or full precision
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
                                self.args.max_grad_norm,
                            )

If this method isn't being overloaded, perhaps it could be removed?

On the other hand I can see that there is no standard on how each engine should make the signatures of functions with the same name/function identical, so you do have a total right to have it in a different way.

I thought I'd just ask if there is a special purpose behind this alias and we surely can leave the slightly strange code as it is.

Thank you!

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions