-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Resolve v_prediction issue for min-SNR gamma weighted loss function #5096
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to merge when the comments are addressed.
Thanks for spotting and fixing.
…ut the application of the epsilon code to sample prediction
|
@sayakpaul thanks. done |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better readability now. Thanks much!
|
Let's fix the code quality issues here. |
|
@sayakpaul ah, i'm used to pre-commit hooks doing this for me :D done |
|
Failing tests are flaky! |
|
|
||
| if noise_scheduler.config.prediction_type == "v_prediction": | ||
| # Velocity objective needs to be floored to an SNR weight of one. | ||
| mse_loss_weights = base_weight + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we do:
mse_loss_weights = max(base_weight, 1) if we want to "floor" the SNR weight?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max would be ceil, here we're just adding one to ensure a value of zero is a minimum of 1. we don't want to blindly overwrite minimal values with 1's, we want to scale all of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mse_loss_weights = max(base_weight, 1) in that case works more like floor - e.g. mse_loos_weights is equal to at least 1 or higher.
Can you explain in a bit more detail why we want to add 1 here all the time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
certainly. when using v-prediction, there is an infinite SNR at the final timestep. if you do not floor the weight of this to 1, it will lead to a NaN loss value and the model fails to produce useable predictions.
this was observed in the original min-SNR paper, and has since been implemented in the reference code upstream.
…uggingface#5096) * Resolve v_prediction issue for min-SNR gamma weighted loss function * Combine MSE loss calculation of epsilon and velocity, with a note about the application of the epsilon code to sample prediction * style --------- Co-authored-by: bghira <bghira@users.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

What does this PR do?
Flooring the SNR weights to 1 results in a solution for the Inf position that arises with a zero-terminal SNR model.
As such, we can safely remove this check and the parameter can be even used by default with the value of 5.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.