Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AdaFactor: avoid updating group["lr"] attributes #9751

Merged
merged 1 commit into from
Feb 1, 2021

Conversation

ceshine
Copy link
Contributor

@ceshine ceshine commented Jan 22, 2021

This affects Adafactor with relative_step=False and scale_parameter=True.

Updating group["lr"] makes the result of ._get_lr() depends on the previous call, i.e., on the scale of other parameters. This isn't supposed to happen.

What does this PR do?

I've observed weird behaviors when using Adafactor with relative_step=False and scale_parameter=True and an LR scheduler. I think the problem is that the code updates the lr attribute of the current parameter group, and then uses the updated attribute to calculate the next attribute. I don't think this is supposed to happen.

A simple fix would be replacing the update operation with an assignment to a local variable.

I'm not entirely sure if I understand the problem correctly, so I apologize in advance if this is a stupid PR. I'd appreciate it if someone could point out where I am wrong. Thanks!

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@moscow25 @sshleifer

This affects Adafactor with relative_step=False and scale_parameter=True.
Updating group["lr"] makes the result of ._get_lr() depends on the previous call,
i.e., on the scale of other parameters. This isn't supposed to happen.
@sshleifer
Copy link
Contributor

Can you provide evidence that supports the following:

Updating group["lr"] makes the result of ._get_lr() depends on the previous call, i.e., on the scale of other parameters. This isn't supposed to happen.

Thanks!

@ceshine
Copy link
Contributor Author

ceshine commented Jan 22, 2021

Can you provide evidence that supports the following:

Updating group["lr"] makes the result of ._get_lr() depends on the previous call, i.e., on the scale of other parameters. This isn't supposed to happen.

Thanks!

Hi,

Thanks for the quick reply.

This is taken from the AdaFactor paper:

fig-1

fig-2

As you can see, ρ only depends on the step number if we use relative steps. And if we switch to any other learning rate schedules (in my case, linear warmup + cosine decay), it doesn't make sense to make the ρ part depends on the scale of the other parameters, nor can I find any reference of this approach in the paper.

If we (loosely) factor the αt in the original implementation to αi,t, where i indicate the set of parameters corresponding to the for p in group["params"] loop. The original implementation essentially made αi,t depended on αi-1,t (i.e., making ρi,t = αi-1,t).

@ceshine
Copy link
Contributor Author

ceshine commented Jan 23, 2021

I've observed weird behaviors when using Adafactor with relative_step=False and scale_parameter=True and an LR scheduler.

I should probably clarify what I meant by "weird behaviors." The model (T5 v1.1) never converged when trained Adafactor with relative_step=False and scale_parameter=True. After this patch, I managed to get convergence and even better results than the built-in LR schedule in the relative_step=True mode (with warmup_init=True).

@sshleifer
Copy link
Contributor

cc @patrickvonplaten @patil-suraj
This looks like a reasonable change to me!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree very much with your explanation here @ceshine - that's a great fix, thanks!

BTW, if you have some working code for how to train a google/t5v1_1 model I think it would be super helpful to post it here, on the forum or as a community notebook! Many people have been asking for good t5v1_1 training scripts :-)

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thank you!

It doesn't look like any other entry in group gets modified.

Ideally in such situation it's a great opportunity to add a test that detects the problem - i.e. lack of convergence, I can imagine this would be quite tricky to accomplish!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your explanation and providing references. LGTM.

@sgugger sgugger merged commit 8672bcd into huggingface:master Feb 1, 2021
@ceshine
Copy link
Contributor Author

ceshine commented Feb 3, 2021

Thank you all for your time and for accepting the patch! Glad to have made a tiny contribution to this great library.

BTW, if you have some working code for how to train a google/t5v1_1 model I think it would be super helpful to post it here, on the forum or as a community notebook! Many people have been asking for good t5v1_1 training scripts :-)

I don't have anything that is sufficiently readable yet. Nonetheless, I have these notebooks published on Kaggle that use the patched Adafactor: one for T5 v1.1 and one for mT5. They are based on this Github repo, which is quite messy at this moment. The part that set up the optimizer is located here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants