Skip to content

Commit

Permalink
[air] pyarrow.fs persistence (7/n): ray.train.Checkpoint restore:…
Browse files Browse the repository at this point in the history
… Auto-recovery fault tolerance (ray-project#38141)

This PR handles the auto-restoration fault tolerance direction for the new `Checkpoint` API:
- The latest `_TrainingResult(checkpoint, metrics)` data saved in the trial state on the driver gets sent to the workers for restoration.
- No checkpoint data gets downloaded during restoration.
- The user can access the checkpoint with `to_directory` and `as_directory`.

This PR also fixed a race condition in `as_directory`: the deletion lock should be set *before* the internal call to `to_directory`. Otherwise, worker 1 can exit the context and delete the directory, while worker 2 is still waiting for the download to finish. Then, once worker 1 lets go of the download lock, the directory has already been deleted, so worker 2 errors..

### Other comments

Here were some other ideas for restoring the checkpoint index:
1. Store it inside the `_TrainingResult` when saving the checkpoint. Then, pass this index along with the checkpoint all the way to the worker worker. Use the index to initialize the starting checkpoint number.
2. Save it inside the Trial storage context. The trial storage context saved on the driver never sets the checkpoint_index, because that indexing is handled all the way on the trainable/worker. **This is what we're doing now. The driver's trial.storage.current_checkpoint_index gets incremented on every reported checkpoint, to stay in sync with the worker/trainable.**
  • Loading branch information
justinvyu authored and harborn committed Aug 17, 2023
1 parent d7f64dc commit baaaa06
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 100 deletions.
53 changes: 28 additions & 25 deletions python/ray/train/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,35 +211,38 @@ def as_directory(self) -> Iterator[str]:
if isinstance(self.filesystem, pyarrow.fs.LocalFileSystem):
yield self.path
else:
temp_dir = self.to_directory()
del_lock_path = _get_del_lock_path(temp_dir)
del_lock_path = _get_del_lock_path(self._get_temporary_checkpoint_dir())
open(del_lock_path, "a").close()

yield temp_dir

# Cleanup
try:
os.remove(del_lock_path)
except Exception:
logger.warning(
f"Could not remove {del_lock_path} deletion file lock. "
f"Traceback:\n{traceback.format_exc()}"
)

# In the edge case (process crash before del lock file is removed),
# we do not remove the directory at all.
# Since it's in /tmp, this is not that big of a deal.
# check if any lock files are remaining
remaining_locks = _list_existing_del_locks(temp_dir)
if not remaining_locks:
temp_dir = self.to_directory()
yield temp_dir
finally:
# Always cleanup the del lock after we're done with the directory.
# This avoids leaving a lock file behind in the case of an exception
# in the user code.
try:
# Timeout 0 means there will be only one attempt to acquire
# the file lock. If it cannot be acquired, a TimeoutError
# will be thrown.
with TempFileLock(f"{temp_dir}.lock", timeout=0):
shutil.rmtree(temp_dir, ignore_errors=True)
except TimeoutError:
pass
os.remove(del_lock_path)
except Exception:
logger.warning(
f"Could not remove {del_lock_path} deletion file lock. "
f"Traceback:\n{traceback.format_exc()}"
)

# In the edge case (process crash before del lock file is removed),
# we do not remove the directory at all.
# Since it's in /tmp, this is not that big of a deal.
# check if any lock files are remaining
remaining_locks = _list_existing_del_locks(temp_dir)
if not remaining_locks:
try:
# Timeout 0 means there will be only one attempt to acquire
# the file lock. If it cannot be acquired, a TimeoutError
# will be thrown.
with TempFileLock(temp_dir, timeout=0):
shutil.rmtree(temp_dir, ignore_errors=True)
except TimeoutError:
pass

def _get_temporary_checkpoint_dir(self) -> str:
"""Return the name for the temporary checkpoint dir that this checkpoint
Expand Down
22 changes: 11 additions & 11 deletions python/ray/train/_internal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def __init__(
experiment_dir_name: str,
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
trial_dir_name: Optional[str] = None,
current_checkpoint_index: Optional[int] = None,
current_checkpoint_index: int = 0,
):
storage_path_provided = storage_path is not None

Expand Down Expand Up @@ -591,18 +591,13 @@ def trial_fs_path(self) -> str:

@property
def checkpoint_fs_path(self) -> str:
"""The trial directory path on the `storage_filesystem`.
"""The current checkpoint directory path on the `storage_filesystem`.
Raises a ValueError if `current_checkpoint_index` is not set beforehand.
"Current" refers to the checkpoint that is currently being created/persisted.
The user of this class is responsible for setting the `current_checkpoint_index`
(e.g., incrementing when needed).
"""
from ray.tune.trainable.util import TrainableUtil

if self.current_checkpoint_index is None:
raise RuntimeError(
"Should not access `checkpoint_fs_path` without setting "
"`current_checkpoint_index`"
)
checkpoint_dir_name = TrainableUtil._make_checkpoint_dir_name(
checkpoint_dir_name = StorageContext._make_checkpoint_dir_name(
self.current_checkpoint_index
)
return os.path.join(self.trial_fs_path, checkpoint_dir_name)
Expand All @@ -620,6 +615,11 @@ def get_experiment_dir_name(run_obj: Union[str, Callable, Type]) -> str:
dir_name = "{}_{}".format(run_identifier, date_str())
return dir_name

@staticmethod
def _make_checkpoint_dir_name(index: int):
"""Get the name of the checkpoint directory, given an index."""
return f"checkpoint_{index:06d}"


_storage_context: Optional[StorageContext] = None

Expand Down
37 changes: 21 additions & 16 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ray.air.config import RunConfig, ScalingConfig
from ray.air.result import Result
from ray.train._internal import session
from ray.train._internal.storage import _use_storage_context
from ray.train.constants import TRAIN_DATASET_KEY
from ray.util import PublicAPI
from ray.util.annotations import DeveloperAPI
Expand Down Expand Up @@ -191,7 +192,7 @@ def __init__(
self.run_config = run_config if run_config is not None else RunConfig()
self.datasets = datasets if datasets is not None else {}
self.preprocessor = preprocessor
self.resume_from_checkpoint = resume_from_checkpoint
self.starting_checkpoint = resume_from_checkpoint

# This path should only be set through restore
self._restore_path = None
Expand Down Expand Up @@ -377,7 +378,7 @@ def __repr__(self):
"run_config": RunConfig(),
"datasets": {},
"preprocessor": None,
"resume_from_checkpoint": None,
"starting_checkpoint": None,
}

non_default_arguments = []
Expand Down Expand Up @@ -452,13 +453,13 @@ def _validate_attributes(self):
f"found {type(self.preprocessor)} with value `{self.preprocessor}`."
)

if self.resume_from_checkpoint is not None and not isinstance(
self.resume_from_checkpoint, ray.air.Checkpoint
if self.starting_checkpoint is not None and not isinstance(
self.starting_checkpoint, ray.air.Checkpoint
):
raise ValueError(
f"`resume_from_checkpoint` should be an instance of "
f"`ray.train.Checkpoint`, found {type(self.resume_from_checkpoint)} "
f"with value `{self.resume_from_checkpoint}`."
f"`ray.train.Checkpoint`, found {type(self.starting_checkpoint)} "
f"with value `{self.starting_checkpoint}`."
)

@classmethod
Expand Down Expand Up @@ -700,18 +701,22 @@ def train_func(config):
# Instantiate new Trainer in Trainable.
trainer = trainer_cls(**config)

# Get the checkpoint from the train context, and use it to initialize
# the restored trainer.
# This handles both worker-level and cluster-level restoration
# of the Train experiment.
# Get the checkpoint from Tune and pass it to workers later on.
checkpoint = session.get_checkpoint()
if checkpoint:
trainer.resume_from_checkpoint = checkpoint
# Always load the preprocessor from an available checkpoint
# Unless we are restoring the experiment and have explicitly
# passed in a new preprocessor
if not (restored and trainer.preprocessor):
trainer.preprocessor = checkpoint.get_preprocessor()
# Set `starting_checkpoint` for auto-recovery fault-tolerance
# as well as manual restoration.
trainer.starting_checkpoint = checkpoint

# TODO(justinvyu): Remove this when Preprocessor is removed from Trainer
if not _use_storage_context():
# Always load the preprocessor from an available checkpoint
# Unless we are restoring the experiment and have explicitly
# passed in a new preprocessor
if not (restored and trainer.preprocessor):
trainer.preprocessor = checkpoint.get_preprocessor()
# Else: Train will restore from the user-provided
# `resume_from_checkpoint` == `starting_checkpoint`.

trainer.setup()
trainer.preprocess_datasets()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def clear_lazy_checkpoint_marker():
datasets=self.datasets,
data_config=self._data_config,
checkpoint_manager=checkpoint_manager,
checkpoint=self.resume_from_checkpoint,
checkpoint=self.starting_checkpoint,
checkpoint_strategy=checkpoint_strategy,
storage_path=self.run_config.storage_path,
)
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/gbdt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def training_loop(self) -> None:
evals_result = {}

init_model = None
if self.resume_from_checkpoint:
init_model, _ = self._load_checkpoint(self.resume_from_checkpoint)
if self.starting_checkpoint:
init_model, _ = self._load_checkpoint(self.starting_checkpoint)

config.setdefault("verbose_eval", False)
config.setdefault("callbacks", [])
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/tests/test_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def test_large_params(ray_start_4_cpus):
array_size = int(1e8)

def training_loop(self):
checkpoint = self.resume_from_checkpoint.to_dict()["ckpt"]
checkpoint = self.starting_checkpoint.to_dict()["ckpt"]
assert len(checkpoint) == array_size

checkpoint = Checkpoint.from_dict({"ckpt": np.zeros(shape=array_size)})
Expand Down
15 changes: 15 additions & 0 deletions python/ray/train/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,21 @@ def test_multiprocess_as_directory(checkpoint: Checkpoint, monkeypatch):
assert not Path(checkpoint_dir_1).exists()


def test_as_directory_lock_cleanup(checkpoint: Checkpoint):
"""Errors when accessing a checkpoint with `as_directory`
shouldn't leave behind lock files.
"""
with pytest.raises(RuntimeError):
with checkpoint.as_directory() as checkpoint_dir:
raise RuntimeError

assert not _list_existing_del_locks(checkpoint_dir)

is_local_checkpoint = isinstance(checkpoint.filesystem, pyarrow.fs.LocalFileSystem)
if not is_local_checkpoint:
assert not Path(checkpoint_dir).exists()


def test_metadata(checkpoint: Checkpoint):
assert checkpoint.get_metadata() == {}

Expand Down
26 changes: 19 additions & 7 deletions python/ray/train/tests/test_new_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,24 +143,28 @@ def train_fn(config):
for i in range(start, config.get("num_iterations", 5)):
time.sleep(0.25)

checkpoint_file_name = "checkpoint.pkl"
temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, "checkpoint.pkl"), "wb") as f:
pickle.dump({"iter": i}, f)

artifact_file_name = f"artifact-iter={i}.txt"
if in_trainer:
rank = train.get_context().get_world_rank()
checkpoint_file_name = f"checkpoint_shard-rank={rank}.pkl"
artifact_file_name = f"artifact-rank={rank}-iter={i}.txt"

checkpoint_file_name = f"checkpoint_shard-rank={rank}.pkl"
with open(os.path.join(temp_dir, checkpoint_file_name), "wb") as f:
pickle.dump({"iter": i}, f)

with open(artifact_file_name, "w") as f:
f.write(f"{i}")

temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, checkpoint_file_name), "wb") as f:
pickle.dump({"iter": i}, f)

train.report(
{"iter": i, _SCORE_KEY: i},
checkpoint=NewCheckpoint.from_directory(temp_dir),
)
if i in config.get("fail_iters", []):
raise RuntimeError(f"Failing on iter={i}!!")


@pytest.mark.parametrize("storage_path_type", [None, "nfs", "cloud", "custom_fs"])
Expand Down Expand Up @@ -287,6 +291,7 @@ def test_trainer(
├── progress.csv
├── result.json
├── checkpoint_000000
│ ├── checkpoint.pkl <- Shared checkpoint file
│ ├── checkpoint_shard-rank=0.pkl <- Worker checkpoint shards
│ └── checkpoint_shard-rank=1.pkl
├── ...
Expand All @@ -309,14 +314,19 @@ def test_trainer(
NUM_WORKERS = 2
trainer = DataParallelTrainer(
train_fn,
train_loop_config={"in_trainer": True, "num_iterations": NUM_ITERATIONS},
train_loop_config={
"in_trainer": True,
"num_iterations": NUM_ITERATIONS,
"fail_iters": [2, 4],
},
scaling_config=train.ScalingConfig(num_workers=2),
run_config=train.RunConfig(
storage_path=storage_path,
storage_filesystem=storage_filesystem,
name=exp_name,
verbose=0,
checkpoint_config=checkpoint_config,
failure_config=train.FailureConfig(max_failures=2),
),
)
result = trainer.fit()
Expand Down Expand Up @@ -352,6 +362,8 @@ def test_trainer(

assert len(list(trial_dir.glob("checkpoint_*"))) == expected_num_checkpoints
for checkpoint_dir in trial_dir.glob("checkpoint_*"):
# 1 shared checkpoint.pkl file, written by all workers.
assert len(list(checkpoint_dir.glob("checkpoint.pkl"))) == 1
# 1 checkpoint shard per worker.
assert (
len(list(checkpoint_dir.glob("checkpoint_shard-*.pkl"))) == NUM_WORKERS
Expand Down
23 changes: 9 additions & 14 deletions python/ray/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def __init__(
# TrainingResult event. There's no need to do these one at a time.
self._checkpoint_to_report = None

# TODO(justinvyu): Is this the best way to do this? Need to save this
# as part of checkpoint metadata and load it back on restore.
self._latest_checkpoint_index = 0
self._storage = None
if _use_storage_context():
self._storage = get_storage_context()

self._start_training(
train_func=train_func,
Expand Down Expand Up @@ -103,7 +103,10 @@ def _start_training(
run_dir=run_dir,
latest_checkpoint_id=latest_checkpoint_id,
)
checkpoint = self._checkpoint_manager._load_checkpoint(checkpoint)

if not _use_storage_context():
checkpoint = self._checkpoint_manager._load_checkpoint(checkpoint)

self._run_with_error_handling(
lambda: self._backend_executor.start_training(
train_func=train_func,
Expand All @@ -119,18 +122,10 @@ def _send_next_checkpoint_path_to_workers(self):
# NOTE: Always upload to storage from workers in the new persistence path
# (no need to check for the `checkpoint_upload_from_workers` flag)
if _use_storage_context():
storage = get_storage_context()

# NOTE: Idea: this checkpoint dir name should be customizable
# and created on the fly when the checkpoint is reported with metrics.
# Ex: lambda metrics: f"checkpoint_iter={metrics['training_iteration']}"
storage.current_checkpoint_index = self._latest_checkpoint_index

self._backend_executor._set_checkpoint_index(
storage.current_checkpoint_index
self._storage.current_checkpoint_index
)

self._latest_checkpoint_index += 1
self._storage.current_checkpoint_index += 1

elif self._checkpoint_strategy._checkpoint_upload_from_workers:
self._backend_executor._set_legacy_checkpoint_uri(
Expand Down
26 changes: 22 additions & 4 deletions python/ray/tune/execution/tune_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,11 +1956,29 @@ def _checkpoint_trial_if_needed(self, trial, force=False):
###
# RESTORE
def _schedule_trial_restore(self, trial: Trial) -> bool:
checkpoint = trial.checkpoint

if _use_storage_context():
# TODO(justinvyu): Skipping restoration altogether for now.
return False
checkpoint_result = trial.checkpoint_manager.latest_checkpoint_result

if not checkpoint_result:
logger.debug(f"Not restoring trial {trial}: No checkpoint found.")
return False

# TODO(justinvyu): Is this really needed?
trial.restoring_from = checkpoint_result

method_name = "restore"
args = (checkpoint_result,)
self._schedule_trial_task(
trial=trial,
method_name=method_name,
args=args,
kwargs={},
on_result=self._on_restoring_result,
on_error=self._trial_task_failure,
)
return True

checkpoint = trial.checkpoint

if checkpoint.dir_or_data is None:
logger.debug(f"Not restoring trial {trial}: No checkpoint found.")
Expand Down
Loading

0 comments on commit baaaa06

Please sign in to comment.