Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.gradle.internal.dependencies.patches;

import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;

import java.util.Arrays;
import java.util.HexFormat;
import java.util.function.Function;

public final class PatcherInfo {
private final String jarEntryName;
private final byte[] classSha256;
private final Function<ClassWriter, ClassVisitor> visitorFactory;

private PatcherInfo(String jarEntryName, byte[] classSha256, Function<ClassWriter, ClassVisitor> visitorFactory) {
this.jarEntryName = jarEntryName;
this.classSha256 = classSha256;
this.visitorFactory = visitorFactory;
}

/**
* Creates a patcher info entry, linking a jar entry path name and its SHA256 digest to a patcher factory (a factory to create an ASM
* visitor)
*
* @param jarEntryName the jar entry path, as a string
* @param classSha256 the SHA256 digest of the class bytes, as a HEX string
* @param visitorFactory the factory to create an ASM visitor from a ASM writer
*/
public static PatcherInfo classPatcher(String jarEntryName, String classSha256, Function<ClassWriter, ClassVisitor> visitorFactory) {
return new PatcherInfo(jarEntryName, HexFormat.of().parseHex(classSha256), visitorFactory);
}

boolean matches(byte[] otherClassSha256) {
return Arrays.equals(this.classSha256, otherClassSha256);
}

public String jarEntryName() {
return jarEntryName;
}

public byte[] classSha256() {
return classSha256;
}

public ClassVisitor createVisitor(ClassWriter classWriter) {
return visitorFactory.apply(classWriter);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.gradle.internal.dependencies.patches;

import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.HexFormat;
import java.util.Locale;
import java.util.function.Function;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.jar.JarOutputStream;
import java.util.stream.Collectors;

import static org.objectweb.asm.ClassWriter.COMPUTE_FRAMES;
import static org.objectweb.asm.ClassWriter.COMPUTE_MAXS;

public class Utils {

private static final MessageDigest SHA_256;

static {
try {
SHA_256 = MessageDigest.getInstance("SHA-256");
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}

private record MismatchInfo(String jarEntryName, String expectedClassSha256, String foundClassSha256) {
@Override
public String toString() {
return "[class='"
+ jarEntryName
+ '\''
+ ", expected='"
+ expectedClassSha256
+ '\''
+ ", found='"
+ foundClassSha256
+ '\''
+ ']';
}
}

/**
* Patches the classes in the input JAR file, using the collection of patchers. Each patcher specifies a target class (its jar entry
* name) and the SHA256 digest on the class bytes.
* This digest is checked against the class bytes in the JAR, and if it does not match, an IllegalArgumentException is thrown.
* If the input file does not contain all the classes to patch specified in the patcher info collection, an IllegalArgumentException
* is also thrown.
* @param inputFile the JAR file to patch
* @param outputFile the output (patched) JAR file
* @param patchers list of patcher info (classes to patch (jar entry name + optional SHA256 digest) and ASM visitor to transform them)
*/
public static void patchJar(File inputFile, File outputFile, Collection<PatcherInfo> patchers) {
var classPatchers = patchers.stream().collect(Collectors.toMap(PatcherInfo::jarEntryName, Function.identity()));
var mismatchedClasses = new ArrayList<MismatchInfo>();
try (JarFile jarFile = new JarFile(inputFile); JarOutputStream jos = new JarOutputStream(new FileOutputStream(outputFile))) {
Enumeration<JarEntry> entries = jarFile.entries();
while (entries.hasMoreElements()) {
JarEntry entry = entries.nextElement();
String entryName = entry.getName();
// Add the entry to the new JAR file
jos.putNextEntry(new JarEntry(entryName));

var classPatcher = classPatchers.remove(entryName);
if (classPatcher != null) {
byte[] classToPatch = jarFile.getInputStream(entry).readAllBytes();
var classSha256 = SHA_256.digest(classToPatch);

if (classPatcher.matches(classSha256)) {
ClassReader classReader = new ClassReader(classToPatch);
ClassWriter classWriter = new ClassWriter(classReader, COMPUTE_MAXS | COMPUTE_FRAMES);
classReader.accept(classPatcher.createVisitor(classWriter), 0);
jos.write(classWriter.toByteArray());
} else {
mismatchedClasses.add(
new MismatchInfo(
classPatcher.jarEntryName(),
HexFormat.of().formatHex(classPatcher.classSha256()),
HexFormat.of().formatHex(classSha256)
)
);
}
} else {
// Read the entry's data and write it to the new JAR
try (InputStream is = jarFile.getInputStream(entry)) {
is.transferTo(jos);
}
}
jos.closeEntry();
}
} catch (IOException ex) {
throw new RuntimeException(ex);
}

if (mismatchedClasses.isEmpty() == false) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"""
Error patching JAR [%s]: SHA256 digest mismatch (%s). This JAR was updated to a version that contains different \
classes, for which this patcher was not designed. Please check if the patcher still \
applies correctly, and update the SHA256 digest(s).""",
inputFile.getName(),
mismatchedClasses.stream().map(MismatchInfo::toString).collect(Collectors.joining())
)
);
}

if (classPatchers.isEmpty() == false) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"error patching [%s]: the jar does not contain [%s]",
inputFile.getName(),
String.join(", ", classPatchers.keySet())
)
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

package org.elasticsearch.gradle.internal.dependencies.patches.hdfs;

import org.elasticsearch.gradle.internal.dependencies.patches.PatcherInfo;
import org.elasticsearch.gradle.internal.dependencies.patches.Utils;
import org.gradle.api.artifacts.transform.CacheableTransform;
import org.gradle.api.artifacts.transform.InputArtifact;
import org.gradle.api.artifacts.transform.TransformAction;
Expand All @@ -20,52 +22,85 @@
import org.gradle.api.tasks.Input;
import org.gradle.api.tasks.Optional;
import org.jetbrains.annotations.NotNull;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.jar.JarOutputStream;
import java.util.regex.Pattern;

import static java.util.Map.entry;
import static org.objectweb.asm.ClassWriter.COMPUTE_FRAMES;
import static org.objectweb.asm.ClassWriter.COMPUTE_MAXS;
import static org.elasticsearch.gradle.internal.dependencies.patches.PatcherInfo.classPatcher;

@CacheableTransform
public abstract class HdfsClassPatcher implements TransformAction<HdfsClassPatcher.Parameters> {

record JarPatchers(String artifactTag, Pattern artifactPattern, Map<String, Function<ClassWriter, ClassVisitor>> jarPatchers) {}
record JarPatchers(String artifactTag, Pattern artifactPattern, List<PatcherInfo> jarPatchers) {}

static final List<JarPatchers> allPatchers = List.of(
new JarPatchers(
"hadoop-common",
Pattern.compile("hadoop-common-(?!.*tests)"),
Map.ofEntries(
entry("org/apache/hadoop/util/ShutdownHookManager.class", ShutdownHookManagerPatcher::new),
entry("org/apache/hadoop/util/Shell.class", ShellPatcher::new),
entry("org/apache/hadoop/security/UserGroupInformation.class", SubjectGetSubjectPatcher::new)
"hadoop2-common",
Pattern.compile("hadoop-common-2(?!.*tests)"),
List.of(
classPatcher(
"org/apache/hadoop/util/ShutdownHookManager.class",
"3912451f02da9199dae7dba3f1420e0d951067addabbb235e7551de52234a0ef",
ShutdownHookManagerPatcher::new
),
classPatcher(
"org/apache/hadoop/util/Shell.class",
"60400dc800e7c3e1a5fc499793033d877f5319bbd7633fee05d5a1d96b947bbd",
ShellPatcher::new
),
classPatcher(
"org/apache/hadoop/security/UserGroupInformation.class",
"218078b8c77838f93d015c843775985a71f3c7a8128e2a9394410f0cd1da5f53",
SubjectGetSubjectPatcher::new
)
)
),
new JarPatchers(
"hadoop3-common",
Pattern.compile("hadoop-common-3(?!.*tests)"),
List.of(
classPatcher(
"org/apache/hadoop/util/ShutdownHookManager.class",
"7720e8545a02de6fd03f4170f0e471d1301ef73d7d6a09097bad361f9e31f819",
ShutdownHookManagerPatcher::new
),
classPatcher(
"org/apache/hadoop/util/Shell.class",
"856d0b829cf550df826387af15fa1c772bc7d26d6461535b17b9d5114d308dc4",
ShellPatcher::new
),
classPatcher(
"org/apache/hadoop/security/UserGroupInformation.class",
"52f5973f35a282908d48a573a03c04f240a22c9f6007d7c5e7852aff1c641420",
SubjectGetSubjectPatcher::new
)
)
),
new JarPatchers(
"hadoop-client-api",
Pattern.compile("hadoop-client-api.*"),
Map.ofEntries(
entry("org/apache/hadoop/util/ShutdownHookManager.class", ShutdownHookManagerPatcher::new),
entry("org/apache/hadoop/util/Shell.class", ShellPatcher::new),
entry("org/apache/hadoop/security/UserGroupInformation.class", SubjectGetSubjectPatcher::new),
entry("org/apache/hadoop/security/authentication/client/KerberosAuthenticator.class", SubjectGetSubjectPatcher::new)
List.of(
classPatcher(
"org/apache/hadoop/util/ShutdownHookManager.class",
"90641e0726fc9372479728ef9b7ae2be20fb7ab4cddd4938e55ffecadddd4d94",
ShutdownHookManagerPatcher::new
),
classPatcher(
"org/apache/hadoop/util/Shell.class",
"8837c7f3eeda3f658fc3d6595f18e77a4558220ff0becdf3e175fa4397a6fd0c",
ShellPatcher::new
),
classPatcher(
"org/apache/hadoop/security/UserGroupInformation.class",
"3c34bbc2716a6c8f4e356e78550599b0a4f01882712b4f7787d032fb10527212",
SubjectGetSubjectPatcher::new
),
classPatcher(
"org/apache/hadoop/security/authentication/client/KerberosAuthenticator.class",
"6bab26c1032a38621c20050ec92067226d1d67972d0d370e412ca25f1df96b76",
SubjectGetSubjectPatcher::new
)
)
)
);
Expand Down Expand Up @@ -95,55 +130,9 @@ public void transform(@NotNull TransformOutputs outputs) {
} else {
patchersToApply.forEach(patchers -> {
System.out.println("Patching " + inputFile.getName());

Map<String, Function<ClassWriter, ClassVisitor>> jarPatchers = new HashMap<>(patchers.jarPatchers());
File outputFile = outputs.file(inputFile.getName().replace(".jar", "-patched.jar"));

patchJar(inputFile, outputFile, jarPatchers);

if (jarPatchers.isEmpty() == false) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"error patching [%s] with [%s]: the jar does not contain [%s]",
inputFile.getName(),
patchers.artifactPattern().toString(),
String.join(", ", jarPatchers.keySet())
)
);
}
Utils.patchJar(inputFile, outputFile, patchers.jarPatchers());
});
}
}

private static void patchJar(File inputFile, File outputFile, Map<String, Function<ClassWriter, ClassVisitor>> jarPatchers) {
try (JarFile jarFile = new JarFile(inputFile); JarOutputStream jos = new JarOutputStream(new FileOutputStream(outputFile))) {
Enumeration<JarEntry> entries = jarFile.entries();
while (entries.hasMoreElements()) {
JarEntry entry = entries.nextElement();
String entryName = entry.getName();
// Add the entry to the new JAR file
jos.putNextEntry(new JarEntry(entryName));

Function<ClassWriter, ClassVisitor> classPatcher = jarPatchers.remove(entryName);
if (classPatcher != null) {
byte[] classToPatch = jarFile.getInputStream(entry).readAllBytes();

ClassReader classReader = new ClassReader(classToPatch);
ClassWriter classWriter = new ClassWriter(classReader, COMPUTE_FRAMES | COMPUTE_MAXS);
classReader.accept(classPatcher.apply(classWriter), 0);

jos.write(classWriter.toByteArray());
} else {
// Read the entry's data and write it to the new JAR
try (InputStream is = jarFile.getInputStream(entry)) {
is.transferTo(jos);
}
}
jos.closeEntry();
}
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class SubjectGetSubjectPatcher extends ClassVisitor {

@Override
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
return new ReplaceCallMethodVisitor(super.visitMethod(access, name, descriptor, signature, exceptions), name, access, descriptor);
return new ReplaceCallMethodVisitor(super.visitMethod(access, name, descriptor, signature, exceptions));
}

/**
Expand All @@ -35,7 +35,7 @@ private static class ReplaceCallMethodVisitor extends MethodVisitor {
private static final String SUBJECT_CLASS_INTERNAL_NAME = "javax/security/auth/Subject";
private static final String METHOD_NAME = "getSubject";

ReplaceCallMethodVisitor(MethodVisitor methodVisitor, String name, int access, String descriptor) {
ReplaceCallMethodVisitor(MethodVisitor methodVisitor) {
super(ASM9, methodVisitor);
}

Expand Down
Loading