Skip to content

Commit

Permalink
distrax: migrate from deprecated jax.linear_util to jax.extend.linear…
Browse files Browse the repository at this point in the history
…_util

PiperOrigin-RevId: 577275354
  • Loading branch information
Jake VanderPlas authored and DistraxDev committed Oct 27, 2023
1 parent a6b19ec commit 0fcc1fe
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion distrax/_src/utils/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@
import jax
import jax.numpy as jnp

try:
# jax >= 0.4.16
from jax.extend import linear_util as lu # pylint: disable=g-import-not-at-top
except ImportError:
from jax import linear_util as lu # pylint: disable=g-import-not-at-top


_inverse_registry = {
# unary ops
Expand Down Expand Up @@ -258,7 +264,7 @@ def write(var, val):
# if primitive is an xla_call, get subexpressions and evaluate recursively
call_jaxpr, params = _extract_call_jaxpr(eqn.primitive, params)
if call_jaxpr:
subfuns = [jax.linear_util.wrap_init(
subfuns = [lu.wrap_init(
functools.partial(_interpret_inverse, call_jaxpr, ()))]
prim_inv = eqn.primitive

Expand Down

0 comments on commit 0fcc1fe

Please sign in to comment.