Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

L2 attention is implemented wrong! #14

Closed
PeterL1n opened this issue Jun 12, 2023 · 10 comments
Closed

L2 attention is implemented wrong! #14

PeterL1n opened this issue Jun 12, 2023 · 10 comments

Comments

@PeterL1n
Copy link

PeterL1n commented Jun 12, 2023

From paper
https://arxiv.org/pdf/2006.04710.pdf

First, token needs to attend itself to ensure Lipschitz!

image

Second, torch.cdist is not the correct way to do it. Follow the original paper.

image

for tied qk

AB = torch.matmul(qk, qk.transpose(-1, -2))
AA = torch.sum(qk ** 2, -1, keepdim=True)
BB = AA.transpose(-1, -2)    # Since query and key are tied.
attn = -(AA - 2 * AB + BB)
attn = attn.mul(self.scale).softmax(-1)

for separate qk

AB = torch.matmul(q, k.transpose(-1, -2))
AA = torch.sum(q ** 2, -1, keepdim=True)
BB = torch.sum(k ** 2, -1, keepdim=True).transpose(-1, -2)
attn = -(AA - 2 * AB + BB)
attn = attn.mul(self.scale).softmax(-1)

This is basically torch.cdist().square(), but more efficient and supports double backward for r1 regularization.

Last, I believe the paper only used L2 self attention in discriminator. The generator should still use dot attention.

@PeterL1n PeterL1n changed the title L2 self attention needs to attent itself L2 attention is implemented wrong! Jun 13, 2023
@lucidrains
Copy link
Owner

@PeterL1n thanks Peter! will get this all resolved this weekend

did they end up using tied qk for their final model?

@lucidrains
Copy link
Owner

@PeterL1n this is news to me that they are using the squared of the euclidean distance; i will reread the original paper, thank you!

@lucidrains
Copy link
Owner

@PeterL1n if the token attends to itself, wouldn't it always have a distance of 0 and attend to itself the most? maybe it works out for their Lipschitz proof, but how does this make sense in the tied scenario?

@PeterL1n
Copy link
Author

PeterL1n commented Jun 14, 2023

The paper only proved that L2 attention with tied qk is Lipschitz for self attention. It must be tied to be Lipschitz!. Also it is not Lipschitz for cross attention, that is why in GigaGAN's discriminator, only self-attention is used. However, they used self & cross in generator, knowing that generator can't be Lipschitz, there is no point in using L2 attention in the generator, so I believe they used regular dot product attention for the generator.

You are correct that in the case of tied qk, token's self distance is always zero, thus always the most similar. So self value is always included in the attention. Other position can have close L2 distance to take away the proportion to self token.

This is my understanding.

@lucidrains
Copy link
Owner

Let's roll with that! Thank you Peter for the review 🙏

@lucidrains
Copy link
Owner

lucidrains commented Jun 14, 2023

it would be euclidean distance squared, so it would have to be quite close. that is strange. just thinking out loud

lucidrains added a commit that referenced this issue Jun 14, 2023
@lucidrains
Copy link
Owner

lucidrains commented Jun 14, 2023

@PeterL1n do you want to see if 0.0.18 unblocks you for your research / startup?

@lucidrains
Copy link
Owner

@PeterL1n i will get back to wiring up the training code soon later this month

lucidrains added a commit that referenced this issue Jun 14, 2023
@lucidrains
Copy link
Owner

@PeterL1n reviewed the old deepmind paper and indeed it is squared distance! thanks for catching this and correcting my misunderstanding

@lucidrains
Copy link
Owner

closing as it should be resolved, feel free to reopen if you note any further issues

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants