Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDXL Training #1

Closed
scarbain opened this issue Dec 7, 2023 · 4 comments
Closed

SDXL Training #1

scarbain opened this issue Dec 7, 2023 · 4 comments

Comments

@scarbain
Copy link

scarbain commented Dec 7, 2023

Well, congrats! I'm really impressed with the generations being so well grounded without any extra network or anything !
And you only finetuned SD1.4 for 24K steps ? That was a pretty fast training right ?

Have you considered training SDXL ? I didn't see anything about it in your paper !
Do you have recommendations on the parameters to use for training SDXL ?

@zwcolin
Copy link
Collaborator

zwcolin commented Dec 7, 2023

Hi Sébastien, finetuning SD1.4 is pretty fast and can be done with a single RTX 3090 (with a batch size 1, grad. accumulation of 4 and grad. checkpointing).

We also provide SD 2.1 checkpoints finetuned on 768*768 resolutions on A6000 GPUs (requires more than >24gb vram).

Unfortunately we don't have checkpoints trained on SDXL because it's too large and likely requires GPUs such as A100. Our training pipeline requires more vram usage compared to a typical training pipeline because we have additional objectives on the cross-attention map.

If you want to train SDXL, my guess would be:
(1) increase the number of optimization steps (we observe that we need to increase this when we train 2.1 compared to 1.4)
(2) change the grounding loss ratio, e.g., $\lambda$ and $\gamma$. Essentially you want the token loss to decrease as much as possible without compromising the denosing objective. At the same time, you want the pixel loss to maintain constant or slowly decrease as the model is finetuned with the token loss in order to preserve good image quality and grounding capabilities. You can experiment with different hyperparameters by observing the training loss curve. Empirically, we use the same set of $\lambda$ and $\gamma$ for training both SD 1.4 and 2.1, which you can use as a starting point too if you want to experiment with SDXL!

@scarbain
Copy link
Author

scarbain commented Dec 7, 2023

Thanks for your insights ! I'm currently trying to reproduce your training on a finetune of SD1.5 but I'm having some errors about missing text keys in the data, I'll keep digging !

About SDXL, I'll make some tests but if you're willing to work with me on this, I can probably provide cloud GPUs (depending on the price it would cost of course). We could then opensource it, I'm seeing a lot of value to this!

@scarbain
Copy link
Author

scarbain commented Dec 7, 2023

Oh, I just checked your license and it's non-commercial

@zwcolin
Copy link
Collaborator

zwcolin commented Dec 13, 2023

Hi Sébastien, our research lab currently adopts a non-commercial license for research projects so it might not be suitable to directly use our codebase for commercial products. I'm closing this issue for now. If you have any further questions, feel free to start a new one or reopen this issue. Thank you!

@zwcolin zwcolin closed this as completed Dec 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants