Skip to content

Commit

Permalink
Fix Torchscript typing in transformer_encoder.py (#4847)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhxchen17 committed Nov 8, 2022
1 parent 59d966a commit b8ac3fa
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion fairseq/models/transformer/transformer_encoder.py
Expand Up @@ -202,9 +202,12 @@ def forward_scriptable(
"""
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
has_pads: Tensor = (
has_pads = (
torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any()
)
# Torchscript doesn't handle bool Tensor correctly, so we need to work around.
if torch.jit.is_scripting():
has_pads = torch.tensor(1) if has_pads else torch.tensor(0)

x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)

Expand Down

0 comments on commit b8ac3fa

Please sign in to comment.