Skip to content

Commit

Permalink
Add LRU eviction policy to persisent compilation cache
Browse files Browse the repository at this point in the history
  • Loading branch information
colemanliyah committed Jun 16, 2021
1 parent 7a3a160 commit 0e9f7de
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 6 deletions.
43 changes: 37 additions & 6 deletions jax/experimental/compilation_cache/file_system_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
import os
from typing import Optional
import warnings

class FileSystemCache:

def __init__(self, path: str):
def __init__(self, path: str, max_cache_size_bytes=32 * 2**30):
"""Sets up a cache at 'path'. Cached values may already be present."""
os.makedirs(path, exist_ok=True)
self._path = path
self._path = path
self._max_cache_size_bytes = max_cache_size_bytes

def get(self, key: str) -> Optional[bytes]:
"""Returns None if 'key' isn't present."""
Expand All @@ -37,6 +39,35 @@ def put(self, key: str, value: bytes):
"""Adds new cache entry, possibly evicting older entries."""
if not key:
raise ValueError("key cannot be empty")
#TODO(colemanliyah):implement eviction
with open(os.path.join(self._path, key), "wb") as file:
file.write(value)
if self._evict_entries_if_necessary(key, value):
path_to_new_file = os.path.join(self._path, key)
with open(path_to_new_file, "wb") as file:
file.write(value)
else:
warnings.warn(f"Cache value of size {len(value)} is larger than"
f" the max cache size of {self._max_cache_size_bytes}")

def _evict_entries_if_necessary(self, key: str, value: bytes) -> bool:
"""Returns True if there's enough space to add 'value', False otherwise."""
new_file_size = len(value)

if new_file_size >= self._max_cache_size_bytes:
return False

#TODO(colemanliyah): Refactor this section so the whole directory doesn't need to be checked
while new_file_size + self._get_cache_directory_size() > self._max_cache_size_bytes:
last_time = float('inf')
file_to_delete = None
for file_name in os.listdir(self._path):
file_to_inspect = os.path.join(self._path, file_name)
atime = os.stat(file_to_inspect).st_atime
if atime < last_time:
last_time = atime
file_to_delete = file_to_inspect
assert file_to_delete
os.remove(file_to_delete)
return True

def _get_cache_directory_size(self):
"""Retrieves the current size of the directory, self.path"""
return sum(os.path.getsize(f) for f in os.scandir(self._path) if f.is_file())
66 changes: 66 additions & 0 deletions tests/file_system_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
import jax.test_util as jtu
import tempfile
import time

class FileSystemCacheTest(jtu.JaxTestCase):

Expand Down Expand Up @@ -56,5 +57,70 @@ def test_empty_key_get(self):
with self.assertRaisesRegex(ValueError , r"key cannot be empty"):
cache.get("")

def test_size_of_directory(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir)
cache.put("foo", b"bar")
self.assertEqual(cache._get_cache_directory_size(), 3)

def test_size_of_empty_directory(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir)
self.assertEqual(cache._get_cache_directory_size(), 0)

def test_size_of_existing_directory(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache1 = FileSystemCache(tmpdir)
cache1.put("foo", b"bar")
del cache1
cache2 = FileSystemCache(tmpdir)
self.assertEqual(cache2._get_cache_directory_size(), 3)

def test_cache_is_full(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir, max_cache_size_bytes=6)
cache.put("first", b"one")
# Sleep because otherwise these operations execute too fast and
# the access time isn't captured properly.
time.sleep(1)
cache.put("second", b"two")
cache.put("third", b"the")
self.assertEqual(cache.get("first"), None)
self.assertEqual(cache.get("second"), b"two")
self.assertEqual(cache.get("third"), b"the")

def test_delete_multiple_files(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir, max_cache_size_bytes=6)
cache.put("first", b"one")
# Sleep because otherwise these operations execute too fast and
# the access time isn't captured properly.
time.sleep(1)
cache.put("second", b"two")
cache.put("third", b"three")
self.assertEqual(cache.get("first"), None)
self.assertEqual(cache.get("second"), None)
self.assertEqual(cache.get("third"), b"three")

def test_least_recently_accessed_file(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir, max_cache_size_bytes=6)
cache.put("first", b"one")
cache.put("second", b"two")
# Sleep because otherwise these operations execute too fast and
# the access time isn't captured properly.
time.sleep(1)
cache.get("first")
cache.put("third", b"the")
self.assertEqual(cache.get("first"), b"one")
self.assertEqual(cache.get("second"), None)

@jtu.ignore_warning(message=("Cache value of size 3 is larger than the max cache size of 2"))
def test_file_bigger_than_cache(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir, max_cache_size_bytes=2)
cache.put("foo", b"bar")
self.assertEqual(cache.get("foo"), None)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 0e9f7de

Please sign in to comment.