Skip to content

Commit

Permalink
Ensured that all Pallas GPU tests depend on :pallas_gpu
Browse files Browse the repository at this point in the history
This dependency is added implicitly by Google-internal infra, but we need
it to be explicit for Bazel builds to avoid ImportErrors at lowering time.

PiperOrigin-RevId: 633147268
  • Loading branch information
superbobry authored and jax authors committed May 13, 2024
1 parent ba8480a commit 1c6855a
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ jax_test(
srcs = [
"pallas_test.py",
],
backend_tags = {
"gpu": ["noasan"], # https://github.com/openai/triton/issues/2918
},
config_tags_overrides = {
"gpu_a100_x32": {
"ondemand": False, # Include in presubmit.
Expand All @@ -50,6 +47,7 @@ jax_test(
"gpu_x32",
"gpu_a100",
"gpu_h100",
"gpu_p100",
"gpu_p100_x32",
],
enable_configs = [
Expand All @@ -58,6 +56,7 @@ jax_test(
],
shard_count = 4,
deps = [
"//jax:pallas",
"//jax:pallas_gpu",
] + py_deps("absl/testing") + py_deps("numpy"),
)
Expand All @@ -67,9 +66,6 @@ jax_test(
srcs = [
"gpu_attention_test.py",
],
backend_tags = {
"gpu": ["noasan"], # https://github.com/openai/triton/issues/2918
},
config_tags_overrides = {
"gpu_a100_x32": {
"ondemand": False, # Include in presubmit.
Expand All @@ -83,17 +79,18 @@ jax_test(
"gpu",
"gpu_x32",
"gpu_p100",
"gpu_p100_x32",
"gpu_a100",
"gpu_h100",
],
enable_configs = [
"gpu_a100_x32",
"gpu_p100_x32",
"gpu_h100_x32",
],
shard_count = 1,
deps = [
"//jax:pallas_gpu",
"//jax:pallas",
"//jax:pallas_gpu", # build_cleaner: keep
"//jax:pallas_gpu_ops",
] + py_deps("absl/testing") + py_deps("numpy"),
)
Expand All @@ -103,9 +100,6 @@ jax_test(
srcs = [
"ops_test.py",
],
backend_tags = {
"gpu": ["noasan"], # https://github.com/openai/triton/issues/2918
},
config_tags_overrides = {
"gpu_a100_x32": {
"ondemand": False, # Include in presubmit.
Expand All @@ -128,6 +122,7 @@ jax_test(
],
deps = [
"//jax:pallas",
"//jax:pallas_gpu", # build_cleaner: keep
] + py_deps("absl/testing") + py_deps("numpy"),
)

Expand Down Expand Up @@ -276,6 +271,7 @@ jax_test(
"gpu_x32",
"gpu_a100",
"gpu_h100",
"gpu_p100",
"gpu_p100_x32",
"gpu_pjrt_c_api",
],
Expand All @@ -287,6 +283,7 @@ jax_test(
deps = [
"//jax:internal_export_back_compat_test_data",
"//jax:internal_export_back_compat_test_util",
"//jax:pallas_gpu",
"//jax:pallas",
"//jax:pallas_gpu", # build_cleaner: keep
],
)

0 comments on commit 1c6855a

Please sign in to comment.