Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Examples] Add support for Min-SNR weighting strategy for better convergence #2899

Merged
merged 27 commits into from
Apr 6, 2023

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Mar 30, 2023

[1] introduces the Min-SNR weighting strategy to rebalance the loss when training the diffusion model for faster convergence. The authors attribute the difficulty in getting diffusion models to converge faster to varying degrees of timesteps in the noise scheduling process. So, they introduce a simple way to balance the losses of the individual samples.

This PR refers to [1] and [2] to incorporate the Min-SNR weighting strategy in the text-to-image fine-tuning script. I believe this rebalancing can be incorporated in other examples too, where we fine-tune the Diffusion model (i.e., the UNet).

My experimentation results are available here: https://wandb.ai/sayakpaul/text2image-finetune-minsnr

Overall, this strategy helps to keep the loss surface less bumpy: https://wandb.ai/sayakpaul/text2image-finetune-minsnr/reports/train_loss-23-04-04-08-49-34---VmlldzozOTY3ODQ2

The number of training samples is definitely a factor that can make the effect of this strategy less pronounced. But, overall, I think it's nice to have as it's directly related to overcoming the training instabilities of diffusion models.

@patil-suraj if we're okay with this change, I will add a section on it in the README and also in https://huggingface.co/docs/diffusers/training/text2image. Let me know.

References

[1] Paper reference: https://arxiv.org/abs/2303.09556
[2] Code: https://github.com/TiankaiHang/Min-SNR-Diffusion-Training

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 30, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj can you take a look here?

@sayakpaul
Copy link
Member Author

@patil-suraj can you take a look here?

Actually, I would suggest not to. I am still gathering evidence to see if the PR is worth merging or even reviewing. After I am done, I will update here. Till then, don't worry about it.

@sayakpaul sayakpaul marked this pull request as ready for review April 4, 2023 03:14
return fn


def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this related to this PR title or does it just add logging?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR could have been without this utility but it was important to have it in the PR because otherwise, it was difficult to validate the effectiveness of the method.

FWIW, I am not a fan of adding unrelated changes in a PR but this seemed important.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me!

@patrickvonplaten
Copy link
Contributor

Would be great to stick to a more imperative coding style which we prefer in libraries like diffusers and transformers (https://stackoverflow.com/questions/21895525/python-programming-functional-vs-imperative-code)

I don't think it's very easy to follow if a function returns another function etc... this is a bit too jaxy/TF-like to me

@sayakpaul
Copy link
Member Author

@patrickvonplaten the latest changes should have addressed all your concerns. PTAL.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding support for this! The loss curve does look smoother than without using min-snr-weighting. Just left a small comment about v-prediction.

It would also be cool to add this train_unconditional.py, it would be easy to verify this there since we train from scratch.

examples/text_to_image/train_text_to_image.py Outdated Show resolved Hide resolved
Comment on lines +838 to +846
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines +834 to +835
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some models (sd2.1 and above) use v-prediction, does this formulation also work with v-prediction?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul
Copy link
Member Author

It would also be cool to add this train_unconditional.py, it would be easy to verify this there since we train from scratch.

Will run an experiment and add a PR for that.

Thanks for approving. Now, I will:

add a section on it in the README and also in https://huggingface.co/docs/diffusers/training/text2image.

@sayakpaul
Copy link
Member Author

@patil-suraj when you get a moment could you review the changes introduced in 96e7254? All of them are related to documentation.

I think then we can merge. @patrickvonplaten maybe you also want to take a look.

}


def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this method as a part of the PR as well. Handles EMA offload and unload properly to ensure inference is being done with the EMA'd checkpoints.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding the doc. Looks good!

docs/source/en/training/text2image.mdx Outdated Show resolved Hide resolved

We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence
by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended
value when using it is 5.0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, where is this value proposed? Is there a rule of thumb when choosing a value for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's reported in the paper. A gamma of 5.0 always leads to better results in the experiments presented by the authors in the paper.


<Tip warning={true}>

Training with Min-SNR weighting strategy is only supported in PyTorch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for future PR: Could be cool to add this in jax as well, will be useful for the jax event.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu could you take a look?

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Thanks for iterating

@sayakpaul sayakpaul merged commit 2494731 into main Apr 6, 2023
@sayakpaul sayakpaul deleted the feat/better-convergence branch April 6, 2023 13:38
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
…ergence (huggingface#2899)

* improve stable unclip doc.

* feat: support for applying min-snr weighting for faster convergence.

* add: support for validation logging with wandb

* make  not a required arg.

* fix: arg name.

* fix: cli args.

* fix: tracker config.

* fix: loss calculation.

* fix: validation logging.

* fix: unwrap call.

* fix: validation logging.

* fix: internval.

* fix: checkpointing push to hub.

* fix: https://github.com/huggingface/diffusers/commit/c8a2856c6d5e45577bf4c24dee06b1a4a2f5c050\#commitcomment-106913193

* fix: norm group test for UNet3D.

* address PR comments.

* remove unneeded code.

* add: entry in the readme and docs.

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

---------

Co-authored-by: Suraj Patil <surajp815@gmail.com>
dg845 pushed a commit to dg845/diffusers that referenced this pull request May 6, 2023
…ergence (huggingface#2899)

* improve stable unclip doc.

* feat: support for applying min-snr weighting for faster convergence.

* add: support for validation logging with wandb

* make  not a required arg.

* fix: arg name.

* fix: cli args.

* fix: tracker config.

* fix: loss calculation.

* fix: validation logging.

* fix: unwrap call.

* fix: validation logging.

* fix: internval.

* fix: checkpointing push to hub.

* fix: https://github.com/huggingface/diffusers/commit/c8a2856c6d5e45577bf4c24dee06b1a4a2f5c050\#commitcomment-106913193

* fix: norm group test for UNet3D.

* address PR comments.

* remove unneeded code.

* add: entry in the readme and docs.

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

---------

Co-authored-by: Suraj Patil <surajp815@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…ergence (huggingface#2899)

* improve stable unclip doc.

* feat: support for applying min-snr weighting for faster convergence.

* add: support for validation logging with wandb

* make  not a required arg.

* fix: arg name.

* fix: cli args.

* fix: tracker config.

* fix: loss calculation.

* fix: validation logging.

* fix: unwrap call.

* fix: validation logging.

* fix: internval.

* fix: checkpointing push to hub.

* fix: https://github.com/huggingface/diffusers/commit/c8a2856c6d5e45577bf4c24dee06b1a4a2f5c050\#commitcomment-106913193

* fix: norm group test for UNet3D.

* address PR comments.

* remove unneeded code.

* add: entry in the readme and docs.

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

---------

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants