In [1]:
%load_ext pycodestyle_magic
%flake8_on --max_line_length 120 --ignore W293,E302

# Memoization to disk

These tools provide functions to facilitate the memoization of certain computations, when the cost of their execution exceeds that of their storage and serialization.

The current approach means to separate the result storage from the memoization checking.

In [2]:
from jupytest import Suite, Report, summarize_results, fail
from unittest.mock import patch, MagicMock, call

In [3]:
suite = Suite()
if __name__ == "__main__":
    suite |= Report()

## Result store

The store is essentially a directory where results can be serialized to files.

In [4]:
import os
from os.path import isdir, realpath


DIR_STORE = realpath(".")


def set_store_directory(path: os.PathLike) -> None:
    global DIR_STORE
    DIR_STORE = realpath(path)

This allows using a certain result storage place only temporarily.

In [5]:
from contextlib import contextmanager
from typing import ContextManager


@contextmanager
def storing_in_directory(path: os.PathLike) -> ContextManager[None]:
    dir_orig = DIR_STORE
    try:
        set_store_directory(path)
        yield
    finally:
        set_store_directory(dir_orig)

The path to a certain result is identified by a **unique** *signature*.

In [6]:
import os
from os.path import join


def path_to_result(sig: str) -> os.PathLike:
    return join(DIR_STORE, sig)

In [7]:
%%test Results will live in the store directory
from os.path import dirname
for sig in ["some_result", "some/deeper/result"]:
    path = path_to_result(sig)
    while path:
        if path == DIR_STORE:
            break
        path = dirname(path)
    else:
        fail(f"Did not get the storage directory anywhere along {path_to_result(sig)}")

Test [1mResults will live in the store directory[0m passed.


## *Serde* -- Serializer-deserializer

Most results will be picklable to disk; let's still make the serde a moving part, in case we should meet results that would not.

In [8]:
import gzip
import io
import os
import pickle
from typing import Any


class Serde:

    def exists(self, sig: str) -> bool:
        return os.access(path_to_result(sig), os.R_OK)

    def read(self, sig: str) -> Any:
        with self._open(sig, "rb") as file:
            return self.from_file(file)
        
    def from_file(self, file: io.RawIOBase) -> Any:
        return pickle.load(file)
    
    def write(self, sig: str, obj: Any) -> Any:
        with self._open(sig, "wb") as file:
            self.to_file(file, obj)
        return obj

    def to_file(self, file: io.RawIOBase, obj: Any) -> None:
        pickle.dump(obj, file)

    def _open(self, sig: str, mode: str) -> io.RawIOBase:
        return gzip.open(path_to_result(sig), mode)

In [9]:
%%test Serializing a result
with patch("gzip.open") as mock:
    assert Serde().write("some_result", "asdf") == "asdf"
    mock.assert_called_once_with(path_to_result("some_result"), "wb")
    mock.return_value.__enter__.return_value.write.assert_called_once_with(pickle.dumps("asdf"))

Test [1mSerializing a result[0m passed.


In [10]:
%%test Deserializing a result
with patch("gzip.open") as mock_open, patch("pickle.load", return_value="qwerty") as mock_load:
    assert Serde().read("known_result") == "qwerty"
    mock_load.assert_called_once_with(mock_open.return_value.__enter__.return_value)

Test [1mDeserializing a result[0m passed.


## Function call signatures

Signatures should incorporate as much as possible from the artifacts of the computation, so that any change among these will fail to reuse a result unduly. We shall take into account:

1. The input parameter names and values: we would consider their most detailed representation as issued by `repr`.
1. The computation's implementation: we will take the source as is, but discard blank lines, thereby providing a measure of robustness.
1. The state of global and closure variables upon function entry.

### Signing the code of a function

In [11]:
from hashlib import sha256
from inspect import getsourcelines
from typing import Callable, Optional


def signature_code(f: Callable, h: Optional[sha256] = None) -> str:
    h = h or sha256()
    try:
        source, _ = getsourcelines(f)
        source_no_blank = b"".join(bytes(line.rstrip(), encoding="utf-8") for line in source if len(line.strip()) > 0)
        h.update(source_no_blank)
    except TypeError:
        h.update(bytes(f.__name__, encoding="utf-8"))
    return h.hexdigest()

In [12]:
def is_sha256(s: str) -> bool:
    import re
    return re.match(r"^[a-f0-9]{64}$", s) is not None

In [13]:
%%test Code signature is a SHA-256 hash
def fn(x):
    return x * x

assert is_sha256(signature_code(fn))

Test [1mCode signature is a SHA-256 hash[0m passed.


In [14]:
%%test Code signature is distinct for functions with distinct code
def f(x):
    return x * x

def g(x):
    return x + x

assert signature_code(f) != signature_code(g)

Test [1mCode signature is distinct for functions with distinct code[0m passed.


In [15]:
%%test Code signature does not change for two functions with the same code
from inspect import getsource

def make_fn(z):
    def f(x):
        y = x * x
        return y * y + 0.5 * y - z
    return f

f1 = make_fn(1)
f2 = make_fn(2)
assert f1 is not f2
assert getsource(f1) == getsource(f2)
assert signature_code(f1) == signature_code(f2)

Test [1mCode signature does not change for two functions with the same code[0m passed.


In [16]:
%%test Code signature is not impacted by blank lines
from inspect import getsource

def f(x):
    return x * x

f1 = f

def f(x):
    
    return x * x

f2 = f
assert f1 is not f2
assert getsource(f1) != getsource(f2)
assert signature_code(f1) == signature_code(f2)

Test [1mCode signature is not impacted by blank lines[0m passed.


In [17]:
%%test For functions that have no code, we sign the function's name
from hashlib import sha256
assert signature_code(int) == sha256(b"int").hexdigest()

Test [1mFor functions that have no code, we sign the function's name[0m passed.


### Signing arguments of a function

In [18]:
from hashlib import sha256
from typing import Any, Sequence, Mapping, Optional


def normalize_env(env: Mapping[str, Any]) -> Mapping[str, Any]:
    return dict(sorted(env.items()))


def bytes_repr(x: Any) -> bytes:
    return bytes(repr(x), encoding="utf-8")


def signature_args(args: Sequence[Any], kwargs: Mapping[str, Any], h: Optional[sha256] = None) -> str:
    h = h or sha256()
    h.update(bytes_repr(args))
    h.update(bytes_repr(normalize_env(kwargs)))
    return h.hexdigest()

In [19]:
%%test Get a signature for empty argument lists
assert is_sha256(signature_args([], {}))

Test [1mGet a signature for empty argument lists[0m passed.


In [20]:
%%test Same signature for same argument lists
assert signature_args(["asdf"], {}) == signature_args(["asdf"], {})
assert signature_args([], dict(x=552)) == signature_args([], dict(x=552))
assert signature_args(["some/path", (3, "tuple")], dict(asdf=45, qwer=98, zxcv=234)) == \
    signature_args(["some/path", (3, "tuple")], dict(qwer=98, asdf=45, zxcv=234))

Test [1mSame signature for same argument lists[0m passed.


In [21]:
%%test Arg signature distinct for distinct positional arg lists (although identical keyword args)
assert signature_args(["asdf"], {}) != signature_args([], {})
assert signature_args(["asdf", "qwer"], dict(x=32, y=56)) != signature_args(["qwer", "asdf"], dict(y=56, x=32))

Test [1mArg signature distinct for distinct positional arg lists (although identical keyword args)[0m passed.


In [22]:
%%test Arg signature distinct for distinct keyword arg lists (although identical positional args)
assert signature_args([], dict(x=32)) != signature_args([], dict(y=32))
assert signature_args([], dict(x=32)) != signature_args([], dict(x=33))
assert signature_args(["asdf", "qwer"], dict(x=12, y=32)) != signature_args(["asdf", "qwer"], dict(x=12, y=32, z=45))

Test [1mArg signature distinct for distinct keyword arg lists (although identical positional args)[0m passed.


### Signature for relevant environment

In [23]:
from inspect import getclosurevars


def signature_env(fn: Callable, h: Optional[sha256] = None) -> str:
    h = h or sha256()
    cv = getclosurevars(fn)
    h.update(bytes_repr(normalize_env(cv.nonlocals)))
    h.update(bytes_repr(normalize_env(cv.globals)))
    return h.hexdigest()

In [24]:
%%test Env signature is SHA256 even when no closure nor global var
def f():
    return "asdf"

cv = getclosurevars(f)
assert len(cv.nonlocals) == 0
assert len(cv.globals) == 0
assert is_sha256(signature_env(f))

Test [1mEnv signature is SHA256 even when no closure nor global var[0m passed.


In [25]:
%%test Env signature for distinct functions with same closures is the same
def make_f(y):
    def f(x):
        return x + y
    return f
    
f1 = make_f(1)
f2 = make_f(1)
assert f1 is not f2 and f1 != f2
assert signature_env(f1) == signature_env(f2)

Test [1mEnv signature for distinct functions with same closures is the same[0m passed.


In [26]:
%%test Env signature for distinct functions with same globals is the same
def f():
    return 56 + G

def g():
    return 23 - G
    
try:
    globals()["G"] = 8
    assert signature_env(f) == signature_env(g)
finally:
    del globals()["G"]

Test [1mEnv signature for distinct functions with same globals is the same[0m passed.


In [27]:
%%test Env signatures for function with distinct closures are distinct
def make_f(y):
    def f(x):
        return x + y
    return f
    
f1 = make_f(1)
f2 = make_f(2)
assert signature_env(f1) != signature_env(f2)

Test [1mEnv signatures for function with distinct closures are distinct[0m passed.


In [28]:
%%test Env signatures for same function but with distinct global bindings are distinct
def f():
    return G + 8
    
try:
    globals()["G"] = 10
    sig1 = signature_env(f)
    globals()["G"] = 20
    sig2 = signature_env(f)
    assert sig1 != sig2
finally:
    del globals()["G"]

Test [1mEnv signatures for same function but with distinct global bindings are distinct[0m passed.


### Signature of a full function call

In [29]:
from typing import Callable, Sequence, Mapping, Any


def signature_call(fn: Callable, args: Sequence[Any], kwargs: Mapping[str, Any]) -> str:
    h = sha256()
    signature_code(fn, h)
    signature_args(args, kwargs, h)
    signature_env(fn, h)
    return h.hexdigest()

In [30]:
%%test Combined signature is a SHA256 hash
assert is_sha256(signature_call(lambda x: x, [], {}))

Test [1mCombined signature is a SHA256 hash[0m passed.


In [31]:
%%test Combined signature is distinct from components
def f():
    return "asdf"

assert signature_call(f, [], {}) != signature_code(f)
assert signature_call(f, [], {}) != signature_args([], {})
assert signature_call(f, [], {}) != signature_env(f)

Test [1mCombined signature is distinct from components[0m passed.


In [32]:
%%test Same combined signature for same code, args and env
def make_f(z):
    def f(x, y=1.0):
        return x * x / y - z ** 2
    return f
    
f1 = make_f(100)
f2 = make_f(100)
assert f1 is not f2 and f1 != f2
from inspect import getclosurevars
for f in [f1, f2]:
    assert getclosurevars(f).nonlocals.get("z") == 100
assert signature_call(f1, [8], {"y": 2}) == signature_call(f2, [8], {"y": 2})

Test [1mSame combined signature for same code, args and env[0m passed.


In [33]:
%%test Distinct combined signatures for functions with distinct code, but same args and env
def f(x, y):
    return x + y + G

def g(x, y):
    return x * y + G


try:
    globals()["G"] = 90
    assert signature_call(f, [3, 4], {}) != signature_call(g, [3, 4], {})
finally:
    del globals()["G"]

Test [1mDistinct combined signatures for functions with distinct code, but same args and env[0m passed.


In [34]:
%%test Distinct combined signatures for function with distinct args, but same code and env
def make_f(z):
    def f(x, **kwargs):
        return x * z - sum(kwargs.values())
    return f
    
f1 = make_f(10)
f2 = make_f(10)
assert signature_code(f1) == signature_code(f2)
assert signature_env(f1) == signature_env(f2)
assert signature_call(f1, [4], dict(u=8, o=9)) != signature_call(f2, [4], dict(r=8, p=9))

Test [1mDistinct combined signatures for function with distinct args, but same code and env[0m passed.


In [35]:
%%test Distinct combined signatures for function with distinct env, but same code and args
def make_f(z):
    def f(x, y):
        return x + y - z * T
    return f
    
try:
    f = make_f(10)
    globals()["T"] = 1
    sig1 = signature_call(f, [10, 21], {})
    globals()["T"] = 2
    sig2 = signature_call(f, [10, 21], {})
    assert sig1 != sig2
finally:
    del globals()["T"]

Test [1mDistinct combined signatures for function with distinct env, but same code and args[0m passed.


## Memoization

In [36]:
from dask.delayed import Delayed
from typing import Callable, Any, Optional


def memo(fn: Optional[Callable] = None, serde: Optional[Serde] = None) -> Callable:
    serde = serde or Serde()

    def _process(fn: Callable) -> Callable:
        def memoized(*args, **kwargs) -> Any:
            sig = signature_call(fn, args, kwargs)
            if serde.exists(sig):
                return serde.read(sig)
            return serde.write(sig, fn(*args, **kwargs))
        return memoized

    if fn is None:
        return _process
    return _process(fn)

In [37]:
from typing import Tuple


class SerdeTest(Serde):
    
    def __init__(self) -> None:
        self._results = {}
        self.num_exists = 0
        self.num_read = 0
        self.num_write = 0

    @property
    def num_ops(self) -> Tuple[int, int, int]:
        return (self.num_exists, self.num_read, self.num_write)
        
    def exists(self, sig: str) -> bool:
        self.num_exists += 1
        return sig in self._results
    
    def read(self, sig: str) -> Any:
        self.num_read += 1
        return self._results[sig]
    
    def write(self, sig: str, obj: Any) -> Any:
        self.num_write += 1
        self._results[sig] = obj
        return obj

In [38]:
%%test Successful memoization
mock_function = MagicMock()
serde_test = SerdeTest()

def make_f(z):
    global serde_test
    
    @memo(serde=serde_test)
    def f(x, *args, **kwargs):
        global mock_function
        mock_function(x, *args, **kwargs)
        return x * (sum(args) + sum(kwargs.values())) - z

    return f

f = make_f(1)
y = 0
for _ in range(10):
    y += f(2, 5, 6, p=2, q=3)
mock_function.assert_called_once_with(2, 5, 6, p=2, q=3)
assert serde_test.num_ops == (10, 9, 1)

Test [1mSuccessful memoization[0m passed.


In [39]:
%%test Memoized result not reused when distinct args
mock_function = MagicMock()
serde_test = SerdeTest()

@memo(serde=serde_test)
def f(x, y, **kwargs):
    global mock_function
    mock_function(x, y, **kwargs)
    return x * sum(kwargs.values()) - y

results = [f(4, 3, p=2, q=5), f(4, 1, p=2, q=5), f(4, 3, r=2, q=5)]
assert mock_function.mock_calls == [call(4, 3, p=2, q=5), call(4, 1, p=2, q=5), call(4, 3, r=2, q=5)]
assert serde_test.num_read == 0
assert serde_test.num_write == 3

Test [1mMemoized result not reused when distinct args[0m passed.


In [40]:
%%test Memoized result not reused when distinct env
serde_test = SerdeTest()

def make_f(z):
    global serde_test
    @memo(serde=serde_test)
    def f(x):
        return x + z - G
    return f
    
try:
    globals()["G"] = 10
    f1 = make_f(20)
    f1(10)
    f2 = make_f(10)
    f2(10)
    globals()["G"] = 8
    f1(10)
finally:
    del globals()["G"]

assert serde_test.num_write == 3
assert serde_test.num_read == 0

Test [1mMemoized result not reused when distinct env[0m passed.


In [41]:
%%test Memoized result not reused when code changed
serde_test = SerdeTest()

try:
    globals()["G"] = 2
    
    @memo(serde=serde_test)
    def f(x, y):
        return (x + y) * G

    f(3, 4)

    @memo(serde=serde_test)
    def f(x, y):
        return (x - y) * G
    
    f(3, 4)
finally:
    del globals()["G"]

assert serde_test.num_write == 2
assert serde_test.num_read == 0

Test [1mMemoized result not reused when code changed[0m passed.


# Final test results

In [42]:
if __name__ == "__main__":
    summarize_results(suite)

27 passed, [37m0 failed[0m, [37m0 raised an error[0m
