Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 113 additions & 3 deletions scala_proto/scala_proto.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ def _gen_proto_srcjar_impl(ctx):
srcjars = srcjarsattr,
)

"""
Deprecated: use scala_proto_gen instead
"""
scala_proto_srcjar = rule(
_gen_proto_srcjar_impl,
attrs = {
Expand Down Expand Up @@ -573,11 +576,12 @@ def scalapb_proto_library(
flags.append("flat_package")
if with_single_line_to_string:
flags.append("single_line_to_string")
scala_proto_srcjar(

_scalapb_proto_gen_with_jvm_deps(
name = srcjar,
flags = flags,
generator = "@io_bazel_rules_scala//src/scala/scripts:scalapb_generator",
deps = deps,
flags = flags,
plugin = "@io_bazel_rules_scala//src/scala/scripts:scalapb_plugin",
visibility = visibility,
)

Expand All @@ -594,3 +598,109 @@ def scalapb_proto_library(
scalac_jvm_flags = scalac_jvm_flags,
visibility = visibility,
)

def _scala_proto_gen_attrs(deps_providers):
return {
"deps": attr.label_list(mandatory = True, providers = deps_providers),
"blacklisted_protos" : attr.label_list(providers = [["proto"]]),
"flags": attr.string_list(default = []),
"plugin": attr.label(executable = True, cfg = "host"),
"_protoc": attr.label(executable = True, cfg = "host", default = "@com_google_protobuf//:protoc")
}

_scala_proto_gen_outputs = {
"srcjar": "lib%{name}.srcjar",
}

def _scalapb_proto_gen_with_jvm_deps_impl(ctx):
jvm_deps = [p for p in ctx.attr.deps if hasattr(p, "proto") == False]

if "java_conversions" in ctx.attr.flags and len(jvm_deps) == 0:
fail("must have at least one jvm dependency if with_java is True (java_conversions is turned on)")

_scala_proto_gen_impl(ctx)

deps_jars = collect_jars(jvm_deps)

srcjarsattr = struct(srcjar = ctx.outputs.srcjar)
scalaattr = struct(
outputs = None,
compile_jars = deps_jars.compile_jars,
transitive_runtime_jars = deps_jars.transitive_runtime_jars,
)
java_provider = create_java_provider(scalaattr, depset())
return struct(
scala = scalaattr,
providers = [java_provider],
srcjars = srcjarsattr,
)

_scalapb_proto_gen_with_jvm_deps = rule(
_scalapb_proto_gen_with_jvm_deps_impl,
attrs = _scala_proto_gen_attrs([["proto"], [JavaInfo]]),
outputs = _scala_proto_gen_outputs,
)

def _strip_root(file, roots):
"""Strip first matching root which comes from proto_library(proto_source_root)
It assumes that proto_source_root are unique.
It should go away once generation is moved to aspects and roots can be handled for each proto_library individualy.
"""
for root in roots:
prefix = root + "/" if file.is_source else file.root.path + "/" + root + "/"
if file.path.startswith(prefix):
return file.path.replace(prefix, "")
return file.short_path

def _scala_proto_gen_impl(ctx):
protos = [p for p in ctx.attr.deps if hasattr(p, "proto")] # because scalapb_proto_library passes JavaInfo as well
descriptors = depset([f for dep in protos for f in dep.proto.transitive_descriptor_sets]).to_list()
sources = depset([f for dep in protos for f in dep.proto.transitive_sources]).to_list()
roots = depset([f for dep in protos for f in dep.proto.transitive_proto_path]).to_list()
inputs = depset([_strip_root(f, roots) for f in _retained_protos(sources, ctx.attr.blacklisted_protos)]).to_list()

srcdotjar = ctx.actions.declare_file("_" + ctx.label.name + "_src.jar")

ctx.actions.run(
inputs = [ctx.executable._protoc, ctx.executable.plugin] + descriptors,
outputs = [srcdotjar],
arguments = [
"--plugin=protoc-gen-scala=" + ctx.executable.plugin.path,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to prevent using workers right? So you have to spin up a new JVM for each file, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it would prevent using workers. I thought this is what is happening now as well. Now I think my assumptions might be wrong as protoc-bridge is doing some tricks to reuse jvm. I'll dig more and update.

"--scala_out=%s:%s" % (",".join(ctx.attr.flags), srcdotjar.path),
"--descriptor_set_in=" + ":".join([descriptor.path for descriptor in descriptors])]
+ inputs,
executable = ctx.executable._protoc,
mnemonic = "ScalaProtoGen",
use_default_shell_env = True,
)

ctx.actions.run_shell(
command = "cp $1 $2",
inputs = [srcdotjar],
outputs = [ctx.outputs.srcjar],
arguments = [srcdotjar.path, ctx.outputs.srcjar.path])

"""Generates code with scala plugin passed to implicit @com_google_protobuf//:protoc

Example:
scala_proto_gen(
name = "a_proto_scala",
deps = [":a_proto"],
plugin = "@io_bazel_rules_scala//src/scala/scripts:scalapb_plugin")

Args:
deps: List of proto_library rules to generate code for
blacklisted_protos: List of proto_library rules to exclude from protoc inputs
(used for libraries that comes from runtime like any.proto)
flags: list of plugin flags passed to --scala_out
plugin: an executable passed to --plugin=protoc-gen-scala= which implements protoc plugin contract
https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/compiler/plugin.proto

Outputs:
Single srcjar with generated sources for all deps and all the transitives
"""
scala_proto_gen = rule(
_scala_proto_gen_impl,
attrs = _scala_proto_gen_attrs(deps_providers = [["proto"]]),
outputs = _scala_proto_gen_outputs,
)
11 changes: 11 additions & 0 deletions src/scala/scripts/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,14 @@ scala_binary(
":scalapb_generator_lib",
],
)

scala_binary(
name = "scalapb_plugin",
srcs = ["ScalaPBPlugin.scala"],
main_class = "scripts.ScalaPBPlugin",
deps = [
"//external:io_bazel_rules_scala/dependency/proto/scalapb_plugin",
"//external:io_bazel_rules_scala/dependency/com_google_protobuf/protobuf_java",
],
visibility = ["//visibility:public"],
)
10 changes: 10 additions & 0 deletions src/scala/scripts/ScalaPBPlugin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package scripts

import com.google.protobuf.compiler.PluginProtos.CodeGeneratorRequest.parseFrom
import scalapb.compiler.ProtobufGenerator.handleCodeGeneratorRequest

object ScalaPBPlugin extends App {

handleCodeGeneratorRequest(parseFrom(System.in)).writeTo(System.out)

}
11 changes: 6 additions & 5 deletions test/proto/BUILD
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
load(
"//scala_proto:scala_proto.bzl",
"scalapb_proto_library",
"scala_proto_srcjar"
"scala_proto_srcjar",
"scala_proto_gen",
)

load(
Expand Down Expand Up @@ -96,16 +97,16 @@ scalapb_proto_library(
deps = [":test_service"],
)

scala_proto_srcjar(
scala_proto_gen(
name = "test1_proto_scala",
deps = ["//test/proto2:test"],
generator = "@io_bazel_rules_scala//src/scala/scripts:scalapb_generator")
plugin = "@io_bazel_rules_scala//src/scala/scripts:scalapb_plugin")

scala_proto_srcjar(
scala_proto_gen(
name = "test2_proto_scala_with_blacklisted_test1_proto_scala",
deps = [":test2"],
blacklisted_protos = ["//test/proto2:test"],
generator = "@io_bazel_rules_scala//src/scala/scripts:scalapb_generator")
plugin = "@io_bazel_rules_scala//src/scala/scripts:scalapb_plugin")

scala_library(
name = "lib_scala_should_fail_on_duplicated_sources_unless_duplicates_are_blacklisted",
Expand Down