Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) #1057

Merged
merged 31 commits into from
Jan 20, 2024

Conversation

KohakuBlueleaf
Copy link
Contributor

Based on the PR for sd-webui on utilizing FP8, we can assume that we can also apply FP8 on the base model of train_network.
Since we don't need to update the weight of it, just need to compute things.

So I implement the first version of fp8 support in your framework and it works well!!

Actually I uploaded a experimental model very early for fp8 training, which only comsume 6.x GB vram when training SDXL with LyCORIS/LoRA.
If we cache the latent and TE, we can even use 4.4 GB vram to train all the things which is incredible.
(All the above experiments are done in 1024x1024 bs1 setup)

I think this is good for the community of SDXL.

BTW, my implementation is rely on the autocast right now which may be a good news for old GPU user or IPEX user. (But I think IPEX actually have autocast support, just slower then manual cast)

If you think it is also good idea I can try to make PR for manual cast. I already tried it can be used for training but may need some modification.

@FurkanGozukara
Copy link

Amazing work. Have you compared results with full BF16 training?

@KohakuBlueleaf
Copy link
Contributor Author

Amazing work. Have you compared results with full BF16 training?

Have done some comparisons few months ago(

They do have some subtle difference but hard to say it is quality difference or performance differences

It is just, difference

@FurkanGozukara
Copy link

thanks for reply. so it will train fp8 and then save as fp16 as usual?

@KohakuBlueleaf
Copy link
Contributor Author

thanks for reply. so it will train fp8 and then save as fp16 as usual?

Trainable part will not be converted to fp8

@FurkanGozukara
Copy link

thanks for reply. so it will train fp8 and then save as fp16 as usual?

Trainable part will not be converted to fp8

can you elaborate more? for example when training with DreamBooth of SDXL we train both network, UNET and Text Encoder I think all parts? or I am missing something. Thank you

@kohya-ss
Copy link
Owner

Thank you for this PR! The changes are less than expected. I will check as soon as possible.

@KohakuBlueleaf
Copy link
Contributor Author

Thank you for this PR! The changes are less than expected. I will check as soon as possible.

I tested it a lot and basically what we did in the past is as same as FP8.
In the past we have fp16 base + fp32 network. now we just change to fp8 base.

The problem is also similar: you need autocast.

So some part of computing which doesn't use autocast may have problem but can be solved easily.
(Like cache TE, but I think it should not have problem. If you find some problem of it maybe we can consider to let user to enable autocast for cache TE procedure?)

@KohakuBlueleaf
Copy link
Contributor Author

thanks for reply. so it will train fp8 and then save as fp16 as usual?

Trainable part will not be converted to fp8

can you elaborate more? for example when training with DreamBooth of SDXL we train both network, UNET and Text Encoder I think all parts? or I am missing something. Thank you

This PR is for lora/lycoris/hypernetwork(losalina) training.
Which means the base model (Unet/TE) will be freezed. We only train the additional network, and the trainable part should be in higher precision like fp16/bf16/fp32. But since these trainable parts are small (compare to original Unet/TE), so it doesn't matter.

@laksjdjf
Copy link
Contributor

Hi, I have a question about this PR.
Is float8_e4m3fn better than float8_e5m2?

@KohakuBlueleaf
Copy link
Contributor Author

Hi, I have a question about this PR. Is float8_e4m3fn better than float8_e5m2?

Yes
Some paper even claimed e3m4 e2m5 are better

I choose e4m3 based on my experiments

If we have better scaling method on it, maybe we can consider e5m2, but since we don't use fp8 for computing in here, i think the better precision is more important

@laksjdjf
Copy link
Contributor

Thanks!

@sdbds
Copy link
Contributor

sdbds commented Jan 18, 2024

good job

@kohya-ss kohya-ss merged commit 9cfa68c into kohya-ss:dev Jan 20, 2024
1 check passed
@kohya-ss
Copy link
Owner

Thank you again for the great work!

kohya-ss added a commit that referenced this pull request Jan 20, 2024
Disty0 pushed a commit to Disty0/sd-scripts that referenced this pull request Jan 28, 2024
…rain_network (or sdxl_train_network) (kohya-ss#1057)

* Add fp8 support

* remove some debug prints

* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* better impl for convert TE to fp8

* fp8 for not only unet

* Better cache TE and TE lr

* match arg name

* Fix with list

* Add timeout settings

* Fix arg style

* Add custom seperator

* Fix typo

* Fix typo again

* Fix dtype error

* Fix gradient problem

* Fix req grad

* fix merge

* Fix merge

* Resolve merge

* arrangement and document

* Resolve merge error

* Add assert for mixed precision
Disty0 pushed a commit to Disty0/sd-scripts that referenced this pull request Jan 28, 2024
wkpark pushed a commit to wkpark/sd-scripts that referenced this pull request Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants