Skip to content

Commit

Permalink
Refactor BUILD files to avoid individually naming Python dependencies.
Browse files Browse the repository at this point in the history
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.

Fix build failure with dangling matplotlib reference.

PiperOrigin-RevId: 465562141
  • Loading branch information
hawkinsp authored and jax authors committed Aug 5, 2022
1 parent f0b6478 commit b865111
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 36 deletions.
11 changes: 4 additions & 7 deletions jax/BUILD
Expand Up @@ -17,15 +17,12 @@
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load(
"//jaxlib:jax.bzl",
"absl_logging_py_deps",
"absl_testing_py_deps",
"jax_extra_deps",
"jax_internal_packages",
"jax_test_util_visibility",
"numpy_py_deps",
"py_deps",
"py_library_providing_imports_info",
"pytype_library",
"scipy_py_deps",
)

licenses(["notice"])
Expand Down Expand Up @@ -73,7 +70,7 @@ py_library(
] + jax_test_util_visibility,
deps = [
":jax",
] + absl_testing_py_deps + numpy_py_deps,
] + py_deps("absl/testing") + py_deps("numpy"),
)

py_library_providing_imports_info(
Expand Down Expand Up @@ -118,7 +115,7 @@ py_library_providing_imports_info(
":enable_jaxlib_build": [":jaxlib_deps"],
"//conditions:default": [],
}) +
numpy_py_deps + scipy_py_deps + jax_extra_deps,
py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
)

py_library(
Expand All @@ -137,7 +134,7 @@ py_library_providing_imports_info(
visibility = ["//visibility:public"],
deps = [
":jax",
] + absl_logging_py_deps + numpy_py_deps,
] + py_deps("absl/logging") + py_deps("numpy"),
)

pytype_library(
Expand Down
5 changes: 2 additions & 3 deletions jax/experimental/jax2tf/BUILD
Expand Up @@ -15,8 +15,7 @@
load(
"//jaxlib:jax.bzl",
"jax2tf_deps",
"numpy_py_deps",
"tensorflow_py_deps",
"py_deps",
)

licenses(["notice"]) # Apache 2
Expand Down Expand Up @@ -44,5 +43,5 @@ py_library(
srcs_version = "PY3",
deps = [
"//jax",
] + numpy_py_deps + tensorflow_py_deps + jax2tf_deps,
] + py_deps("numpy") + py_deps("tensorflow") + jax2tf_deps,
)
4 changes: 2 additions & 2 deletions jax/tools/BUILD
Expand Up @@ -14,7 +14,7 @@

load(
"//jaxlib:jax.bzl",
"tensorflow_py_deps",
"py_deps",
)

licenses(["notice"])
Expand All @@ -39,5 +39,5 @@ py_library(
deps = [
"//jax",
"//jax/experimental/jax2tf",
] + tensorflow_py_deps,
] + py_deps("tensorflow"),
)
15 changes: 7 additions & 8 deletions jaxlib/jax.bzl
Expand Up @@ -36,14 +36,13 @@ jax_internal_packages = []
jax_test_util_visibility = []
loops_visibility = []

absl_logging_py_deps = []
absl_testing_py_deps = []
cloudpickle_py_deps = []
numpy_py_deps = []
pil_py_deps = []
portpicker_py_deps = []
scipy_py_deps = []
tensorflow_py_deps = []
def py_deps(_package):
"""Returns the Bazel deps for Python package `package`."""

# We assume the user has installed all dependencies in their Python environment.
# This indirection exists because in Google's internal build we build
# dependencies from source with Bazel, but that's not something most people would want.
return []

jax_extra_deps = []
jax2tf_deps = []
Expand Down
26 changes: 10 additions & 16 deletions tests/BUILD
Expand Up @@ -14,16 +14,11 @@

load(
"//jaxlib:jax.bzl",
"absl_logging_py_deps",
"cloudpickle_py_deps",
"jax_generate_backend_suites",
"jax_test",
"jax_test_file_visibility",
"pil_py_deps",
"portpicker_py_deps",
"py_deps",
"pytype_library",
"scipy_py_deps",
"tensorflow_py_deps",
)

licenses(["notice"]) # Apache 2
Expand Down Expand Up @@ -53,7 +48,7 @@ jax_test(
name = "array_interoperability_test",
srcs = ["array_interoperability_test.py"],
disable_backends = ["tpu"],
deps = tensorflow_py_deps,
deps = py_deps("tensorflow"),
)

jax_test(
Expand Down Expand Up @@ -111,7 +106,7 @@ py_test(
deps = [
"//jax",
"//jax:test_util",
] + portpicker_py_deps,
] + py_deps("portpicker"),
)

jax_test(
Expand Down Expand Up @@ -155,8 +150,7 @@ jax_test(
},
deps = [
"//jax:experimental_sparse",
"//third_party/py/matplotlib",
],
] + py_deps("matplotlib"),
)

jax_test(
Expand Down Expand Up @@ -241,7 +235,7 @@ jax_test(
"tpu": 10,
"iree": 10,
},
deps = pil_py_deps + tensorflow_py_deps,
deps = py_deps("pil") + py_deps("tensorflow"),
)

jax_test(
Expand Down Expand Up @@ -290,7 +284,7 @@ py_test(
"//jax:test_util",
"//jax/experimental/jax2tf",
"//jax/tools:jax_to_ir",
] + tensorflow_py_deps,
] + py_deps("tensorflow"),
)

jax_test(
Expand Down Expand Up @@ -519,7 +513,7 @@ jax_test(
srcs = ["pickle_test.py"],
deps = [
"//jax:experimental",
] + cloudpickle_py_deps,
] + py_deps("cloudpickle"),
)

jax_test(
Expand Down Expand Up @@ -674,7 +668,7 @@ jax_test(
},
deps = [
"//jax:experimental_sparse",
] + scipy_py_deps,
] + py_deps("scipy"),
)

jax_test(
Expand Down Expand Up @@ -750,7 +744,7 @@ py_test(
deps = [
"//jax",
"//jax:test_util",
] + absl_logging_py_deps,
] + py_deps("absl/logging"),
)

py_test(
Expand Down Expand Up @@ -820,7 +814,7 @@ jax_test(
deps = [
"//jax:experimental_host_callback",
"//jax:ode",
] + tensorflow_py_deps,
] + py_deps("tensorflow"),
)

jax_test(
Expand Down

0 comments on commit b865111

Please sign in to comment.