diff --git a/scala/private/common.bzl b/scala/private/common.bzl index 493071cd5..a6019876b 100644 --- a/scala/private/common.bzl +++ b/scala/private/common.bzl @@ -1,3 +1,5 @@ +load("@io_bazel_rules_scala//scala:providers.bzl", "JarsToLabels") + def write_manifest(ctx): # TODO(bazel-team): I don't think this classpath is what you want manifest = "Class-Path: \n" @@ -104,7 +106,7 @@ def _add_label_of_indirect_jar_to(jars2labels, dependency, jar): # skylark exposes only labels of direct dependencies. # to get labels of indirect dependencies we collect them from the providers transitively if _provider_of_dependency_contains_label_of(dependency, jar): - jars2labels[jar.path] = dependency.jars_to_labels[jar.path] + jars2labels[jar.path] = dependency[JarsToLabels].lookup[jar.path] else: jars2labels[jar.path] = "Unknown label of file {jar_path} which came from {dependency_label}".format( jar_path = jar.path, @@ -115,7 +117,7 @@ def _label_already_exists(jars2labels, jar): return jar.path in jars2labels def _provider_of_dependency_contains_label_of(dependency, jar): - return hasattr(dependency, "jars_to_labels") and jar.path in dependency.jars_to_labels + return JarsToLabels in dependency and jar.path in dependency[JarsToLabels].lookup def create_java_provider(scalaattr, transitive_compile_time_jars): # This is needed because Bazel >=0.7.0 requires ctx.actions and a Java diff --git a/scala/private/rule_impls.bzl b/scala/private/rule_impls.bzl index d8df45262..a1a50e402 100644 --- a/scala/private/rule_impls.bzl +++ b/scala/private/rule_impls.bzl @@ -14,7 +14,7 @@ """Rules for supporting the Scala language.""" load("@io_bazel_rules_scala//scala:scala_toolchain.bzl", "scala_toolchain") -load("@io_bazel_rules_scala//scala:providers.bzl", "create_scala_provider") +load("@io_bazel_rules_scala//scala:providers.bzl", "create_scala_provider", "JarsToLabels") load(":common.bzl", "add_labels_of_jars_to", "create_java_provider", @@ -122,12 +122,14 @@ def _collect_plugin_paths(plugins): for p in plugins: if hasattr(p, "path"): paths.append(p) + elif p[JavaInfo] and p[JavaInfo].full_compile_jars: + paths.extend(p[JavaInfo].full_compile_jars.to_list()) elif hasattr(p, "scala"): paths.append(p.scala.outputs.jar) elif hasattr(p, "java"): paths.extend([j.class_jar for j in p.java.outputs.jars]) - # support http_file pointed at a jar. http_jar uses ijar, - # which breaks scala macros + # support http_file pointed at a jar. http_jar uses ijar, + # which breaks scala macros elif hasattr(p, "files"): paths.extend([f for f in p.files if not_sources_jar(f.basename)]) return depset(paths) @@ -606,10 +608,11 @@ def _scala_binary_common(ctx, cjars, rjars, transitive_compile_time_jars, jars2l return struct( files=depset([ctx.outputs.executable]), - providers = [java_provider], + providers = [ + JarsToLabels(lookup = jars2labels), + java_provider], scala = scalaattr, - transitive_rjars = rjars, #calling rules need this for the classpath in the launcher - jars_to_labels = jars2labels, + transitive_rjars = rjars, #calling rules need this for the classpath in the launcher runfiles=runfiles) def scala_binary_impl(ctx): diff --git a/scala/providers.bzl b/scala/providers.bzl index f275d5ef2..aa5baefc7 100644 --- a/scala/providers.bzl +++ b/scala/providers.bzl @@ -30,3 +30,10 @@ def create_scala_provider( transitive_runtime_jars = transitive_runtime_jars, transitive_exports = [] #needed by intellij plugin ) + +JarsToLabels = provider( + doc = 'provides a mapping from jar files to defining labels for improved end user experience', + fields = { + 'lookup' : 'dictionary with jar files as keys and labels as values', + }, +) diff --git a/scala/scala.bzl b/scala/scala.bzl index 1e1059463..9829b4c81 100644 --- a/scala/scala.bzl +++ b/scala/scala.bzl @@ -9,7 +9,7 @@ load("@io_bazel_rules_scala//scala/private:rule_impls.bzl", load( "@io_bazel_rules_scala//specs2:specs2_junit.bzl", - "specs2_junit_dependencies" + "specs2_junit_dependencies", ) _jar_filetype = FileType([".jar"]) diff --git a/scala/scala_import.bzl b/scala/scala_import.bzl index 67727509b..f32d65b02 100644 --- a/scala/scala_import.bzl +++ b/scala/scala_import.bzl @@ -1,112 +1,155 @@ -#intellij part is tested manually, tread lightly when changing there -#if you change make sure to manually re-import an intellij project and see imports -#are resolved (not red) and clickable +load(":providers.bzl", "JarsToLabels") + +# Note to future authors: +# +# Tread lightly when modifying this code! IntelliJ support needs +# to be tested manually: manually [re-]import an intellij project +# and ensure imports are resolved (not red) and clickable +# + def _scala_import_impl(ctx): - target_data = _code_jars_and_intellij_metadata_from(ctx.attr.jars) - (current_target_compile_jars, intellij_metadata) = (target_data.code_jars, target_data.intellij_metadata) - current_jars = depset(current_target_compile_jars) - exports = _collect(ctx.attr.exports) - transitive_runtime_jars = _collect_runtime(ctx.attr.runtime_deps) - jars = _collect(ctx.attr.deps) - jars2labels = {} - _collect_labels(ctx.attr.deps, jars2labels) - _collect_labels(ctx.attr.exports, jars2labels) #untested - _add_labels_of_current_code_jars(depset(transitive=[current_jars, exports.compile_jars]), ctx.label, jars2labels) #last to override the label of the export compile jars to the current target + + direct_binary_jars = [] + all_jar_files = [] + for jar in ctx.attr.jars: + for file in jar.files.to_list(): + all_jar_files.append(file) + if not file.basename.endswith("-sources.jar"): + direct_binary_jars += [file] + + default_info = DefaultInfo( + files = depset(all_jar_files) + ) + + source_jar = None + if (ctx.attr.srcjar): + source_jar = ctx.file.srcjar + return struct( - scala = struct( - outputs = struct ( - jars = intellij_metadata - ), - ), - jars_to_labels = jars2labels, + scala = _create_intellij_provider(direct_binary_jars, source_jar), providers = [ - _create_provider(current_jars, transitive_runtime_jars, jars, exports) - ], + default_info, + _scala_import_java_info(ctx, direct_binary_jars, source_jar), + _scala_import_jars_to_labels(ctx, direct_binary_jars), + ] + ) + +# The IntelliJ plugin currently does not support JavaInfo. It has its own +# provider. We build that provider and return it in addition to JavaInfo. +# From reading the IntelliJ plugin code, best I can tell it expects a provider +# that looks like this. +# { +# scala: { +# annotation_processing: { +# # see https://docs.bazel.build/versions/master/skylark/lib/java_annotation_processing.html +# }, +# outputs: { +# # see https://docs.bazel.build/versions/master/skylark/lib/java_output_jars.html +# jdeps: +# jars: [ +# { +# # see https://docs.bazel.build/versions/master/skylark/lib/java_output.html +# class_jar: , +# ijar: +# source_jar: +# source_jars: [...] +# } +# ] +# } +# }, +# } +def _create_intellij_provider(jars, source_jar): + return struct( + # TODO: should we support annotation_processing and jdeps? + outputs = struct( + jars = [_create_intellij_output(jar, source_jar) for jar in jars] + ) + ) + +def _create_intellij_output(class_jar, source_jar): + source_jars = [source_jar] if source_jar else [] + return struct( + class_jar = class_jar, + ijar = None, + source_jar = source_jar, + source_jars = source_jars, ) -def _create_provider(current_target_compile_jars, transitive_runtime_jars, jars, exports): - test_provider = java_common.create_provider() - if hasattr(test_provider, "full_compile_jars"): - return java_common.create_provider( - use_ijar = False, - compile_time_jars = depset(transitive = [current_target_compile_jars, exports.compile_jars]), - transitive_compile_time_jars = depset(transitive = [jars.transitive_compile_jars, current_target_compile_jars, exports.transitive_compile_jars]) , - transitive_runtime_jars = depset(transitive = [transitive_runtime_jars, jars.transitive_runtime_jars, current_target_compile_jars, exports.transitive_runtime_jars]) , - ) - else: - return java_common.create_provider( - compile_time_jars = current_target_compile_jars, - runtime_jars = transitive_runtime_jars + jars.transitive_runtime_jars, - transitive_compile_time_jars = jars.transitive_compile_jars + current_target_compile_jars, - transitive_runtime_jars = transitive_runtime_jars + jars.transitive_runtime_jars + current_target_compile_jars, - ) - -def _add_labels_of_current_code_jars(code_jars, label, jars2labels): - for jar in code_jars.to_list(): - jars2labels[jar.path] = label - -def _code_jars_and_intellij_metadata_from(jars): - code_jars = [] - intellij_metadata = [] - for jar in jars: - current_jar_code_jars = _filter_out_non_code_jars(jar.files) - code_jars += current_jar_code_jars - for current_class_jar in current_jar_code_jars: #intellij, untested - intellij_metadata.append(struct( - ijar = None, - class_jar = current_class_jar, - source_jar = None, - source_jars = [], - ) - ) - return struct(code_jars = code_jars, intellij_metadata = intellij_metadata) - -def _filter_out_non_code_jars(files): - return [file for file in files.to_list() if not _is_source_jar(file)] - -def _is_source_jar(file): - return file.basename.endswith("-sources.jar") - -def _collect(deps): - transitive_compile_jars = [] - runtime_jars = [] - compile_jars = [] - - for dep_target in deps: - java_provider = dep_target[java_common.provider] - compile_jars.append(java_provider.compile_jars) - transitive_compile_jars.append(java_provider.transitive_compile_time_jars) - runtime_jars.append(java_provider.transitive_runtime_jars) - - return struct(transitive_runtime_jars = depset(transitive = runtime_jars), - transitive_compile_jars = depset(transitive = transitive_compile_jars), - compile_jars = depset(transitive = compile_jars)) - -def _collect_labels(deps, jars2labels): - for dep_target in deps: - java_provider = dep_target[java_common.provider] - _transitively_accumulate_labels(dep_target, java_provider,jars2labels) - -def _transitively_accumulate_labels(dep_target, java_provider, jars2labels): - if hasattr(dep_target, "jars_to_labels"): - jars2labels.update(dep_target.jars_to_labels) - #scala_library doesn't add labels to the direct dependency itself - for jar in java_provider.compile_jars.to_list(): - jars2labels[jar.path] = dep_target.label - -def _collect_runtime(runtime_deps): - jar_deps = [] - for dep_target in runtime_deps: - java_provider = dep_target[java_common.provider] - jar_deps.append(java_provider.transitive_runtime_jars) - - return depset(transitive = jar_deps) + +def _scala_import_java_info(ctx, direct_binary_jars, source_jar = None): + s_deps = java_common.merge(_collect(JavaInfo, ctx.attr.deps)) + s_exports = java_common.merge(_collect(JavaInfo, ctx.attr.exports)) + s_runtime_deps = java_common.merge(_collect(JavaInfo, ctx.attr.runtime_deps)) + + # build up our final JavaInfo provider + + compile_time_jars = depset( + direct = direct_binary_jars, + transitive = [ + s_exports.transitive_compile_time_jars]) + + transitive_compile_time_jars = depset( + transitive = [ + compile_time_jars, + s_deps.transitive_compile_time_jars, + s_exports.transitive_compile_time_jars]) + + transitive_runtime_jars = depset( + transitive = [ + compile_time_jars, + s_deps.transitive_runtime_jars, + s_exports.transitive_runtime_jars, + s_runtime_deps.transitive_runtime_jars]) + + source_jars = [source_jar] if source_jar else [] + + return java_common.create_provider( + ctx.actions, + use_ijar = False, + compile_time_jars = compile_time_jars, + transitive_compile_time_jars = transitive_compile_time_jars, + transitive_runtime_jars = transitive_runtime_jars, + source_jars = source_jars) + +def _scala_import_jars_to_labels(ctx, direct_binary_jars): + # build up JarsToLabels + # note: consider moving this to an aspect + + lookup = {} + for jar in direct_binary_jars: + lookup[jar.path] = ctx.label + + for entry in ctx.attr.deps: + if JavaInfo in entry: + for jar in entry[JavaInfo].compile_jars: + lookup[jar.path] = entry.label + if JarsToLabels in entry: + lookup.update(entry[JarsToLabels].lookup) + + for entry in ctx.attr.exports: + if JavaInfo in entry: + for jar in entry[JavaInfo].compile_jars.to_list(): + lookup[jar.path] = entry.label + if JarsToLabels in entry: + lookup.update(entry[JarsToLabels].lookup) + + return JarsToLabels(lookup = lookup) + +# Filters an iterable for entries that contain a particular +# index and returns a collection of the indexed values. +def _collect(index, iterable): + return [ + entry[index] + for entry in iterable + if index in entry + ] scala_import = rule( - implementation=_scala_import_impl, - attrs={ - "jars": attr.label_list(allow_files=True), #current hidden assumption is that these point to full, not ijar'd jars - "deps": attr.label_list(), - "runtime_deps": attr.label_list(), - "exports": attr.label_list() - }, + implementation = _scala_import_impl, + attrs = { + "jars": attr.label_list(allow_files=True), #current hidden assumption is that these point to full, not ijar'd jars + "srcjar": attr.label(allow_single_file=True), + "deps": attr.label_list(), + "runtime_deps": attr.label_list(), + "exports": attr.label_list(), + }, ) diff --git a/scala_proto/scala_proto.bzl b/scala_proto/scala_proto.bzl index 4d6ff09f8..0a936d44b 100644 --- a/scala_proto/scala_proto.bzl +++ b/scala_proto/scala_proto.bzl @@ -336,11 +336,12 @@ def _gen_proto_srcjar_impl(ctx): acc_imports.append(target.proto.transitive_sources) #inline this if after 0.12.0 is the oldest supported version if hasattr(target.proto, 'transitive_proto_path'): - transitive_proto_paths.append(target.proto.transitive_proto_path) + transitive_proto_paths.append(target.proto.transitive_proto_path) else: jvm_deps.append(target) acc_imports = depset(transitive = acc_imports) + transitive_proto_paths = depset(transitive = transitive_proto_paths) if "java_conversions" in ctx.attr.flags and len(jvm_deps) == 0: fail("must have at least one jvm dependency if with_java is True (java_conversions is turned on)") @@ -352,7 +353,7 @@ def _gen_proto_srcjar_impl(ctx): # Command line args to worker cannot be empty so using padding flags_arg = "-" + ",".join(ctx.attr.flags), # Command line args to worker cannot be empty so using padding - packages = "-" + ":".join(transitive_proto_paths) + packages = "-" + ":".join(transitive_proto_paths.to_list()) ) argfile = ctx.actions.declare_file("%s_worker_input" % ctx.label.name, sibling = ctx.outputs.srcjar) ctx.actions.write(output=argfile, content=worker_content)