Skip to content

Commit

Permalink
Add support for running JAX tests under Bazel.
Browse files Browse the repository at this point in the history
This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.

To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```

Issue #7323

PiperOrigin-RevId: 458551208
  • Loading branch information
hawkinsp authored and jax authors committed Jul 1, 2022
1 parent 270f73e commit 1fc9afd
Show file tree
Hide file tree
Showing 6 changed files with 1,227 additions and 6 deletions.
239 changes: 239 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# JAX is Autograd and XLA

load(
"//jaxlib:jax.bzl",
"absl_logging_py_deps",
"absl_testing_py_deps",
"jax_extra_deps",
"jax_internal_packages",
"jax_test_util_visibility",
"loops_visibility",
"numpy_py_deps",
"py_library_providing_imports_info",
"pytype_library",
"scipy_py_deps",
"sharded_jit_visibility",
)

licenses(["notice"])

package(default_visibility = [":internal"])

exports_files([
"LICENSE",
"version.py",
])

# Packages that have access to JAX-internal implementation details.
package_group(
name = "internal",
packages = [
"//...",
] + jax_internal_packages,
)

# JAX-private test utilities.
py_library(
# This build target is required in order to use private test utilities in jax._src.test_util,
# and its visibility is intentionally restricted to discourage its use outside JAX itself.
# JAX does provide some public test utilities (see jax/test_util.py);
# these are available in jax.test_util via the standard :jax target.
name = "test_util",
testonly = 1,
srcs = [
"_src/test_util.py",
],
visibility = [
":internal",
] + jax_test_util_visibility,
deps = [
":jax",
] + absl_testing_py_deps + numpy_py_deps,
)

py_library_providing_imports_info(
name = "jax",
srcs = glob(
[
"*.py",
"_src/**/*.py",
"image/**/*.py",
"interpreters/**/*.py",
"lax/**/*.py",
"lib/**/*.py",
"nn/**/*.py",
"numpy/**/*.py",
"ops/**/*.py",
"scipy/**/*.py",
"third_party/**/*.py",
],
exclude = [
"_src/test_util.py",
"*_test.py",
"**/*_test.py",
"interpreters/sharded_jit.py",
],
) + [
# until new parallelism APIs are moved out of experimental
"experimental/maps.py",
"experimental/pjit.py",
"experimental/global_device_array.py",
"experimental/array.py",
"experimental/sharding.py",
"experimental/multihost_utils.py",
# until checkify is moved out of experimental
"experimental/checkify.py",
# to avoid circular dependencies
"experimental/compilation_cache/compilation_cache.py",
"experimental/compilation_cache/gfile_cache.py",
"experimental/compilation_cache/cache_interface.py",
],
lib_rule = pytype_library,
visibility = ["//visibility:public"],
deps = [
"//jaxlib",
] + numpy_py_deps + scipy_py_deps + jax_extra_deps,
)

py_library_providing_imports_info(
name = "experimental",
srcs = glob([
"experimental/*.py",
"example_libraries/*.py",
]),
visibility = ["//visibility:public"],
deps = [
":jax",
":sharded_jit",
] + absl_logging_py_deps + numpy_py_deps,
)

pytype_library(
name = "stax",
srcs = [
"example_libraries/stax.py",
"experimental/stax.py",
],
visibility = ["//visibility:public"],
deps = [":jax"],
)

pytype_library(
name = "experimental_sparse",
srcs = glob([
"experimental/sparse/*.py",
]),
visibility = ["//visibility:public"],
deps = [":jax"],
)

# sharded_jit is deprecated. Please do not add any more projects to the visibility.
pytype_library(
name = "sharded_jit",
srcs = ["interpreters/sharded_jit.py"],
visibility = [
":internal",
] + sharded_jit_visibility,
deps = [":jax"],
)

pytype_library(
name = "optimizers",
srcs = [
"example_libraries/optimizers.py",
"experimental/optimizers.py",
],
visibility = ["//visibility:public"],
deps = [":jax"],
)

pytype_library(
name = "ode",
srcs = ["experimental/ode.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)

# loops is deprecated. Please do not add any more projects to the visibility.
pytype_library(
name = "loops",
srcs = ["experimental/loops.py"],
visibility = [":internal"] + loops_visibility,
deps = [":jax"],
)

pytype_library(
name = "callback",
srcs = ["experimental/callback.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)

# TODO(apaszke): Remove this target
pytype_library(
name = "maps",
srcs = ["experimental/maps.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)

# TODO(apaszke): Remove this target
pytype_library(
name = "pjit",
srcs = ["experimental/pjit.py"],
visibility = ["//visibility:public"],
deps = [
":experimental",
":jax",
],
)

pytype_library(
name = "jet",
srcs = ["experimental/jet.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)

pytype_library(
name = "experimental_host_callback",
srcs = ["experimental/host_callback.py"],
visibility = ["//visibility:public"],
deps = [
":jax",
],
)

pytype_library(
name = "compilation_cache",
srcs = [
"experimental/compilation_cache/compilation_cache.py",
"experimental/compilation_cache/gfile_cache.py",
],
visibility = ["//visibility:public"],
deps = [":jax"],
)

pytype_library(
name = "mesh_utils",
srcs = ["experimental/mesh_utils.py"],
visibility = ["//visibility:public"],
deps = [
":experimental",
":jax",
],
)
1 change: 0 additions & 1 deletion jax/BUILD.bazel

This file was deleted.

48 changes: 48 additions & 0 deletions jax/experimental/jax2tf/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load(
"//jaxlib:jax.bzl",
"jax2tf_deps",
"numpy_py_deps",
"tensorflow_py_deps",
)

licenses(["notice"]) # Apache 2

package(
default_visibility = ["//visibility:private"],
)

py_library(
name = "jax2tf",
srcs = ["__init__.py"],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [":jax2tf_internal"],
)

py_library(
name = "jax2tf_internal",
srcs = [
"call_tf.py",
"impl_no_xla.py",
"jax2tf.py",
"shape_poly.py",
],
srcs_version = "PY3",
deps = [
"//jax",
] + numpy_py_deps + tensorflow_py_deps + jax2tf_deps,
)
14 changes: 9 additions & 5 deletions jax/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load(
"//jaxlib:jax.bzl",
"tensorflow_py_deps",
)

licenses(["notice"])

package(default_visibility = ["//visibility:public"])
Expand All @@ -24,16 +29,15 @@ py_library(
"ignore_for_dep=third_party.py.tensorflow",
],
deps = [
"//third_party/py/jax",
"//jax",
],
)

py_library(
name = "jax_to_ir_with_tensorflow",
srcs = ["jax_to_ir.py"],
deps = [
"//third_party/py/jax",
"//third_party/py/jax/experimental/jax2tf",
"//third_party/py/tensorflow",
],
"//jax",
"//jax/experimental/jax2tf",
] + tensorflow_py_deps,
)
54 changes: 54 additions & 0 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,26 @@ if_rocm_is_configured = _if_rocm_is_configured
if_windows = _if_windows
flatbuffer_cc_library = _flatbuffer_cc_library

jax_internal_packages = []
jax_test_util_visibility = []
loops_visibility = []
sharded_jit_visibility = []

absl_logging_py_deps = []
absl_testing_py_deps = []
cloudpickle_py_deps = []
numpy_py_deps = []
pil_py_deps = []
portpicker_py_deps = []
scipy_py_deps = []
tensorflow_py_deps = []

jax_extra_deps = []
jax2tf_deps = []

def py_library_providing_imports_info(*, name, lib_rule = native.py_library, **kwargs):
lib_rule(name = name, **kwargs)

def py_extension(name, srcs, copts, deps):
pybind_extension(name, srcs = srcs, copts = copts, deps = deps, module_name = name)

Expand Down Expand Up @@ -100,3 +120,37 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
interface_library = interface_library_file,
shared_library = out,
)

def jax_test(
name,
srcs,
args = [],
shard_count = None,
deps = [],
disable_backends = None, # buildifier: disable=unused-variable
backend_tags = {}, # buildifier: disable=unused-variable
disable_configs = None, # buildifier: disable=unused-variable
enable_configs = None, # buildifier: disable=unused-variable
tags = [],
main = None):
if shard_count == None or type(shard_count) == type(0):
shards = shard_count
else:
shards = shard_count.get("cpu", None)
native.py_test(
name = name,
srcs = srcs,
args = args,
deps = [
"//jax",
"//jax:test_util",
] + deps,
shard_count = shards,
tags = tags,
main = main,
)

def jax_generate_backend_suites():
pass

jax_test_file_visibility = []

0 comments on commit 1fc9afd

Please sign in to comment.