diff --git a/horovod/spark/lightning/remote.py b/horovod/spark/lightning/remote.py index f4f180955c..2c083810c8 100644 --- a/horovod/spark/lightning/remote.py +++ b/horovod/spark/lightning/remote.py @@ -193,6 +193,8 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - 'gpus': _num_gpus, 'callbacks': callbacks, 'max_epochs': epochs, + 'limit_train_batches': _train_steps_per_epoch, + 'limit_val_batches': _val_steps_per_epoch, 'logger': train_logger, 'log_every_n_steps': log_every_n_steps, 'num_sanity_val_steps': 0,