Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions jmh/jmh.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,35 @@ load("//scala:scala.bzl", "scala_binary", "scala_library")
def jmh_repositories():
native.maven_jar(
name = "io_bazel_rules_scala_org_openjdk_jmh_jmh_core",
artifact = "org.openjdk.jmh:jmh-core:1.17.4",
sha1 = "126d989b196070a8b3653b5389e602a48fe6bb2f",
artifact = "org.openjdk.jmh:jmh-core:1.20",
sha1 = "5f9f9839bda2332e9acd06ce31ad94afa7d6d447",
)
native.bind(
name = 'io_bazel_rules_scala/dependency/jmh/jmh_core',
actual = '@io_bazel_rules_scala_org_openjdk_jmh_jmh_core//jar',
)
native.maven_jar(
name = "io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_asm",
artifact = "org.openjdk.jmh:jmh-generator-asm:1.17.4",
sha1 = "c85c3d8cfa194872b260e89770d41e2084ce2cb6",
artifact = "org.openjdk.jmh:jmh-generator-asm:1.20",
sha1 = "3c43040e08ae68905657a375e669f11a7352f9db",
)
native.bind(
name = 'io_bazel_rules_scala/dependency/jmh/jmh_generator_asm',
actual = '@io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_asm//jar',
)
native.maven_jar(
name = "io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_reflection",
artifact = "org.openjdk.jmh:jmh-generator-reflection:1.17.4",
sha1 = "f75a7823c9fcf03feed6d74aa44ea61fc70a8439",
artifact = "org.openjdk.jmh:jmh-generator-reflection:1.20",
sha1 = "f2154437b42426a48d5dac0b3df59002f86aed26",
)
native.bind(
name = 'io_bazel_rules_scala/dependency/jmh/jmh_generator_reflection',
actual = '@io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_reflection//jar',
)
native.maven_jar(
name = "io_bazel_rules_scala_org_ows2_asm_asm",
artifact = "org.ow2.asm:asm:5.0.4",
sha1 = "0da08b8cce7bbf903602a25a3a163ae252435795",
artifact = "org.ow2.asm:asm:6.1.1",
sha1 = "264754515362d92acd39e8d40395f6b8dee7bc08",
)
native.bind(
name = "io_bazel_rules_scala/dependency/jmh/org_ows2_asm_asm",
Expand Down Expand Up @@ -78,14 +78,15 @@ def _scala_generate_benchmark(ctx):
outputs = [ctx.outputs.src_jar, ctx.outputs.resource_jar],
inputs = [class_jar] + classpath,
executable = ctx.executable._generator,
arguments = [f.path for f in [class_jar, ctx.outputs.src_jar, ctx.outputs.resource_jar] + classpath],
arguments = [ctx.attr.generator_type] + [f.path for f in [class_jar, ctx.outputs.src_jar, ctx.outputs.resource_jar] + classpath],
progress_message = "Generating benchmark code for %s" % ctx.label,
)

scala_generate_benchmark = rule(
implementation = _scala_generate_benchmark,
attrs = {
"src": attr.label(allow_single_file=True, mandatory=True),
"generator_type": attr.string(default='reflection', mandatory=False),
"_generator": attr.label(executable=True, cfg="host", default=Label("//src/scala/io/bazel/rules_scala/jmh_support:benchmark_generator"))
},
outputs = {
Expand All @@ -98,6 +99,7 @@ def scala_benchmark_jmh(**kw):
name = kw["name"]
deps = kw.get("deps", [])
srcs = kw["srcs"]
generator_type = kw.get("generator_type", "reflection")
lib = "%s_generator" % name
scalacopts = kw.get("scalacopts", [])
main_class = kw.get("main_class", "org.openjdk.jmh.Main")
Expand All @@ -115,7 +117,7 @@ def scala_benchmark_jmh(**kw):
)

codegen = name + "_codegen"
scala_generate_benchmark(name=codegen, src=lib)
scala_generate_benchmark(name=codegen, src=lib, generator_type=generator_type)
compiled_lib = name + "_compiled_benchmark_lib"
scala_library(
name = compiled_lib,
Expand Down
93 changes: 71 additions & 22 deletions src/scala/io/bazel/rules_scala/jmh_support/BenchmarkGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ import java.net.URLClassLoader

import scala.annotation.tailrec
import scala.collection.JavaConverters._

import org.openjdk.jmh.generators.core.{ BenchmarkGenerator => JMHGenerator, FileSystemDestination }
import org.openjdk.jmh.generators.core.{FileSystemDestination, GeneratorSource, BenchmarkGenerator => JMHGenerator}
import org.openjdk.jmh.generators.asm.ASMGeneratorSource
import org.openjdk.jmh.runner.{ Runner, RunnerException }
import org.openjdk.jmh.runner.options.{ Options, OptionsBuilder }

import org.openjdk.jmh.generators.reflection.RFGeneratorSource
import org.openjdk.jmh.runner.{Runner, RunnerException}
import org.openjdk.jmh.runner.options.{Options, OptionsBuilder}
import java.net.URI

import scala.collection.JavaConverters._
import java.nio.file.{Files, FileSystems, Path}
import java.nio.file.{FileSystems, Files, Path, Paths}

import io.bazel.rulesscala.jar.JarCreator

Expand All @@ -27,7 +27,14 @@ import io.bazel.rulesscala.jar.JarCreator
*/
object BenchmarkGenerator {

case class BenchmarkGeneratorArgs(
private sealed trait GeneratorType

private case object ReflectionGenerator extends GeneratorType

private case object AsmGenerator extends GeneratorType

private case class BenchmarkGeneratorArgs(
generatorType: GeneratorType,
inputJar: Path,
resultSourceJar: Path,
resultResourceJar: Path,
Expand All @@ -37,6 +44,7 @@ object BenchmarkGenerator {
def main(argv: Array[String]): Unit = {
val args = parseArgs(argv)
generateJmhBenchmark(
args.generatorType,
args.resultSourceJar,
args.resultResourceJar,
args.inputJar,
Expand All @@ -47,17 +55,18 @@ object BenchmarkGenerator {
private def parseArgs(argv: Array[String]): BenchmarkGeneratorArgs = {
if (argv.length < 3) {
System.err.println(
"Usage: BenchmarkGenerator INPUT_JAR RESULT_JAR RESOURCE_JAR [CLASSPATH_ELEMENT] [CLASSPATH_ELEMENT...]"
"Usage: BenchmarkGenerator GENERATOR_TYPE INPUT_JAR RESULT_JAR RESOURCE_JAR [CLASSPATH_ELEMENT] [CLASSPATH_ELEMENT...]"
)
System.exit(1)
}
val fs = FileSystems.getDefault

BenchmarkGeneratorArgs(
fs.getPath(argv(0)),
if ("asm".equalsIgnoreCase(argv(0))) AsmGenerator else ReflectionGenerator,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we error if it is not asm or reflection?

fs.getPath(argv(1)),
fs.getPath(argv(2)),
argv.slice(3, argv.length).map { s => fs.getPath(s) }.toList
fs.getPath(argv(3)),
argv.slice(4, argv.length).map { s => fs.getPath(s) }.toList
)
}

Expand Down Expand Up @@ -88,13 +97,13 @@ object BenchmarkGenerator {
}

// Courtesy of Doug Tangren (https://groups.google.com/forum/#!topic/simple-build-tool/CYeLHcJjHyA)
private def withClassLoader[A](cp: Seq[Path])(f: => A): A = {
private def withClassLoader[A](cp: Seq[Path])(f: ClassLoader => A): A = {
val originalLoader = Thread.currentThread.getContextClassLoader
val jmhLoader = classOf[JMHGenerator].getClassLoader
val classLoader = new URLClassLoader(cp.map(_.toUri.toURL).toArray, jmhLoader)
try {
Thread.currentThread.setContextClassLoader(classLoader)
f
f(classLoader)
} finally {
Thread.currentThread.setContextClassLoader(originalLoader)
}
Expand All @@ -119,6 +128,7 @@ object BenchmarkGenerator {
}

private def generateJmhBenchmark(
generatorType: GeneratorType,
sourceJarOut: Path,
resourceJarOut: Path,
benchmarkJarPath: Path,
Expand All @@ -131,17 +141,33 @@ object BenchmarkGenerator {
tmpResourceDir.toFile.mkdir()
tmpSourceDir.toFile.mkdir()

withClassLoader(benchmarkJarPath :: classpath) {
val source = new ASMGeneratorSource
val destination = new FileSystemDestination(tmpResourceDir.toFile, tmpSourceDir.toFile)
val generator = new JMHGenerator

collectClassesFromJar(benchmarkJarPath).foreach { path =>
// this would fail due to https://github.com/bazelbuild/rules_scala/issues/295
// let's throw a useful message instead
sys.error("jmh in rules_scala doesn't work with Java 8 bytecode: https://github.com/bazelbuild/rules_scala/issues/295")
source.processClass(Files.newInputStream(path))
withClassLoader(benchmarkJarPath :: classpath) { isolatedClassLoader =>

val source: GeneratorSource = generatorType match {
case AsmGenerator =>
val generatorSource = new ASMGeneratorSource
try {
generatorSource.processClasses(collectClassesFromJar(benchmarkJarPath).map(_.toFile).asJavaCollection)
} catch {
case _: ArrayIndexOutOfBoundsException =>
// this would fail due to https://github.com/bazelbuild/rules_scala/issues/295
// let's throw a useful message instead
sys.error("jmh in rules_scala doesn't work with Java 8 bytecode: https://github.com/bazelbuild/rules_scala/issues/295")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we change the error to say asm mode does not work with java 8 bytecode and instead suggest using reflection?

}
generatorSource

case ReflectionGenerator =>
val generatorSource = new RFGeneratorSource
generatorSource.processClasses(
collectClassesFromJar(benchmarkJarPath)
.flatMap(classByPath(_, isolatedClassLoader))
.asJavaCollection
)
generatorSource
}

val generator = new JMHGenerator
val destination = new FileSystemDestination(tmpResourceDir.toFile, tmpSourceDir.toFile)
generator.generate(source, destination)
generator.complete(source, destination)
if (destination.hasErrors) {
Expand All @@ -156,6 +182,29 @@ object BenchmarkGenerator {
}
}

private def classByPath(path: Path, cl: ClassLoader): Option[Class[_]] = {
val separator = path.getFileSystem.getSeparator
var s = path.toString
.stripPrefix(separator)
.stripSuffix(".class")
.replace(separator, ".")

var index = -1
do {
s = s.substring(index + 1)
try {
return Some(Class.forName(s, false, cl))
} catch {
case _: ClassNotFoundException =>
// ignore and try next one
index = s.indexOf('.')
}
} while (index != -1)

log(s"Failed to find class for path $path")
None
}

private def log(str: String): Unit = {
System.err.println(s"JMH benchmark generation: $str")
}
Expand Down
11 changes: 5 additions & 6 deletions test/jmh/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ scala_library(
visibility = ["//visibility:public"],
)

# Disable the jmh test due to https://github.com/bazelbuild/rules_scala/issues/295
# scala_benchmark_jmh(
# name = "test_benchmark",
# srcs = ["TestBenchmark.scala"],
# deps = [":add_numbers"],
# )
scala_benchmark_jmh(
name = "test_benchmark",
srcs = ["TestBenchmark.scala"],
deps = [":add_numbers"],
)