diff --git a/examples/bert/run_pretraining.py b/examples/bert/run_pretraining.py index 5bfaf723cf..f9c95a1290 100644 --- a/examples/bert/run_pretraining.py +++ b/examples/bert/run_pretraining.py @@ -382,6 +382,7 @@ def main(_): num_parallel_calls=tf.data.experimental.AUTOTUNE, ) dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) + dataset = dataset.repeat() # Create a BERT model the input config. model = BertModel(