From b42fc58537ae93a64b0f38b596f3d07274dadf19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Mendelski?= Date: Sun, 4 Dec 2022 10:41:24 +0100 Subject: [PATCH] Fix context scanner --- build.gradle | 4 +- .../com/coditory/quark/context/ClassPath.java | 367 ++++++++++++++++++ .../quark/context/ClassPathScanner.java | 56 +++ .../quark/context/ClasspathScanner.java | 90 ----- .../quark/context/ContextBuilder.java | 2 +- 5 files changed, 426 insertions(+), 93 deletions(-) create mode 100644 src/main/java/com/coditory/quark/context/ClassPath.java create mode 100644 src/main/java/com/coditory/quark/context/ClassPathScanner.java delete mode 100644 src/main/java/com/coditory/quark/context/ClasspathScanner.java diff --git a/build.gradle b/build.gradle index 2160236..a7fb6d8 100644 --- a/build.gradle +++ b/build.gradle @@ -12,10 +12,10 @@ group = 'com.coditory.quark' description = 'Coditory Quark Context Library' dependencies { - api 'org.slf4j:slf4j-api:2.0.3' + api 'org.slf4j:slf4j-api:2.0.5' api 'org.jetbrains:annotations:23.0.0' api 'com.coditory.quark:quark-eventbus:0.0.5' - testImplementation 'ch.qos.logback:logback-classic:1.4.4' + testImplementation 'ch.qos.logback:logback-classic:1.4.5' testImplementation 'org.spockframework:spock-core:2.3-groovy-4.0' testImplementation 'org.skyscreamer:jsonassert:1.5.1' } diff --git a/src/main/java/com/coditory/quark/context/ClassPath.java b/src/main/java/com/coditory/quark/context/ClassPath.java new file mode 100644 index 0000000..cf6c08a --- /dev/null +++ b/src/main/java/com/coditory/quark/context/ClassPath.java @@ -0,0 +1,367 @@ +package com.coditory.quark.context; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URISyntaxException; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.jar.Attributes; +import java.util.jar.JarEntry; +import java.util.jar.JarFile; +import java.util.jar.Manifest; + +import static java.util.Collections.unmodifiableList; +import static java.util.Collections.unmodifiableMap; +import static java.util.Collections.unmodifiableSet; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toSet; + +final class ClassPath { + private static final Logger logger = LoggerFactory.getLogger(ClassPath.class.getName()); + private static final String CLASS_FILE_NAME_EXTENSION = ".class"; + private static final String PATH_SEPARATOR_SYS_PROP = System.getProperty("path.separator"); + private static final String JAVA_CLASS_PATH_SYS_PROP = System.getProperty("java.class.path"); + + private final Set resources; + + private ClassPath(Set resources) { + this.resources = resources; + } + + public static ClassPath from(ClassLoader classloader) throws IOException { + requireNonNull(classloader); + Set locations = locationsFrom(classloader); + Set scanned = new LinkedHashSet<>(); + for (LocationInfo location : locations) { + scanned.add(location.file()); + } + Set resources = new LinkedHashSet<>(); + for (LocationInfo location : locations) { + resources.addAll(location.scanResources(scanned)); + } + return new ClassPath(resources); + } + + public Set getTopLevelClasses() { + return resources.stream() + .filter(r -> r instanceof ClassInfo) + .map(r -> (ClassInfo) r) + .filter(ClassInfo::isTopLevel) + .collect(toSet()); + } + + public Set getTopLevelClassesRecursive(String packageName) { + requireNonNull(packageName); + String packagePrefix = packageName + '.'; + Set classes = new LinkedHashSet<>(); + for (ClassInfo classInfo : getTopLevelClasses()) { + if (classInfo.getName().startsWith(packagePrefix)) { + classes.add(classInfo); + } + } + return unmodifiableSet(classes); + } + + public static class ResourceInfo { + private final File file; + private final String resourceName; + + final ClassLoader loader; + + static ResourceInfo of(File file, String resourceName, ClassLoader loader) { + return resourceName.endsWith(CLASS_FILE_NAME_EXTENSION) + ? new ClassInfo(file, resourceName, loader) + : new ResourceInfo(file, resourceName, loader); + } + + ResourceInfo(File file, String resourceName, ClassLoader loader) { + this.file = requireNonNull(file); + this.resourceName = requireNonNull(resourceName); + this.loader = requireNonNull(loader); + } + + public File getFile() { + return file; + } + + @Override + public int hashCode() { + return resourceName.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof ResourceInfo that) { + return resourceName.equals(that.resourceName) + && loader == that.loader; + } + return false; + } + + @Override + public String toString() { + return resourceName; + } + } + + public static final class ClassInfo extends ResourceInfo { + private final String className; + + ClassInfo(File file, String resourceName, ClassLoader loader) { + super(file, resourceName, loader); + this.className = getClassName(resourceName); + } + + public String getName() { + return className; + } + + public boolean isTopLevel() { + return className.indexOf('$') == -1; + } + + @Override + public String toString() { + return className; + } + } + + static Set locationsFrom(ClassLoader classloader) { + Set locations = new LinkedHashSet<>(); + for (Map.Entry entry : getClassPathEntries(classloader).entrySet()) { + locations.add(new LocationInfo(entry.getKey(), entry.getValue())); + } + return unmodifiableSet(locations); + } + + static final class LocationInfo { + final File home; + private final ClassLoader classloader; + + LocationInfo(File home, ClassLoader classloader) { + this.home = requireNonNull(home); + this.classloader = requireNonNull(classloader); + } + + public File file() { + return home; + } + + public Set scanResources(Set scannedFiles) throws IOException { + Set resources = new LinkedHashSet<>(); + scannedFiles.add(home); + scan(home, scannedFiles, resources); + return unmodifiableSet(resources); + } + + private void scan(File file, Set scannedUris, Set result) + throws IOException { + try { + if (!file.exists()) { + return; + } + } catch (SecurityException e) { + logger.warn("Cannot access " + file + ": " + e); + return; + } + if (file.isDirectory()) { + scanDirectory(file, result); + } else { + scanJar(file, scannedUris, result); + } + } + + private void scanJar(File file, Set scannedUris, Set result) throws IOException { + JarFile jarFile; + try { + jarFile = new JarFile(file); + } catch (IOException e) { + // Not a jar file + return; + } + try { + for (File path : getClassPathFromManifest(file, jarFile.getManifest())) { + // We only scan each file once independent of the classloader that file might be + // associated with. + if (scannedUris.add(path.getCanonicalFile())) { + scan(path, scannedUris, result); + } + } + scanJarFile(jarFile, result); + } finally { + try { + jarFile.close(); + } catch (IOException ignored) { // similar to try-with-resources, but don't fail scanning + } + } + } + + private void scanJarFile(JarFile file, Set result) { + Enumeration entries = file.entries(); + while (entries.hasMoreElements()) { + JarEntry entry = entries.nextElement(); + if (entry.isDirectory() || entry.getName().equals(JarFile.MANIFEST_NAME)) { + continue; + } + result.add(ResourceInfo.of(new File(file.getName()), entry.getName(), classloader)); + } + } + + private void scanDirectory(File directory, Set result) + throws IOException { + Set currentPath = new HashSet<>(); + currentPath.add(directory.getCanonicalFile()); + scanDirectory(directory, "", currentPath, result); + } + + private void scanDirectory( + File directory, + String packagePrefix, + Set currentPath, + Set builder + ) throws IOException { + File[] files = directory.listFiles(); + if (files == null) { + logger.warn("Cannot read directory " + directory); + // IO error, just skip the directory + return; + } + for (File f : files) { + String name = f.getName(); + if (f.isDirectory()) { + File deref = f.getCanonicalFile(); + if (currentPath.add(deref)) { + scanDirectory(deref, packagePrefix + name + "/", currentPath, builder); + currentPath.remove(deref); + } + } else { + String resourceName = packagePrefix + name; + if (!resourceName.equals(JarFile.MANIFEST_NAME)) { + builder.add(ResourceInfo.of(f, resourceName, classloader)); + } + } + } + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof LocationInfo that) { + return home.equals(that.home) && classloader.equals(that.classloader); + } + return false; + } + + @Override + public int hashCode() { + return home.hashCode(); + } + + @Override + public String toString() { + return home.toString(); + } + } + + static Set getClassPathFromManifest(File jarFile, Manifest manifest) { + if (manifest == null) { + return Set.of(); + } + Set result = new LinkedHashSet<>(); + String classpathAttribute = manifest + .getMainAttributes() + .getValue(Attributes.Name.CLASS_PATH.toString()); + if (classpathAttribute != null) { + for (String path : classpathAttribute.split(" ")) { + if (path.isBlank()) { + continue; + } + URL url; + try { + url = getClassPathEntry(jarFile, path); + } catch (MalformedURLException e) { + // Ignore bad entry + logger.warn("Invalid Class-Path entry: " + path); + continue; + } + if (url.getProtocol().equals("file")) { + result.add(toFile(url)); + } + } + } + return unmodifiableSet(result); + } + + static Map getClassPathEntries(ClassLoader classloader) { + LinkedHashMap entries = new LinkedHashMap<>(); + // Search parent first, since it's the order ClassLoader#loadClass() uses. + ClassLoader parent = classloader.getParent(); + if (parent != null) { + entries.putAll(getClassPathEntries(parent)); + } + for (URL url : getClassLoaderUrls(classloader)) { + if (url.getProtocol().equals("file")) { + File file = toFile(url); + if (!entries.containsKey(file)) { + entries.put(file, classloader); + } + } + } + return unmodifiableMap(entries); + } + + private static List getClassLoaderUrls(ClassLoader classloader) { + if (classloader instanceof URLClassLoader) { + return Arrays.asList(((URLClassLoader) classloader).getURLs()); + } + if (classloader.equals(ClassLoader.getSystemClassLoader())) { + return parseJavaClassPath(); + } + return List.of(); + } + + private static List parseJavaClassPath() { + List urls = new ArrayList<>(); + for (String entry : JAVA_CLASS_PATH_SYS_PROP.split(PATH_SEPARATOR_SYS_PROP)) { + try { + try { + urls.add(new File(entry).toURI().toURL()); + } catch (SecurityException e) { // File.toURI checks to see if the file is a directory + urls.add(new URL("file", null, new File(entry).getAbsolutePath())); + } + } catch (MalformedURLException e) { + logger.warn("Malformed classpath entry: " + entry, e); + } + } + return unmodifiableList(urls); + } + + private static URL getClassPathEntry(File jarFile, String path) throws MalformedURLException { + return new URL(jarFile.toURI().toURL(), path); + } + + private static String getClassName(String filename) { + int classNameEnd = filename.length() - CLASS_FILE_NAME_EXTENSION.length(); + return filename.substring(0, classNameEnd).replace('/', '.'); + } + + private static File toFile(URL url) { + try { + return new File(url.toURI()); // Accepts escaped characters like %20. + } catch (URISyntaxException e) { // URL.toURI() doesn't escape chars. + return new File(url.getPath()); // Accepts non-escaped chars like space. + } + } +} diff --git a/src/main/java/com/coditory/quark/context/ClassPathScanner.java b/src/main/java/com/coditory/quark/context/ClassPathScanner.java new file mode 100644 index 0000000..969ef68 --- /dev/null +++ b/src/main/java/com/coditory/quark/context/ClassPathScanner.java @@ -0,0 +1,56 @@ +package com.coditory.quark.context; + +import java.io.IOException; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; +import java.util.function.Predicate; + +final class ClassPathScanner implements Iterator> { + static ClassPathScanner scanPackageAndSubPackages(String packageName, Predicate filter, ClassLoader classLoader) { + try { + List classes = getClasses(packageName, filter, classLoader); + return new ClassPathScanner(classes, classLoader); + } catch (IOException e) { + throw new RuntimeException("Could not scan classpath", e); + } + } + + private static List getClasses(String packageName, Predicate filter, ClassLoader classLoader) + throws IOException { + return ClassPath.from(classLoader) + .getTopLevelClassesRecursive(packageName) + .stream() + .map(ClassPath.ClassInfo::getName) + .filter(filter) + .toList(); + } + + private final ClassLoader classLoader; + private final Queue classesToScan; + + ClassPathScanner(List classesToScan, ClassLoader classLoader) { + this.classLoader = classLoader; + this.classesToScan = new LinkedList<>(classesToScan); + } + + @Override + public boolean hasNext() { + return !classesToScan.isEmpty(); + } + + @Override + public Class next() { + String className = classesToScan.poll(); + return loadClass(className); + } + + private Class loadClass(String canonicalName) { + try { + return classLoader.loadClass(canonicalName); + } catch (ClassNotFoundException e) { + throw new RuntimeException("Could not load class: " + canonicalName, e); + } + } +} diff --git a/src/main/java/com/coditory/quark/context/ClasspathScanner.java b/src/main/java/com/coditory/quark/context/ClasspathScanner.java deleted file mode 100644 index 72e79ad..0000000 --- a/src/main/java/com/coditory/quark/context/ClasspathScanner.java +++ /dev/null @@ -1,90 +0,0 @@ -package com.coditory.quark.context; - -import java.io.File; -import java.io.IOException; -import java.net.URL; -import java.util.ArrayList; -import java.util.Enumeration; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; -import java.util.Queue; -import java.util.function.Predicate; - -final class ClasspathScanner implements Iterator> { - static ClasspathScanner scanPackageAndSubPackages(String packageName, Predicate filter, ClassLoader classLoader) { - try { - return new ClasspathScanner(getClasses(packageName, filter), classLoader); - } catch (IOException e) { - throw new RuntimeException("Could not scan classpath", e); - } - } - - private static List getClasses(String packageName, Predicate filter) - throws IOException { - ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); - assert classLoader != null; - String path = packageName.replace('.', '/'); - Enumeration resources = classLoader.getResources(path); - List dirs = new ArrayList<>(); - while (resources.hasMoreElements()) { - URL resource = resources.nextElement(); - dirs.add(new File(resource.getFile())); - } - ArrayList classes = new ArrayList<>(); - for (File directory : dirs) { - classes.addAll(findClasses(directory, packageName, filter)); - } - return classes; - } - - private static List findClasses(File directory, String packageName, Predicate filter) { - List classes = new ArrayList<>(); - if (!directory.exists()) { - return classes; - } - File[] files = directory.listFiles(); - if (files == null) { - return classes; - } - for (File file : files) { - if (file.isDirectory()) { - assert !file.getName().contains("."); - classes.addAll(findClasses(file, packageName + "." + file.getName(), filter)); - } else if (file.getName().endsWith(".class")) { - String canonicalName = packageName + '.' + file.getName().substring(0, file.getName().length() - 6); - if (filter.test(canonicalName)) { - classes.add(canonicalName); - } - } - } - return classes; - } - - private final ClassLoader classLoader; - private final Queue classesToScan; - - ClasspathScanner(List classesToScan, ClassLoader classLoader) { - this.classLoader = classLoader; - this.classesToScan = new LinkedList<>(classesToScan); - } - - @Override - public boolean hasNext() { - return !classesToScan.isEmpty(); - } - - @Override - public Class next() { - String className = classesToScan.poll(); - return loadClass(className); - } - - private Class loadClass(String canonicalName) { - try { - return classLoader.loadClass(canonicalName); - } catch (ClassNotFoundException e) { - throw new RuntimeException("Could not load class: " + canonicalName, e); - } - } -} diff --git a/src/main/java/com/coditory/quark/context/ContextBuilder.java b/src/main/java/com/coditory/quark/context/ContextBuilder.java index c66453d..af843d6 100644 --- a/src/main/java/com/coditory/quark/context/ContextBuilder.java +++ b/src/main/java/com/coditory/quark/context/ContextBuilder.java @@ -165,7 +165,7 @@ public ContextBuilder scanPackage(@NotNull String packageName) { public ContextBuilder scanPackage(@NotNull String packageName, @NotNull Predicate canonicalNameFilter) { expectNonNull(packageName, "packageName"); expectNonNull(canonicalNameFilter, "canonicalNameFilter"); - classpathScanners.add(() -> ClasspathScanner.scanPackageAndSubPackages(packageName, canonicalNameFilter, classLoader)); + classpathScanners.add(() -> ClassPathScanner.scanPackageAndSubPackages(packageName, canonicalNameFilter, classLoader)); return this; }