From e0e982f427ba82888f3083d5982cef33a527bdf1 Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Mon, 24 Jan 2022 10:18:34 -0800 Subject: [PATCH] fix_resume_logger (#3375) Signed-off-by: Peng Zhang --- horovod/spark/lightning/remote.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/horovod/spark/lightning/remote.py b/horovod/spark/lightning/remote.py index 50bd303cbb..2456b65b90 100644 --- a/horovod/spark/lightning/remote.py +++ b/horovod/spark/lightning/remote.py @@ -123,15 +123,26 @@ def train(serialized_model): train_logger = TensorBoardLogger(logs_path) print(f"Setup logger: Using TensorBoardLogger: {train_logger}") - elif isinstance(logger, CometLogger) and logger._experiment_key is None: - # Resume logger experiment key if passed correctly from CPU. - train_logger = CometLogger( - save_dir=logs_path, - api_key=logger.api_key, - experiment_key=logger_experiment_key, - ) - - print(f"Setup logger: Resume comet logger: {vars(train_logger)}") + elif isinstance(logger, CometLogger): + if logger._experiment_key: + # use logger passed in. + train_logger = logger + train_logger._save_dir = logs_path + print(f"Setup logger: change save_dir of the logger to {logs_path}") + + elif logger_experiment_key: + # Resume logger experiment with new log path if key passed correctly from CPU. + train_logger = CometLogger( + save_dir=logs_path, + api_key=logger.api_key, + experiment_key=logger_experiment_key, + ) + + print(f"Setup logger: Resume comet logger: {vars(train_logger)}") + + else: + print(f"Failed to setup or resume comet logger. origin logger: {vars(logger)}") + else: # use logger passed in. train_logger = logger