Skip to content

Commit

Permalink
Add subdir for NDArrayLoader to prevent collision of cache files (#78)
Browse files Browse the repository at this point in the history
* Add subdir for NDArrayLoader to prevent collision of cache files

* Release model when cleanup
  • Loading branch information
Wh1isper committed Dec 19, 2023
1 parent 03b7bdc commit 788284a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 4 deletions.
23 changes: 20 additions & 3 deletions sdgx/models/components/optimize/ndarray_loader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

import os
import shutil
from functools import cached_property
from pathlib import Path
from typing import Generator
from uuid import uuid4

import numpy as np
from numpy import ndarray

DEFAULT_CACHE_ROOT = os.getenv("SDG_NDARRAY_CACHE_ROOT", "./.ndarry_cache")


class NDArrayLoader:
"""
Expand All @@ -16,10 +20,23 @@ class NDArrayLoader:
Support for storing two-dimensional data by columns.
"""

def __init__(self, cache_dir: str | Path = "./.ndarry_cache") -> None:
def __init__(self, cache_root: str | Path = DEFAULT_CACHE_ROOT) -> None:
self.store_index = 0
self.cache_dir = Path(cache_dir).expanduser().resolve()
self.cache_dir.mkdir(exist_ok=True, parents=True)
self.cache_root = Path(cache_root).expanduser().resolve()
self.cache_root.mkdir(exist_ok=True, parents=True)

@cached_property
def subdir(self) -> str:
"""
Prevent collision of cache files.
"""
return uuid4().hex

@cached_property
def cache_dir(self) -> Path:
"""Cache directory for storing ndarray."""

return self.cache_root / self.subdir

def _get_cache_filename(self, index: int) -> Path:
return self.cache_dir / f"{index}.npy"
Expand Down
5 changes: 5 additions & 0 deletions sdgx/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,8 @@ def generator_sample_caller():
def cleanup(self):
if self.dataloader:
self.dataloader.finalize(clear_cache=True)
# Release resources
del self.model

def __del__(self):
self.cleanup()
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os

os.environ["SDG_NDARRAY_CACHE_ROOT"] = "/tmp/sdgx/ndarray_cache"
import shutil

import pytest
Expand Down
2 changes: 1 addition & 1 deletion tests/optmize/test_ndarry_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@pytest.fixture
def ndarray_loader(tmp_path, ndarray_list):
cache_dir = tmp_path / "ndarrycache"
loader = NDArrayLoader(cache_dir=cache_dir)
loader = NDArrayLoader(cache_root=cache_dir)
for ndarray in ndarray_list:
loader.store(ndarray)
yield loader
Expand Down

0 comments on commit 788284a

Please sign in to comment.