-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question regarding queries dimensionality in Perceiver IO #53
Comments
@pcicales Hi Pietro! That is correct - I was envisioning that a batch of encoded inputs could be used as decoding queries, but I think also decoding with queries shared across the batch will work perfectly fine |
Hey @lucidrains ! Thanks so much - I also implemented the following solution (keep in mind I believe this is not 100% consistent with the paper, and so if you chose to incorporate it, it could be an optional feature in your code): I simply generate my queries with a linear layer, taking the original input signal as the input. The idea is that by doing this, we preserve some of the low level information from the input signal that may be lost through the various self-attention operations - this should allow us to use a deeper perceiver, reducing the risk of some of the performance issues that other users have associated to the architecture. I will have to experiment with this concept when I have more time, but it seems to be working well for my applications. This also has the added benefit of doing what you suggest in your reply, where there is an embedding per batch element. It does however reduce the flexibility of the output dimensionality associated to the perceiver IO - this could be remedied with convs, but doing that was outside the scope of my current project :) Here is what the code would look like in your example, using code from your previous commit:
|
Hi @lucidrains,
I think I may be missing something - why do we define the perceiver IO queries vector to have a batch dimension (i.e.
queries = torch.randn(1, 128, 32)
)? Was this just to make the code work nicely? Shouldnt we be usingqueries = torch.randn(128, 32)
? I expect to use the same embedding for all of my batch elements, which is IIUC what your code is doing.The text was updated successfully, but these errors were encountered: