Skip to content

Commit

Permalink
Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separ…
Browse files Browse the repository at this point in the history
…ate build target.

Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration.

PiperOrigin-RevId: 554861284
  • Loading branch information
hawkinsp authored and jax authors committed Aug 8, 2023
1 parent b024e01 commit afd56c1
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Remember to align the itemized text with the first line of an item within a list
for details and for mechanisms to override the default.
* The option `--jax_coordination_service` has been removed. It is now always
`True`.
* `jax.jaxpr_util` has been removed from the public JAX namespace.

## jaxlib 0.4.15

Expand Down
12 changes: 12 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ py_library_providing_imports_info(
":dtypes",
":effects",
":environment_info",
":jaxpr_util",
":lazy_loader",
":mesh",
":mlir",
Expand Down Expand Up @@ -401,6 +402,17 @@ pytype_strict_library(
srcs = ["_src/lazy_loader.py"],
)

pytype_strict_library(
name = "jaxpr_util",
srcs = ["_src/jaxpr_util.py"],
deps = [
":core",
":source_info_util",
":util",
"//jax/_src/lib",
],
)

pytype_strict_library(
name = "mesh",
srcs = ["_src/mesh.py"],
Expand Down
File renamed without changes.
7 changes: 6 additions & 1 deletion tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,14 @@ py_test(
] + py_deps("tensorflow_core"),
)

jax_test(
py_test(
name = "jaxpr_util_test",
srcs = ["jaxpr_util_test.py"],
deps = [
"//jax",
"//jax:jaxpr_util",
"//jax:test_util",
],
)

jax_test(
Expand Down
7 changes: 4 additions & 3 deletions tests/jaxpr_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest

import os
import gzip
import json

from absl.testing import absltest

import jax
from jax import jaxpr_util, jit, make_jaxpr, numpy as jnp
from jax import jit, make_jaxpr, numpy as jnp
from jax._src import jaxpr_util
from jax._src.lib import xla_client
from jax._src import test_util as jtu
from jax import config
Expand Down

0 comments on commit afd56c1

Please sign in to comment.