diff --git a/docs/scala_toolchain.md b/docs/scala_toolchain.md
index d0e2d75b3..9bcc2b82c 100644
--- a/docs/scala_toolchain.md
+++ b/docs/scala_toolchain.md
@@ -63,9 +63,6 @@ scala_register_toolchains()
Extra compiler options for this binary to be passed to scalac.
-
- This is overridden by the `scalac_jvm_flags` attribute on individual targets.
-
@@ -75,6 +72,21 @@ scala_register_toolchains()
List of JVM flags to be passed to scalac. For example ["-Xmx5G"] could be passed to control memory usage of Scalac.
+
+ This is overridden by the scalac_jvm_flags attribute on individual targets.
+
+
+
+
+ scala_test_jvm_flags |
+
+ List of strings; optional
+
+ List of JVM flags to be passed to the ScalaTest runner. For example ["-Xmx5G"] could be passed to control memory usage of the ScalaTest runner.
+
+
+ This is overridden by the jvm_flags attribute on individual targets.
+
|
diff --git a/scala/private/rule_impls.bzl b/scala/private/rule_impls.bzl
index 96a96bae4..c92542482 100644
--- a/scala/private/rule_impls.bzl
+++ b/scala/private/rule_impls.bzl
@@ -133,6 +133,13 @@ def _expand_location(ctx, flags):
def _join_path(args, sep = ","):
return sep.join([f.path for f in args])
+# Return the first non-empty arg. If all are empty, return the last.
+def _first_non_empty(*args):
+ for arg in args:
+ if arg:
+ return arg
+ return args[-1]
+
def compile_scala(
ctx,
target_label,
@@ -296,10 +303,10 @@ StatsfileOutput: {statsfile_output}
# scalac_jvm_flags passed in on the target override scalac_jvm_flags passed in on the
# toolchain
- if scalac_jvm_flags:
- final_scalac_jvm_flags = _expand_location(ctx, scalac_jvm_flags)
- else:
- final_scalac_jvm_flags = ctx.toolchains["@io_bazel_rules_scala//scala:toolchain_type"].scalac_jvm_flags
+ final_scalac_jvm_flags = _first_non_empty(
+ scalac_jvm_flags,
+ ctx.toolchains["@io_bazel_rules_scala//scala:toolchain_type"].scalac_jvm_flags
+ )
ctx.actions.run(
inputs = ins,
@@ -319,7 +326,7 @@ StatsfileOutput: {statsfile_output}
# consume the flags on startup.
arguments = [
"--jvm_flag=%s" % f
- for f in final_scalac_jvm_flags
+ for f in _expand_location(ctx, final_scalac_jvm_flags)
] + ["@" + argfile.path],
)
@@ -1202,13 +1209,20 @@ def scala_test_impl(ctx):
])
coverage_runfiles = ctx.files._jacocorunner + ctx.files._lcov_merger + coverage_replacements.values()
+ # jvm_flags passed in on the target override scala_test_jvm_flags passed in on the
+ # toolchain
+ final_jvm_flags = _first_non_empty(
+ ctx.attr.jvm_flags,
+ ctx.toolchains["@io_bazel_rules_scala//scala:toolchain_type"].scala_test_jvm_flags
+ )
+
coverage_runfiles.extend(_write_executable(
ctx = ctx,
executable = executable,
jvm_flags = [
"-DRULES_SCALA_MAIN_WS_NAME=%s" % ctx.workspace_name,
"-DRULES_SCALA_ARGS_FILE=%s" % argsFile.short_path,
- ] + ctx.attr.jvm_flags,
+ ] + _expand_location(ctx, final_jvm_flags),
main_class = ctx.attr.main_class,
rjars = rjars,
use_jacoco = ctx.configuration.coverage_enabled,
diff --git a/scala/scala_toolchain.bzl b/scala/scala_toolchain.bzl
index 0faab9151..f57e23302 100644
--- a/scala/scala_toolchain.bzl
+++ b/scala/scala_toolchain.bzl
@@ -11,6 +11,7 @@ def _scala_toolchain_impl(ctx):
plus_one_deps_mode = ctx.attr.plus_one_deps_mode,
enable_code_coverage_aspect = ctx.attr.enable_code_coverage_aspect,
scalac_jvm_flags = ctx.attr.scalac_jvm_flags,
+ scala_test_jvm_flags = ctx.attr.scala_test_jvm_flags,
)
return [toolchain]
@@ -35,5 +36,6 @@ scala_toolchain = rule(
values = ["off", "on"],
),
"scalac_jvm_flags": attr.string_list(),
+ "scala_test_jvm_flags": attr.string_list(),
},
)
diff --git a/test_expect_failure/scala_test_jvm_flags/BUILD b/test_expect_failure/scala_test_jvm_flags/BUILD
new file mode 100644
index 000000000..f2bb261c2
--- /dev/null
+++ b/test_expect_failure/scala_test_jvm_flags/BUILD
@@ -0,0 +1,43 @@
+load("//scala:scala_toolchain.bzl", "scala_toolchain")
+load("//scala:scala.bzl", "scala_test")
+
+scala_toolchain(
+ name = "failing_toolchain_impl",
+ # This will fail because 1M isn't enough
+ scala_test_jvm_flags = ["-Xmx1M"],
+ visibility = ["//visibility:public"],
+)
+
+scala_toolchain(
+ name = "passing_toolchain_impl",
+ # This will pass because 1G is enough
+ scala_test_jvm_flags = ["-Xmx1G"],
+ visibility = ["//visibility:public"],
+)
+
+toolchain(
+ name = "failing_scala_toolchain",
+ toolchain = "failing_toolchain_impl",
+ toolchain_type = "@io_bazel_rules_scala//scala:toolchain_type",
+ visibility = ["//visibility:public"],
+)
+
+toolchain(
+ name = "passing_scala_toolchain",
+ toolchain = "passing_toolchain_impl",
+ toolchain_type = "@io_bazel_rules_scala//scala:toolchain_type",
+ visibility = ["//visibility:public"],
+)
+
+scala_test(
+ name = "empty_test",
+ srcs = ["EmptyTest.scala"],
+)
+
+scala_test(
+ name = "empty_overriding_test",
+ srcs = ["EmptyTest.scala"],
+ # This overrides the option passed in on the toolchain, and should BUILD, even if
+ # the `failing_scala_toolchain` is used.
+ jvm_flags = ["-Xmx1G"],
+)
diff --git a/test_expect_failure/scala_test_jvm_flags/EmptyTest.scala b/test_expect_failure/scala_test_jvm_flags/EmptyTest.scala
new file mode 100644
index 000000000..d1fbfc7a0
--- /dev/null
+++ b/test_expect_failure/scala_test_jvm_flags/EmptyTest.scala
@@ -0,0 +1,9 @@
+package test_expect_failure.scala_test_jvm_flags
+
+import org.scalatest.FunSuite
+
+class EmptyTest extends FunSuite {
+ test("empty test") {
+ assert(true)
+ }
+}
\ No newline at end of file
diff --git a/test_rules_scala.sh b/test_rules_scala.sh
index f76542e91..9304097cd 100755
--- a/test_rules_scala.sh
+++ b/test_rules_scala.sh
@@ -837,6 +837,18 @@ test_scalac_jvm_flags_on_target_overrides_toolchain_passes() {
bazel build --extra_toolchains="//test_expect_failure/scalac_jvm_opts:failing_scala_toolchain" //test_expect_failure/scalac_jvm_opts:empty_overriding_build
}
+test_scala_test_jvm_flags_from_scala_toolchain_fails() {
+ action_should_fail test --extra_toolchains="//test_expect_failure/scala_test_jvm_flags:failing_scala_toolchain" //test_expect_failure/scala_test_jvm_flags:empty_test
+}
+
+test_scala_test_jvm_flags_from_scala_toolchain_passes() {
+ bazel test --extra_toolchains="//test_expect_failure/scala_test_jvm_flags:passing_scala_toolchain" //test_expect_failure/scala_test_jvm_flags:empty_test
+}
+
+test_scala_test_jvm_flags_on_target_overrides_toolchain_passes() {
+ bazel test --extra_toolchains="//test_expect_failure/scala_test_jvm_flags:failing_scala_toolchain" //test_expect_failure/scala_test_jvm_flags:empty_overriding_test
+}
+
test_unused_dependency_checker_mode_set_in_rule() {
action_should_fail build //test_expect_failure/unused_dependency_checker:failing_build
}
@@ -1118,4 +1130,7 @@ $runner test_coverage_on
$runner scala_pb_library_targets_do_not_have_host_deps
$runner test_scalac_jvm_flags_on_target_overrides_toolchain_passes
$runner test_scalac_jvm_flags_from_scala_toolchain_passes
-$runner test_scalac_jvm_flags_from_scala_toolchain_fails
\ No newline at end of file
+$runner test_scalac_jvm_flags_from_scala_toolchain_fails
+$runner test_scala_test_jvm_flags_on_target_overrides_toolchain_passes
+$runner test_scala_test_jvm_flags_from_scala_toolchain_passes
+$runner test_scala_test_jvm_flags_from_scala_toolchain_fails
\ No newline at end of file