Skip to content

Commit

Permalink
Rework how maven project jars are calculated (#475)
Browse files Browse the repository at this point in the history
Previously, we assumed that all the information required for a project
all code to be included was inside the main repository. This was put
in place largely to prevent code depending on `rules_proto` from
pulling in the classes from protobuf itself.

Issue #448 demonstrated that this assumption wasn't correct.

We now gather all the JavaInfos that should be included in the
artifact, as well as those of dependencies of the artifact (called
"dep_infos" in the MavenInfo provider) We then calculate the
difference of those sets to determine the input files to add to the
generated project jar.

However, this is not sufficient on its own in the case of
protocbuf. As such, we also scan for classes to exclude from the
project jar by examaning the contents of the "dep_infos" runtime jars.

Furthermore, jars that are created by aspects have no way of
expressing what their maven dependencies are. To work around this,
rules may now express `MavenHintInfo` which will be used to help
calculate the contents of the artifact jars and maven dependencies.
  • Loading branch information
shs96c committed Nov 13, 2020
1 parent 08a3160 commit 6862295
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 77 deletions.
142 changes: 87 additions & 55 deletions private/rules/has_maven_deps.bzl
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
MavenInfo = provider(
fields = {
# Fields to do with maven coordinates
"coordinates": "Maven coordinates for the project, which may be None",
"artifact_jars": "Depset of runtime jars that are unique to the artifact",
"jars_from_maven_deps": "Depset of jars from all transitive maven dependencies",
"artifact_source_jars": "Depset of source jars that unique to the artifact",
"source_jars_from_maven_deps": "Depset of jars from all transitive maven dependencies",
"maven_deps": "Depset of first order maven dependencies",
"maven_deps": "Depset of first-order maven dependencies",
"as_maven_dep": "Depset of this project if used as a maven dependency",
"deps_java_infos": "Depset of JavaInfo instances of dependencies not included in the project",
"transitive_infos": "Dict of label to JavaInfos",

# Fields used for generating artifacts
"artifact_infos": "Depset of JavaInfo instances of targets to include in the maven artifact",
"dep_infos": "Depset of JavaInfo instances of dependencies that the maven artifact depends on",
"label_to_javainfo": "Dict mapping a label to the JavaInfo that label produces",
},
)

MavenHintInfo = provider(
doc = """Provides hints to the `has_maven_deps` aspect about additional dependencies.
This is particularly useful if outputs are generated from aspects, and so may not be able
to offer `tags` to be used to infer maven information.
""",
fields = {
"maven_infos": "Depset of MavenInfo instances to also consider as dependencies",
},
)

_EMPTY_INFO = MavenInfo(
coordinates = None,
artifact_jars = depset(),
artifact_source_jars = depset(),
jars_from_maven_deps = depset(),
source_jars_from_maven_deps = depset(),
maven_deps = depset(),
as_maven_dep = depset(),
deps_java_infos = depset(),
transitive_infos = {},
artifact_infos = depset(),
dep_infos = depset(),
label_to_javainfo = {},
)

_MAVEN_PREFIX = "maven_coordinates="
Expand Down Expand Up @@ -56,8 +63,48 @@ def _set_diff(first, second):

return [item for item in first if item not in second]

def _filter_external_jars(workspace_name, items):
return [item for item in items if item.owner.workspace_name in ["", workspace_name]]
def _flatten(array_of_depsets):
flattened = {}
for dep in array_of_depsets:
for item in dep.to_list():
flattened.update({item: True})

return flattened.keys()

def calculate_artifact_jars(maven_info):
"""Calculate the actual jars to include in a maven artifact"""
all_jars = _flatten([i.transitive_runtime_jars for i in maven_info.artifact_infos.to_list()])
dep_jars = _flatten([i.transitive_runtime_jars for i in maven_info.dep_infos.to_list()])

return _set_diff(all_jars, dep_jars)

def calculate_artifact_source_jars(maven_info):
"""Calculate the actual jars to include in a maven artifact"""
all_jars = _flatten([i.transitive_source_jars for i in maven_info.artifact_infos.to_list()])
dep_jars = _flatten([i.transitive_source_jars for i in maven_info.dep_infos.to_list()])

return _set_diff(all_jars, dep_jars)

# Used to gather maven data
_gathered = provider(
fields = [
"all_infos",
"label_to_javainfo",
"artifact_infos",
"dep_infos",
],
)

def _extract_from(gathered, maven_info, dep):
java_info = dep[JavaInfo] if dep and JavaInfo in dep else None

gathered.all_infos.append(maven_info)
gathered.label_to_javainfo.update(maven_info.label_to_javainfo)
if java_info:
if maven_info.coordinates:
gathered.dep_infos.append(dep[JavaInfo])
else:
gathered.artifact_infos.append(dep[JavaInfo])

def _has_maven_deps_impl(target, ctx):
if not JavaInfo in target:
Expand All @@ -72,54 +119,39 @@ def _has_maven_deps_impl(target, ctx):
for attr in _ASPECT_ATTRS:
all_deps.extend(getattr(ctx.rule.attr, attr, []))

all_infos = []
first_order_java_infos = []
coordinates = _read_coordinates(ctx.rule.attr.tags)
label_to_javainfo = {target.label: target[JavaInfo]}

gathered = _gathered(
all_infos = [],
artifact_infos = [target[JavaInfo]],
dep_infos = [],
label_to_javainfo = {target.label: target[JavaInfo]},
)
for dep in all_deps:
if MavenHintInfo in dep:
for info in dep[MavenHintInfo].maven_infos.to_list():
_extract_from(gathered, info, None)

if not MavenInfo in dep:
continue

all_infos.append(dep[MavenInfo])

if JavaInfo in dep and dep[MavenInfo].coordinates:
first_order_java_infos.append(dep[JavaInfo])

deps_java_infos = depset(
direct = first_order_java_infos,
transitive = [dep.deps_java_infos for dep in all_infos])

all_jars = target[JavaInfo].transitive_runtime_jars
jars_from_maven_deps = depset(transitive = [info.jars_from_maven_deps for info in all_infos])
items = _set_diff(all_jars.to_list(), jars_from_maven_deps.to_list())
artifact_jars = depset(_filter_external_jars(ctx.workspace_name, items))

all_source_jars = target[JavaInfo].transitive_source_jars
source_jars_from_maven_deps = depset(transitive = [jpi.source_jars_from_maven_deps for jpi in all_infos])
items = _set_diff(all_source_jars.to_list(), source_jars_from_maven_deps.to_list())
artifact_source_jars = depset(_filter_external_jars(ctx.workspace_name, items))

coordinates = _read_coordinates(ctx.rule.attr.tags)

first_order_maven_deps = depset(transitive = [jpi.as_maven_dep for jpi in all_infos])

# If we have coordinates our current `all_jars` is also our `maven_dep_jars`.
# Otherwise, we need to collect them from the MavenInfos we depend
# upon.
maven_dep_jars = all_jars if coordinates else jars_from_maven_deps
info = dep[MavenInfo]
_extract_from(gathered, info, dep)

transitive_infos = {target.label: target[JavaInfo]}
for mi in all_infos:
transitive_infos.update(mi.transitive_infos)
all_infos = gathered.all_infos
artifact_infos = gathered.artifact_infos
dep_infos = gathered.dep_infos
label_to_javainfo = gathered.label_to_javainfo
maven_deps = depset(transitive = [i.as_maven_dep for i in all_infos])

info = MavenInfo(
coordinates = coordinates,
artifact_jars = artifact_jars,
jars_from_maven_deps = all_jars if coordinates else jars_from_maven_deps,
artifact_source_jars = artifact_source_jars,
source_jars_from_maven_deps = all_source_jars if coordinates else source_jars_from_maven_deps,
maven_deps = first_order_maven_deps,
as_maven_dep = depset([coordinates]) if coordinates else first_order_maven_deps,
deps_java_infos = deps_java_infos,
transitive_infos = transitive_infos,
maven_deps = maven_deps,
as_maven_dep = depset([coordinates]) if coordinates else maven_deps,
artifact_infos = depset(direct = artifact_infos),
dep_infos = depset(direct = dep_infos, transitive = [i.dep_infos for i in all_infos]),
label_to_javainfo = label_to_javainfo,
)

return [
Expand Down
35 changes: 27 additions & 8 deletions private/rules/maven_project_jar.bzl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
load(":has_maven_deps.bzl", "MavenInfo", "has_maven_deps")
load(":has_maven_deps.bzl", "MavenInfo", "calculate_artifact_jars", "calculate_artifact_source_jars", "has_maven_deps")

def _combine_jars(ctx, merge_jars, inputs, output):
def _combine_jars(ctx, merge_jars, inputs, excludes, output):
args = ctx.actions.args()
args.add("--output", output)
args.add_all(inputs, before_each = "--sources")
args.add_all(excludes, before_each = "--exclude")

ctx.actions.run(
mnemonic = "MergeJars",
inputs = inputs,
inputs = inputs + excludes,
outputs = [output],
executable = merge_jars,
arguments = [args],
Expand All @@ -17,12 +18,26 @@ def _maven_project_jar_impl(ctx):
target = ctx.attr.target
info = target[MavenInfo]

# Identify the subset of JavaInfo to include in the artifact
artifact_jars = calculate_artifact_jars(info)
artifact_srcs = calculate_artifact_source_jars(info)

# Merge together all the binary jars
bin_jar = ctx.actions.declare_file("%s.jar" % ctx.label.name)
_combine_jars(ctx, ctx.executable._merge_jars, info.artifact_jars.to_list(), bin_jar)
_combine_jars(
ctx,
ctx.executable._merge_jars,
artifact_jars,
depset(transitive = [ji.transitive_runtime_jars for ji in info.dep_infos.to_list()]).to_list(),
bin_jar)

src_jar = ctx.actions.declare_file("%s-src.jar" % ctx.label.name)
_combine_jars(ctx, ctx.executable._merge_jars, info.artifact_source_jars.to_list(), src_jar)
_combine_jars(
ctx,
ctx.executable._merge_jars,
artifact_srcs,
depset(transitive = [ji.transitive_source_jars for ji in info.dep_infos.to_list()]).to_list(),
src_jar)

java_toolchain = ctx.attr._java_toolchain[java_common.JavaToolchainInfo]
ijar = java_common.run_ijar(
Expand All @@ -34,8 +49,12 @@ def _maven_project_jar_impl(ctx):

# Grab the exported javainfos
exported_infos = []
for label in target[JavaInfo].transitive_exports.to_list():
export_info = target[MavenInfo].transitive_infos.get(label)
targets = [] + target[JavaInfo].transitive_exports.to_list()
for i in info.artifact_infos.to_list():
targets.extend(i.transitive_exports.to_list())

for label in targets:
export_info = info.label_to_javainfo.get(label)
if export_info != None:
exported_infos.append(export_info)

Expand All @@ -45,7 +64,7 @@ def _maven_project_jar_impl(ctx):
source_jar = src_jar,

# TODO: calculate runtime_deps too
deps = info.deps_java_infos.to_list(),
deps = info.dep_infos.to_list(),
exports = exported_infos,
)

Expand Down
55 changes: 44 additions & 11 deletions private/tools/java/rules/jvm/external/jar/MergeJars.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import rules.jvm.external.ByteStreams;

import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -66,6 +67,7 @@ public static void main(String[] args) throws IOException {
Path out = null;
// Insertion order may matter
Set<Path> sources = new LinkedHashSet<>();
Set<Path> excludes = new HashSet<>();
DuplicateEntryStrategy onDuplicate = LAST_IN_WINS;

for (int i = 0; i < args.length; i++) {
Expand All @@ -79,16 +81,16 @@ public static void main(String[] args) throws IOException {
onDuplicate = DuplicateEntryStrategy.fromShortName(args[++i]);
break;

case "--exclude":
excludes.add(isValid(Paths.get(args[++i])));
break;

case "--output":
out = Paths.get(args[++i]);
break;

case "--sources":
Path path = Paths.get(args[++i]);
if (!Files.exists(path) || !Files.isReadable(path)) {
throw new IllegalArgumentException("Source must a readable file: " + path);
}
sources.add(path);
sources.add(isValid(Paths.get(args[++i])));
break;

default:
Expand All @@ -105,6 +107,9 @@ public static void main(String[] args) throws IOException {
return;
}

// Remove any jars from sources that we've been told to exclude
sources.removeIf(excludes::contains);

// To keep things simple, we expand all the inputs jars into a single directory,
// merge the manifests, and then create our own zip.
Path temp = Files.createTempDirectory("mergejars");
Expand All @@ -113,8 +118,8 @@ public static void main(String[] args) throws IOException {
manifest.getMainAttributes().put(Attributes.Name.MANIFEST_VERSION, "1.0");

Map<String, Set<String>> allServices = new TreeMap<>();

Map<Path, SortedMap<Path, Path>> allPaths = new TreeMap<>();
Set<String> excludedPaths = readExcludedFileNames(excludes);

for (Path source : sources) {
try (InputStream fis = Files.newInputStream(source);
Expand All @@ -128,8 +133,8 @@ public static void main(String[] args) throws IOException {
continue;
}

if (entry.isDirectory()) {
skipEntry(zis);
if (entry.isDirectory() ||
(!entry.getName().startsWith("META-INF/") && excludedPaths.contains(entry.getName()))) {
continue;
}

Expand Down Expand Up @@ -240,9 +245,37 @@ public static void main(String[] args) throws IOException {
delete(temp);
}

@SuppressWarnings("CheckReturnValue")
private static void skipEntry(ZipInputStream zis) throws IOException {
ByteStreams.toByteArray(zis);
private static Set<String> readExcludedFileNames(Set<Path> excludes) throws IOException {
Set<String> paths = new HashSet<>();

for (Path exclude : excludes) {
try (InputStream is = Files.newInputStream(exclude);
BufferedInputStream bis = new BufferedInputStream(is);
ZipInputStream jis = new ZipInputStream(bis)) {
ZipEntry entry;
while ((entry = jis.getNextEntry()) != null) {
if (entry.isDirectory()) {
continue;
}

String name = entry.getName();
paths.add(name);
}
}
}
return paths;
}

private static Path isValid(Path path) {
if (!Files.exists(path)) {
throw new IllegalArgumentException("File does not exist: " + path);
}

if (!Files.isReadable(path)) {
throw new IllegalArgumentException("File is not readable: " + path);
}

return path;
}

private static void delete(Path toDelete) throws IOException {
Expand Down
Loading

0 comments on commit 6862295

Please sign in to comment.