diff --git a/WORKSPACE b/WORKSPACE index 15fd16df2..2460e4edc 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -56,6 +56,32 @@ load("//test/proto_cross_repo_boundary:repo.bzl", "proto_cross_repo_boundary_rep proto_cross_repo_boundary_repository() +# test sbt testing frameworks +scala_maven_import_external( + name = "org_scalacheck_scalacheck", + artifact = scala_mvn_artifact( + "org.scalacheck:scalacheck:1.14.3", + default_scala_major_version(), + ), + artifact_sha256 = "3cbc95bb615f1a384b8c4406dfc42b225499f08adf7639de11566069e47d44cf", + licenses = ["notice"], # Apache 2.0 + server_urls = [ + "https://repo1.maven.org/maven2/", + "https://mirror.bazel.build/repo1.maven.org/maven2", + ], +) + +scala_maven_import_external( + name = "com_novocode_junit_interface", + artifact = "com.novocode:junit-interface:0.11", + artifact_sha256 = "29e923226a0d10e9142bbd81073ef52f601277001fcf9014389bf0af3dc33dc3", + licenses = ["notice"], # Apache 2.0 + server_urls = [ + "https://repo1.maven.org/maven2/", + "https://mirror.bazel.build/repo1.maven.org/maven2", + ], +) + # test adding a scala jar: jvm_maven_import_external( name = "com_twitter__scalding_date", diff --git a/scala/defs.bzl b/scala/defs.bzl new file mode 100644 index 000000000..03eb6b3ab --- /dev/null +++ b/scala/defs.bzl @@ -0,0 +1,40 @@ +"""Starlark rules for building Scala projects. + +These are the core rules (library, binary, test) under active +development. Their APIs are not guaranteed stable and we anticipate +some breaking changes. + +We do not yet recommend using these APIs for production codebases. Instead, +use the stable rules exported by scala.bzl: + +``` +load( + "@io_bazel_rules_scala//scala:scala.bzl", + "scala_library", + "scala_binary", + "scala_test" +) +``` + +""" + +load( + "@io_bazel_rules_scala//scala/private:rules/scala_binary.bzl", + _make_scala_binary = "make_scala_binary", +) +load( + "@io_bazel_rules_scala//scala/private:rules/scala_library.bzl", + _make_scala_library = "make_scala_library", +) +load( + "@io_bazel_rules_scala//scala/private:rules/unstable_scala_test.bzl", + _make_scala_test = "make_scala_test", +) + +make_scala_library = _make_scala_library +make_scala_binary = _make_scala_binary +make_scala_test = _make_scala_test + +scala_library = _make_scala_library() +scala_binary = _make_scala_binary() +scala_test = _make_scala_test() diff --git a/scala/private/macros/scala_repositories.bzl b/scala/private/macros/scala_repositories.bzl index ad964498e..41e07e49e 100644 --- a/scala/private/macros/scala_repositories.bzl +++ b/scala/private/macros/scala_repositories.bzl @@ -139,6 +139,23 @@ def scala_repositories( fetch_sources = fetch_sources, ) + # used by the experimental scala_test rule + _scala_maven_import_external( + name = "io_bazel_rules_scala_classgraph", + artifact = "io.github.classgraph:classgraph:jar:4.8.60", + artifact_sha256 = "dacf7d7fec4088e674ee98155adbb74f30af2f8b64f8990d37c223d8b9047b72", + licenses = ["notice"], + server_urls = maven_servers, + ) + + _scala_maven_import_external( + name = "io_bazel_rules_scala_test_interface", + artifact = "org.scala-sbt:test-interface:jar:1.0", + artifact_sha256 = "15f70b38bb95f3002fec9aea54030f19bb4ecfbad64c67424b5e5fea09cd749e", + licenses = ["notice"], + server_urls = maven_servers, + ) + if not native.existing_rule("com_google_protobuf"): http_archive( name = "com_google_protobuf", diff --git a/scala/private/phases/phase_collect_jars.bzl b/scala/private/phases/phase_collect_jars.bzl index a353e0ac7..108f766f4 100644 --- a/scala/private/phases/phase_collect_jars.bzl +++ b/scala/private/phases/phase_collect_jars.bzl @@ -18,6 +18,15 @@ def phase_collect_jars_scalatest(ctx, p): ) return _phase_collect_jars_default(ctx, p, args) +def phase_collect_jars_unstable_scala_test(ctx, p): + args = struct( + base_classpath = p.scalac_provider.default_classpath, + extra_runtime_deps = [ + ctx.attr._discover_tests_runner, + ], + ) + return _phase_collect_jars_default(ctx, p, args) + def phase_collect_jars_repl(ctx, p): args = struct( base_classpath = p.scalac_provider.default_repl_classpath, @@ -45,6 +54,21 @@ def phase_collect_jars_common(ctx, p): return _phase_collect_jars_default(ctx, p) def _phase_collect_jars_default(ctx, p, _args = struct()): + extra_deps = [] + extra_runtime_deps = [] + + phase_names = dir(p) + phase_names.remove("to_json") + phase_names.remove("to_proto") + for phase_name in phase_names: + phase = getattr(p, phase_name) + + if hasattr(phase, "extra_deps"): + extra_deps.extend(phase.extra_deps) + + if hasattr(phase, "extra_runtime_deps"): + extra_runtime_deps.extend(phase.extra_runtime_deps) + return _phase_collect_jars( ctx, p, diff --git a/scala/private/phases/phase_discover_tests.bzl b/scala/private/phases/phase_discover_tests.bzl new file mode 100644 index 000000000..225b867b4 --- /dev/null +++ b/scala/private/phases/phase_discover_tests.bzl @@ -0,0 +1,34 @@ +def phase_discover_tests(ctx, p): + worker = ctx.attr._discover_tests_worker + worker_inputs, _, worker_input_manifests = ctx.resolve_command( + tools = [worker], + ) + + output = ctx.actions.declare_file("{}_discovered_tests.bin".format(ctx.label.name)) + + args = ctx.actions.args() + args.set_param_file_format("multiline") + args.use_param_file("@%s", use_always = True) + + args.add(output) + args.add_all(p.compile.files) + args.add("--") + args.add_all(p.collect_jars.transitive_runtime_jars) + + ctx.actions.run( + mnemonic = "DiscoverTests", + inputs = worker_inputs + p.collect_jars.compile_jars.to_list() + p.compile.files.to_list(), + outputs = [output], + executable = worker.files_to_run.executable, + input_manifests = worker_input_manifests, + execution_requirements = {"supports-workers": "1"}, + arguments = [args], + ) + + return struct( + files = depset([output]), + jvm_flags = [ + "-DDiscoveredTestsResult={}".format(output.short_path), + ], + runfiles = depset([output]), + ) diff --git a/scala/private/phases/phase_write_executable.bzl b/scala/private/phases/phase_write_executable.bzl index 3c8c11fbc..18243c115 100644 --- a/scala/private/phases/phase_write_executable.bzl +++ b/scala/private/phases/phase_write_executable.bzl @@ -51,13 +51,24 @@ def phase_write_executable_common(ctx, p): return _phase_write_executable_default(ctx, p) def _phase_write_executable_default(ctx, p, _args = struct()): + jvm_flags = [] + + phase_names = dir(p) + phase_names.remove("to_json") + phase_names.remove("to_proto") + for phase_name in phase_names: + phase = getattr(p, phase_name) + + if hasattr(phase, "jvm_flags"): + jvm_flags.extend(phase.jvm_flags) + return _phase_write_executable( ctx, p, _args.rjars if hasattr(_args, "rjars") else p.compile.rjars, - _args.jvm_flags if hasattr(_args, "jvm_flags") else ctx.attr.jvm_flags, + (_args.jvm_flags if hasattr(_args, "jvm_flags") else ctx.attr.jvm_flags) + jvm_flags, _args.use_jacoco if hasattr(_args, "use_jacoco") else False, - _args.main_class if hasattr(_args, "main_class") else ctx.attr.main_class, + _args.main_class if hasattr(_args, "main_class") else ctx.attr._main_class if hasattr(ctx.attr, "_main_class") else ctx.attr.main_class, ) def _phase_write_executable( diff --git a/scala/private/phases/phases.bzl b/scala/private/phases/phases.bzl index 3f7ff7f06..929a59f7d 100644 --- a/scala/private/phases/phases.bzl +++ b/scala/private/phases/phases.bzl @@ -26,6 +26,7 @@ load( _phase_collect_jars_macro_library = "phase_collect_jars_macro_library", _phase_collect_jars_repl = "phase_collect_jars_repl", _phase_collect_jars_scalatest = "phase_collect_jars_scalatest", + _phase_collect_jars_unstable_scala_test = "phase_collect_jars_unstable_scala_test", ) load( "@io_bazel_rules_scala//scala/private:phases/phase_compile.bzl", @@ -63,6 +64,7 @@ load("@io_bazel_rules_scala//scala/private:phases/phase_declare_executable.bzl", load("@io_bazel_rules_scala//scala/private:phases/phase_merge_jars.bzl", _phase_merge_jars = "phase_merge_jars") load("@io_bazel_rules_scala//scala/private:phases/phase_jvm_flags.bzl", _phase_jvm_flags = "phase_jvm_flags") load("@io_bazel_rules_scala//scala/private:phases/phase_coverage_runfiles.bzl", _phase_coverage_runfiles = "phase_coverage_runfiles") +load("@io_bazel_rules_scala//scala/private:phases/phase_discover_tests.bzl", _phase_discover_tests = "phase_discover_tests") load("@io_bazel_rules_scala//scala/private:phases/phase_scalafmt.bzl", _phase_scalafmt = "phase_scalafmt") # API @@ -112,6 +114,7 @@ phase_java_wrapper_repl = _phase_java_wrapper_repl phase_java_wrapper_common = _phase_java_wrapper_common # collect_jars +phase_collect_jars_unstable_scala_test = _phase_collect_jars_unstable_scala_test phase_collect_jars_scalatest = _phase_collect_jars_scalatest phase_collect_jars_repl = _phase_collect_jars_repl phase_collect_jars_macro_library = _phase_collect_jars_macro_library @@ -136,5 +139,8 @@ phase_runfiles_common = _phase_runfiles_common # default_info phase_default_info = _phase_default_info +# discover_tests +phase_discover_tests = _phase_discover_tests + # scalafmt phase_scalafmt = _phase_scalafmt diff --git a/scala/private/rules/unstable_scala_test.bzl b/scala/private/rules/unstable_scala_test.bzl new file mode 100644 index 000000000..b29dabf37 --- /dev/null +++ b/scala/private/rules/unstable_scala_test.bzl @@ -0,0 +1,98 @@ +"""Rules for writing tests with ScalaTest""" + +load("@bazel_skylib//lib:dicts.bzl", _dicts = "dicts") +load( + "@io_bazel_rules_scala//scala/private:common_attributes.bzl", + "common_attrs", + "implicit_deps", + "launcher_template", +) +load("@io_bazel_rules_scala//scala/private:common.bzl", "sanitize_string_for_usage") +load("@io_bazel_rules_scala//scala/private:common_outputs.bzl", "common_outputs") +load( + "@io_bazel_rules_scala//scala/private:phases/phases.bzl", + "extras_phases", + "phase_collect_jars_unstable_scala_test", + "phase_compile_common", + "phase_coverage_common", + "phase_coverage_runfiles", + "phase_declare_executable", + "phase_default_info", + "phase_dependency_common", + "phase_discover_tests", + "phase_java_wrapper_common", + "phase_merge_jars", + "phase_runfiles_scalatest", + "phase_scalac_provider", + "phase_write_executable_scalatest", + "phase_write_manifest", + "run_phases", +) + +def _scala_test_impl(ctx): + return run_phases( + ctx, + # customizable phases + [ + ("scalac_provider", phase_scalac_provider), + ("write_manifest", phase_write_manifest), + ("dependency", phase_dependency_common), + ("collect_jars", phase_collect_jars_unstable_scala_test), + ("java_wrapper", phase_java_wrapper_common), + ("declare_executable", phase_declare_executable), + # no need to build an ijar for an executable + ("compile", phase_compile_common), + ("coverage", phase_coverage_common), + ("merge_jars", phase_merge_jars), + ("runfiles", phase_runfiles_scalatest), + ("coverage_runfiles", phase_coverage_runfiles), + ("discover_tests", phase_discover_tests), + ("write_executable", phase_write_executable_scalatest), + ("default_info", phase_default_info), + ], + ) + +_scala_test_attrs = { + "_main_class": attr.string( + default = "io.bazel.rules_scala.discover_tests_runner.DiscoverTestsRunner", + ), + "colors": attr.bool(default = True), + "full_stacktraces": attr.bool(default = True), + "jvm_flags": attr.string_list(), + "_jacocorunner": attr.label( + default = Label("@bazel_tools//tools/jdk:JacocoCoverage"), + ), + "_lcov_merger": attr.label( + default = Label("@bazel_tools//tools/test/CoverageOutputGenerator/java/com/google/devtools/coverageoutputgenerator:Main"), + ), + "_discover_tests_worker": attr.label( + default = Label("@io_bazel_rules_scala//src/scala/io/bazel/rules_scala/discover_tests_worker"), + ), + "_discover_tests_runner": attr.label( + default = Label("@io_bazel_rules_scala//src/scala/io/bazel/rules_scala/discover_tests_runner"), + ), +} + +_scala_test_attrs.update(launcher_template) + +_scala_test_attrs.update(implicit_deps) + +_scala_test_attrs.update(common_attrs) + +def make_scala_test(*extras): + return rule( + attrs = _dicts.add( + _scala_test_attrs, + extras_phases(extras), + *[extra["attrs"] for extra in extras if "attrs" in extra] + ), + executable = True, + fragments = ["java"], + outputs = _dicts.add( + common_outputs, + *[extra["outputs"] for extra in extras if "outputs" in extra] + ), + test = True, + toolchains = ["@io_bazel_rules_scala//scala:toolchain_type"], + implementation = _scala_test_impl, + ) diff --git a/src/scala/io/bazel/rules_scala/discover_tests_runner/BUILD b/src/scala/io/bazel/rules_scala/discover_tests_runner/BUILD new file mode 100644 index 000000000..54081261c --- /dev/null +++ b/src/scala/io/bazel/rules_scala/discover_tests_runner/BUILD @@ -0,0 +1,12 @@ +load("//scala:defs.bzl", "scala_library") + +scala_library( + name = "discover_tests_runner", + srcs = ["DiscoverTestsRunner.scala"], + visibility = ["//visibility:public"], + deps = [ + "//external:io_bazel_rules_scala/dependency/com_google_protobuf/protobuf_java", + "//src/scala/io/bazel/rules_scala/discover_tests_worker:discovered_tests_java_proto", + "@io_bazel_rules_scala_test_interface//jar", + ], +) diff --git a/src/scala/io/bazel/rules_scala/discover_tests_runner/DiscoverTestsRunner.scala b/src/scala/io/bazel/rules_scala/discover_tests_runner/DiscoverTestsRunner.scala new file mode 100644 index 000000000..e03734547 --- /dev/null +++ b/src/scala/io/bazel/rules_scala/discover_tests_runner/DiscoverTestsRunner.scala @@ -0,0 +1,106 @@ +package io.bazel.rules_scala.discover_tests_runner + +import io.bazel.rules_scala.discover_tests_worker.DiscoveredTests.Result +import io.bazel.rules_scala.discover_tests_worker.DiscoveredTests.FrameworkDiscovery +import io.bazel.rules_scala.discover_tests_worker.DiscoveredTests.SubclassDiscovery + +import sbt.testing.AnnotatedFingerprint +import sbt.testing.Event +import sbt.testing.EventHandler +import sbt.testing.Fingerprint +import sbt.testing.Framework +import sbt.testing.Logger +import sbt.testing.Runner +import sbt.testing.SubclassFingerprint +import sbt.testing.SuiteSelector +import sbt.testing.Task +import sbt.testing.TaskDef + +import java.io.FileInputStream +import java.nio.file.Paths + +import scala.collection.JavaConverters._ +import scala.annotation.tailrec + +/** + * DiscoverTestsRunner is responsible for running tests discovered by + * the DiscoverTestsWorker. + */ +object DiscoverTestsRunner { + + def main(args: Array[String]): Unit = { + val input = new FileInputStream(Paths.get(sys.props("DiscoveredTestsResult")).toFile) + val result: Result = Result.parseFrom(input) + input.close() + + result.getFrameworkDiscoveriesList.asScala + .foreach(frameworkDiscovery => handleFrameworkDiscovery(frameworkDiscovery, args)) + + sys.exit(0) + } + + def handleFrameworkDiscovery(frameworkDiscovery: FrameworkDiscovery, args: Array[String]): Unit = { + print(s"\n> beginning run of ${frameworkDiscovery.getFramework}\n") + val framework: Framework = Class.forName(frameworkDiscovery.getFramework).newInstance.asInstanceOf[Framework] + val runner: Runner = framework.runner(args, Array.empty, Thread.currentThread.getContextClassLoader) + + val subclassFingerprintMap: Map[(String, Boolean, Boolean), SubclassFingerprint] = + framework.fingerprints.collect { + case fingerprint: SubclassFingerprint => + (fingerprint.superclassName, fingerprint.isModule, fingerprint.requireNoArgConstructor) -> fingerprint + }.toMap + + val annotatedFingerprintMap: Map[(String, Boolean), AnnotatedFingerprint] = + framework.fingerprints.collect { + case fingerprint: AnnotatedFingerprint => + (fingerprint.annotationName, fingerprint.isModule) -> fingerprint + }.toMap + + frameworkDiscovery.getSubclassDiscoveriesList.asScala + .foreach { subclassDiscovery => + val fingerprint = subclassFingerprintMap.get((subclassDiscovery.getSuperclassName, subclassDiscovery.getIsModule, subclassDiscovery.getRequireNoArgConstructor)) + .getOrElse(sys.error(s"Unable to resolve fingerprint instance for $subclassDiscovery")) + + handleTests(runner, fingerprint, subclassDiscovery.getTestsList.asScala.toList) + } + + frameworkDiscovery.getAnnotatedDiscoveriesList.asScala + .foreach { annotatedDiscovery => + val fingerprint = annotatedFingerprintMap.get((annotatedDiscovery.getAnnotationName, annotatedDiscovery.getIsModule)) + .getOrElse(sys.error(s"Unable to resolve fingerprint instance for $annotatedDiscovery")) + + handleTests(runner, fingerprint, annotatedDiscovery.getTestsList.asScala.toList) + } + + print(runner.done()) + print(s"\n< run of ${frameworkDiscovery.getFramework} complete\n") + } + + def handleTests(runner: Runner, fingerprint: Fingerprint, tests: List[String]): Unit = { + val eventHandler: EventHandler = new EventHandler { + def handle(event: Event): Unit = { + //println(s"- $event") + } + } + val loggers: Array[Logger] = Array(new Logger { + def ansiCodesSupported(): Boolean = true + def debug(msg: String): Unit = println(s"debug: $msg") + def error(msg: String): Unit = println(s"error: $msg") + def info(msg: String): Unit = println(s"info: $msg") + def trace(e: Throwable): Unit = e.printStackTrace + def warn(msg: String): Unit = println(s"warn: $msg") + }) + + @tailrec def execute(tasks: List[Task]): Unit = tasks match { + case head :: tail => + execute(head.execute(eventHandler, loggers) ++: tail) + case Nil => + () + } + + execute(runner + .tasks(tests.map(test => new TaskDef(test, fingerprint, true, Array(new SuiteSelector))).toArray).toList) + + } + +} diff --git a/src/scala/io/bazel/rules_scala/discover_tests_worker/BUILD b/src/scala/io/bazel/rules_scala/discover_tests_worker/BUILD new file mode 100644 index 000000000..3e509a775 --- /dev/null +++ b/src/scala/io/bazel/rules_scala/discover_tests_worker/BUILD @@ -0,0 +1,26 @@ +load("//scala:defs.bzl", "scala_binary") + +scala_binary( + name = "discover_tests_worker", + srcs = ["DiscoverTestsWorker.scala"], + main_class = "io.bazel.rules_scala.discover_tests_worker.DiscoverTestsWorker", + visibility = ["//visibility:public"], + deps = [ + ":discovered_tests_java_proto", + "//external:io_bazel_rules_scala/dependency/com_google_protobuf/protobuf_java", + "@io_bazel_rules_scala//src/java/io/bazel/rulesscala/worker", + "@io_bazel_rules_scala_classgraph//jar", + "@io_bazel_rules_scala_test_interface//jar", + ], +) + +proto_library( + name = "discovered_tests_proto", + srcs = ["discovered_tests.proto"], +) + +java_proto_library( + name = "discovered_tests_java_proto", + visibility = ["//visibility:public"], + deps = [":discovered_tests_proto"], +) diff --git a/src/scala/io/bazel/rules_scala/discover_tests_worker/DiscoverTestsWorker.scala b/src/scala/io/bazel/rules_scala/discover_tests_worker/DiscoverTestsWorker.scala new file mode 100644 index 000000000..5ef36d0b8 --- /dev/null +++ b/src/scala/io/bazel/rules_scala/discover_tests_worker/DiscoverTestsWorker.scala @@ -0,0 +1,149 @@ +package io.bazel.rules_scala.discover_tests_worker + +import io.bazel.rulesscala.worker.Worker +import io.bazel.rules_scala.discover_tests_worker.DiscoveredTests.Result +import io.bazel.rules_scala.discover_tests_worker.DiscoveredTests.FrameworkDiscovery +import io.bazel.rules_scala.discover_tests_worker.DiscoveredTests.AnnotatedDiscovery +import io.bazel.rules_scala.discover_tests_worker.DiscoveredTests.SubclassDiscovery + +import io.github.classgraph.ClassGraph +import io.github.classgraph.ClassInfo +import io.github.classgraph.ClassInfoList +import io.github.classgraph.ScanResult + +import sbt.testing.Framework +import sbt.testing.SubclassFingerprint +import sbt.testing.AnnotatedFingerprint + +import java.io.FileOutputStream +import java.net.URLClassLoader +import java.nio.file.Paths + +import scala.collection.JavaConverters._ + +/** + * DiscoverTestsWorker is responsible for scanning jars to indentify + * classes and modules that conform to the SBT testing interface. + * + * Identified tests are written to a protobuf output file so a separate + * test runner can handle test execution. + */ +object DiscoverTestsWorker extends Worker.Interface { + + def main(args: Array[String]): Unit = Worker.workerMain(args, DiscoverTestsWorker) + + def work(args: Array[String]): Unit = { + // argument format: + -- + + val outputFile = Paths.get(args(0)).toFile + val (args0, args1) = args.tail.span(_ != "--") + val testJars = args0.map(f => Paths.get(f).toUri.toURL) + val frameworkJars = args1.tail.map(f => Paths.get(f).toUri.toURL) + + // prep the scanner used to identify testing frameworks + val frameworkClassloader = new URLClassLoader(frameworkJars) + val frameworkScanResult: ScanResult = (new ClassGraph) + .overrideClassLoaders(frameworkClassloader) + .ignoreParentClassLoaders + .enableClassInfo.scan + + // prep the scanner used to find tests + // here we need the full classpath + val testScanResult: ScanResult = (new ClassGraph) + .overrideClassLoaders(new URLClassLoader(testJars ++ frameworkJars)) + .ignoreParentClassLoaders + .enableClassInfo + .enableMethodInfo + .enableAnnotationInfo + .scan + + val resultBuilder: Result.Builder = Result.newBuilder + + // start identifying frameworks and tests + frameworkScanResult + .getClassesImplementing("sbt.testing.Framework").asScala + .foreach(handleFramework(frameworkScanResult, testScanResult, resultBuilder, _)) + + val result: Result = resultBuilder.build + + testScanResult.close() + frameworkScanResult.close() + + val os = new FileOutputStream(outputFile) + result.writeTo(os) + os.close() + } + + private[this] def handleFramework(frameworkScanResult: ScanResult, testScanResult: ScanResult, builder: Result.Builder, framework: ClassInfo): Unit = { + val frameworkInstance = framework.loadClass.newInstance.asInstanceOf[Framework] + + val frameworkDiscoveryBuilder = FrameworkDiscovery.newBuilder.setFramework(framework.getName) + frameworkInstance.fingerprints.foreach { + case sf: SubclassFingerprint => handleSubclassFingerprint(frameworkScanResult, testScanResult, frameworkDiscoveryBuilder, sf) + case af: AnnotatedFingerprint => handleAnnotatedFingerprint(frameworkScanResult, testScanResult, frameworkDiscoveryBuilder, af) + } + builder.addFrameworkDiscoveries(frameworkDiscoveryBuilder.build) + } + + private[this] def handleSubclassFingerprint(frameworkScanResult: ScanResult, testScanResult: ScanResult, builder: FrameworkDiscovery.Builder, fingerprint: SubclassFingerprint): Unit = { + // + // with the ClassGraph API we need to identify tests differently if they're implementing + // an interface instead of a class + // + // this logic is captured as a function so we can call it a few times + val getCandidates: ScanResult => ClassInfoList = + if (frameworkScanResult.getClassInfo(fingerprint.superclassName).isInterface) + _.getClassesImplementing(fingerprint.superclassName) + else + _.getSubclasses(fingerprint.superclassName) + + val candidates: Iterable[ClassInfo] = + getCandidates(testScanResult) + .exclude(getCandidates(frameworkScanResult)) + .asScala + .filter(_.isStandardClass) + + val tests: Iterable[String] = + if (fingerprint.isModule) + candidates + .map(_.getName) + .filter(_.endsWith("$")).map(_.dropRight(1)) + else + candidates + .filter(_.getConstructorInfo.asScala.exists(_.getParameterInfo.isEmpty) == fingerprint.requireNoArgConstructor) + .map(_.getName) + .filterNot(_.endsWith("$")) + + builder.addSubclassDiscoveries( + SubclassDiscovery.newBuilder + .setSuperclassName(fingerprint.superclassName) + .setIsModule(fingerprint.isModule) + .setRequireNoArgConstructor(fingerprint.requireNoArgConstructor) + .addAllTests(tests.asJava) + .build) + } + + private[this] def handleAnnotatedFingerprint(frameworkScanResult: ScanResult, testScanResult: ScanResult, builder: FrameworkDiscovery.Builder, fingerprint: AnnotatedFingerprint): Unit = { + val candidates: Iterable[ClassInfo] = + testScanResult.getClassesWithAnnotation(fingerprint.annotationName) + .union(testScanResult.getClassesWithMethodAnnotation(fingerprint.annotationName)) + .asScala + + // note: "$" is part of Scala's JVM encoding for modules + val tests: Iterable[String] = + if (fingerprint.isModule) + candidates + .map(_.getName) + .filter(_.endsWith("$")).map(_.dropRight(1)) + else + candidates + .map(_.getName) + .filterNot(_.endsWith("$")) + + builder.addAnnotatedDiscoveries( + AnnotatedDiscovery.newBuilder + .setAnnotationName(fingerprint.annotationName) + .setIsModule(fingerprint.isModule) + .addAllTests(tests.asJava) + .build) + } +} diff --git a/src/scala/io/bazel/rules_scala/discover_tests_worker/discovered_tests.proto b/src/scala/io/bazel/rules_scala/discover_tests_worker/discovered_tests.proto new file mode 100644 index 000000000..1c00e2f34 --- /dev/null +++ b/src/scala/io/bazel/rules_scala/discover_tests_worker/discovered_tests.proto @@ -0,0 +1,38 @@ +syntax = "proto3"; + +package io.bazel.rules_scala.discover_tests_worker; + +option java_package = "io.bazel.rules_scala.discover_tests_worker"; + +/* + * The types here are used to list/describe tests conforming + * to the SBT testing interface: + * + * https://github.com/sbt/test-interface/tree/master/src/main/java/sbt/testing + * + * A "result" lists tests associated with any number of testing frameworks + * implementing the SBT testing interface. + */ + +message Result { + repeated FrameworkDiscovery frameworkDiscoveries = 1; +} + +message FrameworkDiscovery { + string framework = 1; + repeated SubclassDiscovery subclassDiscoveries = 2; + repeated AnnotatedDiscovery annotatedDiscoveries = 3; +} + +message SubclassDiscovery { + string superclassName = 1; + bool isModule = 2; + bool requireNoArgConstructor = 3; + repeated string tests = 4; +} + +message AnnotatedDiscovery { + string annotationName = 1; + bool isModule = 2; + repeated string tests = 3; +} \ No newline at end of file diff --git a/test/v2/BUILD b/test/v2/BUILD new file mode 100644 index 000000000..db8398ece --- /dev/null +++ b/test/v2/BUILD @@ -0,0 +1,38 @@ +load( + "//scala:defs.bzl", + "scala_binary", + "scala_library", + "scala_test", +) + +scala_binary( + name = "binary", + srcs = ["binary.scala"], + main_class = "test.v2.Binary", + deps = [":library"], +) + +scala_library( + name = "library", + srcs = ["library.scala"], + deps = [], +) + +scala_test( + name = "test", + srcs = ["test.scala"], + runtime_deps = [ + "@com_novocode_junit_interface", + # this would normally get exported by scalatest if it was set up as a proper dep + # instead of just a jar + "//external:io_bazel_rules_scala/dependency/scala/scala_xml", + # same, but for junit + "@io_bazel_rules_scala_org_hamcrest_hamcrest_core", + ], + deps = [ + ":library", + "//external:io_bazel_rules_scala/dependency/scalatest/scalatest", + "@io_bazel_rules_scala_junit_junit", + "@org_scalacheck_scalacheck", + ], +) diff --git a/test/v2/binary.scala b/test/v2/binary.scala new file mode 100644 index 000000000..56486bffa --- /dev/null +++ b/test/v2/binary.scala @@ -0,0 +1,7 @@ +package test.v2 + +object Binary { + def main(args: Array[String]): Unit = { + println(s"${Library.method1} ${Library.method2}") + } +} diff --git a/test/v2/library.scala b/test/v2/library.scala new file mode 100644 index 000000000..8048a67ca --- /dev/null +++ b/test/v2/library.scala @@ -0,0 +1,6 @@ +package test.v2 + +object Library { + def method1(): String = "hello" + def method2(): String = "world" +} diff --git a/test/v2/test.scala b/test/v2/test.scala new file mode 100644 index 000000000..9aca30054 --- /dev/null +++ b/test/v2/test.scala @@ -0,0 +1,40 @@ +package test.v2 + +import org.scalatest.FunSuite + +import org.scalacheck.Properties +import org.scalacheck.Prop._ + +import org.junit.Test +import org.junit.Assert.assertEquals + +final class TestSuiteClass extends FunSuite { + test("method1") { + assert(Library.method1 == "hello") + } + + test("method2") { + assert(Library.method2 == "world") + } +} + +object TestSuiteObject extends FunSuite { + test("not-supported") { + assert("hello" == "world") + } +} + +final class TestPropertiesClass extends Properties("TestPropertiesClass") { + property("1") = 1 ?= 1 +} + +object TestPropertiesObject extends Properties("TestPropertiesObject") { + property("2") = 2 ?= 2 +} + +final class JUnitTest { + @Test + def testFoo(): Unit = { + assertEquals("Test should pass", true, true) + } +}