-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
L2 attention is implemented wrong! #14
Comments
@PeterL1n thanks Peter! will get this all resolved this weekend did they end up using tied qk for their final model? |
@PeterL1n this is news to me that they are using the squared of the euclidean distance; i will reread the original paper, thank you! |
@PeterL1n if the token attends to itself, wouldn't it always have a distance of 0 and attend to itself the most? maybe it works out for their Lipschitz proof, but how does this make sense in the tied scenario? |
The paper only proved that L2 attention with tied qk is Lipschitz for self attention. It must be tied to be Lipschitz!. Also it is not Lipschitz for cross attention, that is why in GigaGAN's discriminator, only self-attention is used. However, they used self & cross in generator, knowing that generator can't be Lipschitz, there is no point in using L2 attention in the generator, so I believe they used regular dot product attention for the generator. You are correct that in the case of tied qk, token's self distance is always zero, thus always the most similar. So self value is always included in the attention. Other position can have close L2 distance to take away the proportion to self token. This is my understanding. |
Let's roll with that! Thank you Peter for the review 🙏 |
it would be euclidean distance squared, so it would have to be quite close. that is strange. just thinking out loud |
@PeterL1n do you want to see if 0.0.18 unblocks you for your research / startup? |
@PeterL1n i will get back to wiring up the training code soon later this month |
@PeterL1n reviewed the old deepmind paper and indeed it is squared distance! thanks for catching this and correcting my misunderstanding |
closing as it should be resolved, feel free to reopen if you note any further issues |
From paper
https://arxiv.org/pdf/2006.04710.pdf
First, token needs to attend itself to ensure Lipschitz!
Second,
torch.cdist
is not the correct way to do it. Follow the original paper.for tied qk
for separate qk
This is basically
torch.cdist().square()
, but more efficient and supports double backward for r1 regularization.Last, I believe the paper only used L2 self attention in discriminator. The generator should still use dot attention.
The text was updated successfully, but these errors were encountered: