## Iteration 1

In [52]:
from inspect import signature
from typing import Callable

from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Tuple, Union

@dataclass(frozen=True)
class FunctionCall:
    func_name: str
    args: tuple
    kwargs: frozenset

    def __repr__(self):
        if len(self.kwargs) == 0:
            if len(self.args) == 1:
                return f"{self.func_name}({self.args[0]})"
            return f"{self.func_name}{self.args}"
        return f"{self.func_name}({self.args}, {self.kwargs})"

ArgsHash = Tuple[Tuple, frozenset]

class CacheGraph:
    def __init__(self):
        self.reset()

    def reset(self):
        self.stack: list[FunctionCall] = [] # what function is currently being called
        self.caches: defaultdict[str, dict[ArgsHash, Any]] = defaultdict(dict) # Results of function calls
        self.graph: defaultdict[FunctionCall, set[FunctionCall]] = defaultdict(set) # Call graph, graph[caller] = [callee1, callee2, ...]
        # Typically aggregated results for a function at a timestep.
        self.stored_results: defaultdict[str, dict[int, Any]] = defaultdict(dict)
        # What is the last function that needs the result of a function? Used to help in clearing the cache
        self.last_needed_by: dict[FunctionCall, FunctionCall] = {}
        # can_clear[caller] = [callee1, callee2, ...] means that caller can clear the cache of callee1 and callee2
        self.can_clear: dict[FunctionCall, list[FunctionCall]] = defaultdict(list)

    def check_if_cached(self, function_call: FunctionCall):
        name_in_cache = function_call.func_name in self.caches
        return name_in_cache and (function_call.args, function_call.kwargs) in self.caches[function_call.func_name]
    
    def optimize(self):
        self.can_clear = defaultdict(list)
        for callee, caller in self.last_needed_by.items():
            self.can_clear[caller].append(callee)

    def __call__(self, storage_func: Union[Callable, None] = None):
        def decorator_factory(func):
            def wrapper(*args, **kwargs):
                frozen_kwargs = frozenset(kwargs.items())
                function_call = FunctionCall(func.__name__, args, frozen_kwargs)
                if self.stack:
                    self.graph[self.stack[-1]].add(function_call)
                    self.last_needed_by[function_call] = self.stack[-1]
                if not self.check_if_cached(function_call):
                    self.stack.append(function_call)
                    result = func(*args, **kwargs)
                    for clearable_call in self.can_clear[function_call]:
                        del self.caches[clearable_call.func_name][(clearable_call.args, clearable_call.kwargs)]
                    self.caches[func.__name__][(args, frozen_kwargs)] = result
                    self._store_result(storage_func, func, args, kwargs, result)
                    self.stack.pop()
                return self.caches[func.__name__][(args, frozen_kwargs)]
            decorator = _Cache(self, wrapper)
            return decorator
        return decorator_factory
    
    def _store_result(self, storage_func: Union[Callable, None], func: Callable, args: tuple, kwargs: dict, raw_result: Any):
        """We might want to store an intermediate result"""
        if storage_func is None:
            return
        # These conditions should not trigger, why we assert and not throw an exception
        assert len(args) == 1 and isinstance(args[0], int)
        assert len(kwargs) == 0
        # store the processed result
        timestep = args[0]
        stored_result = storage_func(raw_result)
        self.stored_results[func.__name__][timestep] = stored_result

    def size(self):
        return sum(len(cache) for cache in self.caches.values())

class _Cache:
    def __init__(self, cache_graph: CacheGraph, func: Callable):
        self.cache = cache_graph.caches[func.__name__]
        self._func = func

    def __setitem__(self, key, value):
        if isinstance(key, int):
            self.cache[((key,), frozenset())] = value
        else:
            self.cache[(key, frozenset())] = value

    def __repr__(self):
        return f"<Cache Function: {self._func.__name__} Size: {len(self.cache)}>"
    
    def __call__(self, *args: Any, **kwds: Any) -> Any:
        return self._func(*args, **kwds)
    
    
def check_if_single_parameter_t(func: Callable):
    sig = signature(func)
    return 't' in sig.parameters and len(sig.parameters) == 1

In [53]:
cg = CacheGraph()
@cg()
def fib(n):
    if n <= 1:
        return n
    return fib(n-1) + fib(n-2)
fib(300)
cg.optimize()
print(f"{cg.can_clear=}")
can_clear = cg.can_clear
cg.reset()
cg.can_clear = can_clear
fib(300)
print(cg.size())
# cg.caches


cg.can_clear=defaultdict(<class 'list'>, {fib(300): [fib(299), fib(298)], fib(299): [fib(297)], fib(298): [fib(296)], fib(297): [fib(295)], fib(296): [fib(294)], fib(295): [fib(293)], fib(294): [fib(292)], fib(293): [fib(291)], fib(292): [fib(290)], fib(291): [fib(289)], fib(290): [fib(288)], fib(289): [fib(287)], fib(288): [fib(286)], fib(287): [fib(285)], fib(286): [fib(284)], fib(285): [fib(283)], fib(284): [fib(282)], fib(283): [fib(281)], fib(282): [fib(280)], fib(281): [fib(279)], fib(280): [fib(278)], fib(279): [fib(277)], fib(278): [fib(276)], fib(277): [fib(275)], fib(276): [fib(274)], fib(275): [fib(273)], fib(274): [fib(272)], fib(273): [fib(271)], fib(272): [fib(270)], fib(271): [fib(269)], fib(270): [fib(268)], fib(269): [fib(267)], fib(268): [fib(266)], fib(267): [fib(265)], fib(266): [fib(264)], fib(265): [fib(263)], fib(264): [fib(262)], fib(263): [fib(261)], fib(262): [fib(260)], fib(261): [fib(259)], fib(260): [fib(258)], fib(259): [fib(257)], fib(258): [fib(256)], fi

In [12]:
cg.caches

defaultdict(dict,
            {'fib': {((1,), frozenset()): 1,
              ((0,), frozenset()): 0,
              ((2,), frozenset()): 1,
              ((3,), frozenset()): 2,
              ((4,), frozenset()): 3,
              ((5,), frozenset()): 5,
              ((6,), frozenset()): 8,
              ((7,), frozenset()): 13,
              ((8,), frozenset()): 21,
              ((9,), frozenset()): 34,
              ((10,), frozenset()): 55,
              ((11,), frozenset()): 89,
              ((12,), frozenset()): 144,
              ((13,), frozenset()): 233,
              ((14,), frozenset()): 377,
              ((15,), frozenset()): 610,
              ((16,), frozenset()): 987,
              ((17,), frozenset()): 1597,
              ((18,), frozenset()): 2584,
              ((19,), frozenset()): 4181,
              ((20,), frozenset()): 6765,
              ((21,), frozenset()): 10946,
              ((22,), frozenset()): 17711,
              ((23,), frozenset()): 28657,
          

## iteration 2

In [None]:
from collections import defaultdict
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Tuple, Union
from inspect import signature

@dataclass(frozen=True)
class FunctionCall:
    func_name: str
    args: tuple
    kwargs: frozenset

ArgsHash = Tuple[Tuple, frozenset]

class CacheGraph:
    def __init__(self):
        self.reset()

    def reset(self):
        self.stack: list[FunctionCall] = [] # what function is currently being called
        self.caches: defaultdict[str, dict[ArgsHash, Any]] = defaultdict(dict) # Results of function calls
        self.graph: defaultdict[FunctionCall, set[FunctionCall]] = defaultdict(set) # Call graph, graph[caller] = [callee1, callee2, ...]
        self.indegrees: defaultdict[FunctionCall, int] = defaultdict(int) # Number of incoming edges to a node
        # Typically aggregated results for a function at a timestep.
        self.stored_results: defaultdict[str, dict[int, Any]] = defaultdict(dict) 

    def check_if_cached(self, function_call: FunctionCall):
        name_in_cache = function_call.func_name in self.caches
        return name_in_cache and (function_call.args, function_call.kwargs) in self.caches[function_call.func_name]
    
    def _store_result(self, storage_func: Union[Callable, None], func: Callable, args: tuple, kwargs: dict, raw_result: Any):
        """We might want to store an intermediate result"""
        if storage_func is None:
            return
        # These conditions should not trigger, why we assert and not throw an exception
        assert len(args) == 1 and isinstance(args[0], int)
        assert len(kwargs) == 0
        # store the processed result
        timestep = args[0]
        stored_result = storage_func(raw_result)
        self.stored_results[func.__name__][timestep] = stored_result

class CacheDecorator:
    def __init__(self, func: Callable, cache_graph: CacheGraph, storage_func: Union[Callable, None] = None):
        self.func = func
        self.graph = cache_graph.graph
        self.indegrees = cache_graph.indegrees.copy()
        self.cached_values = cache_graph.caches[func.__name__]
        self.storage_func = storage_func
        sig = signature(func)
        self.single_param_timesteps = 't' in sig.parameters and len(sig.parameters) == 1
    
    def __call__(self, *args, **kwargs):
        frozen_kwargs = frozenset(kwargs.items())
        function_call = FunctionCall(self.func.__name__, args, frozen_kwargs)
        if self.cache_graph.stack:
            self.cache_graph.graph[self.cache_graph.stack[-1]].add(function_call)
            self.cache_graph.indegrees[function_call] += 1
        if not self.cache_graph.check_if_cached(function_call):
            self.cache_graph.stack.append(function_call)
            result = self.func(*args, **kwargs)
            self.cached_values[(args, frozen_kwargs)] = result
            if self.single_param_timesteps:
                self._store_result(args, result)
            self.cache_graph.stack.pop()
        return self.cached_values[(args, frozen_kwargs)]
    
    def _store_result(self, args: tuple, raw_result: Any):
        """We might want to store an intermediate result"""
        if self.storage_func is None:
            return
        if self.single_param_timesteps:
            timestep = args[0]
            stored_result = self.storage_func(raw_result)
            self.cache_graph.stored_results[self.func.__name__][timestep] = stored_result

class CacheEvictingDecorator:
    def __init__(self, func: Callable, cache_graph: CacheGraph, storage_func: Union[Callable, None] = None):
        self.func = func
        self.cache_graph = cache_graph
        self.cached_values = cache_graph.caches[func.__name__]
        self.storage_func = storage_func
        sig = signature(func)
        self.single_param_timesteps = 't' in sig.parameters and len(sig.parameters) == 1
    
    def __call__(self, *args, **kwargs):
        frozen_kwargs = frozenset(kwargs.items())
        function_call = FunctionCall(self.func.__name__, args, frozen_kwargs)
        if self.cache_graph.stack:
            self.cache_graph.graph[self.cache_graph.stack[-1]].add(function_call)
            self.cache_graph.indegrees[function_call] += 1
        if not self.cache_graph.check_if_cached(function_call):
            self.cache_graph.stack.append(function_call)
            result = self.func(*args, **kwargs)
            self.cached_values[(args, frozen_kwargs)] = result
            self._store_result(self.storage_func, self.func, args, kwargs, result)
            self.cache_graph.stack.pop()
        return self.cached_values[(args, frozen_kwargs)]
    
    def _store_result(self, args: tuple, raw_result: Any):
        """We might want to store an intermediate result"""
        if self.storage_func is None:
            return
        if self.single_param_timesteps:
            timestep = args[0]
            stored_result = self.storage_func(raw_result)
            self.cache_graph.stored_results[self.func.__name__][timestep] = stored_result

In [8]:
# import signature
from inspect import signature
from functools import wraps

class _Cache:
    """Cache provides controllable memoization for model methods"""

    def __init__(self, func):
        self.func = func
        self.param_len = len(signature(func).parameters)
        self.has_one_param = self.param_len == 1
        self._store = dict()
        self.__name__ = "Cache: " + func.__name__

    def __call__(self, *arg):
        if arg in self._store:
            return self._store[arg]
        else:
            result = self.func(*arg)
            self._store[arg] = result
            return result

    def __repr__(self):
        return f"<Cache Function: {self.func.__name__} Size: {len(self._store)}>"
    
    def sum(self):
        """return the sum of all values in the Cache Function"""
        return sum(self._store.values())
    
    @property
    def values(self):
        return list(self._store.values())

In [11]:
def fibonacci(n):
    """hello"""
    if n <= 1:
        return n
    return fibonacci(n-1) + fibonacci(n-2)

cached_fib = _Cache(fibonacci)
cached_fib.__doc__

'Cache provides controllable memoization for model methods'