In [1]:
%load_ext pycodestyle_magic
%flake8_on --max_line_length 120 --ignore W293,E302

# Memoization to disk

These tools provide functions to facilitate the memoization of certain computations, when the cost of their execution exceeds that of their storage and serialization.

The current approach means to separate the result storage from the memoization checking.

In [2]:
import notebooks_as_modules

In [3]:
from jupytest import Suite, Report, Magic, summarize_results, fail, assert_, eq, diff, not_, same
from unittest.mock import patch, MagicMock, call

In [4]:
suite = Suite()
if __name__ == "__main__":
    suite |= Report()
    suite |= Magic()

## *Serde* -- Serializer-deserializer

Most results will be picklable to disk; let's still make the serde a moving part, in case we should meet results that would not.

If results are to be stored to disk, the serde is also responsible for deciding where. Let's make this class-level functionality.

In [5]:
from growing import growing
import gzip
import io
import os
from os.path import realpath, join
import pickle
from typing import Any


@growing
class Serde:

    DIR_STORE = realpath(".")

    @classmethod
    def path_to_result(cls, sig: str) -> os.PathLike:
        return join(cls.DIR_STORE, sig)

    def exists(self, sig: str) -> bool:
        return os.access(self.path_to_result(sig), os.R_OK)

    def read(self, sig: str) -> Any:
        with self._open(sig, "rb") as file:
            return self.from_file(file)
        
    def from_file(self, file: io.RawIOBase) -> Any:
        return pickle.load(file)
    
    def write(self, sig: str, obj: Any) -> Any:
        with self._open(sig, "wb") as file:
            self.to_file(file, obj)
        return obj

    def to_file(self, file: io.RawIOBase, obj: Any) -> None:
        pickle.dump(obj, file)

    def _open(self, sig: str, mode: str) -> io.RawIOBase:
        return gzip.open(self.path_to_result(sig), mode)

In [6]:
%%test Results will live in the store directory
from os.path import dirname
for sig in ["some_result", "some/deeper/result"]:
    path = Serde.path_to_result(sig)
    while path:
        if path == Serde.DIR_STORE:
            break
        path = dirname(path)
    else:
        fail(f"Did not get the storage directory anywhere along {path_to_result(sig)}")

Test [1mResults will live in the store directory[0m passed.


In [7]:
%%test Serializing a result
with patch("gzip.open") as mock:
    assert_(eq, expected="asdf", actual=Serde().write("some_result", "asdf"))
    mock.assert_called_once_with(Serde.path_to_result("some_result"), "wb")
    mock.return_value.__enter__.return_value.write.assert_called_once_with(pickle.dumps("asdf"))

Test [1mSerializing a result[0m passed.


In [8]:
%%test Deserializing a result
with patch("gzip.open") as mock_open, patch("pickle.load", return_value="qwerty") as mock_load:
    assert_(eq, expected="qwerty", actual=Serde().read("known_result"))
    mock_load.assert_called_once_with(mock_open.return_value.__enter__.return_value)

Test [1mDeserializing a result[0m passed.


### Result store

The store is essentially a directory where pickled results live.

In [9]:
import os
from os.path import realpath

@Serde.classmethod
def set_store_directory(cls, path: os.PathLike) -> None:
    cls.DIR_STORE = realpath(path)

The following allows using a certain result storage place only temporarily.

In [10]:
from contextlib import contextmanager
from typing import ContextManager


@Serde.classmethod(wrapped_in=[contextmanager])
def storing_in_directory(cls, path: os.PathLike) -> ContextManager[None]:
    dir_orig = cls.DIR_STORE
    try:
        cls.set_store_directory(path)
        yield
    finally:
        cls.set_store_directory(dir_orig)

In [11]:
%%test Temporary switch of the storing directory
dir_at_first = Serde.DIR_STORE
with Serde.storing_in_directory(".."):
    assert_(eq, expected=realpath(".."), actual=Serde.DIR_STORE)
assert_(eq, actual=Serde.DIR_STORE, expected=dir_at_first)

Test [1mTemporary switch of the storing directory[0m passed.


## Function call signatures

Signatures should incorporate as much as possible from the artifacts of the computation, so that any change among these will fail to reuse a result unduly. We shall take into account:

1. The input parameter names and values: we would consider their most detailed representation as issued by `repr`.
1. The computation's implementation: we will take the source as is, but discard blank lines, thereby providing a measure of robustness.
1. The state of global and closure variables upon function entry.

### Signing the code of a function

In [12]:
from hashlib import sha256
from inspect import getsourcelines
from typing import Callable, Optional


def signature_code(f: Callable, h: Optional[sha256] = None) -> str:
    h = h or sha256()
    try:
        source, _ = getsourcelines(f)
        source_no_blank = b"".join(bytes(line.rstrip(), encoding="utf-8") for line in source if len(line.strip()) > 0)
        h.update(source_no_blank)
    except TypeError:
        h.update(bytes(f.__name__, encoding="utf-8"))
    return h.hexdigest()

In [13]:
def is_sha256(s: str) -> bool:
    import re
    return re.match(r"^[a-f0-9]{64}$", s) is not None

In [14]:
%%test Code signature is a SHA-256 hash
def fn(x):
    return x * x

assert_(is_sha256, signature_code(fn))

Test [1mCode signature is a SHA-256 hash[0m passed.


In [15]:
%%test Code signature is distinct for functions with distinct code
def f(x):
    return x * x

def g(x):
    return x + x

assert_(diff, signature_f=signature_code(f), signature_g=signature_code(g))

Test [1mCode signature is distinct for functions with distinct code[0m passed.


In [16]:
%%test Code signature does not change for two functions with the same code
from inspect import getsource

def make_fn(z):
    def f(x):
        y = x * x
        return y * y + 0.5 * y - z
    return f

f1 = make_fn(1)
f2 = make_fn(2)
assert_(not_(same), f1=f1, f2=f2)
assert_(eq, source_f1=getsource(f1), source_f2=getsource(f2))
assert_(eq, signature_f1=signature_code(f1), signature_f2=signature_code(f2))

Test [1mCode signature does not change for two functions with the same code[0m passed.


In [17]:
%%test Code signature is not impacted by blank lines
from inspect import getsource

def f(x):
    return x * x

f1 = f

def f(x):
    
    return x * x

f2 = f
assert_(not_(same), f1=f1, f2=f2)

s1 = getsource(f1)
s2 = getsource(f2)
assert_(diff, s1=s1, s2=s2)
s2_mod = "\n".join(line for i, line in enumerate(s2.split("\n")) if i in {0, 2}) + "\n"
assert_(eq, s1=s1, s2_mod=s2_mod)

assert_(eq, signature_f1=signature_code(f1), signature_f2=signature_code(f2))

Test [1mCode signature is not impacted by blank lines[0m passed.


In [18]:
%%test For functions that have no code, we sign the function's name
from hashlib import sha256
assert_(eq, actual=signature_code(int), expected=sha256(b"int").hexdigest())

Test [1mFor functions that have no code, we sign the function's name[0m passed.


### Signing arguments of a function

In [19]:
from hashlib import sha256
from typing import Any, Sequence, Mapping, Optional


def normalize_env(env: Mapping[str, Any]) -> Mapping[str, Any]:
    return dict(sorted(env.items()))


def bytes_repr(x: Any) -> bytes:
    return bytes(repr(x), encoding="utf-8")


def signature_args(args: Sequence[Any], kwargs: Mapping[str, Any], h: Optional[sha256] = None) -> str:
    h = h or sha256()
    h.update(bytes_repr(args))
    h.update(bytes_repr(normalize_env(kwargs)))
    return h.hexdigest()

In [20]:
%%test Get a signature for empty argument lists
assert_(is_sha256, signature_args([], {}))

Test [1mGet a signature for empty argument lists[0m passed.


In [21]:
%%test Same signature for same argument lists
assert_(eq, signature_args(["asdf"], {}), signature_args(["asdf"], {}))
assert_(eq, signature_args([], dict(x=552)), signature_args([], dict(x=552)))
assert_(
    eq,
    signature_args(["some/path", (3, "tuple")], dict(asdf=45, qwer=98, zxcv=234)),
    signature_args(["some/path", (3, "tuple")], dict(qwer=98, asdf=45, zxcv=234))
)

Test [1mSame signature for same argument lists[0m passed.


In [22]:
%%test Arg signature distinct for distinct positional arg lists (although identical keyword args)
assert_(diff, signature_args(["asdf"], {}), signature_args([], {}))
assert_(diff, signature_args(["asdf", "qwer"], dict(x=32, y=56)), signature_args(["qwer", "asdf"], dict(y=56, x=32)))

Test [1mArg signature distinct for distinct positional arg lists (although identical keyword args)[0m passed.


In [23]:
%%test Arg signature distinct for distinct keyword arg lists (although identical positional args)
assert_(diff, signature_args([], dict(x=32)), signature_args([], dict(y=32)))
assert_(diff, signature_args([], dict(x=32)), signature_args([], dict(x=33)))
assert_(
    diff,
    signature_args(["asdf", "qwer"], dict(x=12, y=32)),
    signature_args(["asdf", "qwer"], dict(x=12, y=32, z=45))
)

Test [1mArg signature distinct for distinct keyword arg lists (although identical positional args)[0m passed.


### Signature for relevant environment

In [24]:
from inspect import getclosurevars


def signature_env(fn: Callable, h: Optional[sha256] = None) -> str:
    h = h or sha256()
    cv = getclosurevars(fn)
    h.update(bytes_repr(normalize_env(cv.nonlocals)))
    h.update(bytes_repr(normalize_env(cv.globals)))
    return h.hexdigest()

In [25]:
%%test Env signature is SHA256 even when no closure nor global var
def f():
    return "asdf"

cv = getclosurevars(f)
assert_(eq, actual=len(cv.nonlocals), expected=0)
assert_(eq, actual=len(cv.globals), expected=0)
assert_(is_sha256, signature_env(f))

Test [1mEnv signature is SHA256 even when no closure nor global var[0m passed.


In [26]:
%%test Env signature for distinct functions with same closures is the same
def make_f(y):
    def f(x):
        return x + y
    return f
    
f1 = make_f(1)
f2 = make_f(1)
assert_(not_(same), f1=f1, f2=f2)
assert_(diff, f1=f1, f2=f2)
assert_(eq, signature_f1=signature_env(f1), signature_f2=signature_env(f2))

Test [1mEnv signature for distinct functions with same closures is the same[0m passed.


In [27]:
%%test Env signature for distinct functions with same globals is the same
def f():
    return 56 + G

def g():
    return 23 - G
    
try:
    globals()["G"] = 8
    assert_(eq, signature_f=signature_env(f), signature_g=signature_env(g))
finally:
    del globals()["G"]

Test [1mEnv signature for distinct functions with same globals is the same[0m passed.


In [28]:
%%test Env signatures for function with distinct closures are distinct
def make_f(y):
    def f(x):
        return x + y
    return f
    
f1 = make_f(1)
f2 = make_f(2)
assert_(diff, signature_f1=signature_env(f1), signature_f2=signature_env(f2))

Test [1mEnv signatures for function with distinct closures are distinct[0m passed.


In [29]:
%%test Env signatures for same function but with distinct global bindings are distinct
def f():
    return G + 8
    
try:
    globals()["G"] = 10
    sig1 = signature_env(f)
    globals()["G"] = 20
    sig2 = signature_env(f)
    assert_(diff, sig1=sig1, sig2=sig2)
finally:
    del globals()["G"]

Test [1mEnv signatures for same function but with distinct global bindings are distinct[0m passed.


### Signature of a full function call

In [30]:
from typing import Callable, Sequence, Mapping, Any


def signature_call(fn: Callable, args: Sequence[Any], kwargs: Mapping[str, Any]) -> str:
    h = sha256()
    signature_code(fn, h)
    signature_args(args, kwargs, h)
    signature_env(fn, h)
    return h.hexdigest()

In [31]:
%%test Combined signature is a SHA256 hash
assert_(is_sha256,signature_call(lambda x: x, [], {}))

Test [1mCombined signature is a SHA256 hash[0m passed.


In [32]:
%%test Combined signature is distinct from components
def f():
    return "asdf"

assert_(diff, signature_call(f, [], {}), signature_code(f))
assert_(diff, signature_call(f, [], {}), signature_args([], {}))
assert_(diff, signature_call(f, [], {}), signature_env(f))

Test [1mCombined signature is distinct from components[0m passed.


In [33]:
%%test Same combined signature for same code, args and env
def make_f(z):
    def f(x, y=1.0):
        return x * x / y - z ** 2
    return f
    
f1 = make_f(100)
f2 = make_f(100)
assert_(not_(same), f1=f1, f2=f2)
assert_(diff, f1=f1, f2=f2)
from inspect import getclosurevars
for f in [f1, f2]:
    assert_(eq, actual=getclosurevars(f).nonlocals.get("z"), expected=100)
assert_(eq, signature_call(f1, [8], {"y": 2}), signature_call(f2, [8], {"y": 2}))

Test [1mSame combined signature for same code, args and env[0m passed.


In [34]:
%%test Distinct combined signatures for functions with distinct code, but same args and env
def f(x, y):
    return x + y + G

def g(x, y):
    return x * y + G


try:
    globals()["G"] = 90
    assert_(diff, signature_call(f, [3, 4], {}), signature_call(g, [3, 4], {}))
finally:
    del globals()["G"]

Test [1mDistinct combined signatures for functions with distinct code, but same args and env[0m passed.


In [35]:
%%test Distinct combined signatures for function with distinct args, but same code and env
def make_f(z):
    def f(x, **kwargs):
        return x * z - sum(kwargs.values())
    return f
    
f1 = make_f(10)
f2 = make_f(10)
assert_(eq, signature_code(f1), signature_code(f2))
assert_(eq, signature_env(f1), signature_env(f2))
assert_(diff, signature_call(f1, [4], dict(u=8, o=9)), signature_call(f2, [4], dict(r=8, p=9)))

Test [1mDistinct combined signatures for function with distinct args, but same code and env[0m passed.


In [36]:
%%test Distinct combined signatures for function with distinct env, but same code and args
def make_f(z):
    def f(x, y):
        return x + y - z * T
    return f
    
try:
    f = make_f(10)
    globals()["T"] = 1
    sig1 = signature_call(f, [10, 21], {})
    globals()["T"] = 2
    sig2 = signature_call(f, [10, 21], {})
    assert_(diff, sig1=sig1, sig2=sig2)
finally:
    del globals()["T"]

Test [1mDistinct combined signatures for function with distinct env, but same code and args[0m passed.


## Memoization

In [37]:
from dask.delayed import Delayed
from typing import Callable, Any, Optional


def _is_memoizing():
    return True


def memo(fn: Optional[Callable] = None, serde: Optional[Serde] = None) -> Callable:
    serde = serde or Serde()

    def _process(fn: Callable) -> Callable:
        def memoized(*args, **kwargs) -> Any:
            if _is_memoizing():
                sig = signature_call(fn, args, kwargs)
                if serde.exists(sig):
                    return serde.read(sig)
                return serde.write(sig, fn(*args, **kwargs))
            else:
                return fn(*args, **kwargs)
        return memoized

    if fn is None:
        return _process
    return _process(fn)

In [38]:
from typing import Tuple


class SerdeTest(Serde):
    
    def __init__(self) -> None:
        self._results = {}
        self.num_exists = 0
        self.num_read = 0
        self.num_write = 0

    @property
    def num_ops(self) -> Tuple[int, int, int]:
        return (self.num_exists, self.num_read, self.num_write)
        
    def exists(self, sig: str) -> bool:
        self.num_exists += 1
        return sig in self._results
    
    def read(self, sig: str) -> Any:
        self.num_read += 1
        return self._results[sig]
    
    def write(self, sig: str, obj: Any) -> Any:
        self.num_write += 1
        self._results[sig] = obj
        return obj

In [39]:
%%test Successful memoization
mock_function = MagicMock()
serde_test = SerdeTest()

def make_f(z):
    global serde_test
    
    @memo(serde=serde_test)
    def f(x, *args, **kwargs):
        global mock_function
        mock_function(x, *args, **kwargs)
        return x * (sum(args) + sum(kwargs.values())) - z

    return f

f = make_f(1)
y = 0
for _ in range(10):
    y += f(2, 5, 6, p=2, q=3)
mock_function.assert_called_once_with(2, 5, 6, p=2, q=3)
assert_(eq, actual=serde_test.num_ops, expected=(10, 9, 1))

Test [1mSuccessful memoization[0m passed.


In [40]:
%%test Memoized result not reused when distinct args
mock_function = MagicMock()
serde_test = SerdeTest()

@memo(serde=serde_test)
def f(x, y, **kwargs):
    global mock_function
    mock_function(x, y, **kwargs)
    return x * sum(kwargs.values()) - y

results = [f(4, 3, p=2, q=5), f(4, 1, p=2, q=5), f(4, 3, r=2, q=5)]
assert_(eq, actual=mock_function.mock_calls, expected=[call(4, 3, p=2, q=5), call(4, 1, p=2, q=5), call(4, 3, r=2, q=5)])
assert_(eq, actual=serde_test.num_read, expected=0)
assert_(eq, actual=serde_test.num_write, expected=3)

Test [1mMemoized result not reused when distinct args[0m passed.


In [41]:
%%test Memoized result not reused when distinct env
serde_test = SerdeTest()

def make_f(z):
    global serde_test
    @memo(serde=serde_test)
    def f(x):
        return x + z - G
    return f
    
try:
    globals()["G"] = 10
    f1 = make_f(20)
    f1(10)
    f2 = make_f(10)
    f2(10)
    globals()["G"] = 8
    f1(10)
finally:
    del globals()["G"]

assert_(eq, actual=serde_test.num_write, expected=3)
assert_(eq, actual=serde_test.num_read, expected=0)

Test [1mMemoized result not reused when distinct env[0m passed.


In [42]:
%%test Memoized result not reused when code changed
serde_test = SerdeTest()

try:
    globals()["G"] = 2
    
    @memo(serde=serde_test)
    def f(x, y):
        return (x + y) * G

    f(3, 4)

    @memo(serde=serde_test)
    def f(x, y):
        return (x - y) * G
    
    f(3, 4)
finally:
    del globals()["G"]

assert_(eq, actual=serde_test.num_write, expected=2)
assert_(eq, actual=serde_test.num_read, expected=0)

Test [1mMemoized result not reused when code changed[0m passed.


## Suspending result memoization

In [43]:
IS_MEMOIZING: bool = True

    
def _is_memoizing():
    return IS_MEMOIZING

In [44]:
@contextmanager
def suspending_memoization():
    global IS_MEMOIZING
    try:
        IS_MEMOIZING = False
        yield
    finally:
        IS_MEMOIZING = True

In [45]:
%%test Memoization suspension
mock = MagicMock()
serde_test = SerdeTest()

@memo(serde=serde_test)
def f(*args, **kwargs):
    global mock
    mock(*args, **kwargs)
    
f(8)
f(8)
mock.assert_called_once_with(8)
assert_(eq, actual=serde_test.num_read, expected=1)

with suspending_memoization():
    f(8)
    assert_(eq, actual=serde_test.num_read, expected=1)
    assert_(eq, actual=mock.mock_calls, expected=[call(8), call(8)])
    
f(8)
assert_(eq, actual=len(mock.mock_calls), expected=2)
assert_(eq, actual=serde_test.num_read, expected=2)

Test [1mMemoization suspension[0m passed.


# Final test results

In [46]:
if __name__ == "__main__":
    summarize_results(suite)

29 passed, [37m0 failed[0m, [37m0 raised an error[0m
