diff --git a/horovod/spark/lightning/estimator.py b/horovod/spark/lightning/estimator.py index c039ffc751..f60608d4d7 100644 --- a/horovod/spark/lightning/estimator.py +++ b/horovod/spark/lightning/estimator.py @@ -206,9 +206,6 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable, profiler = Param(Params._dummy(), 'profiler', 'lightning profiler to use') - checkpoint_callback = Param(Params._dummy(), 'checkpoint_callback', - 'model checkpointing callback') - @keyword_only def __init__(self, num_proc=None, @@ -249,8 +246,7 @@ def __init__(self, data_module=None, loader_num_epochs=None, terminate_on_nan=False, - profiler=None, - checkpoint_callback=None): + profiler=None): super(TorchEstimator, self).__init__() self._setDefault(loss_constructors=None, @@ -264,8 +260,7 @@ def __init__(self, data_module=None, loader_num_epochs=None, terminate_on_nan=False, - profiler=None, - checkpoint_callback=None) + profiler=None) kwargs = self._input_kwargs @@ -338,12 +333,6 @@ def setTerminateOnNan(self, value): def getTerminateOnNan(self): return self.getOrDefault(self.terminate_on_nan) - def setCheckpointCallback(self, value): - return self._set(checkpoint_callback=value) - - def getCheckpointCallback(self): - return self.getOrDefault(self.checkpoint_callback) - def getProfiler(self): return self.getOrDefault(self.profiler) diff --git a/horovod/spark/lightning/remote.py b/horovod/spark/lightning/remote.py index 648ba68bd8..f4f3a76c06 100644 --- a/horovod/spark/lightning/remote.py +++ b/horovod/spark/lightning/remote.py @@ -53,7 +53,6 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro transformation = transformation_fn if transformation_fn else None inmemory_cache_all = estimator.getInMemoryCacheAll() callbacks = estimator.getCallbacks() or [] - checkpoint_callback = estimator.getCheckpointCallback() train_steps_per_epoch = estimator.getTrainStepsPerEpoch() val_steps_per_epoch = estimator.getValidationStepsPerEpoch() num_gpus = estimator.getNumGPUs() @@ -99,6 +98,7 @@ def train(serialized_model): import horovod.torch as hvd # Horovod: initialize library. hvd.init() + _checkpoint_callback = None with remote_store.get_local_output_dir() as run_output_dir: logs_path = os.path.join(run_output_dir, remote_store.logs_subdir) @@ -127,16 +127,15 @@ def train(serialized_model): # Lightning requires to add checkpoint callbacks for all ranks. # Otherwise we are seeing hanging in training. - _checkpoint_callback = checkpoint_callback + for i, cb in enumerate(callbacks): + if isinstance(cb, ModelCheckpoint): + _checkpoint_callback = cb + _checkpoint_callback.dir_path = ckpt_dir + _checkpoint_callback.filename = ckpt_filename + break if _checkpoint_callback: - _checkpoint_callback.dir_path = ckpt_dir - _checkpoint_callback.filename = ckpt_filename - else: - # By default 'monitor'=None which saves a checkpoint only for the last epoch. - _checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir, - filename=ckpt_filename, - verbose=True) - callbacks.append(_checkpoint_callback) + callbacks.pop(i) + callbacks.append(_checkpoint_callback) if remote_store.saving_runs and hvd.rank() == 0: # Horovod: sync checkpoint and logging files only on rank 0 to @@ -224,7 +223,10 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - remote_store.sync(logs_path) # rank 0 overwrites model with best checkpoint and returns. - best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path) + if _checkpoint_callback: + best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path) + else: + best_model = model serialized_checkpoint = io.BytesIO() module = best_model if not is_legacy else best_model._model