Skip to content

Commit

Permalink
Remove references to deprecated jax aliases
Browse files Browse the repository at this point in the history
jax version 0.4.7 deprecated several top-level aliases; these have been raising warnings since March 2023, and will soon be removed. They include

- jax.ad -> jax.interpreters.ad
- jax.flatten_fun_nokwargs - jax.api_util.flatten_fun_nokwargs
- jax.partial_eval -> jax.interpreters.partial_eval
- jax.pxla -> jax.interpreters.pxla
- jax.xla -> jax.interpreters.xla

This change renames these references in preparation for them to be removed from JAX

PiperOrigin-RevId: 545986604
  • Loading branch information
Jake VanderPlas authored and jaehlee committed Aug 2, 2023
1 parent 1e82970 commit 4de699a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion neural_tangents/_src/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,7 +1768,8 @@ def _get_primals_out_and_pullback(
cotangents), but collects and returns other quantities.
"""
primals_in_flat, in_tree = tree_flatten(primals_in)
fn_flat, out_tree = jax.flatten_fun_nokwargs(lu.wrap_init(fn), in_tree)
fn_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
lu.wrap_init(fn), in_tree)

# TODO(romann): handle call primitives more gracefully.
with jax.disable_jit():
Expand Down
4 changes: 2 additions & 2 deletions neural_tangents/_src/utils/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,8 +1045,8 @@ def _zeros_like_j(
) -> np.ndarray:
return np.zeros(cts_in.shape + invals[idx].shape, cts_in.dtype) # pytype: disable=unsupported-operands # always-use-return-annotations

STRUCTURE_RULES[jax.ad.zeros_like_p] = _eye_s
JACOBIAN_RULES[jax.ad.zeros_like_p] = _zeros_like_j
STRUCTURE_RULES[jax.interpreters.ad.zeros_like_p] = _eye_s
JACOBIAN_RULES[jax.interpreters.ad.zeros_like_p] = _zeros_like_j


def _transpose_s(
Expand Down

0 comments on commit 4de699a

Please sign in to comment.