Skip to content

Commit

Permalink
stub in where symlinking/copying might happen
Browse files Browse the repository at this point in the history
  • Loading branch information
suchenzang committed Mar 13, 2023
1 parent da2b120 commit 9c92e4e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
10 changes: 1 addition & 9 deletions metaseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,9 @@ def save_checkpoint(
extra_state,
training_finished=training_finished,
async_callback_fn=async_callback_fn if save_to_NFS else None,
files_to_symlink_to=checkpoints[1:] if len(checkpoints) > 1 else None,
)

# if len(checkpoints) > 1:
# # Create symlink between identical checkpoints (differing in naming for epoch/update/last).
# for other_checkpoint in checkpoints[1:]:
# if PathManager.islink(other_checkpoint):
# PathManager.rm(other_checkpoint)
# assert PathManager.symlink(
# checkpoints[0], other_checkpoint
# ), f"Failed to symlink {checkpoints[0]} to {other_checkpoint}"

write_timer.stop()
logger.info(
f"Saved checkpoint {checkpoints[0]} (epoch {epoch} @ {updates} updates) "
Expand Down
17 changes: 16 additions & 1 deletion metaseq/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _checkpoint_add_directory(basename):
return m[1], f"checkpoint{m[3]}"


def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
def post_checkpoint_callback(cfg, num_updates, training_finished, filename, files_to_symlink_to):
if cfg.checkpoint.cloud_upload_path is not None:
if "blob.core.windows.net" in cfg.checkpoint.cloud_upload_path:
azcopy_logs = filename + "_azcopy_logs"
Expand Down Expand Up @@ -540,6 +540,9 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
f"Successfully copied {filename} to {cfg.checkpoint.cloud_upload_path}"
)
os.remove(filename)

# TODO[Susan]: Add symlink logic here? Check what cloud_upload_path is being used for Uriel's jobs.

elif cfg.checkpoint.cloud_upload_path.startswith("nfs:"):
path, basename = os.path.split(filename)
checkpoint_dir, checkpoint_file = _checkpoint_add_directory(basename)
Expand Down Expand Up @@ -579,6 +582,8 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
)
os.remove(filename)

# TODO[Susan]: Add symlink logic here.

# Start running evals on uploaded checkpoint
nfs_evaluation(
cfg,
Expand All @@ -602,6 +607,16 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
except (FileNotFoundError, AssertionError) as e:
logger.info(f"could not upload {filename}: {e}")

# TODO[Susan]: Add symlink logic here.

# if files_to_symlink_to is not None and len(files_to_symlink_to) > 1:
# for other_checkpoint in files_to_symlink_to:
# if PathManager.islink(other_checkpoint):
# PathManager.rm(other_checkpoint)
# assert PathManager.symlink(
# filename, other_checkpoint
# ), f"Failed to symlink {filename} to {other_checkpoint}"


def nfs_evaluation(
cfg, num_updates, training_finished, checkpoint_dir, destination_checkpoints_dir
Expand Down
4 changes: 2 additions & 2 deletions metaseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def state_dict(self, filename, training_finished=False) -> Dict[str, Dict]:
return state_dicts

def save_checkpoint(
self, filename, extra_state, training_finished=False, async_callback_fn=None
self, filename, extra_state, training_finished=False, async_callback_fn=None, files_to_symlink_to=None
):
"""Save all training state in a checkpoint file."""

Expand All @@ -445,7 +445,7 @@ def save_checkpoint(
def perform_save():
try:
logger.info(f"Beginning asynchronous torch.save to {filename}")
async_callback_fn(filename)
async_callback_fn(filename, files_to_symlink_to)
logger.info(f"Asynchronous torch.save to {filename} complete.")
except Exception as e:
logger.exception(f"Asynchronous save failed: {e}")
Expand Down

0 comments on commit 9c92e4e

Please sign in to comment.