diff --git a/megatron/training.py b/megatron/training.py index 592a1d4f6..a518923f0 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1071,7 +1071,7 @@ def build_train_valid_test_data_iterators( if test_ds is not None else [] # Flags to know if we need to do training/validation/testing. - do_train = (train_dataloader is not None and args.train_iters > 0) and not args.eval_only + do_train = train_dataloader is not None and args.train_iters > 0 and not args.eval_only # Need to broadcast num_tokens and num_type_tokens. flags = torch.cuda.LongTensor([