-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
attn_mask #8
Comments
I have the same question. It seems like the attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) is not right. Welcome to discuss. |
* (lucidrains#8) Inicia documentação Sphinx * (lucidrains#8) Realoca assets * (lucidrains#8) Hotfix + melhoria da organização do repositório * (lucidrains#8) Remove desnecessário Co-authored-by: edudsan <eduschusan@gmail.com> * (lucidrains#8) Containerização da biblioteca + Banco (lucidrains#8) Containerização da biblioteca + Banco --------- Co-authored-by: edudsan <eduschusan@gmail.com>
@skyerhxx This is not the causal mask, this is a mask that prevents CLS tokens from attending to PAD tokens in the batch. We add PAD tokens to the text batch since text examples have different length but the tensor has a fixed dimension, so to concat them into a batch tensor one must pad the end sequence with dummy token, i.e. a PAD token. However, since we append CLS token to the very end, it will attend to the entire sequence, including PAD tokens, which we don't want. So we mask them out. |
@pldlgb we only mask the last row of We don't need to mask other queries because we don't care what PAD queries attend to - they will be masked out when we compute CE loss. We also don't need to mask text queries since they are already masked by the causal mask so they can only look backwards at other text queries. |
Hello, I am confused of the implement of "attn_mask". I think this padding function only can mask the last row of "sim". Could you please explain it? Perhaps it's a very fool question. Thank you so much.
The text was updated successfully, but these errors were encountered: