Skip to content

Commit

Permalink
Fix attn_mask shape in MHA docstrings (#441)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #441

n/a

Reviewed By: ankitade, pikapecan

Differential Revision: D47998161

fbshipit-source-id: 48804597c69979f76bcad92ade24a632dfda7f9c
  • Loading branch information
ebsmothers authored and facebook-github-bot committed Aug 2, 2023
1 parent 1aa2ed2 commit 81e281c
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions torchmultimodal/modules/layers/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def forward(
"""
Args:
query (Tensor): input query of shape bsz x seq_len x embed_dim
attn_mask (optional Tensor): attention mask of shape bsz x seq_len x seq_len. Two types of masks are supported.
attn_mask (optional Tensor): attention mask of shape bsz x num_heads x seq_len x seq_len.
Note that the num_heads dimension can equal 1 and the mask will be broadcasted to all heads.
Two types of masks are supported.
A boolean mask where a value of True indicates that the element should take part in attention.
A float mask of the same type as query that is added to the attention score.
is_causal (bool): If true, does causal attention masking. attn_mask should be set to None if this is set to True
Expand Down Expand Up @@ -124,7 +126,8 @@ def forward(
query (Tensor): input query of shape bsz x target_seq_len x embed_dim
key (Tensor): key of shape bsz x source_seq_len x embed_dim
value (Tensor): value of shape bsz x source_seq_len x embed_dim
attn_mask (optional Tensor): Attention mask of shape bsz x target_seq_len x source_seq_len.
attn_mask (optional Tensor): Attention mask of shape bsz x num_heads x target_seq_len x source_seq_len.
Note that the num_heads dimension can equal 1 and the mask will be broadcasted to all heads.
Two types of masks are supported. A boolean mask where a value of True
indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
Expand Down

0 comments on commit 81e281c

Please sign in to comment.