Skip to content

Commit

Permalink
Explicitly separate JAX and non-JAX data during Jittable serialization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 458051922
  • Loading branch information
Jake Bruce authored and DistraxDev committed Jul 13, 2022
1 parent fb73118 commit 4359409
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 15 deletions.
45 changes: 35 additions & 10 deletions distrax/_src/utils/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,46 @@ class Jittable(metaclass=abc.ABCMeta):
"""ABC that can be passed as an arg to a jitted fn, with readable state."""

def __new__(cls, *args, **kwargs):
del args, kwargs
try:
registered_cls = jax.tree_util.register_pytree_node_class(cls)
except ValueError:
registered_cls = cls # already registered
instance = super(Jittable, cls).__new__(registered_cls)
instance._args = args
instance._kwargs = kwargs
return instance
registered_cls = cls # Already registered.
return object.__new__(registered_cls)

def tree_flatten(self):
return ((), ((self._args, self._kwargs), self.__dict__))
leaves, treedef = jax.tree_flatten(self.__dict__)
switch = list(map(_is_jax_data, leaves))
children = [leaf if s else None for leaf, s in zip(leaves, switch)]
metadata = [None if s else leaf for leaf, s in zip(leaves, switch)]
return children, (metadata, switch, treedef)

@classmethod
def tree_unflatten(cls, aux_data, _):
(args, kwargs), state_dict = aux_data
obj = cls(*args, **kwargs)
obj.__dict__ = state_dict
def tree_unflatten(cls, aux_data, children):
metadata, switch, treedef = aux_data
leaves = [j if s else p for j, p, s in zip(children, metadata, switch)]
obj = object.__new__(cls)
obj.__dict__ = jax.tree_unflatten(treedef, leaves)
return obj


def _is_jax_data(x):
"""Check whether `x` is an instance of a JAX-compatible type."""
# If it's a tracer, then it's already been converted by JAX.
if isinstance(x, jax.core.Tracer):
return True

# `jax.vmap` replaces vmappable leaves with `object()` during serialization.
if type(x) is object: # pylint: disable=unidiomatic-typecheck
return True

# Primitive types (e.g. shape tuples) are treated as metadata for distrax.
if isinstance(x, (bool, int, float)) or x is None:
return False

# Otherwise, try to make it into a tracer. If it succeeds, then it's JAX data.
try:
jax.xla.abstractify(x)
return True
except TypeError:
return False
68 changes: 63 additions & 5 deletions distrax/_src/utils/jittable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ def get_params(obj):
obj = DummyJittable(jnp.ones((5,)))
np.testing.assert_array_equal(get_params(obj), obj.data['params'])

def test_vmappable(self):
def do_sum(obj):
return obj.data['params'].sum()
obj = DummyJittable(jnp.array([[1, 2, 3], [4, 5, 6]]))

with self.subTest('no vmap'):
np.testing.assert_array_equal(do_sum(obj), obj.data['params'].sum())

with self.subTest('in_axes=0'):
np.testing.assert_array_equal(
jax.vmap(do_sum, in_axes=0)(obj), obj.data['params'].sum(axis=1))

with self.subTest('in_axes=1'):
np.testing.assert_array_equal(
jax.vmap(do_sum, in_axes=1)(obj), obj.data['params'].sum(axis=0))

def test_traceable(self):
@jax.jit
def inner_fn(obj):
Expand All @@ -50,13 +66,55 @@ def loss_fn(params):
obj.data['params'] *= 2 # Modification before passing to jitted fn.
return inner_fn(obj)

params = jnp.ones((5,))
with self.subTest('numpy'):
params = np.ones((5,))
# Both modifications will be traced if data tree is correctly traversed.
grad_expected = params * 2 * 3
grad = jax.grad(loss_fn)(params)
np.testing.assert_array_equal(grad, grad_expected)

with self.subTest('jax.numpy'):
params = jnp.ones((5,))
# Both modifications will be traced if data tree is correctly traversed.
grad_expected = params * 2 * 3
grad = jax.grad(loss_fn)(params)
np.testing.assert_array_equal(grad, grad_expected)

def test_different_jittables_to_compiled_function(self):
@jax.jit
def add_one_to_params(obj):
obj.data['params'] = obj.data['params'] + 1
return obj

with self.subTest('numpy'):
add_one_to_params(DummyJittable(np.zeros((5,))))
add_one_to_params(DummyJittable(np.ones((5,))))

# Both modifications will be traced if data tree is correctly traversed.
grad_expected = params * 2 * 3
grad = jax.grad(loss_fn)(params)
with self.subTest('jax.numpy'):
add_one_to_params(DummyJittable(jnp.zeros((5,))))
add_one_to_params(DummyJittable(jnp.ones((5,))))

def test_modifying_object_data_does_not_leak_tracers(self):
@jax.jit
def add_one_to_params(obj):
obj.data['params'] = obj.data['params'] + 1
return obj

dummy = DummyJittable(jnp.ones((5,)))
dummy_out = add_one_to_params(dummy)
dummy_out.data['params'] -= 1

def test_metadata_modification_statements_are_removed_by_compilation(self):
@jax.jit
def add_char_to_name(obj):
obj.name += '_x'
return obj

np.testing.assert_array_equal(grad, grad_expected)
dummy = DummyJittable(jnp.ones((5,)))
dummy_out = add_char_to_name(dummy)
dummy_out = add_char_to_name(dummy) # `name` change has been compiled out.
dummy_out.name += 'y'
self.assertEqual(dummy_out.name, 'dummy_xy')


if __name__ == '__main__':
Expand Down

0 comments on commit 4359409

Please sign in to comment.