Skip to content

Commit

Permalink
[train] make Trainable storage optional (ray-project#38853)
Browse files Browse the repository at this point in the history
* [train] make Trainable storage optional

Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
matthewdeng authored and arvind-chandra committed Aug 31, 2023
1 parent 6079d1d commit 021e945
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ def __init__(

self._storage = storage

if _use_storage_context():
assert storage
if _use_storage_context() and storage:
assert storage.trial_fs_path
logger.debug(f"StorageContext on the TRAINABLE:\n{storage}")

Expand Down Expand Up @@ -525,17 +524,29 @@ def save(
)

local_checkpoint = NewCheckpoint.from_directory(checkpoint_dir)
persisted_checkpoint = self._storage.persist_current_checkpoint(
local_checkpoint
)
# The checkpoint index needs to be incremented.
# NOTE: This is no longer using "iteration" as the folder indexing
# to be consistent with fn trainables.
self._storage.current_checkpoint_index += 1

checkpoint_result = _TrainingResult(
checkpoint=persisted_checkpoint, metrics=self._last_result.copy()
)
if self._storage:
persisted_checkpoint = self._storage.persist_current_checkpoint(
local_checkpoint
)
# The checkpoint index needs to be incremented.
# NOTE: This is no longer using "iteration" as the folder indexing
# to be consistent with fn trainables.
self._storage.current_checkpoint_index += 1

checkpoint_result = _TrainingResult(
checkpoint=persisted_checkpoint,
metrics=self._last_result.copy(),
)
else:
# `storage=None` only happens when initializing the
# Trainable manually, outside of Tune/Train.
# In this case, no storage is set, so the default behavior
# is to just not upload anything and report a local checkpoint.
# This is fine for the main use case of local debugging.
checkpoint_result = _TrainingResult(
checkpoint=local_checkpoint, metrics=self._last_result.copy()
)

else:
checkpoint_result: _TrainingResult = checkpoint_dict_or_path
Expand Down

0 comments on commit 021e945

Please sign in to comment.