Skip to content

Commit

Permalink
Add weakreaf_lru_cache to prevent caches from pinning jaxprs.
Browse files Browse the repository at this point in the history
To use this cache, the first argument must be some type that is
object identity hashed (like a jaxpr).
  • Loading branch information
pschuh committed Mar 21, 2022
1 parent b4f47c4 commit d0e0da0
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 13 deletions.
56 changes: 54 additions & 2 deletions jax/_src/util.py
Expand Up @@ -16,10 +16,13 @@
import functools
from functools import partial
import itertools as it
from collections import namedtuple
import operator
import types
from typing import (Any, Callable, Iterable, List, Tuple, Generic, TypeVar, Set,
Iterator, Sequence)
import threading
from typing import (Any, Callable, Dict, Iterable, List, Tuple, Generic,
TypeVar, Set, Iterator, Sequence)
import weakref

from absl import logging
import numpy as np
Expand Down Expand Up @@ -216,6 +219,55 @@ def wrapper(*args, **kwargs):

memoize = cache(max_size=None)

_CacheInfo = namedtuple("_CacheInfo", ["hits", "misses", "maxsize", "currsize"])

def weakref_lru_cache(call: Callable, maxsize=2048):
cache: Dict[Any, Any] = {}
hits = misses = 0
lock = threading.Lock()

def remove_key(tctx, args, kwargs, weak_arg):
del cache[(weak_arg, tctx, args, kwargs)]

def wrapped(weak_arg, *args, **kwargs):
nonlocal hits, misses
if config.jax_check_tracer_leaks:
return call(weak_arg, *args, **kwargs)
kwargs_key = tuple(kwargs.items())
tctx = config._trace_context()
k = (weakref.ref(weak_arg,
functools.partial(remove_key, tctx, args, kwargs_key)),
tctx, args, kwargs_key)
with lock:
if k in cache:
hits += 1
result = cache[k]
# del and reinsert to bump key in the insertion order.
del cache[k]
cache[k] = result
return result
misses += 1
result = call(weak_arg, *args, **kwargs)
with lock:
cache[k] = result
while len(cache) > maxsize:
del cache[next(iter(cache))]
return result

def cache_info():
with lock:
return _CacheInfo(hits, misses, maxsize, len(cache))

def cache_clear():
nonlocal hits, misses
with lock:
hits = misses = 0
cache.clear()

wrapped.cache_info = cache_info
wrapped.cache_clear = cache_clear
return wrapped

def prod(xs):
out = 1
for x in xs:
Expand Down
6 changes: 3 additions & 3 deletions jax/core.py
Expand Up @@ -41,8 +41,8 @@

from jax._src import source_info_util
from jax._src.util import (safe_zip, safe_map, curry, prod, tuple_insert,
tuple_delete, cache, as_hashable_function,
HashableFunction)
tuple_delete, as_hashable_function,
HashableFunction, weakref_lru_cache)
import jax._src.pretty_printer as pp

from jax._src import traceback_util
Expand Down Expand Up @@ -2031,7 +2031,7 @@ def do_subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst
return ClosedJaxpr(new_jaxpr, consts)
return new_jaxpr

@cache()
@weakref_lru_cache
def used_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr]):
subst = NameGatheringSubst()
do_subst_axis_names_jaxpr(jaxpr, subst)
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/pjit.py
Expand Up @@ -46,7 +46,7 @@
from jax._src.tree_util import prefix_errors
from jax._src.util import (extend_name_stack, HashableFunction, safe_zip,
wrap_name, wraps, distributed_debug_log,
split_list, cache, tuple_insert)
split_list, cache, tuple_insert, weakref_lru_cache)
xops = xc._xla.ops

class _FromGdaSingleton:
Expand Down Expand Up @@ -609,7 +609,7 @@ def _pjit_call_impl(*args, jaxpr,
return compiled.unsafe_call(*args)
pjit_p.def_impl(_pjit_call_impl)

@cache()
@weakref_lru_cache
def _pjit_lower(
jaxpr: core.ClosedJaxpr,
in_axis_resources: Tuple[CanonicalizedParsedPartitionSpec, ...],
Expand Down
12 changes: 6 additions & 6 deletions jax/interpreters/partial_eval.py
Expand Up @@ -34,8 +34,8 @@
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
tree_leaves)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, cache, OrderedSet,
as_hashable_function)
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache)
from jax.core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
ConcreteArray, raise_to_shaped, Var, Atom, JaxprEqn,
Expand Down Expand Up @@ -743,7 +743,7 @@ def getconstvar(c):
config.jax_enable_checks and core.check_jaxpr(jaxpr)
return jaxpr, const_vals, env_vals

@cache()
@weakref_lru_cache
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
"""Moves the constvars to the start of invars."""
config.jax_enable_checks and core.check_jaxpr(jaxpr)
Expand Down Expand Up @@ -806,7 +806,7 @@ def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate
return _partial_eval_jaxpr(jaxpr, tuple(unknowns), instantiate)

@cache()
@weakref_lru_cache
def _partial_eval_jaxpr(jaxpr, unknowns, instantiate):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))

Expand Down Expand Up @@ -1172,7 +1172,7 @@ def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=
new_jaxpr = _dce_open_jaxpr(closed_jaxpr.jaxpr, tuple(outputs), drop_outputs)
return core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)

@cache()
@weakref_lru_cache
def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr:
# This dead-code elimination is pretty rudimentary, and in particular doesn't
# nontrivially DCE through scan, call, or other higher-order primitives.
Expand All @@ -1192,7 +1192,7 @@ def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False)
new_eqns = new_eqns[::-1]
return Jaxpr(jaxpr.constvars, jaxpr.invars, new_outvars, new_eqns)

@cache()
@weakref_lru_cache
def _drop_vars(jaxpr: Jaxpr, drop_ins: Tuple[bool, ...], drop_outs: Tuple[bool, ...]):
return Jaxpr(jaxpr.constvars,
[v for v, d in zip(jaxpr.invars, drop_ins) if not d],
Expand Down

0 comments on commit d0e0da0

Please sign in to comment.