diff --git a/scala_proto/scala_proto.bzl b/scala_proto/scala_proto.bzl index 81fe08eb8..a6a9bab90 100644 --- a/scala_proto/scala_proto.bzl +++ b/scala_proto/scala_proto.bzl @@ -482,6 +482,9 @@ def _gen_proto_srcjar_impl(ctx): srcjars = srcjarsattr, ) +""" + Deprecated: use scala_proto_gen instead +""" scala_proto_srcjar = rule( _gen_proto_srcjar_impl, attrs = { @@ -573,11 +576,12 @@ def scalapb_proto_library( flags.append("flat_package") if with_single_line_to_string: flags.append("single_line_to_string") - scala_proto_srcjar( + + _scalapb_proto_gen_with_jvm_deps( name = srcjar, - flags = flags, - generator = "@io_bazel_rules_scala//src/scala/scripts:scalapb_generator", deps = deps, + flags = flags, + plugin = "@io_bazel_rules_scala//src/scala/scripts:scalapb_plugin", visibility = visibility, ) @@ -594,3 +598,109 @@ def scalapb_proto_library( scalac_jvm_flags = scalac_jvm_flags, visibility = visibility, ) + +def _scala_proto_gen_attrs(deps_providers): + return { + "deps": attr.label_list(mandatory = True, providers = deps_providers), + "blacklisted_protos" : attr.label_list(providers = [["proto"]]), + "flags": attr.string_list(default = []), + "plugin": attr.label(executable = True, cfg = "host"), + "_protoc": attr.label(executable = True, cfg = "host", default = "@com_google_protobuf//:protoc") + } + +_scala_proto_gen_outputs = { + "srcjar": "lib%{name}.srcjar", +} + +def _scalapb_proto_gen_with_jvm_deps_impl(ctx): + jvm_deps = [p for p in ctx.attr.deps if hasattr(p, "proto") == False] + + if "java_conversions" in ctx.attr.flags and len(jvm_deps) == 0: + fail("must have at least one jvm dependency if with_java is True (java_conversions is turned on)") + + _scala_proto_gen_impl(ctx) + + deps_jars = collect_jars(jvm_deps) + + srcjarsattr = struct(srcjar = ctx.outputs.srcjar) + scalaattr = struct( + outputs = None, + compile_jars = deps_jars.compile_jars, + transitive_runtime_jars = deps_jars.transitive_runtime_jars, + ) + java_provider = create_java_provider(scalaattr, depset()) + return struct( + scala = scalaattr, + providers = [java_provider], + srcjars = srcjarsattr, + ) + +_scalapb_proto_gen_with_jvm_deps = rule( + _scalapb_proto_gen_with_jvm_deps_impl, + attrs = _scala_proto_gen_attrs([["proto"], [JavaInfo]]), + outputs = _scala_proto_gen_outputs, +) + +def _strip_root(file, roots): + """Strip first matching root which comes from proto_library(proto_source_root) + It assumes that proto_source_root are unique. + It should go away once generation is moved to aspects and roots can be handled for each proto_library individualy. + """ + for root in roots: + prefix = root + "/" if file.is_source else file.root.path + "/" + root + "/" + if file.path.startswith(prefix): + return file.path.replace(prefix, "") + return file.short_path + +def _scala_proto_gen_impl(ctx): + protos = [p for p in ctx.attr.deps if hasattr(p, "proto")] # because scalapb_proto_library passes JavaInfo as well + descriptors = depset([f for dep in protos for f in dep.proto.transitive_descriptor_sets]).to_list() + sources = depset([f for dep in protos for f in dep.proto.transitive_sources]).to_list() + roots = depset([f for dep in protos for f in dep.proto.transitive_proto_path]).to_list() + inputs = depset([_strip_root(f, roots) for f in _retained_protos(sources, ctx.attr.blacklisted_protos)]).to_list() + + srcdotjar = ctx.actions.declare_file("_" + ctx.label.name + "_src.jar") + + ctx.actions.run( + inputs = [ctx.executable._protoc, ctx.executable.plugin] + descriptors, + outputs = [srcdotjar], + arguments = [ + "--plugin=protoc-gen-scala=" + ctx.executable.plugin.path, + "--scala_out=%s:%s" % (",".join(ctx.attr.flags), srcdotjar.path), + "--descriptor_set_in=" + ":".join([descriptor.path for descriptor in descriptors])] + + inputs, + executable = ctx.executable._protoc, + mnemonic = "ScalaProtoGen", + use_default_shell_env = True, + ) + + ctx.actions.run_shell( + command = "cp $1 $2", + inputs = [srcdotjar], + outputs = [ctx.outputs.srcjar], + arguments = [srcdotjar.path, ctx.outputs.srcjar.path]) + +"""Generates code with scala plugin passed to implicit @com_google_protobuf//:protoc + +Example: + scala_proto_gen( + name = "a_proto_scala", + deps = [":a_proto"], + plugin = "@io_bazel_rules_scala//src/scala/scripts:scalapb_plugin") + +Args: + deps: List of proto_library rules to generate code for + blacklisted_protos: List of proto_library rules to exclude from protoc inputs + (used for libraries that comes from runtime like any.proto) + flags: list of plugin flags passed to --scala_out + plugin: an executable passed to --plugin=protoc-gen-scala= which implements protoc plugin contract + https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/compiler/plugin.proto + +Outputs: + Single srcjar with generated sources for all deps and all the transitives +""" +scala_proto_gen = rule( + _scala_proto_gen_impl, + attrs = _scala_proto_gen_attrs(deps_providers = [["proto"]]), + outputs = _scala_proto_gen_outputs, +) diff --git a/src/scala/scripts/BUILD b/src/scala/scripts/BUILD index e4ba98d4a..a610112c2 100644 --- a/src/scala/scripts/BUILD +++ b/src/scala/scripts/BUILD @@ -55,3 +55,14 @@ scala_binary( ":scalapb_generator_lib", ], ) + +scala_binary( + name = "scalapb_plugin", + srcs = ["ScalaPBPlugin.scala"], + main_class = "scripts.ScalaPBPlugin", + deps = [ + "//external:io_bazel_rules_scala/dependency/proto/scalapb_plugin", + "//external:io_bazel_rules_scala/dependency/com_google_protobuf/protobuf_java", + ], + visibility = ["//visibility:public"], +) diff --git a/src/scala/scripts/ScalaPBPlugin.scala b/src/scala/scripts/ScalaPBPlugin.scala new file mode 100644 index 000000000..15ab6019b --- /dev/null +++ b/src/scala/scripts/ScalaPBPlugin.scala @@ -0,0 +1,10 @@ +package scripts + +import com.google.protobuf.compiler.PluginProtos.CodeGeneratorRequest.parseFrom +import scalapb.compiler.ProtobufGenerator.handleCodeGeneratorRequest + +object ScalaPBPlugin extends App { + + handleCodeGeneratorRequest(parseFrom(System.in)).writeTo(System.out) + +} diff --git a/test/proto/BUILD b/test/proto/BUILD index 278440091..a823b3bc2 100644 --- a/test/proto/BUILD +++ b/test/proto/BUILD @@ -1,7 +1,8 @@ load( "//scala_proto:scala_proto.bzl", "scalapb_proto_library", - "scala_proto_srcjar" + "scala_proto_srcjar", + "scala_proto_gen", ) load( @@ -96,16 +97,16 @@ scalapb_proto_library( deps = [":test_service"], ) -scala_proto_srcjar( +scala_proto_gen( name = "test1_proto_scala", deps = ["//test/proto2:test"], - generator = "@io_bazel_rules_scala//src/scala/scripts:scalapb_generator") + plugin = "@io_bazel_rules_scala//src/scala/scripts:scalapb_plugin") -scala_proto_srcjar( +scala_proto_gen( name = "test2_proto_scala_with_blacklisted_test1_proto_scala", deps = [":test2"], blacklisted_protos = ["//test/proto2:test"], - generator = "@io_bazel_rules_scala//src/scala/scripts:scalapb_generator") + plugin = "@io_bazel_rules_scala//src/scala/scripts:scalapb_plugin") scala_library( name = "lib_scala_should_fail_on_duplicated_sources_unless_duplicates_are_blacklisted",