Skip to content

Commit

Permalink
Use public import for device_put_p
Browse files Browse the repository at this point in the history
jax._src is a non-public, non-stable import path and should be avoided by downstream packages.

PiperOrigin-RevId: 513364978
  • Loading branch information
Jake VanderPlas authored and romanngg committed Mar 9, 2023
1 parent d225284 commit 65004a5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
6 changes: 2 additions & 4 deletions neural_tangents/_src/utils/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import jax
from jax import lax
from jax.core import JaxprEqn, ShapedArray, Primitive, Jaxpr, Var, AbstractValue, Literal
from jax._src import dispatch as jax_dispatch
from jax.interpreters import ad
import jax.numpy as np
import numpy as onp
Expand Down Expand Up @@ -1111,9 +1110,8 @@ def _squeeze_s(
JACOBIAN_RULES[lax.convert_element_type_p] = _eye_j


device_put_p = jax_dispatch.device_put_p
STRUCTURE_RULES[device_put_p] = _eye_s
JACOBIAN_RULES[device_put_p] = _eye_j
STRUCTURE_RULES[lax.device_put_p] = _eye_s
JACOBIAN_RULES[lax.device_put_p] = _eye_j


copy_p = jax.lax.copy_p
Expand Down
5 changes: 2 additions & 3 deletions tests/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from absl.testing import absltest
import jax
from jax import lax
from jax._src import dispatch as jax_dispatch
from jax.config import config
from jax.core import Primitive
from jax.core import ShapedArray
Expand Down Expand Up @@ -334,7 +333,7 @@ def _concat_shapes(max_n_args: int = 4, *shapes):
'dimensions': d
} for d in more_itertools.powerset(range(len(s)))],

jax_dispatch.device_put_p:
lax.device_put_p:
lambda s, _: [{}], # Test cases generated elsewhere.

lax.pad_p:
Expand Down Expand Up @@ -580,7 +579,7 @@ def _test_primitive(
for params in _UNARY_PRIMITIVES[primitive](shape, dtype)
)
def test_unary(self, primitive: Optional[Primitive], shape, dtype, params):
if primitive == jax_dispatch.device_put_p:
if primitive == lax.device_put_p:
# Can't instantiate devices at test generation time; using subtests.
for device in [None] + jax.devices() + jax.devices('cpu'):
with self.subTest(device=device):
Expand Down

0 comments on commit 65004a5

Please sign in to comment.