Skip to content

Commit

Permalink
Increase bazel sharding of GPU tests.
Browse files Browse the repository at this point in the history
Reduces the maximum time for some test shards to avoid flaky timeouts.
  • Loading branch information
hawkinsp committed Jul 11, 2022
1 parent b666f66 commit 64e0b5d
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions tests/BUILD
Expand Up @@ -35,7 +35,7 @@ jax_generate_backend_suites()
jax_test(
name = "api_test",
srcs = ["api_test.py"],
shard_count = 5,
shard_count = 10,
)

jax_test(
Expand All @@ -53,6 +53,9 @@ jax_test(
jax_test(
name = "batching_test",
srcs = ["batching_test.py"],
shard_count = {
"gpu": 5,
},
)

jax_test(
Expand All @@ -66,6 +69,7 @@ jax_test(
srcs = ["core_test.py"],
shard_count = {
"cpu": 5,
"gpu": 10,
},
)

Expand Down Expand Up @@ -157,6 +161,7 @@ jax_test(
srcs = ["xmap_test.py"],
shard_count = {
"cpu": 10,
"gpu": 4,
"tpu": 4,
},
tags = ["multiaccelerator"],
Expand Down Expand Up @@ -210,7 +215,7 @@ jax_test(
srcs = ["image_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"gpu": 20,
"tpu": 10,
"iree": 10,
},
Expand Down Expand Up @@ -276,6 +281,7 @@ jax_test(
srcs = ["jet_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
},
deps = [
"//jax:jet",
Expand All @@ -288,7 +294,7 @@ jax_test(
srcs = ["lax_control_flow_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"gpu": 20,
"tpu": 10,
"iree": 10,
},
Expand Down Expand Up @@ -361,7 +367,7 @@ jax_test(
},
shard_count = {
"cpu": 10,
"gpu": 10,
"gpu": 40,
"tpu": 10,
"iree": 10,
},
Expand Down Expand Up @@ -446,7 +452,7 @@ jax_test(
},
shard_count = {
"cpu": 20,
"gpu": 20,
"gpu": 40,
"tpu": 10,
"iree": 20,
},
Expand Down Expand Up @@ -504,7 +510,7 @@ jax_test(
srcs = ["pmap_test.py"],
shard_count = {
"cpu": 5,
"gpu": 5,
"gpu": 10,
"tpu": 5,
},
tags = ["multiaccelerator"],
Expand Down Expand Up @@ -579,7 +585,7 @@ jax_test(
main = "random_test.py",
shard_count = {
"cpu": 30,
"gpu": 20,
"gpu": 40,
"tpu": 20,
"iree": 20,
},
Expand Down Expand Up @@ -618,6 +624,7 @@ jax_test(
], # Test times out under asan/tsan.
},
shard_count = {
"gpu": 10,
"tpu": 5,
},
)
Expand All @@ -627,7 +634,7 @@ jax_test(
srcs = ["scipy_stats_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"gpu": 20,
"tpu": 10,
"iree": 10,
},
Expand Down Expand Up @@ -655,7 +662,7 @@ jax_test(
},
shard_count = {
"cpu": 10,
"gpu": 20,
"gpu": 40,
"tpu": 10,
"iree": 10,
},
Expand All @@ -668,6 +675,9 @@ jax_test(
name = "sparsify_test",
srcs = ["sparsify_test.py"],
args = ["--jax_bcoo_cusparse_lowering=true"],
shard_count = {
"gpu": 20,
},
deps = [
"//jax:experimental_sparse",
],
Expand Down Expand Up @@ -787,6 +797,9 @@ jax_test(
"tpu", # On TPU we always use outfeed
],
main = "host_callback_test.py",
shard_count = {
"gpu": 5,
},
deps = [
"//jax:experimental",
"//jax:experimental_host_callback",
Expand Down

0 comments on commit 64e0b5d

Please sign in to comment.