Skip to content

Commit

Permalink
Merge pull request #24 from khaeru/feature-cache
Browse files Browse the repository at this point in the history
Add file-based caching
  • Loading branch information
khaeru committed Feb 7, 2021
2 parents 91d906d + a4437f5 commit f003e78
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dist
.coverage*
.pytest_cache
htmlcov
myfunc2-*.pkl

# mypy
.mypy_cache
Expand Down
7 changes: 5 additions & 2 deletions doc/api.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
API
***
API reference
*************

.. currentmodule:: genno

Expand Down Expand Up @@ -263,3 +263,6 @@ Utilities

.. automodule:: genno.util
:members:

.. automodule:: genno.caching
:members:
12 changes: 12 additions & 0 deletions doc/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,18 @@ This section simply makes the output of one task available under another key.
"foo:x-y": "bar:x-y"
"baz:x-y": "bar:x-y"
.. _config-cache:

Caching
-------

Computer-specific configuration that controls the behaviour of functions decorated with :meth:`.Computer.cache`.

- **cache_path** (:class:`pathlib.Path`, optional): base path for cache files. If not provided, defaults to the current working directory.
- **cache_skip** (:class:`bool`, optional): If :obj:`True`, existing cache files are never used; files with the same cache key are overwritten.


``combine:``
------------

Expand Down
6 changes: 4 additions & 2 deletions doc/whatsnew.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ What's new
:backlinks: none
:depth: 1

.. Next release
.. ============
Next release
============

- Add file-based caching via :meth:`.Computer.cache` and :mod:`genno.caching` (:issue:`20`, :pull:`24`).

v0.4.0 (2021-02-07)
===================
Expand Down
75 changes: 75 additions & 0 deletions genno/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json
import logging
import pickle
from hashlib import sha1
from pathlib import Path

log = logging.getLogger(__name__)


class PathEncoder(json.JSONEncoder):
"""JSON encoder that handles :class:`pathlib.Path`.
Used by :func:`.arg_hash`.
"""

def default(self, o):
if isinstance(o, Path):
return str(o)
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, o)


def arg_hash(*args, **kwargs):
"""Return a unique hash for `args` and `kwargs`.
Used by :func:`.make_cache_decorator`.
"""
if len(args) + len(kwargs) == 0:
unique = ""
else:
unique = json.dumps(args, cls=PathEncoder) + json.dumps(kwargs, cls=PathEncoder)

# Uncomment for debugging
# log.debug(f"Cache key hashed from: {unique}")

return sha1(unique.encode()).hexdigest()


def make_cache_decorator(computer, func):
"""Helper for :meth:`.Computer.cache`."""
log.debug(f"Wrapping {func.__name__} in Computer.cache()")

# Wrap the call to load_func
def cached_load(*args, **kwargs):
# Path to the cache file
name_parts = [func.__name__, arg_hash(*args, **kwargs)]

cache_path = computer.graph["config"].get("cache_path")

if not cache_path:
cache_path = Path.cwd()
log.warning(f"'cache_path' configuration not set; using {cache_path}")

cache_path = cache_path.joinpath("-".join(name_parts)).with_suffix(".pkl")

# Shorter name for logging
short_name = f"{name_parts[0]}(<{name_parts[1][:8]}…>)"

if (
not computer.graph["config"].get("cache_skip", False)
and cache_path.exists()
):
log.info(f"Cache hit for {short_name}")
with open(cache_path, "rb") as f:
return pickle.load(f)
else:
log.info(f"Cache miss for {short_name}")
data = func(*args, **kwargs)

with open(cache_path, "wb") as f:
pickle.dump(data, f)

return data

return cached_load
2 changes: 2 additions & 0 deletions genno/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
"aggregate",
"apply_units",
"broadcast_map",
"combine",
"concat",
"disaggregate_shares",
"group_sum",
"load_file",
"product",
"ratio",
Expand Down
13 changes: 9 additions & 4 deletions genno/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

HANDLERS = {}

# Keys to be stored with no action.
STORE = set(["cache_path", "cache_skip"])

CALLBACKS: List[Callable] = []


Expand Down Expand Up @@ -75,9 +78,11 @@ def parse_config(c: Computer, data: dict):
try:
handler = HANDLERS[section_name]
except KeyError:
log.warning(
f"No handler for configuration section named {section_name}; ignored"
)
if section_name not in STORE:
log.warning(
f"No handler for configuration section {repr(section_name)}; "
"ignored"
)
continue

if not handler.keep_data:
Expand Down Expand Up @@ -108,7 +113,7 @@ def parse_config(c: Computer, data: dict):
c.add_queue(queue, max_tries=2, fail="raise")

# Store configuration in the graph itself
c.graph["config"] = data
c.graph["config"].update(data)
else:
if len(queue):
raise RuntimeError(
Expand Down
42 changes: 33 additions & 9 deletions genno/core/computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from dask.optimization import cull

from genno import computations
from genno.caching import make_cache_decorator
from genno.util import partial_split

from .describe import describe_recursive
Expand Down Expand Up @@ -81,16 +82,10 @@ def __init__(self, **kwargs):
def configure(self, path=None, **config):
"""Configure the Computer.
Accepts a *path* to a configuration file and/or keyword arguments.
Configuration keys loaded from file are replaced by keyword arguments.
Accepts a `path` to a configuration file and/or keyword arguments.
Configuration keys loaded from file are superseded by keyword arguments.
Valid configuration keys include:
- *default*: the default key; sets :attr:`default_key`.
- *filters*: a :class:`dict`, passed to :meth:`set_filters`.
- *files*: a :class:`list` where every element is a :class:`dict`
of keyword arguments to :meth:`add_file`.
- *alias*: a :class:`dict` mapping aliases to original keys.
See :doc:`config` for a list of all configuration sections and keys.
Warns
-----
Expand Down Expand Up @@ -210,6 +205,35 @@ def add(self, data, *args, **kwargs):
# Some other kind of input
raise TypeError(data)

def cache(self, func):
"""Return a decorator to cache data.
Use this function to decorate another function to be added as the computation/
callable in a task:
.. code-block:: python
c = Computer(cache_path=Path("/some/directory"))
@c.cache
def myfunction(*args, **kwargs):
# Expensive operations, e.g. loading large files
return data
c.add("myvar", (myfunction,))
# Data is cached in /some/directory/myfunction-*.pkl
On the first call of :meth:`get` that invokes the decorated function (directly
or indirectly), the data requested is returned, but also cached in the cache
directory (see :ref:`Configuration → Caching <config-cache>`).
On subsequent calls, if the cache exists, it is used instead of calling the
(possibly slow) method; *unless* the *skip_cache* configuration option is
given, in which case it is loaded again.
"""
return make_cache_decorator(self, func)

def add_queue(self, queue, max_tries=1, fail="raise"):
"""Add tasks from a list or `queue`.
Expand Down
86 changes: 86 additions & 0 deletions genno/tests/core/test_computer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from functools import partial

import numpy as np
import pandas as pd
Expand All @@ -22,6 +23,91 @@
assert_qty_equal,
)

log = logging.getLogger(__name__)


def test_cache(caplog, tmp_path, test_data_path, ureg):
caplog.set_level(logging.INFO)

# Set the cache path
c = Computer(cache_path=tmp_path)

# Arguments and keyword arguments for the computation. These are hashed to make the
# cache key
args = (test_data_path / "input0.csv", "foo")
kwargs = dict(bar="baz")

# Expected value
exp = computations.load_file(test_data_path / "input0.csv")
exp.attrs["args"] = repr(args)
exp.attrs["kwargs"] = repr(kwargs)

def myfunc1(*args, **kwargs):
# Send something to the log for caplog to pick up when the function runs
log.info("myfunc executing")
result = computations.load_file(args[0])
result.attrs["args"] = repr(args)
result.attrs["kwargs"] = repr(kwargs)
return result

# Add to the Computer
c.add("test 1", (partial(myfunc1, *args, **kwargs),))

# Returns the expected result
assert_qty_equal(exp, c.get("test 1"))

# Function was executed
assert "myfunc executing" in caplog.messages

# Same function, but cached
@c.cache
def myfunc2(*args, **kwargs):
return myfunc1(*args, **kwargs)

# Add to the computer
c.add("test 2", (partial(myfunc2, *args, **kwargs),))

# First time computed, returns the expected result
caplog.clear()
assert_qty_equal(exp, c.get("test 2"))

# Function was executed
assert "myfunc executing" in caplog.messages

# 1 cache file was created in the cache_path
files = list(tmp_path.glob("*.pkl"))
assert 1 == len(files)

# File name includes the full hash; retrieve it
hash = files[0].stem.split("-")[-1]

# Cache miss was logged
assert f"Cache miss for myfunc2(<{hash[:8]}…>)" in caplog.messages

# Second time computed, returns the expected result
caplog.clear()
assert_qty_equal(exp, c.get("test 2"))

# Value was loaded from the cache file
assert f"Cache hit for myfunc2(<{hash[:8]}…>)" in caplog.messages
# The function was NOT executed
assert not ("myfunc executing" in caplog.messages)

# With cache_skip
caplog.clear()
c.configure(cache_skip=True)
c.get("test 2")

# Function is executed
assert "myfunc executing" in caplog.messages

# With no cache_path set
c.graph["config"].pop("cache_path")

caplog.clear()
c.get("test 2")
assert "'cache_path' configuration not set; using " in caplog.messages[0]


def test_get():
"""Computer.get() using a default key."""
Expand Down
18 changes: 18 additions & 0 deletions genno/tests/test_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pathlib import Path

import pytest

from genno.caching import PathEncoder, arg_hash


def test_PathEncoder():
# Encodes pathlib.Path or subclass
PathEncoder().default(Path.cwd())

with pytest.raises(TypeError):
PathEncoder().default(lambda foo: foo)


def test_arg_hash():
# Expected value with no arguments
assert "da39a3ee5e6b4b0d3255bfef95601890afd80709" == arg_hash()

0 comments on commit f003e78

Please sign in to comment.