Skip to content

Commit

Permalink
Simplify TestClassLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrannen committed May 8, 2023
1 parent f16982e commit ce8dcbb
Showing 1 changed file with 11 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
package org.junit.platform.commons.test;

import java.lang.StackWalker.Option;
import java.lang.StackWalker.StackFrame;
import java.net.URL;
import java.net.URLClassLoader;
import java.security.CodeSource;
Expand All @@ -32,22 +31,20 @@
*/
public class TestClassLoader extends URLClassLoader {

private static final StackWalker stackWalker = StackWalker.getInstance(Option.RETAIN_CLASS_REFERENCE);

static {
ClassLoader.registerAsParallelCapable();
}

private static final Predicate<Class<?>> notTestClassLoader = clazz -> !clazz.equals(TestClassLoader.class);

private final Predicate<String> classNameFilter;

/**
* Create a {@link TestClassLoader} that filters the provided classes.
*
* @see #forClassNamePrefix(String)
*/
public static TestClassLoader forClasses(Class<?>... classes) {
Predicate<String> classNameFilter = name -> Arrays.stream(classes).map(Class::getName).anyMatch(name::equals);
return new TestClassLoader(classNameFilter);
return new TestClassLoader(getCodeSourceUrl(stackWalker.getCallerClass()), classNameFilter);
}

/**
Expand All @@ -57,11 +54,13 @@ public static TestClassLoader forClasses(Class<?>... classes) {
* @see #forClasses(Class...)
*/
public static TestClassLoader forClassNamePrefix(String prefix) {
return new TestClassLoader(name -> name.startsWith(prefix));
return new TestClassLoader(getCodeSourceUrl(stackWalker.getCallerClass()), name -> name.startsWith(prefix));
}

public TestClassLoader(Predicate<String> classNameFilter) {
super(new URL[] { getCodeSourceUrl() }, ClassLoaderUtils.getDefaultClassLoader());
private final Predicate<String> classNameFilter;

private TestClassLoader(URL codeSourceUrl, Predicate<String> classNameFilter) {
super(new URL[] { codeSourceUrl }, ClassLoaderUtils.getDefaultClassLoader());

this.classNameFilter = classNameFilter;
}
Expand All @@ -78,22 +77,10 @@ public Class<?> loadClass(String name) throws ClassNotFoundException {
}

/**
* Get the {@link CodeSource} {@link URL} of the class that instantiated the
* {@code TestClassLoader}.
* Get the {@link CodeSource} {@link URL} of the supplied class.
*/
private static URL getCodeSourceUrl() {
StackWalker walker = StackWalker.getInstance(Option.RETAIN_CLASS_REFERENCE);

// @formatter:off
Class<?> callerClass = walker.walk(stream -> stream
.map(StackFrame::getDeclaringClass)
.filter(notTestClassLoader)
.findFirst()
.get()
);
// @formatter:on

return callerClass.getProtectionDomain().getCodeSource().getLocation();
private static URL getCodeSourceUrl(Class<?> clazz) {
return clazz.getProtectionDomain().getCodeSource().getLocation();
}

}

0 comments on commit ce8dcbb

Please sign in to comment.