Skip to content

Commit

Permalink
[tune] Remove temporary checkpoint directories after restore (ray-pro…
Browse files Browse the repository at this point in the history
…ject#37173)

`FunctionTrainable.restore_from_object` creates a temporary checkpoint directory.

This directory is kept around as we don't control how the user interacts with the checkpoint - they might load it several times, or no time at all.

Once a new checkpoint is tracked in the status reporter, there is no need to keep the temporary object around anymore. 

In this PR, we add functionality to remove these temporary directories. Additionally we adjust the number of checkpoints to keep in `pytorch_pbt_failure` to 10 to reduce disk pressure in the release test. It looks like this lead to recent failures of the test. By removing the total number of checkpoints and fixing the issue with temporary directories we should see much less disk usage.

Signed-off-by: Kai Fricke <kai@anyscale.com>
  • Loading branch information
krfricke authored and Kai Fricke committed Jul 8, 2023
1 parent 89086bf commit 1bde014
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
46 changes: 44 additions & 2 deletions python/ray/tune/tests/test_function_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from ray.tune.logger import NoopLogger
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.tune.trainable.util import TrainableUtil
from ray.tune.trainable import with_parameters, wrap_function, FuncCheckpointUtil
from ray.tune.trainable import (
with_parameters,
wrap_function,
FuncCheckpointUtil,
FunctionTrainable,
)
from ray.tune.result import DEFAULT_METRIC
from ray.tune.schedulers import ResourceChangingScheduler

Expand Down Expand Up @@ -287,10 +292,11 @@ def train(config, checkpoint_dir=None):

new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint_obj)
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
checkpoint_obj = new_trainable2.save_to_object()
new_trainable2.train()
result = new_trainable2.train()
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0
new_trainable2.stop()
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0
assert result[TRAINING_ITERATION] == 4
Expand Down Expand Up @@ -596,6 +602,42 @@ def train(config):
self.assertEqual(trial_2.last_result["m"], 8 + 9)


def test_restore_from_object_delete(tmp_path):
"""Test that temporary checkpoint directories are deleted after restoring.
`FunctionTrainable.restore_from_object` creates a temporary checkpoint directory.
This directory is kept around as we don't control how the user interacts with
the checkpoint - they might load it several times, or no time at all.
Once a new checkpoint is tracked in the status reporter, there is no need to keep
the temporary object around anymore. This test asserts that the temporary
checkpoint directories are then deleted.
"""
# Create 2 checkpoints
cp_1 = TrainableUtil.make_checkpoint_dir(str(tmp_path), index=1, override=True)
cp_2 = TrainableUtil.make_checkpoint_dir(str(tmp_path), index=2, override=True)

# Instantiate function trainable
trainable = FunctionTrainable()
trainable._logdir = str(tmp_path)
trainable._status_reporter.set_checkpoint(cp_1)

# Save to object and restore. This will create a temporary checkpoint directory.
cp_obj = trainable.save_to_object()
trainable.restore_from_object(cp_obj)

# Assert there is at least one `checkpoint_tmpxxxxx` directory in the logdir
assert any(path.name.startswith("checkpoint_tmp") for path in tmp_path.iterdir())

# Track a new checkpoint. This should delete the temporary checkpoint directory.
trainable._status_reporter.set_checkpoint(cp_2)

# Directory should have been deleted
assert not any(
path.name.startswith("checkpoint_tmp") for path in tmp_path.iterdir()
)


if __name__ == "__main__":
import pytest

Expand Down
10 changes: 10 additions & 0 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,20 @@ def set_checkpoint(self, checkpoint, is_new=True):
"make_checkpoint_dir."
)
raise
previous_checkpoint = self._last_checkpoint
self._last_checkpoint = checkpoint
if is_new:
self._fresh_checkpoint = True

# Delete temporary checkpoint folder from `restore_from_object`
if previous_checkpoint and FuncCheckpointUtil.is_temp_checkpoint_dir(
previous_checkpoint
):
previous_checkpoint_dir = TrainableUtil.find_checkpoint_dir(
previous_checkpoint
)
shutil.rmtree(previous_checkpoint_dir, ignore_errors=True)

def has_new_checkpoint(self):
return self._fresh_checkpoint

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import ray
from ray import tune
from ray.air.config import RunConfig, ScalingConfig, FailureConfig
from ray.air.config import CheckpointConfig, FailureConfig, RunConfig, ScalingConfig
from ray.train.examples.pytorch.tune_cifar_torch_pbt_example import train_func
from ray.train.torch import TorchConfig, TorchTrainer
from ray.tune.schedulers import PopulationBasedTraining
Expand Down Expand Up @@ -70,6 +70,7 @@
run_config=RunConfig(
stop={"training_iteration": 1} if args.smoke_test else None,
failure_config=FailureConfig(max_failures=-1),
checkpoint_config=CheckpointConfig(num_to_keep=10),
callbacks=[FailureInjectorCallback(time_between_checks=90), ProgressCallback()],
),
)
Expand Down

0 comments on commit 1bde014

Please sign in to comment.