Skip to content

Commit

Permalink
Document how to run tests using Bazel.
Browse files Browse the repository at this point in the history
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
  • Loading branch information
hawkinsp committed Jul 6, 2022
1 parent 118db40 commit 1c75eee
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 52 deletions.
84 changes: 52 additions & 32 deletions build/build.py
Expand Up @@ -213,11 +213,13 @@ def get_bazel_version(bazel_path):
return tuple(int(x) for x in match.group(1).split("."))


def write_bazelrc(python_bin_path=None, remote_build=None,
cuda_toolkit_path=None, cudnn_install_path=None,
cuda_version=None, cudnn_version=None, rocm_toolkit_path=None,
cpu=None, cuda_compute_capabilities=None,
rocm_amdgpu_targets=None):
def write_bazelrc(*, python_bin_path, remote_build,
cuda_toolkit_path, cudnn_install_path,
cuda_version, cudnn_version, rocm_toolkit_path,
cpu, cuda_compute_capabilities,
rocm_amdgpu_targets, bazel_options, target_cpu_features,
wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl,
enable_tpu, enable_remote_tpu, enable_rocm):
tf_cuda_paths = []

with open("../.jax_configure.bazelrc", "w") as f:
Expand Down Expand Up @@ -263,6 +265,32 @@ def write_bazelrc(python_bin_path=None, remote_build=None,
else:
f.write("build --distinct_host_configuration=false\n")

for o in bazel_options:
f.write(f"common {o}\n")
if target_cpu_features == "release":
if wheel_cpu == "x86_64":
f.write("build --config=avx_windows\n" if is_windows()
else "build --config=avx_posix\n")
elif target_cpu_features == "native":
if is_windows():
print("--target_cpu_features=native is not supported on Windows; ignoring.")
else:
f.write("build --config=native_arch_posix\n")

if enable_mkl_dnn:
f.write("build --config=mkl_open_source_only\n")
if enable_cuda:
f.write("build --config=cuda\n")
if not enable_nccl:
f.write("build --config=nonccl\n")
if enable_tpu:
f.write("build --config=tpu\n")
if enable_remote_tpu:
f.write("build --//build:enable_remote_tpu=true\n")
if enable_rocm:
f.write("build --config=rocm\n")
if not enable_nccl:
f.write("build --config=nonccl\n")

BANNER = r"""
_ _ __ __
Expand Down Expand Up @@ -362,7 +390,7 @@ def main():
parser,
"remote_build",
default=False,
help_str="Should we build with RBE.")
help_str="Should we build with RBE (Remote Build Environment)?")
parser.add_argument(
"--cuda_path",
default=None,
Expand Down Expand Up @@ -410,6 +438,11 @@ def main():
default=None,
help="CPU platform to target. Default is the same as the host machine. "
"Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.")
add_boolean_argument(
parser,
"configure_only",
default=False,
help_str="If true, writes a .bazelrc file but does not build jaxlib.")
args = parser.parse_args()

if is_windows() and args.enable_cuda:
Expand Down Expand Up @@ -491,38 +524,25 @@ def main():
cpu=args.target_cpu,
cuda_compute_capabilities=args.cuda_compute_capabilities,
rocm_amdgpu_targets=args.rocm_amdgpu_targets,
bazel_options=args.bazel_options,
target_cpu_features=args.target_cpu_features,
wheel_cpu=wheel_cpu,
enable_mkl_dnn=args.enable_mkl_dnn,
enable_cuda=args.enable_cuda,
enable_nccl=args.enable_nccl,
enable_tpu=args.enable_tpu,
enable_remote_tpu=args.enable_remote_tpu,
enable_rocm=args.enable_rocm,
)

print("\nBuilding XLA and installing it in the jaxlib source tree...")
if args.configure_only:
return

config_args = args.bazel_options
if args.target_cpu_features == "release":
if wheel_cpu == "x86_64":
config_args += ["--config=avx_windows" if is_windows()
else "--config=avx_posix"]
elif args.target_cpu_features == "native":
if is_windows():
print("--target_cpu_features=native is not supported on Windows; ignoring.")
else:
config_args += ["--config=native_arch_posix"]
print("\nBuilding XLA and installing it in the jaxlib source tree...")

if args.enable_mkl_dnn:
config_args += ["--config=mkl_open_source_only"]
if args.enable_cuda:
config_args += ["--config=cuda"]
if not args.enable_nccl:
config_args += ["--config=nonccl"]
if args.enable_tpu:
config_args += ["--config=tpu"]
if args.enable_remote_tpu:
config_args += ["--//build:enable_remote_tpu=true"]
if args.enable_rocm:
config_args += ["--config=rocm"]
if not args.enable_nccl:
config_args += ["--config=nonccl"]

command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] + config_args +
["run", "--verbose_failures=true"] +
[":build_wheel", "--",
f"--output_path={output_path}",
f"--cpu={wheel_cpu}"])
Expand Down
68 changes: 51 additions & 17 deletions docs/developer.md
Expand Up @@ -147,24 +147,68 @@ sets up symbolic links from site-packages into the repository.

# Running the tests

To run all the JAX tests, we recommend using `pytest-xdist`, which can run tests in
parallel. First, install `pytest-xdist` and `pytest-benchmark` by running
`pip install -r build/test-requirements.txt`.
There are two supported mechanisms for running the JAX tests, either using Bazel
or using pytest.

## Using Bazel

First, configure the JAX build by running:
```
python build/build.py --configure_only
```

You may pass additional options to `build.py` to configure the build; see the
`jaxlib` build documentation for details.

By default the Bazel build runs the JAX tests using `jaxlib` built form source.
To run JAX tests, run:

```
bazel test //tests/...
```

To use a preinstalled `jaxlib` instead of building `jaxlib` from source, run

```
bazel test --//jax:build_jaxlib=false //tests/...
```


A number of test behaviors can be controlled using environment variables (see
below). Environment variables may be passed to JAX tests using the
`--test_env=FLAG=value` flag to Bazel.

## Using pytest

To run all the JAX tests using `pytest`, we recommend using `pytest-xdist`,
which can run tests in parallel. First, install `pytest-xdist` and
`pytest-benchmark` by running `pip install -r build/test-requirements.txt`.
Then, from the repository root directory run:

```
pytest -n auto tests
```

## Controlling test behavior

JAX generates test cases combinatorially, and you can control the number of
cases that are generated and checked for each test (default is 10). The automated tests
currently use 25:
cases that are generated and checked for each test (default is 10) using the
`JAX_NUM_GENERATED_CASES` environment variable. The automated tests
currently use 25 by default.

For example, one might write
```
# Bazel
bazel test //tests/... --test_env=JAX_NUM_GENERATED_CASES=25`
```
or
```
# pytest
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
```

The automated tests also run the tests with default 64-bit floats and ints:
The automated tests also run the tests with default 64-bit floats and ints
(`JAX_ENABLE_X64`):

```
JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
Expand All @@ -179,7 +223,7 @@ file directly to see more detailed information about the cases being run:
python tests/lax_numpy_test.py --num_generated_cases=5
```

You can skip a few tests known as slow, by passing environment variable
You can skip a few tests known to be slow, by passing environment variable
JAX_SKIP_SLOW_TESTS=1.

To specify a particular set of tests to run from a test file, you can pass a string
Expand All @@ -192,16 +236,6 @@ python tests/lax_numpy_test.py --test_targets="testPad"

The Colab notebooks are tested for errors as part of the documentation build.

Note that to run the full pmap tests on a (multi-core) CPU-only machine, you
can run:

```
pytest tests/pmap_tests.py
```

I.e. don't use the `-n auto` option, since that effectively runs each test on a
single-core worker.

## Doctests
JAX uses pytest in doctest mode to test the code examples within the documentation.
You can run this using
Expand Down
1 change: 1 addition & 0 deletions examples/jax_cpp/BUILD
Expand Up @@ -35,4 +35,5 @@ tf_cc_binary(
"@org_tensorflow//tensorflow/core/platform:logging",
"@org_tensorflow//tensorflow/core/platform:platform_port",
],
tags = ["manual"],
)
23 changes: 20 additions & 3 deletions jax/BUILD
Expand Up @@ -29,10 +29,25 @@ load(
"sharded_jit_visibility",
)

load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")

licenses(["notice"])

package(default_visibility = [":internal"])

bool_flag(
name = "build_jaxlib",
build_setting_default = True,
)


config_setting(
name = "enable_jaxlib_build",
flag_values = {
":build_jaxlib": "True",
},
)

exports_files([
"LICENSE",
"version.py",
Expand Down Expand Up @@ -104,9 +119,11 @@ py_library_providing_imports_info(
],
lib_rule = pytype_library,
visibility = ["//visibility:public"],
deps = [
"//jaxlib",
] + numpy_py_deps + scipy_py_deps + jax_extra_deps,
deps = select({
":enable_jaxlib_build": ["//jaxlib"],
"//conditions:default": [],
}) +
numpy_py_deps + scipy_py_deps + jax_extra_deps,
)

py_library_providing_imports_info(
Expand Down
6 changes: 6 additions & 0 deletions jaxlib/jax.bzl
Expand Up @@ -77,6 +77,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
linkshared = 1,
linkstatic = 1,
deps = deps,
target_compatible_with = ["@platforms//os:windows"],
)

# .def file with all symbols, not usable
Expand All @@ -85,6 +86,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
name = full_def_name,
srcs = [dummy_library_name],
output_group = "def_file",
target_compatible_with = ["@platforms//os:windows"],
)

# filtered def_file, only the needed symbols are included
Expand All @@ -95,6 +97,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
srcs = [full_def_name],
outs = [filtered_def_file],
cmd = """echo 'LIBRARY {}\nEXPORTS ' > $@ && grep '^\\W*mlir' $(location :{}) >> $@""".format(out, full_def_name),
target_compatible_with = ["@platforms//os:windows"],
)

# create the desired library
Expand All @@ -103,6 +106,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
linkshared = 1,
deps = deps,
win_def_file = filtered_def_file,
target_compatible_with = ["@platforms//os:windows"],
)

# however, the created cc_library (a shared library) cannot be correctly
Expand All @@ -112,13 +116,15 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
name = interface_library_file,
srcs = [out],
output_group = "interface_library",
target_compatible_with = ["@platforms//os:windows"],
)

# but this one can be correctly consumed, this is our final product
native.cc_import(
name = name,
interface_library = interface_library_file,
shared_library = out,
target_compatible_with = ["@platforms//os:windows"],
)

def jax_test(
Expand Down

0 comments on commit 1c75eee

Please sign in to comment.