Skip to content

Commit

Permalink
Adjust *is_causal documentation
Browse files Browse the repository at this point in the history
According to pytorch#97214, as discussed in pytorch#97166.
  • Loading branch information
janEbert committed Apr 4, 2023
1 parent 4934dde commit 92a350d
Showing 1 changed file with 41 additions and 10 deletions.
51 changes: 41 additions & 10 deletions torch/nn/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,13 @@ def forward(
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
is_causal: If specified, applies a causal mask as mask (optional)
and ignores attn_mask for computing scaled dot product attention.
is_causal: If specified, applies a causal mask as mask.
Default: ``False``.
Warning:
``is_causal`` provides a hint that ``mask`` is the
causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
Expand Down Expand Up @@ -361,9 +365,20 @@ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
tgt_is_causal: If specified, applies a causal mask as tgt mask.
Mutually exclusive with providing tgt_mask. Default: ``False``.
memory_is_causal: If specified, applies a causal mask as memory mask.
Mutually exclusive with providing memory_mask. Default: ``False``.
Default: ``False``.
Warning:
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
the causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
memory_is_causal: If specified, applies a causal mask as
memory mask.
Default: ``False``.
Warning:
``memory_is_causal`` provides a hint that
``memory_mask`` is the causal mask. Providing incorrect
hints can result in incorrect execution, including
forward and backward compatibility.
Shape:
see the docs in Transformer class.
Expand Down Expand Up @@ -493,8 +508,13 @@ def forward(
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
is_causal: If specified, applies a causal mask as src_mask.
Default: ``False``.
is_causal: If specified, applies a causal mask as src mask.
Default: ``False``.
Warning:
``is_causal`` provides a hint that ``src_mask`` is the
causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
Expand Down Expand Up @@ -706,9 +726,20 @@ def forward(
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
tgt_is_causal: If specified, applies a causal mask as tgt mask.
Mutually exclusive with providing tgt_mask. Default: ``False``.
memory_is_causal: If specified, applies a causal mask as memory mask.
Mutually exclusive with providing memory_mask. Default: ``False``.
Default: ``False``.
Warning:
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
the causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
memory_is_causal: If specified, applies a causal mask as
memory mask.
Default: ``False``.
Warning:
``memory_is_causal`` provides a hint that
``memory_mask`` is the causal mask. Providing incorrect
hints can result in incorrect execution, including
forward and backward compatibility.
Shape:
see the docs in Transformer class.
"""
Expand Down

0 comments on commit 92a350d

Please sign in to comment.