Skip to content

Commit

Permalink
[air] pyarrow.fs persistence: Don't automatically delete the local …
Browse files Browse the repository at this point in the history
…checkpoint (ray-project#38507)

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
justinvyu authored and arvind-chandra committed Aug 31, 2023
1 parent 7d517ec commit 1f07d60
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 27 deletions.
17 changes: 6 additions & 11 deletions python/ray/train/_internal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,9 @@ def persist_current_checkpoint(self, checkpoint: "Checkpoint") -> "Checkpoint":
"Current" is defined by the `current_checkpoint_index` attribute of the
storage context.
This method copies the checkpoint files to the storage location,
drops a marker at the storage path to indicate that the checkpoint
is completely uploaded, then deletes the original checkpoint directory.
This method copies the checkpoint files to the storage location.
It's up to the user to delete the original checkpoint files if desired.
For example, the original directory is typically a local temp directory.
Args:
Expand Down Expand Up @@ -561,17 +561,12 @@ def persist_current_checkpoint(self, checkpoint: "Checkpoint") -> "Checkpoint":
destination_filesystem=self.storage_filesystem,
)

# Delete local checkpoint files.
# TODO(justinvyu): What if checkpoint.path == self.checkpoint_fs_path?
# TODO(justinvyu): What if users don't want to delete the local checkpoint?
checkpoint.filesystem.delete_dir(checkpoint.path)

uploaded_checkpoint = Checkpoint(
persisted_checkpoint = Checkpoint(
filesystem=self.storage_filesystem,
path=self.checkpoint_fs_path,
)
logger.debug(f"Checkpoint successfully created at: {uploaded_checkpoint}")
return uploaded_checkpoint
logger.debug(f"Checkpoint successfully created at: {persisted_checkpoint}")
return persisted_checkpoint

@property
def experiment_fs_path(self) -> str:
Expand Down
35 changes: 19 additions & 16 deletions python/ray/train/tests/test_new_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,26 +165,29 @@ def train_fn(config):
for i in range(start, config.get("num_iterations", 5)):
time.sleep(0.25)

temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, "checkpoint.pkl"), "wb") as f:
pickle.dump({"iter": i}, f)
with tempfile.TemporaryDirectory() as temp_dir:
with open(os.path.join(temp_dir, "checkpoint.pkl"), "wb") as f:
pickle.dump({"iter": i}, f)

artifact_file_name = f"artifact-iter={i}.txt"
if in_trainer:
rank = train.get_context().get_world_rank()
artifact_file_name = f"artifact-rank={rank}-iter={i}.txt"
artifact_file_name = f"artifact-iter={i}.txt"
if in_trainer:
rank = train.get_context().get_world_rank()
artifact_file_name = f"artifact-rank={rank}-iter={i}.txt"

checkpoint_file_name = f"checkpoint_shard-rank={rank}.pkl"
with open(os.path.join(temp_dir, checkpoint_file_name), "wb") as f:
pickle.dump({"iter": i}, f)
checkpoint_file_name = f"checkpoint_shard-rank={rank}.pkl"
with open(os.path.join(temp_dir, checkpoint_file_name), "wb") as f:
pickle.dump({"iter": i}, f)

with open(artifact_file_name, "w") as f:
f.write(f"{i}")
with open(artifact_file_name, "w") as f:
f.write(f"{i}")

train.report(
{"iter": i, _SCORE_KEY: i},
checkpoint=NewCheckpoint.from_directory(temp_dir),
)
# `train.report` should not have deleted this!
assert os.path.exists(temp_dir)

train.report(
{"iter": i, _SCORE_KEY: i},
checkpoint=NewCheckpoint.from_directory(temp_dir),
)
if i in config.get("fail_iters", []):
raise RuntimeError(f"Failing on iter={i}!!")

Expand Down
6 changes: 6 additions & 0 deletions python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,12 @@ def get_state(self):
def _create_checkpoint_dir(
self, checkpoint_dir: Optional[str] = None
) -> Optional[str]:
if _use_storage_context():
# NOTE: There's no need to supply the checkpoint directory inside
# the local trial dir, since it'll get persisted to the right location.
checkpoint_dir = tempfile.mkdtemp()
return checkpoint_dir

# Create checkpoint_xxxxx directory and drop checkpoint marker
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
checkpoint_dir or self.logdir, index=self.iteration, override=True
Expand Down

0 comments on commit 1f07d60

Please sign in to comment.