Skip to content

Commit

Permalink
Implement is_causal API for Transformer
Browse files Browse the repository at this point in the history
As discussed in pytorch#97166.
  • Loading branch information
janEbert committed Apr 4, 2023
1 parent 92a350d commit bdbbdbf
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions torch/nn/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int =

def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
src_is_causal: bool = False, tgt_is_causal: bool = False, memory_is_causal: bool = False) -> Tensor:
r"""Take in and process masked source/target sequences.
Args:
Expand All @@ -97,6 +98,28 @@ def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, t
src_key_padding_mask: the Tensor mask for src keys per batch (optional).
tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
src_is_causal: If specified, applies a causal mask as tgt mask.
Default: ``False``.
Warning:
``src_is_causal`` provides a hint that ``src_mask`` is
the causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
tgt_is_causal: If specified, applies a causal mask as tgt mask.
Default: ``False``.
Warning:
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
the causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
memory_is_causal: If specified, applies a causal mask as
memory mask.
Default: ``False``.
Warning:
``memory_is_causal`` provides a hint that
``memory_mask`` is the causal mask. Providing incorrect
hints can result in incorrect execution, including
forward and backward compatibility.
Shape:
- src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
Expand Down Expand Up @@ -142,10 +165,12 @@ def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, t
if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
raise RuntimeError("the feature number of src and tgt must be equal to d_model")

memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask,
is_causal=src_is_causal)
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
memory_key_padding_mask=memory_key_padding_mask,
tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
return output

@staticmethod
Expand Down

0 comments on commit bdbbdbf

Please sign in to comment.