Skip to content

Commit

Permalink
[train+tune] Local directory refactor (1/n): Write launcher state fil…
Browse files Browse the repository at this point in the history
…es (`tuner.pkl`, `trainer.pkl`) directly to storage (ray-project#43369)

This PR updates `Trainer`s and the `Tuner` to upload its state directly to `storage_path`, rather than dumping it in a local directory and relying on driver syncing to upload.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
  • Loading branch information
justinvyu authored and hebiao064 committed Mar 12, 2024
1 parent b965f4e commit 992ec0e
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 125 deletions.
2 changes: 1 addition & 1 deletion doc/source/train/doc_code/key_concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def train_fn(config):
result_path: str = result.path
result_filesystem: pyarrow.fs.FileSystem = result.filesystem

print("Results location (fs, path) = ({result_filesystem}, {result_path})")
print(f"Results location (fs, path) = ({result_filesystem}, {result_path})")
# __result_path_end__


Expand Down
2 changes: 1 addition & 1 deletion doc/source/train/doc_code/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@

tuner = Tuner(
trainable=trainer,
run_config=RunConfig(name="test_tuner"),
run_config=RunConfig(name="test_tuner_xgboost"),
param_space=param_space,
tune_config=tune.TuneConfig(
mode="min", metric="train-logloss", num_samples=2, max_concurrent_trials=2
Expand Down
3 changes: 2 additions & 1 deletion python/ray/air/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- Assert how errors from the trainable/Trainer get propagated to the user.
- Assert how errors from the Tune driver get propagated to the user.
"""

import gc
import threading
import time
Expand Down Expand Up @@ -198,7 +199,7 @@ def test_driver_error_with_tuner(ray_start_4_cpus, error_on):
tuner.fit()

# TODO(ml-team): Assert the cause error type once driver error propagation is fixed
assert "_TestSpecificError" in str(exc_info.value.__cause__)
assert "_TestSpecificError" in str(exc_info.value)


@pytest.mark.parametrize("error_on", ["on_trial_result"])
Expand Down
41 changes: 27 additions & 14 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from ray.air.result import Result
from ray.train import Checkpoint
from ray.train._internal.session import _get_session
from ray.train._internal.storage import _exists_at_fs_path, get_fs_and_path
from ray.train._internal.storage import (
StorageContext,
_exists_at_fs_path,
get_fs_and_path,
)
from ray.util import PublicAPI
from ray.util.annotations import DeveloperAPI

Expand Down Expand Up @@ -226,7 +230,9 @@ def __init__(
self.scaling_config = (
scaling_config if scaling_config is not None else ScalingConfig()
)
self.run_config = run_config if run_config is not None else RunConfig()
self.run_config = (
copy.copy(run_config) if run_config is not None else RunConfig()
)
self.metadata = metadata
self.datasets = datasets if datasets is not None else {}
self.starting_checkpoint = resume_from_checkpoint
Expand Down Expand Up @@ -569,11 +575,23 @@ def fit(self) -> Result:
``self.as_trainable()``, or during the Tune execution loop.
"""
from ray.tune import ResumeConfig, TuneError
from ray.tune.tuner import Tuner, TunerInternal
from ray.tune.tuner import Tuner

trainable = self.as_trainable()
param_space = self._extract_fields_for_tuner_param_space()

self.run_config.name = (
self.run_config.name or StorageContext.get_experiment_dir_name(trainable)
)
# The storage context here is only used to access the resolved
# storage fs and experiment path, in order to avoid duplicating that logic.
# This is NOT the storage context object that gets passed to remote workers.
storage = StorageContext(
storage_path=self.run_config.storage_path,
experiment_dir_name=self.run_config.name,
storage_filesystem=self.run_config.storage_filesystem,
)

if self._restore_path:
tuner = Tuner.restore(
path=self._restore_path,
Expand All @@ -594,16 +612,11 @@ def fit(self) -> Result:
_entrypoint=AirEntrypoint.TRAINER,
)

experiment_local_path, _ = TunerInternal.setup_create_experiment_checkpoint_dir(
trainable, self.run_config
)

experiment_local_path = Path(experiment_local_path)
self._save(experiment_local_path)
self._save(storage.storage_filesystem, storage.experiment_fs_path)

restore_msg = TrainingFailedError._RESTORE_MSG.format(
trainer_cls_name=self.__class__.__name__,
path=str(experiment_local_path),
path=str(storage.experiment_fs_path),
)

try:
Expand All @@ -627,7 +640,7 @@ def fit(self) -> Result:
) from result.error
return result

def _save(self, experiment_path: Union[str, Path]):
def _save(self, fs: pyarrow.fs.FileSystem, experiment_path: str):
"""Saves the current trainer's class along with the `param_dict` of
parameters passed to this trainer's constructor.
Expand Down Expand Up @@ -656,9 +669,9 @@ def raise_fn():

cls_and_param_dict = (self.__class__, param_dict)

experiment_path = Path(experiment_path)
with open(experiment_path / _TRAINER_PKL, "wb") as fp:
pickle.dump(cls_and_param_dict, fp)
fs.create_dir(experiment_path)
with fs.open_output_stream(Path(experiment_path, _TRAINER_PKL).as_posix()) as f:
f.write(pickle.dumps(cls_and_param_dict))

def _extract_fields_for_tuner_param_space(self) -> Dict:
"""Extracts fields to be included in `Tuner.param_space`.
Expand Down
25 changes: 12 additions & 13 deletions python/ray/train/tests/test_trainer_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Dict, List

import pyarrow.fs
import pytest

import ray
Expand Down Expand Up @@ -185,26 +186,24 @@ def test_gbdt_trainer_restore(ray_start_6_cpus, tmp_path, trainer_cls, monkeypat
assert tmp_path / exp_name in Path(result.path).parents


@pytest.mark.parametrize("name", [None, "restore_from_uri"])
def test_restore_from_uri_s3(
ray_start_4_cpus, tmp_path, monkeypatch, mock_s3_bucket_uri
ray_start_4_cpus, tmp_path, monkeypatch, mock_s3_bucket_uri, name
):
"""Restoration from S3 should work."""
monkeypatch.setenv("RAY_AIR_LOCAL_CACHE_DIR", str(tmp_path))
trainer = DataParallelTrainer(
train_loop_per_worker=lambda config: train.report({"score": 1}),
scaling_config=ScalingConfig(num_workers=2),
run_config=RunConfig(name="restore_from_uri", storage_path=mock_s3_bucket_uri),
run_config=RunConfig(name=name, storage_path=mock_s3_bucket_uri),
)
trainer.fit()
result = trainer.fit()

# Restore from local dir
DataParallelTrainer.restore(str(tmp_path / "restore_from_uri"))
if name is None:
name = Path(result.path).parent.name

# Restore from S3
assert DataParallelTrainer.can_restore(
str(URI(mock_s3_bucket_uri) / "restore_from_uri")
)
DataParallelTrainer.restore(str(URI(mock_s3_bucket_uri) / "restore_from_uri"))
assert DataParallelTrainer.can_restore(str(URI(mock_s3_bucket_uri) / name))
DataParallelTrainer.restore(str(URI(mock_s3_bucket_uri) / name))


def test_restore_with_datasets(ray_start_4_cpus, tmpdir):
Expand All @@ -220,7 +219,7 @@ def test_restore_with_datasets(ray_start_4_cpus, tmpdir):
scaling_config=ScalingConfig(num_workers=2),
run_config=RunConfig(name="datasets_respecify_test", local_dir=tmpdir),
)
trainer._save(tmpdir)
trainer._save(pyarrow.fs.LocalFileSystem(), str(tmpdir))

# Restore should complain, if all the datasets don't get passed in again
with pytest.raises(ValueError):
Expand All @@ -246,7 +245,7 @@ def test_restore_with_different_trainer(tmpdir):
scaling_config=ScalingConfig(num_workers=1),
run_config=RunConfig(name="restore_with_diff_trainer"),
)
trainer._save(tmpdir)
trainer._save(pyarrow.fs.LocalFileSystem(), str(tmpdir))

def attempt_restore(trainer_cls, should_warn: bool, should_raise: bool):
def check_for_raise():
Expand Down Expand Up @@ -299,7 +298,7 @@ def test_trainer_can_restore_utility(tmp_path):
scaling_config=ScalingConfig(num_workers=1),
)
(tmp_path / name).mkdir(exist_ok=True)
trainer._save(tmp_path / name)
trainer._save(pyarrow.fs.LocalFileSystem(), str(tmp_path / name))

assert DataParallelTrainer.can_restore(path)

Expand Down
3 changes: 1 addition & 2 deletions python/ray/train/tests/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def test_run_config_in_trainer_and_tuner(
run_config=trainer_run_config,
)
with caplog.at_level(logging.INFO, logger="ray.tune.impl.tuner_internal"):
tuner = Tuner(trainer, run_config=tuner_run_config)
Tuner(trainer, run_config=tuner_run_config)

both_msg = (
"`RunConfig` was passed to both the `Tuner` and the `DataParallelTrainer`"
Expand All @@ -302,7 +302,6 @@ def test_run_config_in_trainer_and_tuner(
assert not (tmp_path / "trainer").exists()
assert both_msg not in caplog.text
else:
assert tuner._local_tuner.get_run_config() == RunConfig()
assert both_msg not in caplog.text


Expand Down
79 changes: 24 additions & 55 deletions python/ray/tune/impl/tuner_internal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import io
import os
import math
import logging
from pathlib import Path
Expand All @@ -27,7 +26,6 @@
from ray.tune import Experiment, ExperimentAnalysis, ResumeConfig, TuneError
from ray.tune.tune import _Config
from ray.tune.registry import is_function_trainable
from ray.tune.result import _get_defaults_results_dir
from ray.tune.result_grid import ResultGrid
from ray.tune.trainable import Trainable
from ray.tune.tune import run
Expand Down Expand Up @@ -102,7 +100,7 @@ def __init__(
)

self._tune_config = tune_config or TuneConfig()
self._run_config = run_config or RunConfig()
self._run_config = copy.copy(run_config) or RunConfig()
self._entrypoint = _entrypoint

# Restore from Tuner checkpoint.
Expand All @@ -129,23 +127,27 @@ def __init__(
self._resume_config = None
self._is_restored = False
self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {}
(
self._local_experiment_dir,
self._experiment_dir_name,
) = self.setup_create_experiment_checkpoint_dir(
self.converted_trainable, self._run_config
)
self._experiment_analysis = None

# This needs to happen before `tune.run()` is kicked in.
# This is because currently tune does not exit gracefully if
# run in ray client mode - if crash happens, it just exits immediately
# without allowing for checkpointing tuner and trainable.
# Thus this has to happen before tune.run() so that we can have something
# to restore from.
experiment_checkpoint_path = Path(self._local_experiment_dir, _TUNER_PKL)
with open(experiment_checkpoint_path, "wb") as fp:
pickle.dump(self.__getstate__(), fp)
self._run_config.name = (
self._run_config.name
or StorageContext.get_experiment_dir_name(self.converted_trainable)
)
# The storage context here is only used to access the resolved
# storage fs and experiment path, in order to avoid duplicating that logic.
# This is NOT the storage context object that gets passed to remote workers.
storage = StorageContext(
storage_path=self._run_config.storage_path,
experiment_dir_name=self._run_config.name,
storage_filesystem=self._run_config.storage_filesystem,
)

fs = storage.storage_filesystem
fs.create_dir(storage.experiment_fs_path)
with fs.open_output_stream(
Path(storage.experiment_fs_path, _TUNER_PKL).as_posix()
) as f:
f.write(pickle.dumps(self.__getstate__()))

def get_run_config(self) -> RunConfig:
return self._run_config
Expand Down Expand Up @@ -349,20 +351,16 @@ def _restore_from_path_or_uri(
# Ex: s3://bucket/exp_name -> s3://bucket, exp_name
self._run_config.name = path_or_uri_obj.name
self._run_config.storage_path = str(path_or_uri_obj.parent)

(
self._local_experiment_dir,
self._experiment_dir_name,
) = self.setup_create_experiment_checkpoint_dir(
self.converted_trainable, self._run_config
)
# Update the storage_filesystem with the one passed in on restoration, if any.
self._run_config.storage_filesystem = storage_filesystem

# Load the experiment results at the point where it left off.
try:
self._experiment_analysis = ExperimentAnalysis(
experiment_checkpoint_path=path_or_uri,
default_metric=self._tune_config.metric,
default_mode=self._tune_config.mode,
storage_filesystem=storage_filesystem,
)
except Exception:
self._experiment_analysis = None
Expand Down Expand Up @@ -426,35 +424,6 @@ def _process_scaling_config(self) -> None:
return
self._param_space["scaling_config"] = scaling_config.__dict__.copy()

@classmethod
def setup_create_experiment_checkpoint_dir(
cls, trainable: TrainableType, run_config: Optional[RunConfig]
) -> Tuple[str, str]:
"""Sets up and creates the local experiment checkpoint dir.
This is so that the `tuner.pkl` file gets stored in the same directory
and gets synced with other experiment results.
Returns:
Tuple: (experiment_path, experiment_dir_name)
"""
# TODO(justinvyu): Move this logic into StorageContext somehow
experiment_dir_name = run_config.name or StorageContext.get_experiment_dir_name(
trainable
)
storage_local_path = _get_defaults_results_dir()
experiment_path = (
Path(storage_local_path).joinpath(experiment_dir_name).as_posix()
)

os.makedirs(experiment_path, exist_ok=True)
return experiment_path, experiment_dir_name

# This has to be done through a function signature (@property won't do).
def get_experiment_checkpoint_dir(self) -> str:
# TODO(justinvyu): This is used to populate an error message.
# This should point to the storage path experiment dir instead.
return self._local_experiment_dir

@property
def trainable(self) -> TrainableTypeOrTrainer:
return self._trainable
Expand Down Expand Up @@ -583,7 +552,7 @@ def _get_tune_run_arguments(self, trainable: TrainableType) -> Dict[str, Any]:
return dict(
storage_path=self._run_config.storage_path,
storage_filesystem=self._run_config.storage_filesystem,
name=self._experiment_dir_name,
name=self._run_config.name,
mode=self._tune_config.mode,
metric=self._tune_config.metric,
callbacks=self._run_config.callbacks,
Expand Down
8 changes: 5 additions & 3 deletions python/ray/tune/tests/test_tuner_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def create_trainable_with_params():
)
return trainable_with_params

exp_name = "restore_with_params"
exp_name = f"restore_with_params-{use_function_trainable=}"
fail_marker = tmp_path / "fail_marker"
fail_marker.write_text("", encoding="utf-8")

Expand Down Expand Up @@ -943,11 +943,13 @@ def get_checkpoints(experiment_dir):
else:
raise ValueError(f"Invalid trainable type: {trainable_type}")

exp_name = f"{trainable_type=}"

tuner = Tuner(
trainable,
tune_config=TuneConfig(num_samples=1),
run_config=RunConfig(
name="exp_name",
name=exp_name,
storage_path=str(tmp_path),
checkpoint_config=checkpoint_config,
),
Expand All @@ -966,7 +968,7 @@ def get_checkpoints(experiment_dir):

fail_marker.unlink()
tuner = Tuner.restore(
str(tmp_path / "exp_name"), trainable=trainable, resume_errored=True
str(tmp_path / exp_name), trainable=trainable, resume_errored=True
)
results = tuner.fit()

Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,9 @@ def run(

if _entrypoint == AirEntrypoint.TRAINER:
error_message_map = {
"entrypoint": "Trainer(...)",
"entrypoint": "<FrameworkTrainer>(...)",
"search_space_arg": "param_space",
"restore_entrypoint": 'Trainer.restore(path="{path}", ...)',
"restore_entrypoint": '<FrameworkTrainer>.restore(path="{path}", ...)',
}
elif _entrypoint == AirEntrypoint.TUNER:
error_message_map = {
Expand Down

0 comments on commit 992ec0e

Please sign in to comment.