Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In AutoregressiveWrapper, if no attention mask is supplied, create a lower triangular one #74

Conversation

galatolofederico
Copy link

Hi @lucidrains
Thank you for your amazing work with all of your repositories!
I don't know if this behavior fits the minimal philosophy of this implementation but usually when training in an autoregressive fashion the future tokens are masked to prevent the transformer to "see in the future".
I added a default lower triangular attention mask in the AutoregressiveWrapper forward logic to implement this idea.
I tested it in a decoder-only architecture like the one from the enwik8 example and it works.
Reading the code it should work in a encoder-decoder architecture with cross_attend = True too but i haven't tested it.

@lucidrains
Copy link
Owner

Hey! Yup, that's already taken care of automatically, if you set causal to True

https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L621

@galatolofederico
Copy link
Author

Ops my bad i didn't notice that.
Maybe it can be an idea to set causal=True in the enwik8 example since that is an example of a causal language modelling?

@lucidrains
Copy link
Owner

As long as you use the Decoder class, it is already set for you!

@galatolofederico
Copy link
Author

Ok got it thanks! So in theory passing the attention mask as i did in the PR should have no effect at all right?

@galatolofederico
Copy link
Author

I was experimenting with a decoder-only TransformerWrapper trained using the AutoregressiveWrapper and i get very different results.
image
The yellow one is without any mask and the orange one is with the lower triangular attention mask

this is the network i used that is basically the same network of the documentation

self.transformer = TransformerWrapper(
    num_tokens=self.vqvae.hparams.codebook_size,
    max_seq_len=self.hparams.signal_size // self.hparams.window_size,
    attn_layers = Decoder(
        dim=self.hparams.transformer_dim,
        depth=self.hparams.transformer_depth,
        heads=self.hparams.transformer_heads,
        rel_pos_bias=True
    )
)
self.transformer = AutoregressiveWrapper(self.transformer)

@lucidrains
Copy link
Owner

Yup it should be the same, I'm not sure why you are seeing that!

@galatolofederico
Copy link
Author

Lets leave the PR open for now, in the next days as soon as i will have some free time i will try to write some replicable MWE so that we can figure out if there is something weird going on or not

@lucidrains
Copy link
Owner

lucidrains commented Feb 1, 2022

@galatolofederico if you use the Decoder, it is pretty unlikely there is a bug with the causal mask

the repository is used in a lot of different labs, and even used in production at some consultancy companies

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants