## Idea
- function turns `dataarray` with attached `hash` into new `dataarray` with attached `hash`
- `dataarray` can then be saved to `cache` directory

In [1]:
import collections
import hashlib
import inspect
import json
from pathlib import Path

import numpy as np
import scipy.integrate as si
import xarray as xr

get_hash = lambda x: hashlib.sha1(x).hexdigest()

In [2]:
keys = {
    k: k
    for k in (
        "hash",
        "hash_input",
        "hash_function",
        "cache",
        "function_name",
        "function_signature",
    )
}

keys = collections.namedtuple("keys", keys.keys())(**keys)
keys

keys(hash='hash', hash_input='hash_input', hash_function='hash_function', cache='cache', function_name='function_name', function_signature='function_signature')

In [3]:
# Input dataarray

x = np.arange(10)
y = np.sin(x)

# initial hash
h = get_hash(np.array([x, y]))

da = xr.DataArray(y, dims="x", coords={"x": x}, name="some_array")
da.attrs.update({keys.hash: h})
da

<xarray.DataArray 'some_array' (x: 10)>
array([ 0.      ,  0.841471,  0.909297,  0.14112 , -0.756802, -0.958924,
       -0.279415,  0.656987,  0.989358,  0.412118])
Coordinates:
  * x        (x) int64 0 1 2 3 4 5 6 7 8 9
Attributes:
    hash:     040a8e4bb7f8cbfc196f77b7243a95819e09c3d2

In [4]:
# write wrapper for caching

_path_cache = Path(keys.cache)
_path_log = _path_cache / f"{keys.hash}.json"


def log_update(filename, hash_str):
    """Update the hash logfile with current `filename: hash` pair"""
    path = _path_log
    log = {}
    if path.exists():
        log = json.load(path.open())
    log.update({filename: hash_str})

    with path.open("w") as f:
        json.dump(log, f, indent=2)


def log_lookup(filename):
    """lookup the hash for given filename"""
    path = _path_log
    if path.exists():
        log = json.load(path.open())
        return log.get(filename)


def function_name_from_attrs(attrs):
    """helper: extract function name from attrs"""
    s = json.loads(attrs[keys.function_signature])
    return s["__name__"]


def cache_filename(array, function_name):
    """determine filename for cache-file from the function name"""
    assert isinstance(array, xr.DataArray)
    return str(function_name) + ".nc"


def cache_write(array, filename):
    """write array to the cache"""
    h = array.attrs[keys.hash]  # extract hash
    file = _path_cache / filename
    file.parent.mkdir(exist_ok=True)  # make sure folder exists
    array.to_netcdf(file)
    # update hash.json
    log_update(filename, h)


def cache_read(filename):
    """read file in cache"""
    path = _path_cache / filename
    array = xr.load_dataarray(path)
    return array


def get_signature_dict(func, array, kwargs):
    """add the function signature dictionary"""
    s = inspect.signature(func)
    dct = {"__name__": func.__name__}
    dct.update({k: v.default for (k, v) in s.parameters.items()})
    dct.update(kwargs)  # update the signature
    dct.update({"__type__": str(type(array))})  # add data type
    dct.pop("kwargs", None)  # pop literal "kwargs" (if they were empty)

    return dct


def cached(func):
    """Provides on-disk cache for a function that accepts and returns an xarray.DataArray"""
    name = func.__name__

    def _func(array, **kwargs):
        signature = get_signature_dict(func, array, kwargs)
        signature_str = json.dumps(signature)

        # create hashes
        hash_input = array.attrs.get(keys.hash, "")  # hash input array
        hash_function = get_hash(signature_str.encode())  # hash the function
        input_function_str = hash_function.encode() + hash_input.encode()
        hash_output = get_hash(input_function_str)  # hash output array

        filename = cache_filename(array, name)

        # look up if the result is already cached in the local cache folder
        hash_lookup = log_lookup(filename)

        if hash_lookup == hash_output:
            return cache_read(filename)  # return the cached result

        # otherwise compute and attach metainfo
        result = func(array=array, **kwargs)
        attrs = {
            keys.hash_input: hash_input,
            keys.function_signature: signature_str,
            keys.hash_function: hash_function,
            keys.hash: hash_output,
        }
        result.attrs.update(attrs)

        cache_write(result, filename)

        return result

    return _func

In [5]:
# function that performs integration on the data
@cached
def cumtrapz(array=None, dx: float = 1.0, axis: int = 0):
    """perform trapz on array and return"""
    y = np.asarray(array)
    x = np.asarray(array.coords[array.dims[0]])
    i = si.cumtrapz(y, x=x, dx=dx, axis=axis, initial=0)

    # create dataarray
    da = xr.DataArray(i, dims=array.dims, coords=array.coords, attrs=array.attrs)

    return da

cumtrapz(da)

<xarray.DataArray (x: 10)>
array([0.      , 0.420735, 1.29612 , 1.821328, 1.513487, 0.655624, 0.036454,
       0.225239, 1.048412, 1.74915 ])
Coordinates:
  * x        (x) int64 0 1 2 3 4 5 6 7 8 9
Attributes:
    hash:                ebf61ebc07e549598eeb651d73a8505b899d7d1e
    hash_input:          040a8e4bb7f8cbfc196f77b7243a95819e09c3d2
    function_signature:  {"__name__": "cumtrapz", "array": null, "dx": 1.0, "...
    hash_function:       374631568a488a70417cd6f9beecbaf15c21306c

In [6]:
def stupid_function(array=None):
    array = array.copy()
    for i, x in enumerate(array):
        for j, y in enumerate(array):
            array[i] += i + y
    return array


@cached
def cached_stupid_function(array=None):
    return stupid_function(array)

In [7]:
%%timeit
stupid_function(da)

79.8 ms ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
%%timeit
cached_stupid_function(da)

1.73 ms ± 123 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
