Skip to content

Commit

Permalink
Make padding_mask Optional in SequenceBatch. (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
kauterry committed Oct 17, 2023
1 parent 8442edc commit e9abbb3
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/fairseq2/models/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class SequenceBatch:
size, :math:`S` is the sequence length, and :math:`*` is any number of
sequence-specific dimensions including none."""

padding_mask: PaddingMask
padding_mask: Optional[PaddingMask]
"""The padding mask of ``seqs``. *Shape:* :math:`(N,S)`, where :math:`N` is
the batch size and :math:`S` is the sequence length."""

Expand All @@ -51,6 +51,9 @@ def batch_size(self) -> int:

def compute_num_tokens(self) -> Tensor:
"""Compute the number of tokens in this batch."""
if self.padding_mask is None:
return torch.full((), self.seqs.numel(), device=self.seqs.device)

return self.padding_mask.seq_lens.sum()


Expand Down

0 comments on commit e9abbb3

Please sign in to comment.