Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instability when resuming trains #13

Closed
angusturner opened this issue Feb 26, 2023 · 22 comments
Closed

Instability when resuming trains #13

angusturner opened this issue Feb 26, 2023 · 22 comments

Comments

@angusturner
Copy link

Hi, I have been testing this out on some diffusion models I am training.

Convergence seems decent (somewhat faster than AdamW, using 1/10th the learning rate, 10x weight decay).
However, I recently paused a few experiments and tried to resume, and the loss explodes immediately. I do not
face this issue when resuming AdamW trains.

I have also found it necessary to use an LR warm-up period in my trains (even with the 1/10th loss), which again,
is not required in AdamW. I'll try to do a bit more digging to see if I can track down the source of instability - however for experiment resuming, surely if I load the optimizer state correctly things should resume as expected?

My only thought is whether something could be going wrong with saving of EMA / moving average statistics? If I get a chance to dig into this more I'll let you know what I find. (Possibly I am doing something wrong).

@xiangning-chen
Copy link
Contributor

Hi, thanks for testing the Lion optimizer.
Just wondering what's the betas for AdamW and Lion, are they both as default, i.e., (0.9, 0.999) for AdamW, and (0.9, 0.99) for Lion?

@angusturner
Copy link
Author

No worries! It's a really cool idea, so it will be nice if it can consistently improve on Adam! It is too early to say in my own experiments. For Lion I am using default (0.9, 0.99). For AdamW I have had successful runs with the defaults (0.9, 0.999). and also using (0.9, 0.99). Still tuning, so not sure what the optimum values are.

@xiangning-chen
Copy link
Contributor

Sounds good. Another betas setting (0.95, 0.98) can help with the training stability.

@lucidrains
Copy link
Owner

@xiangning-chen have you run into this stability issue yourself?

@xiangning-chen
Copy link
Contributor

Not on diffusion model, when I encountered instability I just lower the learning rate. On language modeling, I found that lowering the beta2 improves stability for both AdamW and Lion.

@lucidrains
Copy link
Owner

@angusturner have you tried the suggestions?

@clementpoiret
Copy link

Betas (0.95, 0.98) considerably lowered my instabilities, thanks for the tip @xiangning-chen !
I still have heavy instabilities as soon as I unfreeze some layers during fine-tuning sessions, so I just try to lower the LR.

@lucidrains
Copy link
Owner

@clementpoiret nice! 🙏

@xiangning-chen
Copy link
Contributor

@clementpoiret Thanks for the update! For fine-tuning, are you referring to using Lion to fine-tune an AdamW trained model?

@clementpoiret
Copy link

@xiangning-chen, Yup, it's a pretrained EfficientNet from timm. I replaced the classifier with my own MLP.
For 1/4th of the epochs I train only my MLP, then I gradually unfreeze the last two blocks based on an EarlyStopping strategy. You can see that everything explodes, and it does not happen using AdamW, or when I only train the MLP without fine-tuning layers. I also use gradient clipping.
image

@xiangning-chen
Copy link
Contributor

Did you load the AdamW optimizer status from pre-training?

@clementpoiret
Copy link

Not at all, I just loaded the weights from timm. (Please note that the strange double descent at the start is normal, I also have it using other optimizers.)

@lucidrains
Copy link
Owner

@xiangning-chen do you have any experiments showing that loading adam momentum into lion for fine tuning is better? happy to add that feature, provided it isn't just a hunch you have

@lucidrains
Copy link
Owner

oh, actually, loading adam optimizer state works fine as is, ok no worries

@xiangning-chen
Copy link
Contributor

do you have any experiments showing that loading adam momentum into lion for fine tuning is better? happy to add that feature, provided it isn't just a hunch you have

Oh I meant when using both Adam for pre-training and fine-tuning, loading the 1st and 2nd moments are helpful. Never tried loading the Adam momentum into Lion as their EMA parameters are different.

@flymark2010
Copy link

flymark2010 commented Mar 20, 2023

Hi, I have been testing this out on some diffusion models I am training.

Convergence seems decent (somewhat faster than AdamW, using 1/10th the learning rate, 10x weight decay). However, I recently paused a few experiments and tried to resume, and the loss explodes immediately. I do not face this issue when resuming AdamW trains.

I have also found it necessary to use an LR warm-up period in my trains (even with the 1/10th loss), which again, is not required in AdamW. I'll try to do a bit more digging to see if I can track down the source of instability - however for experiment resuming, surely if I load the optimizer state correctly things should resume as expected?

My only thought is whether something could be going wrong with saving of EMA / moving average statistics? If I get a chance to dig into this more I'll let you know what I find. (Possibly I am doing something wrong).

I got the same problem that loss explodes immediately when resuming checkpoints when use_triton=True in NLP task. Then setting use_triton=False, the loss becomes normal.

@mitchellnw
Copy link

I think this is because of #24 (comment)

@ipoletaev
Copy link

ipoletaev commented May 9, 2023

To your point @mitchellnw , from triton:

When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the reset_to_zero argument, which
reset the value of the provided tensor to zero before running any configuration.

Looks like it is what happens - multiple kernel launches update model weights wrongly a few times in a row.
Also, shall we close this one as a duplicate of #20 ?

UPD: removing autotune and setting fixed BLOCK_SIZE=1024 worked and resolved the issue.

@lucidrains
Copy link
Owner

@ipoletaev ah yes, Mitchell filled me in on this issue through email this morning

do you want to see if 6ab873a resolves it? the more permanent solution would be to submit a PR to Triton to clone all tensors prior to auto-tuning

@ipoletaev
Copy link

IMO it makes sense to just remove autotune and keep it simple: where the user specifies the block size they need.

@lucidrains
Copy link
Owner

@ipoletaev yea true, i'll do that if this hack doesn't do the trick

@lucidrains
Copy link
Owner

@ipoletaev actually, you are right, this battle isn't worth fighting

2226ec8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants