Skip to content

Commit

Permalink
Future-proof use of isinstance(x, jnp.ndarray)
Browse files Browse the repository at this point in the history
The next release of JAX will subtly change the instancecheck behavior for jnp.ndarray; this will only affect libraries that make use of jax-internal objects like ad.UndefinedPrimal. This update future-proofs neural_tangents to this change in instancecheck behavior.

PiperOrigin-RevId: 476587404
  • Loading branch information
Jake VanderPlas authored and romanngg committed Oct 18, 2022
1 parent c2e8d07 commit 64931a3
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions neural_tangents/_src/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,14 +1674,19 @@ def _get_dims(
return df_dy_dims_1, df_dy_dims_2, out_dims


def _is_abstract_array(x) -> bool:
return isinstance(x, np.ndarray) or isinstance(
getattr(x, 'aval', None), core.ShapedArray)


def _vmap(f: Callable, in_axes, out_axes, squeeze_out: bool = True) -> Callable:
"""An expand-then-squeeze `vmap` for `f` expecting/returning batch dims."""
in_axes_plus_1 = tree_map(lambda x: x if x in (None, -1) else x + 1, in_axes)

@utils.wraps(f)
def f_vmapped(*args):
args = tree_map(_expand_dims, args, in_axes_plus_1,
is_leaf=lambda x: isinstance(x, np.ndarray))
args = tree_map(
_expand_dims, args, in_axes_plus_1, is_leaf=_is_abstract_array)
out = vmap(f, in_axes, out_axes)(*args)
if squeeze_out:
out_axes_plus_1 = tree_map(
Expand Down

0 comments on commit 64931a3

Please sign in to comment.