Skip to content

Commit

Permalink
fix: causal mask requires less position embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
borisdayma committed Feb 8, 2022
1 parent 7bb8c97 commit 1a5e914
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/bart/modeling_flax_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,11 +1478,11 @@ def prepare_inputs_for_generation(
# initializing the cache
batch_size, seq_length = decoder_input_ids.shape

past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4")
if decoder_attention_mask is not None:
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
Expand Down

0 comments on commit 1a5e914

Please sign in to comment.