New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Isn't loss only supposed to be calculated on masked tokens? #14
Comments
I've attempted both strategies for a simple MaskGIT on CIFAR10 but the generation quality seems to still be bad. There are tricks that the authors are not telling us in the paper for their training scheme |
I have the same issue. Why loss was calculated on all tokens? |
@Lamikins I believe the training issues come from an error in the masking formula. I've ammended the error: #16. |
@EmaadKhwaja |
@xuesongnie it's because the mask calculated is applied to the wrong values. The other option would be to do |
Hi, bro. I find that poor performance after modifying |
In the training loop we have:
However, the output of the transformer is:
which means the returned target is the original unmasked image tokens.
The MaskGIT paper seems to suggest that loss was only calculated on the masked tokens
The text was updated successfully, but these errors were encountered: