diff --git a/WORKSPACE b/WORKSPACE index e76d66e59..767fb61b8 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -3,8 +3,11 @@ workspace(name = "io_bazel_rules_scala") load("//scala:scala.bzl", "scala_repositories", "scala_mvn_artifact") scala_repositories() +load("//twitter_scrooge:twitter_scrooge.bzl", "twitter_scrooge", "scrooge_scala_library") +twitter_scrooge() + # test adding a scala jar: maven_jar( name = "com_twitter__scalding_date", artifact = scala_mvn_artifact("com.twitter:scalding-date:0.16.0-RC4") -) +) \ No newline at end of file diff --git a/scala/scala.bzl b/scala/scala.bzl index 45516972f..724c044b8 100644 --- a/scala/scala.bzl +++ b/scala/scala.bzl @@ -66,7 +66,8 @@ touch -t 198001010000 {manifest} progress_message="scala %s" % ctx.label, arguments=[]) -def _compile(ctx, jars, dep_srcjars, buildijar): +def _compile(ctx, _jars, dep_srcjars, buildijar): + jars = _jars res_cmd = _add_resources_cmd(ctx) ijar_cmd = "" if buildijar: @@ -75,14 +76,9 @@ def _compile(ctx, jars, dep_srcjars, buildijar): out=ctx.outputs.jar.path, ijar_out=ctx.outputs.ijar.path) - sources = [] - srcjars = [] - for f in ctx.files.srcs: - #TODO this is gross but we aren't given a good "filterNot" - if len(_srcjar_filetype.filter([f])) == 0: - sources.append(f) - else: - srcjars.append(f) + sources = _scala_filetype.filter(ctx.files.srcs) + srcjars = _srcjar_filetype.filter(ctx.files.srcs) + all_srcjars = set(srcjars + list(dep_srcjars)) # Set up the args to pass to scalac because they can be too long for bash scalac_args_file = ctx.new_file(ctx.outputs.jar, ctx.outputs.jar.short_path + "scalac_args") @@ -95,29 +91,22 @@ def _compile(ctx, jars, dep_srcjars, buildijar): ) ctx.file_action(output = scalac_args_file, content = scalac_args) - all_srcjars = srcjars + list(dep_srcjars) srcjar_cmd = "" if len(all_srcjars) > 0: + srcjar_cmd = "\nmkdir -p {out}_tmp_expand_srcjars\n" for srcjar in all_srcjars: # Note: this is double escaped because we need to do one format call # per each srcjar, but then we are going to include this in the bigger format # call that is done to generate the full command - # Note: unzip has -o set (overriding files), and all of the files are unzipped into the same directory. - # I feel this is ok because everything should be from the same source tree, so it should have to be consistent - # (the whole point of bazel is resolving diamonds). That said, a good TODO might be to ensure that - # if there are duplicate files, they are identical (and error otherwise) - #TODO would like to be able to switch >/dev/null, -v, etc based on the user's settings srcjar_cmd += """ -rm -rf {{out}}_tmp_expand_srcjars -mkdir -p {{out}}_tmp_expand_srcjars unzip -o {srcjar} -d {{out}}_tmp_expand_srcjars >/dev/null -echo " " >> {{out}}_args/files_from_jar -find {{out}}_tmp_expand_srcjars -type f -name "*.scala" >> {{out}}_args/files_from_jar """.format(srcjar = srcjar.path) + srcjar_cmd += """find {out}_tmp_expand_srcjars -type f -name "*.scala" > {out}_args/files_from_jar\n""" cmd = """ +rm -rf {out}_tmp_expand_srcjars rm -rf {out}_tmp set -e rm -rf {out}_args @@ -148,6 +137,8 @@ rm -rf {out}_tmp ctx.action( inputs=list(jars) + list(dep_srcjars) + + list(srcjars) + + list(sources) + ctx.files.srcs + ctx.files.resources + ctx.files._jdk + @@ -185,15 +176,16 @@ def write_manifest(ctx): content = manifest) def _write_launcher(ctx, jars): + classpath = ':'.join(["$0.runfiles/" + f.short_path for f in jars]) content = """#!/bin/bash -cd $0.runfiles -{java} -cp {cp} {name} "$@" -""" - content = content.format( - java=ctx.file._java.path, - name=ctx.attr.main_class, - deploy_jar=ctx.outputs.jar.path, - cp=":".join([j.short_path for j in jars])) +export CLASSPATH={classpath} +$0.runfiles/{java} {name} "$@" +""".format( + java=ctx.file._java.path, + name=ctx.attr.main_class, + deploy_jar=ctx.outputs.jar.path, + classpath=classpath, + ) ctx.file_action( output=ctx.outputs.executable, content=content) @@ -214,11 +206,11 @@ def _write_test_launcher(ctx, jars): output=ctx.outputs.executable, content=content) -def _collect_srcjars(targets): +def collect_srcjars(targets): srcjars = set() for target in targets: - if hasattr(target, "srcjar"): - srcjars += [target.srcjar] + if hasattr(target, "srcjars"): + srcjars += [target.srcjars.srcjar] return srcjars def _collect_jars(targets): @@ -251,7 +243,7 @@ def _collect_jars(targets): def _lib(ctx, non_macro_lib): # This will be used to pick up srcjars from non-scala library # targets (like thrift code generation) - srcjars = _collect_srcjars(ctx.attr.deps) + srcjars = collect_srcjars(ctx.attr.deps) jars = _collect_jars(ctx.attr.deps) (cjars, rjars) = (jars.compiletime, jars.runtime) write_manifest(ctx) @@ -275,7 +267,25 @@ def _lib(ctx, non_macro_lib): collect_data = True) return struct( scala = scalaattr, - runfiles=runfiles) + runfiles=runfiles, + # This is a free monoid given to the graph for the purpose of + # extensibility. This is necessary when one wants to create + # new targets which want to leverage a scala_library. For example, + # new_target1 -> scala_library -> new_target2. There might be + # information that new_target2 needs to get from new_target1, + # but we do not want to ohave to change scala_library to pass + # this information through. extra_information allows passing + # this information through, and it is up to the new_targets + # to filter and make sense of this information. + extra_information=_collect_extra_information(ctx.attr.deps), + ) + +def _collect_extra_information(targets): + r = [] + for target in targets: + if hasattr(target, 'extra_information'): + r.extend(target.extra_information) + return r def _scala_library_impl(ctx): return _lib(ctx, True) @@ -315,18 +325,17 @@ def _scala_test_impl(ctx): _write_test_launcher(ctx, rjars) return _scala_binary_common(ctx, cjars, rjars) -def implicit_deps(): - return { - "_ijar": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:ijar"), single_file=True, allow_files=True), - "_scalac": attr.label(executable=True, default=Label("@scala//:bin/scalac"), single_file=True, allow_files=True), - "_scalalib": attr.label(default=Label("@scala//:lib/scala-library.jar"), single_file=True, allow_files=True), - "_scalaxml": attr.label(default=Label("@scala//:lib/scala-xml_2.11-1.0.4.jar"), single_file=True, allow_files=True), - "_scalasdk": attr.label(default=Label("@scala//:sdk"), allow_files=True), - "_scalareflect": attr.label(default=Label("@scala//:lib/scala-reflect.jar"), single_file=True, allow_files=True), - "_java": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:java"), single_file=True, allow_files=True), - "_jar": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:jar"), single_file=True, allow_files=True), - "_jdk": attr.label(default=Label("//tools/defaults:jdk"), allow_files=True), - } +_implicit_deps = { + "_ijar": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:ijar"), single_file=True, allow_files=True), + "_scalac": attr.label(executable=True, default=Label("@scala//:bin/scalac"), single_file=True, allow_files=True), + "_scalalib": attr.label(default=Label("@scala//:lib/scala-library.jar"), single_file=True, allow_files=True), + "_scalaxml": attr.label(default=Label("@scala//:lib/scala-xml_2.11-1.0.4.jar"), single_file=True, allow_files=True), + "_scalasdk": attr.label(default=Label("@scala//:sdk"), allow_files=True), + "_scalareflect": attr.label(default=Label("@scala//:lib/scala-reflect.jar"), single_file=True, allow_files=True), + "_java": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:java"), single_file=True, allow_files=True), + "_jar": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:jar"), single_file=True, allow_files=True), + "_jdk": attr.label(default=Label("//tools/defaults:jdk"), allow_files=True), +} # Common attributes reused across multiple rules. _common_attrs = { @@ -345,7 +354,7 @@ scala_library = rule( attrs={ "main_class": attr.string(), "exports": attr.label_list(allow_files=False), - } + implicit_deps() + _common_attrs, + } + _implicit_deps + _common_attrs, outputs={ "jar": "%{name}_deploy.jar", "ijar": "%{name}_ijar.jar", @@ -358,7 +367,7 @@ scala_macro_library = rule( attrs={ "main_class": attr.string(), "exports": attr.label_list(allow_files=False), - } + implicit_deps() + _common_attrs, + } + _implicit_deps + _common_attrs, outputs={ "jar": "%{name}_deploy.jar", "manifest": "%{name}_MANIFEST.MF", @@ -369,7 +378,7 @@ scala_binary = rule( implementation=_scala_binary_impl, attrs={ "main_class": attr.string(mandatory=True), - } + implicit_deps() + _common_attrs, + } + _implicit_deps + _common_attrs, outputs={ "jar": "%{name}_deploy.jar", "manifest": "%{name}_MANIFEST.MF", @@ -384,7 +393,7 @@ scala_test = rule( "suites": attr.string_list(), "_scalatest": attr.label(executable=True, default=Label("@scalatest//file"), single_file=True, allow_files=True), "_scalatest_reporter": attr.label(default=Label("//scala/support:test_reporter")), - } + implicit_deps() + _common_attrs, + } + _implicit_deps + _common_attrs, outputs={ "jar": "%{name}_deploy.jar", "manifest": "%{name}_MANIFEST.MF", diff --git a/src/scala/scripts/BUILD b/src/scala/scripts/BUILD new file mode 100644 index 000000000..d064db9fd --- /dev/null +++ b/src/scala/scripts/BUILD @@ -0,0 +1,19 @@ +load("//scala:scala.bzl", "scala_binary") + +scala_binary( + name = "generator", + srcs = ["TwitterScroogeGenerator.scala"], + main_class = "scripts.ScroogeGenerator", + deps = [ + "@scrooge_generator//jar", + "@util_core//jar", + "@util_logging//jar", + ":scala_parsers", + ], + visibility = ["//visibility:public"], +) + +java_import( + name = "scala_parsers", + jars = ["@scala//:lib/scala-parser-combinators_2.11-1.0.4.jar"], +) \ No newline at end of file diff --git a/src/scala/scripts/TwitterScroogeGenerator.scala b/src/scala/scripts/TwitterScroogeGenerator.scala new file mode 100644 index 000000000..060847550 --- /dev/null +++ b/src/scala/scripts/TwitterScroogeGenerator.scala @@ -0,0 +1,171 @@ +package scripts + +import com.twitter.scrooge.Compiler + +import scala.collection.mutable.Buffer +import scala.io.Source + +import java.io.{ File, FileOutputStream, IOException } +import java.nio.file.{ Files, SimpleFileVisitor, FileVisitResult, Path, Paths } +import java.nio.file.attribute.{ BasicFileAttributes, FileTime } +import java.util.jar.{ JarEntry, JarFile, JarOutputStream } + +object FinalJarCreator { + val gm = """(\S+) -> (\S+)""".r + + def apply(dest: Path, owned: Set[Path], genFileMap: Path, scroogeDir: Path) { + val genmap = Source.fromFile(genFileMap.toString) + .getLines + .foldLeft(Map.empty[String, Set[String]]) { case (m, gm(thrift, gen)) => + m.+((thrift, m.getOrElse(thrift, Set.empty[String]) + gen)) + } + val shouldMove = + owned.map(_.toString).foldLeft(Set.empty[String]) { (s, n) => + genmap.get(n).fold(s) { s ++ _ } + }.map { Paths.get(_).normalize } + val jar = new JarOutputStream(new FileOutputStream(dest.toFile)) + Files.walkFileTree( + scroogeDir, + FinalJarCreator(scroogeDir, jar, shouldMove) + ) + jar.close() + } +} +case class FinalJarCreator(_baseDir: Path, jar: JarOutputStream, shouldMove: Set[Path]) extends SimpleFileVisitor[Path] { + val baseDir = _baseDir.normalize + + // We return the path of the file to add to the jar + def shouldVisitFile(file: Path): Option[Path] = + if (shouldMove.contains(file)) Some(baseDir.relativize(file)) + else None + + override def visitFile(file: Path, attr: BasicFileAttributes) = { + shouldVisitFile(file).foreach { _file => + val entry = new JarEntry(_file.toString) + entry.setTime(198001010000L) + jar.putNextEntry(entry) + Files.copy(file, jar) + } + FileVisitResult.CONTINUE + } +} + +object DeleteRecursively extends SimpleFileVisitor[Path] { + override def visitFile(file: Path, attr: BasicFileAttributes) = { + Files.delete(file) + FileVisitResult.CONTINUE + } + + override def postVisitDirectory(dir: Path, e: IOException) = { + if (e != null) throw e + Files.delete(dir) + FileVisitResult.CONTINUE + } +} + +case class ForeachFile(f: Path => Unit) extends SimpleFileVisitor[Path] { + override def visitFile(file: Path, attr: BasicFileAttributes) = { + f(file) + FileVisitResult.CONTINUE + } +} + +object ScroogeGenerator { + def deleteDir(path: Path) { + try { + Files.walkFileTree(path, DeleteRecursively) + } catch { + case e: Exception => + } + } + + def extractJarTo(_jar: Path, _dest: Path): List[Path] = { + val files = Buffer[Path]() + val jar = new JarFile(_jar.toFile) + val enumEntries = jar.entries() + while (enumEntries.hasMoreElements) { + val file = enumEntries.nextElement().asInstanceOf[JarEntry] + val path = _dest.resolve(file.getName) + if (file.isDirectory) Files.createDirectories(path) + else { + val is = jar.getInputStream(file) + + try Files.copy(is, path) // Will error out if path already exists + finally is.close() + + files += path + } + } + files.toList + } + + def readLinesAsPaths(path: Path): Set[Path] = + Source.fromFile(path.toString).getLines.map(Paths.get(_)).toSet + + def main(args: Array[String]) { + if (args.length != 4) sys.error("Need to ensure enough arguments! " + + "Required 4 arguments: onlyTransitiveThriftSrcs immediateThriftSrcs " + + "jarOutput remoteJarsFile. Received: " + args) + + val onlyTransitiveThriftSrcsFile = Paths.get(args(0)) + val immediateThriftSrcsFile = Paths.get(args(1)) + val jarOutput = Paths.get(args(2)) + val remoteJarsFile = Paths.get(args(3)) + + val tmp = Paths.get(Option(System.getenv("TMPDIR")).getOrElse("/tmp")) + val scroogeOutput = Files.createTempDirectory(tmp, "scrooge") + + // These are all of the files to include when generating scrooge + // Should not include anything in immediateThriftSrcs + val onlyTransitiveThriftSrcJars = readLinesAsPaths(onlyTransitiveThriftSrcsFile) + + // These are the files whose output we want + val immediateThriftSrcJars = readLinesAsPaths(immediateThriftSrcsFile) + + val genFileMap = scroogeOutput.resolve("gen-file-map.txt") + + val scrooge = new Compiler + + // we need to extract into the same tree, as that is the only way to get relative imports between them working.. + // AS SUCH, we are just going to try extracting EVERYTHING to the same tree, and we will just error if there + // are more than one. + val _tmp = Files.createTempDirectory(tmp, "jar") + // This will only be meaningful if they have absolute_prefix set + scrooge.includePaths += _tmp.toString + + def extract(jars: Set[Path]): Set[Path] = + jars.flatMap { jar => + val files = extractJarTo(jar, _tmp) + files.foreach { scrooge.includePaths += _.toString } + files + } + + val immediateThriftSrcs = extract(immediateThriftSrcJars) + + immediateThriftSrcs.foreach { scrooge.thriftFiles += _.toString } + + val onlyTransitiveThriftSrcs = extract(onlyTransitiveThriftSrcJars) + + val intersect = onlyTransitiveThriftSrcs.intersect(immediateThriftSrcs) + + if (intersect.nonEmpty) + sys.error("onlyTransitiveThriftSrcs and immediateThriftSrcs should " + + s"have not intersection, found: ${intersect.mkString(",")}") + + val remoteSrcJars = readLinesAsPaths(remoteJarsFile) + extract(remoteSrcJars) + + val dirsToDelete = Set(scroogeOutput, _tmp) + + scrooge.destFolder = scroogeOutput.toString + scrooge.fileMapPath = Some(genFileMap.toString) + //TODO we should make this configurable + scrooge.strict = false + scrooge.run() + + FinalJarCreator(jarOutput, immediateThriftSrcs, genFileMap, scroogeOutput) + + // Clean it out to be idempotent + dirsToDelete.foreach { deleteDir(_) } + } +} \ No newline at end of file diff --git a/test/src/main/scala/scala/test/srcjars/BUILD b/test/src/main/scala/scala/test/srcjars/BUILD index dca98f4d4..9d9129d51 100644 --- a/test/src/main/scala/scala/test/srcjars/BUILD +++ b/test/src/main/scala/scala/test/srcjars/BUILD @@ -7,7 +7,7 @@ load("//scala:scala.bzl", "scala_library") scala_library( name = "source_jar", # SourceJar1.jar was created by: - # jar -cfM test/src/main/scala/scala/test/srcjars/SourceJar1.sources.jar \ + # jar -cfM test/src/main/scala/scala/test/srcjars/SourceJar1.srcjar \ # test/src/main/scala/scala/test/srcjars/SourceJar1.scala srcs = ["SourceJar1.srcjar"], ) diff --git a/test/src/main/scala/scala/test/twitter_scrooge/BUILD b/test/src/main/scala/scala/test/twitter_scrooge/BUILD new file mode 100644 index 000000000..cd79723ee --- /dev/null +++ b/test/src/main/scala/scala/test/twitter_scrooge/BUILD @@ -0,0 +1,123 @@ +load("//scala:scala.bzl", "scala_binary", "scala_library") +load("//twitter_scrooge:twitter_scrooge.bzl", "scrooge_scala_library") + +scrooge_scala_library( + name = "scrooge1", + deps = [ + "//test/src/main/scala/scala/test/twitter_scrooge/thrift", + ":scrooge2_a", + ":scrooge2_b", + ":scrooge3", + ], + visibility = ["//visibility:public"], +) + +scrooge_scala_library( + name = "scrooge2_a", + deps = [ + "//test/src/main/scala/scala/test/twitter_scrooge/thrift/thrift2:thrift2_a", + ":scrooge3", + ], + visibility = ["//visibility:public"], +) + +scrooge_scala_library( + name = "scrooge2_b", + deps = [ + "//test/src/main/scala/scala/test/twitter_scrooge/thrift/thrift2:thrift2_b", + ":scrooge3", + ], + visibility = ["//visibility:public"], +) + +scrooge_scala_library( + name = "scrooge3", + deps = ["//test/src/main/scala/scala/test/twitter_scrooge/thrift/thrift2/thrift3"], + visibility = ["//visibility:public"], +) + +scrooge_scala_library( + name = "scrooge2", + deps = [ + "//test/src/main/scala/scala/test/twitter_scrooge/thrift/thrift2:thrift2_a", + "//test/src/main/scala/scala/test/twitter_scrooge/thrift/thrift2:thrift2_b", + ":scrooge3", + ], + visibility = ["//visibility:public"], +) + +scala_library( + name = "justscrooge1", + srcs = ["JustScrooge1.scala"], + deps = [":scrooge1"], +) + +scala_library( + name = "justscrooge2a", + srcs = ["JustScrooge2a.scala"], + deps = [":scrooge2_a"], +) + +scala_library( + name = "justscrooge2b", + srcs = ["JustScrooge2b.scala"], + deps = [":scrooge2_b"], +) + +scala_library( + name = "justscrooge3", + srcs = ["JustScrooge3.scala"], + deps = [":scrooge3"], +) + +scala_library( + name = "scrooge2_both", + srcs = ["Scrooge2.scala"], + deps = [":scrooge2"], +) + +scala_library( + name = "mixed", + srcs = ["Mixed.scala"], + deps = [ + ":justscrooge1", + ":justscrooge2a", + ":justscrooge2b", + ":justscrooge3", + ], +) + +scala_library( + name = "twodeep", + srcs = ["Twodeep.scala"], + deps = [":justscrooge3"], +) + +scala_binary( + name = "twodeep_binary", + deps = [":twodeep"], + main_class = "scala.test.twitter_scrooge.Twodeep", +) + +scala_binary( + name = "justscrooge2b_binary", + deps = [":justscrooge2b"], + main_class = "scala.test.twitter_scrooge.JustScrooge2b" +) + +scala_library( + name = "allscrooges", + deps = [ + ":scrooge1", + ":scrooge2_a", + ":scrooge2_b", + ":scrooge3", + ], +) + +scala_binary( + name = "justscrooges", + srcs = ["JustScrooge1.scala"], + deps = [":allscrooges"], + main_class = "scala.test.twitter_scrooge.JustScrooge1", +) \ No newline at end of file diff --git a/test/src/main/scala/scala/test/twitter_scrooge/JustScrooge1.scala b/test/src/main/scala/scala/test/twitter_scrooge/JustScrooge1.scala new file mode 100644 index 000000000..dcc110c5d --- /dev/null +++ b/test/src/main/scala/scala/test/twitter_scrooge/JustScrooge1.scala @@ -0,0 +1,14 @@ +package scala.test.twitter_scrooge + +import scala.test.twitter_scrooge.thrift.Struct1 +import scala.test.twitter_scrooge.thrift.thrift2.Struct2A +import scala.test.twitter_scrooge.thrift.thrift2.Struct2B +import scala.test.twitter_scrooge.thrift.thrift2.thrift3.Struct3 + +object JustScrooge1 { + val classes = Seq(classOf[Struct1], classOf[Struct2A], classOf[Struct2B], classOf[Struct3]) + + def main(args: Array[String]) { + print(s"classes ${classes.mkString(",")}") + } +} \ No newline at end of file diff --git a/test/src/main/scala/scala/test/twitter_scrooge/JustScrooge2a.scala b/test/src/main/scala/scala/test/twitter_scrooge/JustScrooge2a.scala new file mode 100644 index 000000000..14acfcf2a --- /dev/null +++ b/test/src/main/scala/scala/test/twitter_scrooge/JustScrooge2a.scala @@ -0,0 +1,8 @@ +package scala.test.twitter_scrooge + +import scala.test.twitter_scrooge.thrift.thrift2.Struct2A +import scala.test.twitter_scrooge.thrift.thrift2.thrift3.Struct3 + +object JustScrooge2a { + val classes = Seq(classOf[Struct2A], classOf[Struct3]) +} \ No newline at end of file diff --git a/test/src/main/scala/scala/test/twitter_scrooge/JustScrooge2b.scala b/test/src/main/scala/scala/test/twitter_scrooge/JustScrooge2b.scala new file mode 100644 index 000000000..ddbe3b477 --- /dev/null +++ b/test/src/main/scala/scala/test/twitter_scrooge/JustScrooge2b.scala @@ -0,0 +1,12 @@ +package scala.test.twitter_scrooge + +import scala.test.twitter_scrooge.thrift.thrift2.Struct2B +import scala.test.twitter_scrooge.thrift.thrift2.thrift3.Struct3 + +object JustScrooge2b { + val classes = Seq(classOf[Struct2B], classOf[Struct3]) + + def main(args: Array[String]) { + classes foreach println + } +} \ No newline at end of file diff --git a/test/src/main/scala/scala/test/twitter_scrooge/JustScrooge3.scala b/test/src/main/scala/scala/test/twitter_scrooge/JustScrooge3.scala new file mode 100644 index 000000000..654850079 --- /dev/null +++ b/test/src/main/scala/scala/test/twitter_scrooge/JustScrooge3.scala @@ -0,0 +1,7 @@ +package scala.test.twitter_scrooge + +import scala.test.twitter_scrooge.thrift.thrift2.thrift3.Struct3 + +object JustScrooge3 { + val classes = Seq(classOf[Struct3]) +} \ No newline at end of file diff --git a/test/src/main/scala/scala/test/twitter_scrooge/Mixed.scala b/test/src/main/scala/scala/test/twitter_scrooge/Mixed.scala new file mode 100644 index 000000000..337eb5f0a --- /dev/null +++ b/test/src/main/scala/scala/test/twitter_scrooge/Mixed.scala @@ -0,0 +1,20 @@ +package scala.test.twitter_scrooge + +import scala.test.twitter_scrooge.thrift.Struct1 +import scala.test.twitter_scrooge.thrift.thrift2.Struct2A +import scala.test.twitter_scrooge.thrift.thrift2.Struct2B +import scala.test.twitter_scrooge.thrift.thrift2.thrift3.Struct3 + +object Mixed { + val classes = + Seq( + classOf[Struct1], + classOf[Struct2A], + classOf[Struct2B], + classOf[Struct3], + JustScrooge1.getClass, + JustScrooge2a.getClass, + JustScrooge2b.getClass, + JustScrooge3.getClass + ) +} diff --git a/test/src/main/scala/scala/test/twitter_scrooge/Scrooge2.scala b/test/src/main/scala/scala/test/twitter_scrooge/Scrooge2.scala new file mode 100644 index 000000000..a83483966 --- /dev/null +++ b/test/src/main/scala/scala/test/twitter_scrooge/Scrooge2.scala @@ -0,0 +1,8 @@ +package scala.test.twitter_scrooge + +import scala.test.twitter_scrooge.thrift.thrift2.{Struct2A, Struct2B} +import scala.test.twitter_scrooge.thrift.thrift2.thrift3.Struct3 + +object Scrooge2 { + val classes = Seq(classOf[Struct2A], classOf[Struct2B], classOf[Struct3]) +} \ No newline at end of file diff --git a/test/src/main/scala/scala/test/twitter_scrooge/Twodeep.scala b/test/src/main/scala/scala/test/twitter_scrooge/Twodeep.scala new file mode 100644 index 000000000..64f512fd5 --- /dev/null +++ b/test/src/main/scala/scala/test/twitter_scrooge/Twodeep.scala @@ -0,0 +1,15 @@ +package scala.test.twitter_scrooge + +import scala.test.twitter_scrooge.thrift.thrift2.thrift3.Struct3 + +object Twodeep { + val classes = + Seq( + classOf[Struct3], + JustScrooge3.getClass + ) + + def main(args: Array[String]) { + classes foreach println + } +} diff --git a/test/src/main/scala/scala/test/twitter_scrooge/thrift/thrift2/thrift3/Thrift3.thrift b/test/src/main/scala/scala/test/twitter_scrooge/thrift/thrift2/thrift3/Thrift3.thrift index 6f23891cd..d5bb12b2f 100644 --- a/test/src/main/scala/scala/test/twitter_scrooge/thrift/thrift2/thrift3/Thrift3.thrift +++ b/test/src/main/scala/scala/test/twitter_scrooge/thrift/thrift2/thrift3/Thrift3.thrift @@ -2,4 +2,8 @@ namespace java scala.test.twitter_scrooge.thrift.thrift2.thrift3 struct Struct3 { 1: string msg +} + +struct Struct3Extra { + 1: string msg } \ No newline at end of file diff --git a/test_run.sh b/test_run.sh index 32b4f57f0..c722f03e2 100755 --- a/test_run.sh +++ b/test_run.sh @@ -22,6 +22,7 @@ bazel build test/... \ && bazel run test:ScalaLibBinary \ && bazel run test:JavaBinary \ && bazel run test:JavaBinary2 \ + && bazel run test/src/main/scala/scala/test/twitter_scrooge:justscrooges \ && bazel test test/... \ && find -L ./bazel-testlogs -iname "*.xml" \ && (find -L ./bazel-testlogs -iname "*.xml" | xargs -n1 xmllint > /dev/null) \ diff --git a/thrift/thrift.bzl b/thrift/thrift.bzl index 4fb0fe283..e68f5d922 100644 --- a/thrift/thrift.bzl +++ b/thrift/thrift.bzl @@ -3,37 +3,40 @@ _thrift_filetype = FileType([".thrift"]) def _thrift_library_impl(ctx): + prefix = ctx.attr.absolute_prefix + jarcmd = "{jar} cMf {out} -C {out}_tmp ." + if prefix != '': + jarcmd = "{{jar}} cMf {{out}} -C {{out}}_tmp/{prefix} .".format(prefix=prefix) + _valid_thrift_deps(ctx.attr.deps) # We move the files and touch them so that the output file is a purely deterministic # product of the _content_ of the inputs - #TODO is rsync an acceptable dependency here? cmd = """ rm -rf {out}_tmp mkdir -p {out}_tmp {jar} cMf {out}_tmp/tmp.jar $@ -unzip -o {out}_tmp/tmp.jar -d {out}_tmp >/dev/null +unzip -o {out}_tmp/tmp.jar -d {out}_tmp 2>/dev/null rm -rf {out}_tmp/tmp.jar find {out}_tmp -exec touch -t 198001010000 {{}} \; -{jar} cMf {out} -C {out}_tmp . -rm -rf {out}_tmp""".format(out=ctx.outputs.libarchive.path, - jar=ctx.file._jar.path) +""" + jarcmd + """ +rm -rf {out}_tmp""" + + cmd = cmd.format(out=ctx.outputs.libarchive.path, + jar=ctx.file._jar.path) ctx.action( - inputs = ctx.files.srcs + ctx.files._jar + ctx.files._jdk, + inputs = ctx.files.srcs + ctx.files._jar, outputs = [ctx.outputs.libarchive], command = cmd, progress_message = "making thrift archive %s" % ctx.label, arguments = [f.path for f in ctx.files.srcs], ) - transitive_archives = _collect_thrift_tars(ctx.attr.deps) - transitive_archives += [ctx.outputs.libarchive] transitive_srcs = _collect_thrift_srcs(ctx.attr.deps) - transitive_srcs += ctx.files.srcs + transitive_srcs += [ctx.outputs.libarchive] return struct( - srcs = ctx.files.srcs, thrift = struct( + srcs = ctx.outputs.libarchive, transitive_srcs = transitive_srcs, - transitive_archives = transitive_archives, ), ) @@ -43,9 +46,6 @@ def _collect_thrift_attr(targets, attr): s += getattr(target.thrift, attr) return s -def _collect_thrift_tars(targets): - return _collect_thrift_attr(targets, "transitive_archives") - def _collect_thrift_srcs(targets): return _collect_thrift_attr(targets, "transitive_srcs") @@ -54,21 +54,35 @@ def _valid_thrift_deps(targets): if not hasattr(target, "thrift"): fail("thrift_library can only depend on thrift_library", target) -# Some notes on the raison d'etre of thrift_library vs. scrooge_scala_library. -# The idea is to be able to separate concerns -- thrift_library is concerned -# with the ownership of thrift files, and organizing them into targets. It is -# not concerned with how the process of converting thrifts into sources. Thus, -# the scrooge_scala_library is what handles the specifics of code generation... -# this is useful because it means that if there are different code generation -# targets, we don't need to have a whole separate tree of targets organizing -# the thrifts. +# Some notes on the raison d'etre of thrift_library vs. code gen specific +# targets. The idea is to be able to separate concerns -- thrift_library is +# concerned purely with the ownership and organization of thrift files. It +# is not concerned with what to do with them. Thus, the code gen specific +# targets will take the graph of thrift_libraries and use them to generate +# code. This organization is useful because it means that if there are +# different code generation targets, we don't need to have a whole separate +# tree of targets organizing the thrifts per code gen paradigm. thrift_library = rule( implementation = _thrift_library_impl, attrs = { "srcs": attr.label_list(allow_files=_thrift_filetype), "deps": attr.label_list(), + #TODO this is not necessarily the best way to do this... the goal + # is that we want thrifts to be able to be imported via an absolute + # path. But the thrift files have no clue what part of their path + # should serve as the base for the import... for example, if a file is + # in src/main/thrift/com/hello/World.thrift, if something depends on that + # via "include 'com/hello/World.thrift'", there is no way to know what + # path that should be relative to. One option is to just search for anything + # that matches that, but that could create correctness issues if there are more + # than one in different parts of the tree. Another option is to take an argument + # that references namespace, and base the tree off of that. The downside + # to that is that thrift_library then gets enmeshed in the details of code + # generation. This could also be something punted to scrooge_scala_library + # or whatever, but I think that we should make it such that the archive + # created by this is created in such a way that absolute imports work... + "absolute_prefix": attr.string(default='', mandatory=False), "_jar": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:jar"), single_file=True, allow_files=True), - "_jdk": attr.label(default=Label("//tools/defaults:jdk"), allow_files=True), }, outputs={"libarchive": "lib%{name}.jar"}, ) \ No newline at end of file diff --git a/twitter_scrooge/BUILD b/twitter_scrooge/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/twitter_scrooge/twitter_scrooge.bzl b/twitter_scrooge/twitter_scrooge.bzl new file mode 100644 index 000000000..ea903a970 --- /dev/null +++ b/twitter_scrooge/twitter_scrooge.bzl @@ -0,0 +1,223 @@ +_jar_filetype = FileType([".jar"]) + +load("//scala:scala.bzl", + "scala_mvn_artifact", + "scala_library", + "write_manifest", + "collect_srcjars") + +def twitter_scrooge(): + native.maven_jar( + name = "libthrift", + artifact = "org.apache.thrift:libthrift:0.8.0", + sha1 = "2203b4df04943f4d52c53b9608cef60c08786ef2", + ) + native.maven_jar( + name = "scrooge_core", + artifact = scala_mvn_artifact("com.twitter:scrooge-core:4.6.0"), + sha1 = "84b86c2e082aba6e0c780b3c76281703b891a2c8", + ) + + #scrooge-generator related dependencies + native.maven_jar( + name = "scrooge_generator", + artifact = scala_mvn_artifact("com.twitter:scrooge-generator:4.6.0"), + sha1 = "cacf72eedeb5309ca02b2d8325c587198ecaac82", + ) + native.maven_jar( + name = "util_core", + artifact = scala_mvn_artifact("com.twitter:util-core:6.33.0"), + sha1 = "bb49fa66a3ca9b7db8cd764d0b26ce498bbccc83", + ) + native.maven_jar( + name = "util_logging", + artifact = scala_mvn_artifact("com.twitter:util-logging:6.33.0"), + sha1 = "3d28e46f8ee3b7ad1b98a51b98089fc01c9755dd", + ) + +def _collect_transitive_srcs(targets): + r = set() + for target in targets: + if hasattr(target, "thrift"): + r += target.thrift.transitive_srcs + return r + +def _collect_owned_srcs(targets): + r = set() + for _target in targets: + if hasattr(_target, "extra_information"): + for target in _target.extra_information: + if hasattr(target, "scrooge_srcjar"): + r += target.scrooge_srcjar.transitive_owned_srcs + return r + +def collect_extra_srcjars(targets): + srcjars = set() + for target in targets: + if hasattr(target, "extra_information"): + for _target in target.extra_information: + srcjars += [_target.srcjars.srcjar] + srcjars += _target.srcjars.transitive_srcjars + return srcjars + +def _collect_immediate_srcs(targets): + r = set() + for target in targets: + if hasattr(target, "thrift"): + r += [target.thrift.srcs] + return r + +def _assert_set_is_subset(left, right): + missing = set() + for l in left: + if l not in right: + missing += [l] + if len(missing) > 0: + fail('scrooge_srcjar target must depend on scrooge_srcjar targets sufficient to ' + + 'cover the transitive graph of thrift files. Uncovered sources: ' + missing) + +def _path_newline(data): + return '\n'.join([f.path for f in data]) + +def _gen_scrooge_srcjar_impl(ctx): + remote_jars = set() + for target in ctx.attr.remote_jars: + remote_jars += _jar_filetype.filter(target.files) + + # These are the thrift sources whose generated code we will "own" as a target + immediate_thrift_srcs = _collect_immediate_srcs(ctx.attr.deps) + + # This is the set of sources which is covered by any scala_library + # or scala_scrooge_gen targets that are depended on by this. This is + # necessary as we only compile the sources we own, and rely on other + # targets compiling the rest (for the benefit of caching and correctness). + transitive_owned_srcs = _collect_owned_srcs(ctx.attr.deps) + + # These are the thrift sources in the dependency graph. They are necessary + # to generate the code, but are not "owned" by this target and will not + # be in the resultant source jar + + transitive_thrift_srcs = transitive_owned_srcs + _collect_transitive_srcs(ctx.attr.deps) + + only_transitive_thrift_srcs = set() + for src in transitive_thrift_srcs: + if src not in immediate_thrift_srcs: + only_transitive_thrift_srcs += [src] + + # We want to ensure that the thrift sources which we do not own (but need + # in order to generate code) have targets which will compile them. + _assert_set_is_subset(only_transitive_thrift_srcs, transitive_owned_srcs) + + remote_jars_file = ctx.new_file(ctx.outputs.srcjar, ctx.outputs.srcjar.short_path + "_remote_jars") + ctx.file_action(output=remote_jars_file, content=_path_newline(remote_jars)) + + only_transitive_thrift_srcs_file = ctx.new_file(ctx.outputs.srcjar, ctx.outputs.srcjar.short_path + "_only_transitive_thrift_srcs") + ctx.file_action(output = only_transitive_thrift_srcs_file, content = _path_newline(only_transitive_thrift_srcs)) + + immediate_thrift_srcs_file = ctx.new_file(ctx.outputs.srcjar, ctx.outputs.srcjar.short_path + "_immediate_thrift_srcs") + ctx.file_action(output = immediate_thrift_srcs_file, content = _path_newline(immediate_thrift_srcs)) + + ctx.action( + executable = ctx.executable._pluck_scrooge_scala, + inputs = list(remote_jars) + + list(only_transitive_thrift_srcs) + + list(immediate_thrift_srcs) + + [remote_jars_file, + only_transitive_thrift_srcs_file, + immediate_thrift_srcs_file], + outputs = [ctx.outputs.srcjar], + arguments = [ + only_transitive_thrift_srcs_file.path, + immediate_thrift_srcs_file.path, + ctx.outputs.srcjar.path, + remote_jars_file.path, + ], + progress_message = "creating scrooge files %s" % ctx.label, + ) + + jars = _collect_scalaattr(ctx.attr.deps) + + scalaattr = struct(outputs = None, + transitive_runtime_deps = jars.transitive_runtime_deps, + transitive_compile_exports = jars.transitive_compile_exports, + transitive_runtime_exports = jars.transitive_runtime_exports, + ) + + transitive_srcjars = collect_srcjars(ctx.attr.deps) + collect_extra_srcjars(ctx.attr.deps) + + srcjarsattr = struct( + srcjar = ctx.outputs.srcjar, + transitive_srcjars = transitive_srcjars, + ) + + return struct( + scala = scalaattr, + srcjars=srcjarsattr, + extra_information=[struct( + srcjars=srcjarsattr, + scrooge_srcjar=struct(transitive_owned_srcs = transitive_owned_srcs + immediate_thrift_srcs), + )], + ) + +def _collect_scalaattr(targets): + transitive_runtime_deps = set() + transitive_compile_exports = set() + transitive_runtime_exports = set() + for target in targets: + if hasattr(target, "scala"): + transitive_runtime_deps += target.scala.transitive_runtime_deps + transitive_compile_exports += target.scala.transitive_compile_exports + if hasattr(target.scala.outputs, "ijar"): + transitive_compile_exports += [target.scala.outputs.ijar] + transitive_runtime_exports += target.scala.transitive_runtime_exports + + return struct( + transitive_runtime_deps = transitive_runtime_deps, + transitive_compile_exports = transitive_compile_exports, + transitive_runtime_exports = transitive_runtime_exports, + ) + +scrooge_scala_srcjar = rule( + _gen_scrooge_srcjar_impl, + attrs={ + "deps": attr.label_list(mandatory=True), + #TODO we should think more about how we want to deal + # with these sorts of things... this basically + # is saying that we have a jar with a bunch + # of thrifts that we want to depend on. Seems like + # that should be a concern of thrift_library? we have + # it here through becuase we need to show that it is + # "covered," as well as needing the thrifts to + # do the code gen. + "remote_jars": attr.label_list(), + "_pluck_scrooge_scala": attr.label( + executable=True, + default=Label("//src/scala/scripts:generator"), + allow_files=True), + }, + outputs={ + "srcjar": "lib%{name}.srcjar", + }, +) + +def scrooge_scala_library(name, deps=[], remote_jars=[], jvm_flags=[], visibility=None): + scrooge_scala_srcjar( + name = name + '_srcjar', + deps = deps, + remote_jars = remote_jars, + visibility = visibility, + ) + scala_library( + name = name, + deps = remote_jars + [ + name + '_srcjar', + "@libthrift//jar", + "@scrooge_core//jar", + ], + exports = [ + "@libthrift//jar", + "@scrooge_core//jar", + ], + jvm_flags = jvm_flags, + visibility = visibility, + )