From dadca5354e326b9181c5d494df8fc6d6b5a6cbcb Mon Sep 17 00:00:00 2001 From: chongxiaoc <74630762+chongxiaoc@users.noreply.github.com> Date: Wed, 6 Oct 2021 19:05:28 -0700 Subject: [PATCH] Spark/Lightning: don't overwrite model with checkpoint by default (#3201) Lightning estimator saves model by default if there is no specified checkpoint callback. However, model is not overwritten with checkpoint file in that case. Signed-off-by: Chongxiao Cao --- .../pytorch/pytorch_lightning_spark_mnist.py | 7 +++-- horovod/spark/lightning/estimator.py | 15 ++-------- horovod/spark/lightning/remote.py | 28 +++++++++++++------ 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/examples/spark/pytorch/pytorch_lightning_spark_mnist.py b/examples/spark/pytorch/pytorch_lightning_spark_mnist.py index d1d164dab3..369a4bd0e6 100644 --- a/examples/spark/pytorch/pytorch_lightning_spark_mnist.py +++ b/examples/spark/pytorch/pytorch_lightning_spark_mnist.py @@ -177,14 +177,15 @@ def on_train_end(self, trainer, model): # added EarlyStopping and ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint - callbacks.append(ModelCheckpoint(dirpath=args.work_dir)) + callbacks.append(ModelCheckpoint(monitor='val_loss', mode="min", + save_top_k=1, verbose=True)) from pytorch_lightning.callbacks.early_stopping import EarlyStopping callbacks.append(EarlyStopping(monitor='val_loss', - min_delta=0.00, + min_delta=0.001, patience=3, verbose=True, - mode='max')) + mode='min')) torch_estimator = hvd.TorchEstimator(backend=backend, store=store, diff --git a/horovod/spark/lightning/estimator.py b/horovod/spark/lightning/estimator.py index c039ffc751..f60608d4d7 100644 --- a/horovod/spark/lightning/estimator.py +++ b/horovod/spark/lightning/estimator.py @@ -206,9 +206,6 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable, profiler = Param(Params._dummy(), 'profiler', 'lightning profiler to use') - checkpoint_callback = Param(Params._dummy(), 'checkpoint_callback', - 'model checkpointing callback') - @keyword_only def __init__(self, num_proc=None, @@ -249,8 +246,7 @@ def __init__(self, data_module=None, loader_num_epochs=None, terminate_on_nan=False, - profiler=None, - checkpoint_callback=None): + profiler=None): super(TorchEstimator, self).__init__() self._setDefault(loss_constructors=None, @@ -264,8 +260,7 @@ def __init__(self, data_module=None, loader_num_epochs=None, terminate_on_nan=False, - profiler=None, - checkpoint_callback=None) + profiler=None) kwargs = self._input_kwargs @@ -338,12 +333,6 @@ def setTerminateOnNan(self, value): def getTerminateOnNan(self): return self.getOrDefault(self.terminate_on_nan) - def setCheckpointCallback(self, value): - return self._set(checkpoint_callback=value) - - def getCheckpointCallback(self): - return self.getOrDefault(self.checkpoint_callback) - def getProfiler(self): return self.getOrDefault(self.profiler) diff --git a/horovod/spark/lightning/remote.py b/horovod/spark/lightning/remote.py index 648ba68bd8..35800b79ae 100644 --- a/horovod/spark/lightning/remote.py +++ b/horovod/spark/lightning/remote.py @@ -53,7 +53,6 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro transformation = transformation_fn if transformation_fn else None inmemory_cache_all = estimator.getInMemoryCacheAll() callbacks = estimator.getCallbacks() or [] - checkpoint_callback = estimator.getCheckpointCallback() train_steps_per_epoch = estimator.getTrainStepsPerEpoch() val_steps_per_epoch = estimator.getValidationStepsPerEpoch() num_gpus = estimator.getNumGPUs() @@ -99,6 +98,8 @@ def train(serialized_model): import horovod.torch as hvd # Horovod: initialize library. hvd.init() + _checkpoint_callback = None + require_checkpoint = False with remote_store.get_local_output_dir() as run_output_dir: logs_path = os.path.join(run_output_dir, remote_store.logs_subdir) @@ -115,6 +116,7 @@ def train(serialized_model): elif isinstance(logger, CometLogger) and logger._experiment_key is None: # Resume logger experiment key if passed correctly from CPU. train_logger = CometLogger( + save_dir=logs_path, api_key=logger.api_key, experiment_key=logger_experiment_key, ) @@ -123,20 +125,24 @@ def train(serialized_model): else: # use logger passed in. train_logger = logger + train_logger.save_dir = logs_path print(f"Setup logger: Using logger passed from estimator: {train_logger}") # Lightning requires to add checkpoint callbacks for all ranks. # Otherwise we are seeing hanging in training. - _checkpoint_callback = checkpoint_callback - if _checkpoint_callback: - _checkpoint_callback.dir_path = ckpt_dir - _checkpoint_callback.filename = ckpt_filename - else: + for cb in callbacks: + if isinstance(cb, ModelCheckpoint): + cb.dir_path = ckpt_dir + cb.filename = ckpt_filename + _checkpoint_callback = cb + require_checkpoint = True + break + if not _checkpoint_callback: # By default 'monitor'=None which saves a checkpoint only for the last epoch. _checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir, filename=ckpt_filename, verbose=True) - callbacks.append(_checkpoint_callback) + callbacks.append(_checkpoint_callback) if remote_store.saving_runs and hvd.rank() == 0: # Horovod: sync checkpoint and logging files only on rank 0 to @@ -224,7 +230,13 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - remote_store.sync(logs_path) # rank 0 overwrites model with best checkpoint and returns. - best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path) + if require_checkpoint: + if verbose: + print("load from checkpoint best model path:", + _checkpoint_callback.best_model_path) + best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path) + else: + best_model = model serialized_checkpoint = io.BytesIO() module = best_model if not is_legacy else best_model._model