Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spark/Lightning: fix the usage of checkpoint callback #3186

Merged
merged 1 commit into from Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/spark/pytorch/pytorch_lightning_spark_mnist.py
Expand Up @@ -43,6 +43,8 @@
help='temporary working directory to write intermediate files (prefix with hdfs:// to use HDFS)')
parser.add_argument('--data-dir', default='/tmp',
help='location of the training dataset in the local filesystem (will be downloaded if needed)')
parser.add_argument('--enable-profiler', action='store_true',
help='Enable profiler')


def train_model(args):
Expand Down Expand Up @@ -195,7 +197,7 @@ def on_train_end(self, trainer, model):
validation=0.1,
verbose=1,
callbacks=callbacks,
profiler="simple")
profiler="simple" if args.enable_profiler else None)

torch_model = torch_estimator.fit(train_df).setOutputCols(['label_prob'])

Expand Down
16 changes: 14 additions & 2 deletions horovod/spark/lightning/estimator.py
Expand Up @@ -206,6 +206,9 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,

profiler = Param(Params._dummy(), 'profiler', 'lightning profiler to use')

checkpoint_callback = Param(Params._dummy(), 'checkpoint_callback',
'model checkpointing callback')

@keyword_only
def __init__(self,
num_proc=None,
Expand Down Expand Up @@ -246,7 +249,8 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None):
profiler=None,
checkpoint_callback=None):

super(TorchEstimator, self).__init__()
self._setDefault(loss_constructors=None,
Expand All @@ -260,7 +264,8 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None)
profiler=None,
checkpoint_callback=None)

kwargs = self._input_kwargs

Expand Down Expand Up @@ -333,6 +338,12 @@ def setTerminateOnNan(self, value):
def getTerminateOnNan(self):
return self.getOrDefault(self.terminate_on_nan)

def setCheckpointCallback(self, value):
return self._set(checkpoint_callback=value)

def getCheckpointCallback(self):
return self.getOrDefault(self.checkpoint_callback)

def getProfiler(self):
return self.getOrDefault(self.profiler)

Expand Down Expand Up @@ -401,6 +412,7 @@ def _fit_on_prepared_data(self, backend, train_rows, val_rows, metadata, avg_row
validation=self.getValidation())

serialized_model = serialize_fn()(model)
# FIXME: checkpoint bytes should be loaded into serialized_model, same as Keras Estimator.
ckpt_bytes = self._read_checkpoint(run_id) if self._has_checkpoint(run_id) else None
trainer = remote.RemoteTrainer(self,
metadata=metadata,
Expand Down
71 changes: 39 additions & 32 deletions horovod/spark/lightning/remote.py
Expand Up @@ -53,6 +53,7 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
transformation = transformation_fn if transformation_fn else None
inmemory_cache_all = estimator.getInMemoryCacheAll()
callbacks = estimator.getCallbacks() or []
checkpoint_callback = estimator.getCheckpointCallback()
train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
val_steps_per_epoch = estimator.getValidationStepsPerEpoch()
num_gpus = estimator.getNumGPUs()
Expand Down Expand Up @@ -88,16 +89,12 @@ def train(serialized_model):
# Horovod: initialize library.
hvd.init()

with tempfile.TemporaryDirectory() as last_ckpt_dir, remote_store.get_local_output_dir() as run_output_dir:
last_ckpt_file = os.path.join(last_ckpt_dir, 'last.ckpt')
if ckpt_bytes:
with open(last_ckpt_file, 'wb') as f:
f.write(ckpt_bytes)

# TODO: Pass the logger from estimator constructor
with remote_store.get_local_output_dir() as run_output_dir:
logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)
os.makedirs(logs_path, exist_ok=True)
print(f"Made directory {logs_path} for horovod rank {hvd.rank()}")
ckpt_dir = run_output_dir
ckpt_filename = remote_store.checkpoint_filename

# Use default logger if no logger is supplied
train_logger = logger
Expand All @@ -106,22 +103,25 @@ def train(serialized_model):
if train_logger is None:
train_logger = TensorBoardLogger(logs_path)

# TODO: find out a way to use ckpt_path created from remote store, but all other parameters ingest from estimator config
# ckpt_path = os.path.join(run_output_dir, remote_store.checkpoint_filename)
# os.makedirs(ckpt_path, exist_ok=True)
# model_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path)
# callbacks.append(model_checkpoint_callback)

is_model_checkpoint_callback_exist = False
for cb in callbacks:
if isinstance(cb, ModelCheckpoint):
is_model_checkpoint_callback_exist = True
break
# Lightning requires to add checkpoint callbacks for all ranks.
# Otherwise we are seeing hanging in training.
_checkpoint_callback = checkpoint_callback
if _checkpoint_callback:
irasit marked this conversation as resolved.
Show resolved Hide resolved
_checkpoint_callback.dir_path = ckpt_dir
_checkpoint_callback.filename = ckpt_filename
else:
# By default 'monitor'=None which saves a checkpoint only for the last epoch.
irasit marked this conversation as resolved.
Show resolved Hide resolved
_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir,
filename=ckpt_filename,
verbose=True)
callbacks.append(_checkpoint_callback)

if remote_store.saving_runs and hvd.rank() == 0:
# Horovod: sync checkpoint and logging files only on rank 0 to
# prevent other ranks from corrupting them.
class _SyncCallback(Callback):
def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
remote_store.sync(logs_path)
remote_store.sync(run_output_dir)

callbacks.append(_SyncCallback())

Expand All @@ -133,7 +133,11 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
_val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else \
int(math.floor(float(val_rows) / val_batch_size / hvd.size()))

print(f"Training data of rank[{hvd.local_rank()}]: train_rows:{train_rows}, batch_size:{batch_size}, _train_steps_per_epoch:{_train_steps_per_epoch}.")
if verbose:
print(f"Training data of rank[{hvd.local_rank()}]: Epochs: {epochs}\n"
f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {_train_steps_per_epoch}\n"
f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {_val_steps_per_epoch}\n"
f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n")

cuda_available = torch.cuda.is_available()
# We need to check all ranks have same device type for traning.
Expand All @@ -158,8 +162,6 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
'max_epochs': epochs,
'logger': train_logger,
'log_every_n_steps': log_every_n_steps,
'resume_from_checkpoint': (last_ckpt_file if ckpt_bytes else None),
'checkpoint_callback': is_model_checkpoint_callback_exist,
'num_sanity_val_steps': 0,
'reload_dataloaders_every_epoch': False,
'progress_bar_refresh_rate': _train_steps_per_epoch // 10,
Expand All @@ -172,6 +174,9 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if trainer.profiler:
print(f"Set profiler's logs_path to {logs_path}")
trainer.profiler.dirpath = logs_path
# filename where the profiler results will be saved instead of
# printing to stdout. The .txt extension will be used automatically.
trainer.profiler.filename = "profile"

print(f"pytorch_lightning version={pl.__version__}")

Expand All @@ -191,19 +196,21 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
verbose=verbose)
trainer.fit(model, dataset)

serialized_checkpoint = io.BytesIO()
module = model if not is_legacy else model._model
if hvd.rank() == 0:
if remote_store.saving_runs and trainer.profiler:
# One more file sync to push profiler result.
remote_store.sync(logs_path)

# TODO: find a way to pass trainer.logged_metrics out.
output = {'model': module.state_dict()}
# rank 0 overwrites model with best checkpoint and returns.
best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path)
serialized_checkpoint = io.BytesIO()
module = best_model if not is_legacy else best_model._model

torch.save(output, serialized_checkpoint)

if remote_store.saving_runs and hvd.rank() == 0:
remote_store.sync(logs_path)
# TODO: find a way to pass trainer.logged_metrics out.
output = {'model': module.state_dict()}

serialized_checkpoint.seek(0)
return serialized_checkpoint
torch.save(output, serialized_checkpoint)
return serialized_checkpoint
return train


Expand Down