diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 2845d9ce58f0..e65fc5fa9dd6 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -368,7 +368,7 @@ def filter_input_ids(self, input_ids, sentinel_ids): batch_size = input_ids.shape[0] input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) - input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1)) + input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1)) input_ids = np.concatenate( [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1 )