Skip to content

Commit

Permalink
Explicitly separate JAX and non-JAX data during Jittable serialization.
Browse files Browse the repository at this point in the history
Previously, all properties of Jittable objects were considered non-JAX metadata for the purposes of JAX serialization. This kept bookkeeping to a minimum, but had several undesirable effects:

1. `jax.tree_map` treated Jittable objects as pytree nodes, but did not recur into their parameters, resulting in the Jittable being invisible to the `tree_map`.
2. Jitted functions would fail when called with a different distribution than the one they were initially compiled for, as the JAX parameters that differ between the instances would be considered different static metadata.
3. Modifying Jittable properties inside a jitted function resulted in a tracer, which would then be added back to the object's metadata on exiting the function, subsequently leaking into the rest of the code.

More detail on these issues is reported in this Github issue: #162

This change addresses the issue by modifying the way that we serialize Jittables (which includes Distributions and Bijectors) to explicitly separate all JAX data in the `self.__dict__` from any metadata such as strings and primitives.

PiperOrigin-RevId: 462347605
  • Loading branch information
Jake Bruce authored and DistraxDev committed Jul 21, 2022
1 parent b1cbbf8 commit 0ecad05
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 15 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,22 @@ print(jitted_kl(dist_0, dist_1))
print(dist_0.kl_divergence(dist_1))
```

##### A note about `vmap` and `pmap`

The serialization logic that enables Distrax objects to be passed as arguments
to jitted functions also enables functions to map over them as data using
`jax.vmap` and `jax.pmap`.

However, ***support for this behavior is experimental and incomplete. Use
caution when applying `jax.vmap` or `jax.pmap` to functions that take Distrax
objects as arguments, or return Distrax objects.***

Simple objects such as `distrax.Categorical` may behave correctly under these
transformations, but more complex objects such as
`distrax.MultivariateNormalDiag` may generate exceptions when used as inputs or
outputs of a `vmap`-ed or `pmap`-ed function.


### Subclassing Distributions and Bijectors

User-defined distributions can be created by subclassing `distrax.Distribution`.
Expand Down
47 changes: 37 additions & 10 deletions distrax/_src/utils/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,48 @@ class Jittable(metaclass=abc.ABCMeta):
"""ABC that can be passed as an arg to a jitted fn, with readable state."""

def __new__(cls, *args, **kwargs):
# Discard the parameters to this function because the constructor is not
# called during serialization: its `__dict__` gets repopulated directly.
del args, kwargs
try:
registered_cls = jax.tree_util.register_pytree_node_class(cls)
except ValueError:
registered_cls = cls # already registered
instance = super(Jittable, cls).__new__(registered_cls)
instance._args = args
instance._kwargs = kwargs
return instance
registered_cls = cls # Already registered.
return object.__new__(registered_cls)

def tree_flatten(self):
return ((), ((self._args, self._kwargs), self.__dict__))
leaves, treedef = jax.tree_flatten(self.__dict__)
switch = list(map(_is_jax_data, leaves))
children = [leaf if s else None for leaf, s in zip(leaves, switch)]
metadata = [None if s else leaf for leaf, s in zip(leaves, switch)]
return children, (metadata, switch, treedef)

@classmethod
def tree_unflatten(cls, aux_data, _):
(args, kwargs), state_dict = aux_data
obj = cls(*args, **kwargs)
obj.__dict__ = state_dict
def tree_unflatten(cls, aux_data, children):
metadata, switch, treedef = aux_data
leaves = [j if s else p for j, p, s in zip(children, metadata, switch)]
obj = object.__new__(cls)
obj.__dict__ = jax.tree_unflatten(treedef, leaves)
return obj


def _is_jax_data(x):
"""Check whether `x` is an instance of a JAX-compatible type."""
# If it's a tracer, then it's already been converted by JAX.
if isinstance(x, jax.core.Tracer):
return True

# `jax.vmap` replaces vmappable leaves with `object()` during serialization.
if type(x) is object: # pylint: disable=unidiomatic-typecheck
return True

# Primitive types (e.g. shape tuples) are treated as metadata for Distrax.
if isinstance(x, (bool, int, float)) or x is None:
return False

# Otherwise, try to make it into a tracer. If it succeeds, then it's JAX data.
try:
jax.xla.abstractify(x)
return True
except TypeError:
return False
68 changes: 63 additions & 5 deletions distrax/_src/utils/jittable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ def get_params(obj):
obj = DummyJittable(jnp.ones((5,)))
np.testing.assert_array_equal(get_params(obj), obj.data['params'])

def test_vmappable(self):
def do_sum(obj):
return obj.data['params'].sum()
obj = DummyJittable(jnp.array([[1, 2, 3], [4, 5, 6]]))

with self.subTest('no vmap'):
np.testing.assert_array_equal(do_sum(obj), obj.data['params'].sum())

with self.subTest('in_axes=0'):
np.testing.assert_array_equal(
jax.vmap(do_sum, in_axes=0)(obj), obj.data['params'].sum(axis=1))

with self.subTest('in_axes=1'):
np.testing.assert_array_equal(
jax.vmap(do_sum, in_axes=1)(obj), obj.data['params'].sum(axis=0))

def test_traceable(self):
@jax.jit
def inner_fn(obj):
Expand All @@ -50,13 +66,55 @@ def loss_fn(params):
obj.data['params'] *= 2 # Modification before passing to jitted fn.
return inner_fn(obj)

params = jnp.ones((5,))
with self.subTest('numpy'):
params = np.ones((5,))
# Both modifications will be traced if data tree is correctly traversed.
grad_expected = params * 2 * 3
grad = jax.grad(loss_fn)(params)
np.testing.assert_array_equal(grad, grad_expected)

with self.subTest('jax.numpy'):
params = jnp.ones((5,))
# Both modifications will be traced if data tree is correctly traversed.
grad_expected = params * 2 * 3
grad = jax.grad(loss_fn)(params)
np.testing.assert_array_equal(grad, grad_expected)

def test_different_jittables_to_compiled_function(self):
@jax.jit
def add_one_to_params(obj):
obj.data['params'] = obj.data['params'] + 1
return obj

with self.subTest('numpy'):
add_one_to_params(DummyJittable(np.zeros((5,))))
add_one_to_params(DummyJittable(np.ones((5,))))

# Both modifications will be traced if data tree is correctly traversed.
grad_expected = params * 2 * 3
grad = jax.grad(loss_fn)(params)
with self.subTest('jax.numpy'):
add_one_to_params(DummyJittable(jnp.zeros((5,))))
add_one_to_params(DummyJittable(jnp.ones((5,))))

def test_modifying_object_data_does_not_leak_tracers(self):
@jax.jit
def add_one_to_params(obj):
obj.data['params'] = obj.data['params'] + 1
return obj

dummy = DummyJittable(jnp.ones((5,)))
dummy_out = add_one_to_params(dummy)
dummy_out.data['params'] -= 1

def test_metadata_modification_statements_are_removed_by_compilation(self):
@jax.jit
def add_char_to_name(obj):
obj.name += '_x'
return obj

np.testing.assert_array_equal(grad, grad_expected)
dummy = DummyJittable(jnp.ones((5,)))
dummy_out = add_char_to_name(dummy)
dummy_out = add_char_to_name(dummy) # `name` change has been compiled out.
dummy_out.name += 'y'
self.assertEqual(dummy_out.name, 'dummy_xy')


if __name__ == '__main__':
Expand Down

0 comments on commit 0ecad05

Please sign in to comment.