diff --git a/jax/BUILD b/jax/BUILD index 6ae1664d41df..fe8c36ca41c4 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -17,15 +17,12 @@ load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load( "//jaxlib:jax.bzl", - "absl_logging_py_deps", - "absl_testing_py_deps", "jax_extra_deps", "jax_internal_packages", "jax_test_util_visibility", - "numpy_py_deps", + "py_deps", "py_library_providing_imports_info", "pytype_library", - "scipy_py_deps", ) licenses(["notice"]) @@ -73,7 +70,7 @@ py_library( ] + jax_test_util_visibility, deps = [ ":jax", - ] + absl_testing_py_deps + numpy_py_deps, + ] + py_deps("absl/testing") + py_deps("numpy"), ) py_library_providing_imports_info( @@ -118,7 +115,7 @@ py_library_providing_imports_info( ":enable_jaxlib_build": [":jaxlib_deps"], "//conditions:default": [], }) + - numpy_py_deps + scipy_py_deps + jax_extra_deps, + py_deps("numpy") + py_deps("scipy") + jax_extra_deps, ) py_library( @@ -137,7 +134,7 @@ py_library_providing_imports_info( visibility = ["//visibility:public"], deps = [ ":jax", - ] + absl_logging_py_deps + numpy_py_deps, + ] + py_deps("absl/logging") + py_deps("numpy"), ) pytype_library( diff --git a/jax/experimental/jax2tf/BUILD b/jax/experimental/jax2tf/BUILD index 60325c8b1977..750c85addd99 100644 --- a/jax/experimental/jax2tf/BUILD +++ b/jax/experimental/jax2tf/BUILD @@ -15,8 +15,7 @@ load( "//jaxlib:jax.bzl", "jax2tf_deps", - "numpy_py_deps", - "tensorflow_py_deps", + "py_deps", ) licenses(["notice"]) # Apache 2 @@ -44,5 +43,5 @@ py_library( srcs_version = "PY3", deps = [ "//jax", - ] + numpy_py_deps + tensorflow_py_deps + jax2tf_deps, + ] + py_deps("numpy") + py_deps("tensorflow") + jax2tf_deps, ) diff --git a/jax/tools/BUILD b/jax/tools/BUILD index 34e5c2d8ee74..e5921d293634 100644 --- a/jax/tools/BUILD +++ b/jax/tools/BUILD @@ -14,7 +14,7 @@ load( "//jaxlib:jax.bzl", - "tensorflow_py_deps", + "py_deps", ) licenses(["notice"]) @@ -39,5 +39,5 @@ py_library( deps = [ "//jax", "//jax/experimental/jax2tf", - ] + tensorflow_py_deps, + ] + py_deps("tensorflow"), ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index e4fb819db7ed..f8c6bcfa798e 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -36,14 +36,13 @@ jax_internal_packages = [] jax_test_util_visibility = [] loops_visibility = [] -absl_logging_py_deps = [] -absl_testing_py_deps = [] -cloudpickle_py_deps = [] -numpy_py_deps = [] -pil_py_deps = [] -portpicker_py_deps = [] -scipy_py_deps = [] -tensorflow_py_deps = [] +def py_deps(_package): + """Returns the Bazel deps for Python package `package`.""" + + # We assume the user has installed all dependencies in their Python environment. + # This indirection exists because in Google's internal build we build + # dependencies from source with Bazel, but that's not something most people would want. + return [] jax_extra_deps = [] jax2tf_deps = [] diff --git a/tests/BUILD b/tests/BUILD index 683af96fb45a..9d8ad07fbfce 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -14,16 +14,11 @@ load( "//jaxlib:jax.bzl", - "absl_logging_py_deps", - "cloudpickle_py_deps", "jax_generate_backend_suites", "jax_test", "jax_test_file_visibility", - "pil_py_deps", - "portpicker_py_deps", + "py_deps", "pytype_library", - "scipy_py_deps", - "tensorflow_py_deps", ) licenses(["notice"]) # Apache 2 @@ -53,7 +48,7 @@ jax_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], disable_backends = ["tpu"], - deps = tensorflow_py_deps, + deps = py_deps("tensorflow"), ) jax_test( @@ -111,7 +106,7 @@ py_test( deps = [ "//jax", "//jax:test_util", - ] + portpicker_py_deps, + ] + py_deps("portpicker"), ) jax_test( @@ -155,8 +150,7 @@ jax_test( }, deps = [ "//jax:experimental_sparse", - "//third_party/py/matplotlib", - ], + ] + py_deps("matplotlib"), ) jax_test( @@ -241,7 +235,7 @@ jax_test( "tpu": 10, "iree": 10, }, - deps = pil_py_deps + tensorflow_py_deps, + deps = py_deps("pil") + py_deps("tensorflow"), ) jax_test( @@ -290,7 +284,7 @@ py_test( "//jax:test_util", "//jax/experimental/jax2tf", "//jax/tools:jax_to_ir", - ] + tensorflow_py_deps, + ] + py_deps("tensorflow"), ) jax_test( @@ -519,7 +513,7 @@ jax_test( srcs = ["pickle_test.py"], deps = [ "//jax:experimental", - ] + cloudpickle_py_deps, + ] + py_deps("cloudpickle"), ) jax_test( @@ -674,7 +668,7 @@ jax_test( }, deps = [ "//jax:experimental_sparse", - ] + scipy_py_deps, + ] + py_deps("scipy"), ) jax_test( @@ -750,7 +744,7 @@ py_test( deps = [ "//jax", "//jax:test_util", - ] + absl_logging_py_deps, + ] + py_deps("absl/logging"), ) py_test( @@ -820,7 +814,7 @@ jax_test( deps = [ "//jax:experimental_host_callback", "//jax:ode", - ] + tensorflow_py_deps, + ] + py_deps("tensorflow"), ) jax_test(