From 9c92e4ecb3308d485e34406323bf76daa5170cf1 Mon Sep 17 00:00:00 2001 From: suchenzang Date: Sun, 12 Mar 2023 21:30:02 -0400 Subject: [PATCH] stub in where symlinking/copying might happen --- metaseq/checkpoint_utils.py | 10 +--------- metaseq/cli/train.py | 17 ++++++++++++++++- metaseq/trainer.py | 4 ++-- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 99aa3fb16..48b78357f 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -95,17 +95,9 @@ def save_checkpoint( extra_state, training_finished=training_finished, async_callback_fn=async_callback_fn if save_to_NFS else None, + files_to_symlink_to=checkpoints[1:] if len(checkpoints) > 1 else None, ) - # if len(checkpoints) > 1: - # # Create symlink between identical checkpoints (differing in naming for epoch/update/last). - # for other_checkpoint in checkpoints[1:]: - # if PathManager.islink(other_checkpoint): - # PathManager.rm(other_checkpoint) - # assert PathManager.symlink( - # checkpoints[0], other_checkpoint - # ), f"Failed to symlink {checkpoints[0]} to {other_checkpoint}" - write_timer.stop() logger.info( f"Saved checkpoint {checkpoints[0]} (epoch {epoch} @ {updates} updates) " diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index c131a270c..b7b8a3ba9 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -511,7 +511,7 @@ def _checkpoint_add_directory(basename): return m[1], f"checkpoint{m[3]}" -def post_checkpoint_callback(cfg, num_updates, training_finished, filename): +def post_checkpoint_callback(cfg, num_updates, training_finished, filename, files_to_symlink_to): if cfg.checkpoint.cloud_upload_path is not None: if "blob.core.windows.net" in cfg.checkpoint.cloud_upload_path: azcopy_logs = filename + "_azcopy_logs" @@ -540,6 +540,9 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename): f"Successfully copied {filename} to {cfg.checkpoint.cloud_upload_path}" ) os.remove(filename) + + # TODO[Susan]: Add symlink logic here? Check what cloud_upload_path is being used for Uriel's jobs. + elif cfg.checkpoint.cloud_upload_path.startswith("nfs:"): path, basename = os.path.split(filename) checkpoint_dir, checkpoint_file = _checkpoint_add_directory(basename) @@ -579,6 +582,8 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename): ) os.remove(filename) + # TODO[Susan]: Add symlink logic here. + # Start running evals on uploaded checkpoint nfs_evaluation( cfg, @@ -602,6 +607,16 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename): except (FileNotFoundError, AssertionError) as e: logger.info(f"could not upload {filename}: {e}") + # TODO[Susan]: Add symlink logic here. + + # if files_to_symlink_to is not None and len(files_to_symlink_to) > 1: + # for other_checkpoint in files_to_symlink_to: + # if PathManager.islink(other_checkpoint): + # PathManager.rm(other_checkpoint) + # assert PathManager.symlink( + # filename, other_checkpoint + # ), f"Failed to symlink {filename} to {other_checkpoint}" + def nfs_evaluation( cfg, num_updates, training_finished, checkpoint_dir, destination_checkpoints_dir diff --git a/metaseq/trainer.py b/metaseq/trainer.py index faa888d11..d28b3b5cb 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -421,7 +421,7 @@ def state_dict(self, filename, training_finished=False) -> Dict[str, Dict]: return state_dicts def save_checkpoint( - self, filename, extra_state, training_finished=False, async_callback_fn=None + self, filename, extra_state, training_finished=False, async_callback_fn=None, files_to_symlink_to=None ): """Save all training state in a checkpoint file.""" @@ -445,7 +445,7 @@ def save_checkpoint( def perform_save(): try: logger.info(f"Beginning asynchronous torch.save to {filename}") - async_callback_fn(filename) + async_callback_fn(filename, files_to_symlink_to) logger.info(f"Asynchronous torch.save to {filename} complete.") except Exception as e: logger.exception(f"Asynchronous save failed: {e}")