diff --git a/scala/providers.bzl b/scala/providers.bzl new file mode 100644 index 000000000..c23196177 --- /dev/null +++ b/scala/providers.bzl @@ -0,0 +1,6 @@ +JarsToLabels = provider( + doc = 'provides a mapping from jar files to defining labels for improved end user experience', + fields = { + 'lookup' : 'dictionary with jar files as keys and labels as values', + }, +) diff --git a/scala/scala.bzl b/scala/scala.bzl index c843c29b9..ab602d158 100644 --- a/scala/scala.bzl +++ b/scala/scala.bzl @@ -17,6 +17,7 @@ load("//specs2:specs2_junit.bzl", "specs2_junit_dependencies") load(":scala_cross_version.bzl", "scala_version", "scala_mvn_artifact") load("@io_bazel_rules_scala//scala:scala_toolchain.bzl", "scala_toolchain") +load(":providers.bzl", "JarsToLabels") _jar_filetype = FileType([".jar"]) _java_filetype = FileType([".java"]) @@ -494,7 +495,7 @@ def add_label_of_indirect_jar_to(jars2labels, dependency, jar): # skylark exposes only labels of direct dependencies. # to get labels of indirect dependencies we collect them from the providers transitively if provider_of_dependency_contains_label_of(dependency, jar): - jars2labels[jar.path] = dependency.jars_to_labels[jar.path] + jars2labels[jar.path] = dependency[JarsToLabels].lookup[jar.path] else: jars2labels[jar.path] = "Unknown label of file {jar_path} which came from {dependency_label}".format( jar_path = jar.path, @@ -505,7 +506,7 @@ def label_already_exists(jars2labels, jar): return jar.path in jars2labels def provider_of_dependency_contains_label_of(dependency, jar): - return hasattr(dependency, "jars_to_labels") and jar.path in dependency.jars_to_labels + return JarsToLabels in dependency and jar.path in dependency[JarsToLabels].lookup def dep_target_contains_ijar(dep_target): return (hasattr(dep_target, 'scala') and hasattr(dep_target.scala, 'outputs') and @@ -728,7 +729,9 @@ def _lib(ctx, non_macro_lib): return struct( files = depset([ctx.outputs.jar]), # Here is the default output scala = scalaattr, - providers = [java_provider], + providers = [ + JarsToLabels(lookup = jars.jars2labels), + java_provider], runfiles = runfiles, # This is a free monoid given to the graph for the purpose of # extensibility. This is necessary when one wants to create @@ -740,7 +743,6 @@ def _lib(ctx, non_macro_lib): # this information through, and it is up to the new_targets # to filter and make sense of this information. extra_information=_collect_extra_information(ctx.attr.deps), - jars_to_labels = jars.jars2labels, ) @@ -783,10 +785,11 @@ def _scala_binary_common(ctx, cjars, rjars, transitive_compile_time_jars, jars2l return struct( files=depset([ctx.outputs.executable]), - providers = [java_provider], + providers = [ + JarsToLabels(lookup = jars2labels), + java_provider], scala = scalaattr, transitive_rjars = rjars, #calling rules need this for the classpath in the launcher - jars_to_labels = jars2labels, runfiles=runfiles) def _scala_binary_impl(ctx): diff --git a/scala/scala_import.bzl b/scala/scala_import.bzl index 03d4b4e83..cb00e75a6 100644 --- a/scala/scala_import.bzl +++ b/scala/scala_import.bzl @@ -1,121 +1,105 @@ -#intellij part is tested manually, tread lightly when changing there -#if you change make sure to manually re-import an intellij project and see imports -#are resolved (not red) and clickable +load(":providers.bzl", "JarsToLabels") + +# Note to future authors: +# +# Tread lightly when modifying this code! IntelliJ support needs +# to be tested manually: manually [re-]import an intellij project +# and ensure imports are resolved (not red) and clickable +# + def _scala_import_impl(ctx): - target_data = _code_jars_and_intellij_metadata_from(ctx.attr.jars) - (current_target_compile_jars, intellij_metadata) = (target_data.code_jars, target_data.intellij_metadata) - current_jars = depset(current_target_compile_jars) - exports = _collect(ctx.attr.exports) - transitive_runtime_jars = _collect_runtime(ctx.attr.runtime_deps) - jars = _collect(ctx.attr.deps) - jars2labels = {} - _collect_labels(ctx.attr.deps, jars2labels) - _collect_labels(ctx.attr.exports, jars2labels) #untested - _add_labels_of_current_code_jars(depset(transitive=[current_jars, exports.compile_jars]), ctx.label, jars2labels) #last to override the label of the export compile jars to the current target - return struct( - scala = struct( - outputs = struct ( - jars = intellij_metadata - ), - ), - jars_to_labels = jars2labels, - providers = [ - _create_provider(current_jars, transitive_runtime_jars, jars, exports) - ], + + direct_binary_jars = [] + all_jar_files = [] + for jar in ctx.attr.jars: + for file in jar.files: + all_jar_files.append(file) + if not file.basename.endswith("-sources.jar"): + direct_binary_jars += [file] + + default_info = DefaultInfo( + files = depset(all_jar_files) ) -def _create_provider(current_target_compile_jars, transitive_runtime_jars, jars, exports): - test_provider = java_common.create_provider() - if hasattr(test_provider, "full_compile_jars"): - return java_common.create_provider( - use_ijar = False, - compile_time_jars = depset(transitive = [current_target_compile_jars, exports.compile_jars]), - transitive_compile_time_jars = depset(transitive = [jars.transitive_compile_jars, current_target_compile_jars, exports.transitive_compile_jars]) , - transitive_runtime_jars = depset(transitive = [transitive_runtime_jars, jars.transitive_runtime_jars, current_target_compile_jars, exports.transitive_runtime_jars]) , - ) - else: - return java_common.create_provider( - compile_time_jars = current_target_compile_jars, - runtime_jars = transitive_runtime_jars + jars.transitive_runtime_jars, - transitive_compile_time_jars = jars.transitive_compile_jars + current_target_compile_jars, - transitive_runtime_jars = transitive_runtime_jars + jars.transitive_runtime_jars + current_target_compile_jars, - ) - -def _add_labels_of_current_code_jars(code_jars, label, jars2labels): - for jar in code_jars.to_list(): - jars2labels[jar.path] = label - -def _code_jars_and_intellij_metadata_from(jars): - code_jars = [] - intellij_metadata = [] - for jar in jars: - current_jar_code_jars = _filter_out_non_code_jars(jar.files) - code_jars += current_jar_code_jars - for current_class_jar in current_jar_code_jars: #intellij, untested - intellij_metadata.append(struct( - ijar = None, - class_jar = current_class_jar, - source_jar = None, - source_jars = [], - ) - ) - return struct(code_jars = code_jars, intellij_metadata = intellij_metadata) - -def _filter_out_non_code_jars(files): - return [file for file in files.to_list() if not _is_source_jar(file)] - -def _is_source_jar(file): - return file.basename.endswith("-sources.jar") - -def _collect(deps): - transitive_compile_jars = [] - runtime_jars = [] - compile_jars = [] - - for dep_target in deps: - java_provider = dep_target[java_common.provider] - compile_jars.append(java_provider.compile_jars) - transitive_compile_jars.append(java_provider.transitive_compile_time_jars) - runtime_jars.append(java_provider.transitive_runtime_jars) - - return struct(transitive_runtime_jars = depset(transitive = runtime_jars), - transitive_compile_jars = depset(transitive = transitive_compile_jars), - compile_jars = depset(transitive = compile_jars)) - -def _collect_labels(deps, jars2labels): - for dep_target in deps: - java_provider = dep_target[java_common.provider] - _transitively_accumulate_labels(dep_target, java_provider,jars2labels) - -def _transitively_accumulate_labels(dep_target, java_provider, jars2labels): - if hasattr(dep_target, "jars_to_labels"): - jars2labels.update(dep_target.jars_to_labels) - #scala_library doesn't add labels to the direct dependency itself - for jar in java_provider.compile_jars.to_list(): - jars2labels[jar.path] = dep_target.label - -def _collect_runtime(runtime_deps): - jar_deps = [] - for dep_target in runtime_deps: - java_provider = dep_target[java_common.provider] - jar_deps.append(java_provider.transitive_runtime_jars) - - return depset(transitive = jar_deps) - -def _collect_exports(exports): - exported_jars = depset() - - for dep_target in exports: - java_provider = dep_target[java_common.provider] - exported_jars += java_provider.full_compile_jars - - return exported_jars + + return [ + default_info, + _scala_import_java_info(ctx, direct_binary_jars), + _scala_import_jars_to_labels(ctx, direct_binary_jars), + ] + +def _scala_import_java_info(ctx, direct_binary_jars): + # merge all deps, exports, and runtime deps into single JavaInfo instances + + s_deps = java_common.merge(_collect(JavaInfo, ctx.attr.deps)) + s_exports = java_common.merge(_collect(JavaInfo, ctx.attr.exports)) + s_runtime_deps = java_common.merge(_collect(JavaInfo, ctx.attr.runtime_deps)) + + # build up our final JavaInfo provider + + compile_time_jars = depset( + direct = direct_binary_jars, + transitive = [ + s_exports.transitive_compile_time_jars]) + + transitive_compile_time_jars = depset( + transitive = [ + compile_time_jars, + s_deps.transitive_compile_time_jars, + s_exports.transitive_compile_time_jars]) + + transitive_runtime_jars = depset( + transitive = [ + compile_time_jars, + s_deps.transitive_runtime_jars, + s_exports.transitive_runtime_jars, + s_runtime_deps.transitive_runtime_jars]) + + return java_common.create_provider( + ctx.actions, + use_ijar = False, + compile_time_jars = compile_time_jars, + transitive_compile_time_jars = transitive_compile_time_jars, + transitive_runtime_jars = transitive_runtime_jars) + +def _scala_import_jars_to_labels(ctx, direct_binary_jars): + # build up JarsToLabels + # note: consider moving this to an aspect + + lookup = {} + for jar in direct_binary_jars: + lookup[jar.path] = ctx.label + + for entry in ctx.attr.deps: + if JavaInfo in entry: + for jar in entry[JavaInfo].compile_jars: + lookup[jar.path] = entry.label + if JarsToLabels in entry: + lookup.update(entry[JarsToLabels].lookup) + + for entry in ctx.attr.exports: + if JavaInfo in entry: + for jar in entry[JavaInfo].compile_jars: + lookup[jar.path] = entry.label + if JarsToLabels in entry: + lookup.update(entry[JarsToLabels].lookup) + + return JarsToLabels(lookup = lookup) + +# Filters an iterable for entries that contain a particular +# index and returns a collection of the indexed values. +def _collect(index, iterable): + return [ + entry[index] + for entry in iterable + if index in entry + ] scala_import = rule( - implementation=_scala_import_impl, - attrs={ - "jars": attr.label_list(allow_files=True), #current hidden assumption is that these point to full, not ijar'd jars - "deps": attr.label_list(), - "runtime_deps": attr.label_list(), - "exports": attr.label_list() - }, + implementation = _scala_import_impl, + attrs = { + "jars": attr.label_list(allow_files=True), #current hidden assumption is that these point to full, not ijar'd jars + "deps": attr.label_list(), + "runtime_deps": attr.label_list(), + "exports": attr.label_list() + }, )