Skip to content

Commit

Permalink
[jax2tf] Improve the TF constant sharing
Browse files Browse the repository at this point in the history
Use fewer cache tables for constants: one per top-level converted function,
and a separate table for the gradient.

Bug: #7992
  • Loading branch information
gnecula committed Nov 11, 2021
1 parent 7f3609f commit 9175ed6
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 10 deletions.
22 changes: 16 additions & 6 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,16 @@ def fix_in_ct(in_ct, arg_aval: core.ShapedArray):
@tf.custom_gradient
def converted_fun_flat_with_custom_gradient(*args_flat: TfVal) -> TfVal:
out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat,
name_stack)
name_stack,
fresh_constant_cache=True)
outs, out_avals = util.unzip2(out_with_avals)
return (tuple(outs),
partial(converted_grad_fn, _out_cts_avals=tuple(out_avals)))

out_flat = converted_fun_flat_with_custom_gradient(*args_flat)
else:
out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat,
name_stack)
name_stack, fresh_constant_cache=True)
outs, out_avals = util.unzip2(out_with_avals)
message = ("The jax2tf-converted function does not support gradients. "
"Use `with_gradient` parameter to enable gradients")
Expand Down Expand Up @@ -470,13 +471,16 @@ def _extended_name_stack(extra_name_stack: Optional[str]):
def _interpret_fun(
fun: lu.WrappedFun, in_vals: Sequence[TfVal],
in_avals: Sequence[core.ShapedArray],
extra_name_stack: Optional[str]
extra_name_stack: Optional[str],
fresh_constant_cache: bool = False
) -> Sequence[Tuple[TfVal, core.ShapedArray]]:
try:
prev_constant_cache = _thread_local_state.constant_cache
prev_constant_cache_keys = set(prev_constant_cache.keys()) if prev_constant_cache is not None else set()
# Start a new cache, so that we don't share constants across tf.function
# boundaries.
_thread_local_state.constant_cache = {}
if fresh_constant_cache:
_thread_local_state.constant_cache = {}

with core.new_base_main(TensorFlowTrace) as main: # type: ignore
fun = _interpret_subtrace(fun, main, in_avals)
Expand All @@ -486,6 +490,11 @@ def _interpret_fun(
fun.call_wrapped(*in_vals)
del main
finally:
if prev_constant_cache is not None and not fresh_constant_cache:
newly_added_keys = set(prev_constant_cache.keys()) - prev_constant_cache_keys
# Delete the newly added keys
for k in newly_added_keys:
del prev_constant_cache[k]
_thread_local_state.constant_cache = prev_constant_cache

return tuple(out_vals)
Expand Down Expand Up @@ -614,7 +623,8 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
# collected and reused for a different value, which would create correctness
# issues. We keep the `val` alive by storing in the cache the pair
# `(val, tf_val)`.
if memoize_constants and _thread_local_state.constant_cache is not None:
do_memoize = (memoize_constants and np.shape(val) and _thread_local_state.constant_cache is not None)
if do_memoize:
_, tf_val = _thread_local_state.constant_cache.get(const_key, (None, None))
else:
tf_val = None
Expand All @@ -624,7 +634,7 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
if jax_dtype == dtypes.float0:
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype)
if memoize_constants and _thread_local_state.constant_cache is not None:
if do_memoize:
_thread_local_state.constant_cache[const_key] = (val, tf_val)
return tf_val, jax_dtype

Expand Down
35 changes: 31 additions & 4 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Specific JAX primitive conversion tests are in primitives_test."""

import re
from typing import Dict, Tuple

from absl.testing import absltest
Expand Down Expand Up @@ -753,11 +752,39 @@ def test_shared_constants(self):
def f(x):
return x + const + const + const + const

f_tf_graph = tf.function(jax2tf.convert(f), autograph=False).get_concrete_function(const).graph.as_graph_def()
f_tf_graph_nr_consts = len(re.findall(r'op:\s*"Const"', str(f_tf_graph)))
f_tf_nr_consts = self.CountTfConstants(jax2tf.convert(f), const)
# It seems that there is already a shape constant in the graph, we want to
# make sure our 4 instances of "const" are shared.
self.assertEqual(f_tf_graph_nr_consts, 2)
self.assertEqual(f_tf_nr_consts, 2)

def test_shared_constants_under_cond(self):
# Check that the constants are shared properly in converted functions
# See https://github.com/google/jax/issues/7992.
const = np.arange(16, dtype=np.float32)
x = np.ones((16,), dtype=np.float32)
def f1(x):
return lax.cond(x[0] >= 0., lambda x: x + const, lambda x: x * const, x) + const
def f2(x):
return f1(x) + const # The extra const should not cost anything
f1_nr_consts = self.CountTfConstants(jax2tf.convert(f1), x)
f2_nr_consts = self.CountTfConstants(jax2tf.convert(f2), x)
self.assertEqual(f1_nr_consts, f2_nr_consts)

def test_shared_constants_under_scan(self):
# See https://github.com/google/jax/issues/7992.
const = np.arange(16, dtype=np.float32)
xs = np.ones((8, 16), dtype=np.float32)
def f1(xs):
res, _ = lax.scan(lambda carry, x: (carry + x + const, None),
np.zeros((16,), dtype=np.float32), xs)
return res

def f2(xs):
return f1(xs) + const # The extra const should not be saved

f1_nr_consts = self.CountTfConstants(jax2tf.convert(f1), xs)
f2_nr_consts = self.CountTfConstants(jax2tf.convert(f2), xs)
self.assertEqual(f1_nr_consts, f2_nr_consts)

def test_weak_types(self):
mul = jax.jit(jnp.multiply)
Expand Down
5 changes: 5 additions & 0 deletions jax/experimental/jax2tf/tests/tf_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import contextlib
import dataclasses
import logging
import re
import os

from typing import Any, Callable, List, Optional, Sequence, Tuple
Expand Down Expand Up @@ -421,6 +422,10 @@ def polymorphic_shape_to_tensorspec(poly_shape: str) -> tf.TensorSpec:

return tree_util.tree_multimap(polymorphic_shape_to_tensorspec, polymorphic_shapes)

def CountTfConstants(self, tf_fun: Callable, *args):
f_tf_graph = tf.function(tf_fun, autograph=False).get_concrete_function(*args).graph.as_graph_def()
return len(re.findall("tensor_content", str(f_tf_graph)))

def CheckOpMetadata(self, jax_fun, x,
expected: Sequence[OpMetadataGraph],
include_xla_op_metadata=True):
Expand Down

0 comments on commit 9175ed6

Please sign in to comment.