diff --git a/python/ray/tune/trainable/trainable.py b/python/ray/tune/trainable/trainable.py index 79c4cfca37f56..730a5c5d1d910 100644 --- a/python/ray/tune/trainable/trainable.py +++ b/python/ray/tune/trainable/trainable.py @@ -177,8 +177,7 @@ def __init__( self._storage = storage - if _use_storage_context(): - assert storage + if _use_storage_context() and storage: assert storage.trial_fs_path logger.debug(f"StorageContext on the TRAINABLE:\n{storage}") @@ -525,17 +524,29 @@ def save( ) local_checkpoint = NewCheckpoint.from_directory(checkpoint_dir) - persisted_checkpoint = self._storage.persist_current_checkpoint( - local_checkpoint - ) - # The checkpoint index needs to be incremented. - # NOTE: This is no longer using "iteration" as the folder indexing - # to be consistent with fn trainables. - self._storage.current_checkpoint_index += 1 - checkpoint_result = _TrainingResult( - checkpoint=persisted_checkpoint, metrics=self._last_result.copy() - ) + if self._storage: + persisted_checkpoint = self._storage.persist_current_checkpoint( + local_checkpoint + ) + # The checkpoint index needs to be incremented. + # NOTE: This is no longer using "iteration" as the folder indexing + # to be consistent with fn trainables. + self._storage.current_checkpoint_index += 1 + + checkpoint_result = _TrainingResult( + checkpoint=persisted_checkpoint, + metrics=self._last_result.copy(), + ) + else: + # `storage=None` only happens when initializing the + # Trainable manually, outside of Tune/Train. + # In this case, no storage is set, so the default behavior + # is to just not upload anything and report a local checkpoint. + # This is fine for the main use case of local debugging. + checkpoint_result = _TrainingResult( + checkpoint=local_checkpoint, metrics=self._last_result.copy() + ) else: checkpoint_result: _TrainingResult = checkpoint_dict_or_path