Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Percevier_IO mask #49

Open
xesdiny opened this issue Sep 13, 2021 · 2 comments
Open

Percevier_IO mask #49

xesdiny opened this issue Sep 13, 2021 · 2 comments

Comments

@xesdiny
Copy link

xesdiny commented Sep 13, 2021

Hi,guy.I learing you code to attention mask code in perceiver io.

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h = h) # mask (b,h,1,j)
            sim.masked_fill_(~mask, max_neg_value)

If it is the step of encode, it can be well understood as whether the input information in the mask is mapped to the hidden space, but in the decode part, the logic of this code is not explained.
Mask in lantent space array mapping to Output array in lantent dimension, What's means?

@lucidrains
Copy link
Owner

@Nintorac
Copy link

Nintorac commented Sep 19, 2021

Hi, not sure if I'm missing a key detail here but from what I can see the mask in this implementation would not work like a normal transformer.

The mask as applied here allows you to control which latents get information from which part of the input sequence (i.e the mask is b x n_latents x src_seq_len.

To match existing transformers concept of an attention mask (or at least the one used in autoregressive LM's) the mask would need to be b x trg_seq_len x src_seq_len. You would then need a seperate set of latents for each unique row in the mask wrt the trg_seq_len.

something like

src_seq_len = 3
trg_seq_len = 5
data = torch.randn(batch_size, src_seq_len, features)
queries = torch.randn(batch_size, trg_seq_len, features)

mask = torch.tesor([
[1, 0, 0],
[1, 0, 0],
[1, 1, 0],
[1, 1, 0],
[1, 1, 1]
])

x = repeat(self.latents, 'n d -> b trg_len n d', b = b, trg_len = trg_seq_len)

x = cross_attention(x, data, mask) # b x trg_seq_len x latent x features
# we now have trg_seq_len sets of latent values, each latent has no information about seq items it has been masked from

for layers in ....:
   # stuff

latents = self.decoder_cross_attn(queries, context = x)

Does this makes sense, am I missing something?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants