Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support new optimizer Schedule free #1250

Merged
merged 9 commits into from May 4, 2024

Conversation

sdbds
Copy link
Contributor

@sdbds sdbds commented Apr 9, 2024

From facebookresearch,no need set any schedule
https://github.com/facebookresearch/schedule_free

For SGD, a learning rate 10x-50x larger than classical rates seems to be a good starting point.
For AdamW, learnings rates in the range 1x-10x larger than with schedule based approaches seem to work.

Do not use lr_schedule_warmup and with it in optimizer args
warmup_steps
warmup_steps (int): Enables a linear learning rate warmup (default 0).

other optimizer args for this

r (float): Use polynomial weighting in the average 
            with power r (default 0).
weight_lr_power (float): During warmup, the weights in the average will
            be equal to lr raised to this power. Set to 0 for no weighting
            (default 2.0)

Dependences

pip install schedulefree

Need lastest accelerate
pip install git+https://github.com/huggingface/accelerate.git@main#egg=accelerate

image
image

Fix TYPO in library/train_util.py line 3090
grandient to gradient

@kohya-ss
Copy link
Owner

kohya-ss commented Apr 9, 2024

Thank you for this! I'll merge it into dev or main once accelerate supports this feature in the release version!

@feffy380
Copy link
Contributor

feffy380 commented Apr 9, 2024

This implementation is not correct. optimizer.train() needs to be called somewhere before the training loop, but this PR blindly inserts it after every unet.train(), which is not correct (especially in train_network.py where we're not even training the unet). It's most obvious in the textual inversion scripts where optimizer.train() is never called at all.

I think you also need to call optimizer.eval() before saving the model because calling train() and eval() modifies the weights. Based on the example script, we need to save the eval() version of the weights.

And the "support" added by accelerate is just a convenience function that passes the train()/eval() calls to the internal optimizer object. You can call optimizer.optimizer.train() directly instead of waiting for accelerate to update.

@gesen2egee
Copy link
Contributor

gesen2egee commented Apr 10, 2024

Edit
I see the problem now.
facebookresearch/schedule_free#5

===

Currently running will result in such an error.
accelerate 0.30.0.dev0
schedulefree 1.2.1

steps: 0%| | 0/3000 [00:00<?, ?it/s]
epoch 1/9
warnings.warn(
Traceback (most recent call last):
File "D:\SDXL\sd-scripts\train_network.py", line 1128, in
trainer.train(args)
File "D:\SDXL\sd-scripts\train_network.py", line 921, in train
optimizer.step()
File "D:\SDXL\sd-scripts\venv\lib\site-packages\accelerate\optimizer.py", line 148, in step
self.scaler.step(self.optimizer, closure)
File "D:\SDXL\sd-scripts\venv\lib\site-packages\torch\cuda\amp\grad_scaler.py", line 416, in step
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
File "D:\SDXL\sd-scripts\venv\lib\site-packages\torch\cuda\amp\grad_scaler.py", line 315, in _maybe_opt_step
retval = optimizer.step(*args, **kwargs)
File "D:\SDXL\sd-scripts\venv\lib\site-packages\accelerate\optimizer.py", line 203, in patched_step
return method(*args, **kwargs)
File "D:\SDXL\sd-scripts\venv\lib\site-packages\torch\optim\lr_scheduler.py", line 68, in wrapper
return wrapped(*args, **kwargs)
File "D:\SDXL\sd-scripts\venv\lib\site-packages\torch\optim\optimizer.py", line 373, in wrapper
out = func(*args, **kwargs)
File "D:\SDXL\sd-scripts\venv\lib\site-packages\schedulefree\adamw_schedulefree.py", line 122, in step
ckp1 = weight/weight_sum
ZeroDivisionError: float division by zero
steps: 0%| | 0/3000 [00:36<?, ?it/s]

@sdbds
Copy link
Contributor Author

sdbds commented Apr 10, 2024

This implementation is not correct. optimizer.train() needs to be called somewhere before the training loop, but this PR blindly inserts it after every unet.train(), which is not correct (especially in train_network.py where we're not even training the unet). It's most obvious in the textual inversion scripts where optimizer.train() is never called at all.

I think you also need to call optimizer.eval() before saving the model because calling train() and eval() modifies the weights. Based on the example script, we need to save the eval() version of the weights.

And the "support" added by accelerate is just a convenience function that passes the train()/eval() calls to the internal optimizer object. You can call optimizer.optimizer.train() directly instead of waiting for accelerate to update.

Thank you for reviewing!
I did make some mistakes, but I'd like to say what I think.
1、I change sdxl_train.py first, it use training_models to contain unet and textencoder if we choose.
And loop i had insert is in every epoch start and before train loop start, it is no problem.
optimizer.train() just switch optimizer status when train or eval.(and default use train model)
image
(OK,For greater clarity, I will put it before loading train_model loop just once)

2、i actually forget to add eval before save models.
i will add it in before every save models and every sample images

3、i missed TI...ill add it soon

4、I know convenience function about accelerate because that updates is my request, it can keep completion about code.

@sdbds
Copy link
Contributor Author

sdbds commented Apr 10, 2024

Edit I see the problem now. facebookresearch/schedule_free#5

===

Currently running will result in such an error. accelerate 0.30.0.dev0 schedulefree 1.2.1

steps: 0%| | 0/3000 [00:00<?, ?it/s] epoch 1/9 warnings.warn( Traceback (most recent call last): File "D:\SDXL\sd-scripts\train_network.py", line 1128, in trainer.train(args) File "D:\SDXL\sd-scripts\train_network.py", line 921, in train optimizer.step() File "D:\SDXL\sd-scripts\venv\lib\site-packages\accelerate\optimizer.py", line 148, in step self.scaler.step(self.optimizer, closure) File "D:\SDXL\sd-scripts\venv\lib\site-packages\torch\cuda\amp\grad_scaler.py", line 416, in step retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) File "D:\SDXL\sd-scripts\venv\lib\site-packages\torch\cuda\amp\grad_scaler.py", line 315, in _maybe_opt_step retval = optimizer.step(*args, **kwargs) File "D:\SDXL\sd-scripts\venv\lib\site-packages\accelerate\optimizer.py", line 203, in patched_step return method(*args, **kwargs) File "D:\SDXL\sd-scripts\venv\lib\site-packages\torch\optim\lr_scheduler.py", line 68, in wrapper return wrapped(*args, **kwargs) File "D:\SDXL\sd-scripts\venv\lib\site-packages\torch\optim\optimizer.py", line 373, in wrapper out = func(*args, **kwargs) File "D:\SDXL\sd-scripts\venv\lib\site-packages\schedulefree\adamw_schedulefree.py", line 122, in step ckp1 = weight/weight_sum ZeroDivisionError: float division by zero steps: 0%| | 0/3000 [00:36<?, ?it/s]

set lr_scheduler=null and not use lr_warmup will fix this

@sdbds
Copy link
Contributor Author

sdbds commented Apr 10, 2024

image

I add eval() in train loop because args.save_every_n_steps so maybe anysteps save in loop.
i add it before save and after train.

at the time i had to move train() to train loop before train start for eval() in loop maybe refresh train() status

Add Ti part.

@sdbds
Copy link
Contributor Author

sdbds commented Apr 12, 2024

Thank you for this! I'll merge it into dev or main once accelerate supports this feature in the release version!

@kohya-ss

accelerate release newest version to support schedule free,and this pr had checked and fixed some mistakes now.

https://github.com/huggingface/accelerate/releases/tag/v0.29.2

@kohya-ss
Copy link
Owner

Thank you for update! I will review and merge as soon as I have time. I will also consider whether the code can be made common.

@gesen2egee
Copy link
Contributor

args.optimizer_type.lower().endswith("scheduleFree"):
I'm not sure if this is a problem.

@sdbds
Copy link
Contributor Author

sdbds commented Apr 18, 2024

args.optimizer_type.lower().endswith("scheduleFree"): I'm not sure if this is a problem.

It doesn't really matter, but I modified it for aesthetic purposes

@feffy380
Copy link
Contributor

feffy380 commented Apr 18, 2024

It does matter because optimizer_type.lower() is all lowercase and will never match scheduleFree

@rockerBOO
Copy link
Contributor

When it goes to do sample_images it should also set the optimizer to eval as well? I see big spikes in the loss after generating samples.

@sdbds
Copy link
Contributor Author

sdbds commented Apr 30, 2024

When it goes to do sample_images it should also set the optimizer to eval as well? I see big spikes in the loss after generating samples.

Actually eval() switching is before saving, sampler is processed after saving.
And train() is before the start of each training step.
I guess it change because from eval() sampler to train().

@rockerBOO
Copy link
Contributor

rockerBOO commented May 1, 2024

Actually eval() switching is before saving, sampler is processed after saving. And train() is before the start of each training step. I guess it change because from eval() sampler to train().

Ahh my bad, I see that now. Must be something odd with my logging.

@kohya-ss kohya-ss changed the base branch from main to scheduler-free-opt May 4, 2024 09:55
@kohya-ss kohya-ss merged commit c687126 into kohya-ss:scheduler-free-opt May 4, 2024
1 check passed
@kohya-ss
Copy link
Owner

kohya-ss commented May 4, 2024

I merged this into scheduler-free-opt branch. Sorry for the delay.

I also made some modifications to simplify the code. I'll do some more testing and plan to merge it into the dev branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants