diff --git a/examples/spark/pytorch/pytorch_lightning_spark_mnist.py b/examples/spark/pytorch/pytorch_lightning_spark_mnist.py index 99e05c981d..d1d164dab3 100644 --- a/examples/spark/pytorch/pytorch_lightning_spark_mnist.py +++ b/examples/spark/pytorch/pytorch_lightning_spark_mnist.py @@ -43,6 +43,8 @@ help='temporary working directory to write intermediate files (prefix with hdfs:// to use HDFS)') parser.add_argument('--data-dir', default='/tmp', help='location of the training dataset in the local filesystem (will be downloaded if needed)') +parser.add_argument('--enable-profiler', action='store_true', + help='Enable profiler') def train_model(args): @@ -195,7 +197,7 @@ def on_train_end(self, trainer, model): validation=0.1, verbose=1, callbacks=callbacks, - profiler="simple") + profiler="simple" if args.enable_profiler else None) torch_model = torch_estimator.fit(train_df).setOutputCols(['label_prob']) diff --git a/horovod/spark/lightning/estimator.py b/horovod/spark/lightning/estimator.py index 0b6cedc065..c039ffc751 100644 --- a/horovod/spark/lightning/estimator.py +++ b/horovod/spark/lightning/estimator.py @@ -206,6 +206,9 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable, profiler = Param(Params._dummy(), 'profiler', 'lightning profiler to use') + checkpoint_callback = Param(Params._dummy(), 'checkpoint_callback', + 'model checkpointing callback') + @keyword_only def __init__(self, num_proc=None, @@ -246,7 +249,8 @@ def __init__(self, data_module=None, loader_num_epochs=None, terminate_on_nan=False, - profiler=None): + profiler=None, + checkpoint_callback=None): super(TorchEstimator, self).__init__() self._setDefault(loss_constructors=None, @@ -260,7 +264,8 @@ def __init__(self, data_module=None, loader_num_epochs=None, terminate_on_nan=False, - profiler=None) + profiler=None, + checkpoint_callback=None) kwargs = self._input_kwargs @@ -333,6 +338,12 @@ def setTerminateOnNan(self, value): def getTerminateOnNan(self): return self.getOrDefault(self.terminate_on_nan) + def setCheckpointCallback(self, value): + return self._set(checkpoint_callback=value) + + def getCheckpointCallback(self): + return self.getOrDefault(self.checkpoint_callback) + def getProfiler(self): return self.getOrDefault(self.profiler) @@ -401,6 +412,7 @@ def _fit_on_prepared_data(self, backend, train_rows, val_rows, metadata, avg_row validation=self.getValidation()) serialized_model = serialize_fn()(model) + # FIXME: checkpoint bytes should be loaded into serialized_model, same as Keras Estimator. ckpt_bytes = self._read_checkpoint(run_id) if self._has_checkpoint(run_id) else None trainer = remote.RemoteTrainer(self, metadata=metadata, diff --git a/horovod/spark/lightning/remote.py b/horovod/spark/lightning/remote.py index 348b32f34e..eccc4fabe0 100644 --- a/horovod/spark/lightning/remote.py +++ b/horovod/spark/lightning/remote.py @@ -53,6 +53,7 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro transformation = transformation_fn if transformation_fn else None inmemory_cache_all = estimator.getInMemoryCacheAll() callbacks = estimator.getCallbacks() or [] + checkpoint_callback = estimator.getCheckpointCallback() train_steps_per_epoch = estimator.getTrainStepsPerEpoch() val_steps_per_epoch = estimator.getValidationStepsPerEpoch() num_gpus = estimator.getNumGPUs() @@ -88,16 +89,12 @@ def train(serialized_model): # Horovod: initialize library. hvd.init() - with tempfile.TemporaryDirectory() as last_ckpt_dir, remote_store.get_local_output_dir() as run_output_dir: - last_ckpt_file = os.path.join(last_ckpt_dir, 'last.ckpt') - if ckpt_bytes: - with open(last_ckpt_file, 'wb') as f: - f.write(ckpt_bytes) - - # TODO: Pass the logger from estimator constructor + with remote_store.get_local_output_dir() as run_output_dir: logs_path = os.path.join(run_output_dir, remote_store.logs_subdir) os.makedirs(logs_path, exist_ok=True) print(f"Made directory {logs_path} for horovod rank {hvd.rank()}") + ckpt_dir = run_output_dir + ckpt_filename = remote_store.checkpoint_filename # Use default logger if no logger is supplied train_logger = logger @@ -106,22 +103,25 @@ def train(serialized_model): if train_logger is None: train_logger = TensorBoardLogger(logs_path) - # TODO: find out a way to use ckpt_path created from remote store, but all other parameters ingest from estimator config - # ckpt_path = os.path.join(run_output_dir, remote_store.checkpoint_filename) - # os.makedirs(ckpt_path, exist_ok=True) - # model_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path) - # callbacks.append(model_checkpoint_callback) - - is_model_checkpoint_callback_exist = False - for cb in callbacks: - if isinstance(cb, ModelCheckpoint): - is_model_checkpoint_callback_exist = True - break + # Lightning requires to add checkpoint callbacks for all ranks. + # Otherwise we are seeing hanging in training. + _checkpoint_callback = checkpoint_callback + if _checkpoint_callback: + _checkpoint_callback.dir_path = ckpt_dir + _checkpoint_callback.filename = ckpt_filename + else: + # By default 'monitor'=None which saves a checkpoint only for the last epoch. + _checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir, + filename=ckpt_filename, + verbose=True) + callbacks.append(_checkpoint_callback) if remote_store.saving_runs and hvd.rank() == 0: + # Horovod: sync checkpoint and logging files only on rank 0 to + # prevent other ranks from corrupting them. class _SyncCallback(Callback): def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - remote_store.sync(logs_path) + remote_store.sync(run_output_dir) callbacks.append(_SyncCallback()) @@ -133,7 +133,11 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else \ int(math.floor(float(val_rows) / val_batch_size / hvd.size())) - print(f"Training data of rank[{hvd.local_rank()}]: train_rows:{train_rows}, batch_size:{batch_size}, _train_steps_per_epoch:{_train_steps_per_epoch}.") + if verbose: + print(f"Training data of rank[{hvd.local_rank()}]: Epochs: {epochs}\n" + f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {_train_steps_per_epoch}\n" + f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {_val_steps_per_epoch}\n" + f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n") cuda_available = torch.cuda.is_available() # We need to check all ranks have same device type for traning. @@ -158,8 +162,6 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - 'max_epochs': epochs, 'logger': train_logger, 'log_every_n_steps': log_every_n_steps, - 'resume_from_checkpoint': (last_ckpt_file if ckpt_bytes else None), - 'checkpoint_callback': is_model_checkpoint_callback_exist, 'num_sanity_val_steps': 0, 'reload_dataloaders_every_epoch': False, 'progress_bar_refresh_rate': _train_steps_per_epoch // 10, @@ -172,6 +174,9 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - if trainer.profiler: print(f"Set profiler's logs_path to {logs_path}") trainer.profiler.dirpath = logs_path + # filename where the profiler results will be saved instead of + # printing to stdout. The .txt extension will be used automatically. + trainer.profiler.filename = "profile" print(f"pytorch_lightning version={pl.__version__}") @@ -191,19 +196,21 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - verbose=verbose) trainer.fit(model, dataset) - serialized_checkpoint = io.BytesIO() - module = model if not is_legacy else model._model + if hvd.rank() == 0: + if remote_store.saving_runs and trainer.profiler: + # One more file sync to push profiler result. + remote_store.sync(logs_path) - # TODO: find a way to pass trainer.logged_metrics out. - output = {'model': module.state_dict()} + # rank 0 overwrites model with best checkpoint and returns. + best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path) + serialized_checkpoint = io.BytesIO() + module = best_model if not is_legacy else best_model._model - torch.save(output, serialized_checkpoint) - - if remote_store.saving_runs and hvd.rank() == 0: - remote_store.sync(logs_path) + # TODO: find a way to pass trainer.logged_metrics out. + output = {'model': module.state_dict()} - serialized_checkpoint.seek(0) - return serialized_checkpoint + torch.save(output, serialized_checkpoint) + return serialized_checkpoint return train