Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question regarding queries dimensionality in Perceiver IO #53

Open
pcicales opened this issue Oct 9, 2021 · 3 comments
Open

Question regarding queries dimensionality in Perceiver IO #53

pcicales opened this issue Oct 9, 2021 · 3 comments

Comments

@pcicales
Copy link

pcicales commented Oct 9, 2021

Hi @lucidrains,

I think I may be missing something - why do we define the perceiver IO queries vector to have a batch dimension (i.e. queries = torch.randn(1, 128, 32))? Was this just to make the code work nicely? Shouldnt we be using queries = torch.randn(128, 32) ? I expect to use the same embedding for all of my batch elements, which is IIUC what your code is doing.

@lucidrains
Copy link
Owner

@pcicales Hi Pietro! That is correct - I was envisioning that a batch of encoded inputs could be used as decoding queries, but I think also decoding with queries shared across the batch will work perfectly fine

@lucidrains
Copy link
Owner

@pcicales ok done! 2a1b039

@pcicales
Copy link
Author

pcicales commented Oct 10, 2021

Hey @lucidrains ! Thanks so much - I also implemented the following solution (keep in mind I believe this is not 100% consistent with the paper, and so if you chose to incorporate it, it could be an optional feature in your code):

I simply generate my queries with a linear layer, taking the original input signal as the input. The idea is that by doing this, we preserve some of the low level information from the input signal that may be lost through the various self-attention operations - this should allow us to use a deeper perceiver, reducing the risk of some of the performance issues that other users have associated to the architecture. I will have to experiment with this concept when I have more time, but it seems to be working well for my applications. This also has the added benefit of doing what you suggest in your reply, where there is an embedding per batch element. It does however reduce the flexibility of the output dimensionality associated to the perceiver IO - this could be remedied with convs, but doing that was outside the scope of my current project :)

Here is what the code would look like in your example, using code from your previous commit:

import torch
from perceiver_pytorch import PerceiverIO

model = PerceiverIO(
    dim = 32,                    # dimension of sequence to be encoded
    queries_dim = 16,            # dimension of decoder queries
    logits_dim = 100,            # dimension of final logits
    depth = 6,                   # depth of net
    num_latents = 256,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,         # number of dimensions per cross attention head
    latent_dim_head = 64,        # number of dimensions per latent self attention head
    weight_tie_layers = False    # whether to weight tie layers (optional, as indicated in the diagram)
)

seq = torch.randn(1, 512, 32) # our input sequence, (batch, encoder seq, dim)
queries_gen = torch.nn.Linear(32, 16) # (dim, queries_dim)
queries = queries_gen(seq)

logits = model(seq, queries = queries) # (1, 512, 100) - (batch, decoder seq, logits dim)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants