diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index 935fec30605..8edd29e1505 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -1,6 +1,6 @@ +import shutil import tempfile import typing -from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Optional, TypeVar @@ -17,12 +17,6 @@ T = TypeVar("T") -@dataclass -class DeleteAllResult: - deleted_count: int - freed_space_bytes: float - - class ObjectSerializerDisk(ObjectSerializerBase[T]): """Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`. @@ -35,6 +29,12 @@ def __init__(self, output_dir: Path, ephemeral: bool = False): self._ephemeral = ephemeral self._base_output_dir = output_dir self._base_output_dir.mkdir(parents=True, exist_ok=True) + + if self._ephemeral: + # Remove dangling tempdirs that might have been left over from an earlier unplanned shutdown. + for temp_dir in filter(Path.is_dir, self._base_output_dir.glob("tmp*")): + shutil.rmtree(temp_dir) + # Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows self._tempdir = ( tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None diff --git a/tests/test_object_serializer_disk.py b/tests/test_object_serializer_disk.py index 125534c5002..84c6e876fcb 100644 --- a/tests/test_object_serializer_disk.py +++ b/tests/test_object_serializer_disk.py @@ -99,6 +99,20 @@ def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path): assert not Path(tmp_path, obj_1_name).exists() +def test_obj_serializer_ephemeral_deletes_dangling_tempdirs_on_init(tmp_path: Path): + tempdir = tmp_path / "tmpdir" + tempdir.mkdir() + ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + assert not tempdir.exists() + + +def test_obj_serializer_does_not_delete_tempdirs_on_init(tmp_path: Path): + tempdir = tmp_path / "tmpdir" + tempdir.mkdir() + ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=False) + assert tempdir.exists() + + def test_obj_serializer_disk_different_types(tmp_path: Path): obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path) obj_1 = MockDataclass(foo="bar")