Skip to content

Commit

Permalink
fix: save merged metadata for sharded ckpts locally (#5657)
Browse files Browse the repository at this point in the history
  • Loading branch information
aciborowska authored Jan 3, 2023
1 parent d611afc commit 27b3bcd
Showing 1 changed file with 12 additions and 27 deletions.
39 changes: 12 additions & 27 deletions harness/determined/core/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,9 @@ def _upload_sharded(
if conflicts:
self._try_resolving_conflicts(ckpt_dir, conflicts)

# The lowest rank that is uploading ckpt_dir collects
# and saves metadata.
metadata_writer_rank = ckpt_dir_mask.index(True)
all_metadata = self._merge_and_save_metadata(
ckpt_dir, writer_rank=metadata_writer_rank, metadata=metadata or {}
)
# Merge and save merged metadata locally for each rank to avoid conflicts
# after pausing and unpausing experiment.
all_metadata = self._merge_and_save_metadata(ckpt_dir, metadata=metadata or {})

if want_upload:
assert ckpt_dir
Expand Down Expand Up @@ -333,8 +330,9 @@ def _try_resolving_conflicts(
if len(set(md5_ranks)) == 1:
# All files have the same md5 checksum, which means there is no conflict.
all_conflicts.pop(fname)
else:
self._print_conflict_error(all_conflicts, "files")

if len(all_conflicts) > 0:
self._print_conflict_error(all_conflicts, "files")

def _print_conflict_error(self, conflicts: Dict[str, List], conflict_dtype: str) -> None:
# Try to keep the logs easier to read; print the whole failure only on the chief.
Expand Down Expand Up @@ -482,11 +480,7 @@ def _store_path_sharded(
if self._dist.rank == 0:
resources = self._storage_manager._list_directory(ckpt_dir)

# Metadata should not be counted among resources.
# Chief handles merging and saving metadata to ckpt_dir.
all_metadata = self._merge_and_save_metadata(
ckpt_dir, writer_rank=0, metadata=metadata or {}
)
all_metadata = self._merge_and_save_metadata(ckpt_dir, metadata=metadata or {})

if self._dist.rank == 0:
self._report_checkpoint(storage_id, resources, all_metadata)
Expand Down Expand Up @@ -515,16 +509,9 @@ def _store_path_sharded(
if conflicts:
self._try_resolving_conflicts(ckpt_dir, conflicts)

# Chief gathers metadata across workers, checks for conflicts, saves metadata.
all_metadata = self._merge_and_save_metadata(
ckpt_dir, writer_rank=0, metadata=metadata or {}
)
# My assumption is that chief (local_rank=0) typically uploads stuff anyway.
# If this assumption does not hold, then we can also do this:
# upload_mask = self._dist.allgather(want_upload)
# metadata_writer_rank = upload_mask.index(True)
# all_metadata = self._merge_and_save_metadata(
# ckpt_dir, writer_rank=metadata_writer_rank, metadata=metadata or {})
# Merge and save merged metadata locally for each rank to avoid conflicts
# after pausing and unpausing experiment.
all_metadata = self._merge_and_save_metadata(ckpt_dir, metadata=metadata or {})

if want_upload:
# Use post_store_path to upload and clean up ckpt_dir after uploading.
Expand All @@ -541,17 +528,15 @@ def _store_path_sharded(
def _merge_and_save_metadata(
self,
ckpt_dir: Optional[str],
writer_rank: int,
metadata: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
# Gather metadata across nodes.
all_metadata = self._dist.allgather(metadata or {})
# Merge metadata. If a metadata key repeats, raise error.
# Merge metadata and report errors when the same keys have different values.
merged_metadata, conflicts = merge_metadata(all_metadata)
if conflicts:
self._print_conflict_error(conflicts, "metadata")
if self._dist.rank == writer_rank:
assert ckpt_dir
if ckpt_dir is not None and self._dist.local_rank == 0:
self._write_metadata_file(ckpt_dir, merged_metadata)
return merged_metadata

Expand Down

0 comments on commit 27b3bcd

Please sign in to comment.