Skip to content

Commit

Permalink
Convert String to Class using custom ClassLoader for @ParameterizedTest
Browse files Browse the repository at this point in the history
The @ParameterizedTest infrastructure has built-in support for
converting from a fully-qualified class name (String) to a Class;
however, if the named class is not visible in JUnit's default
ClassLoader, the conversion will fail.

This commit addresses this issue by refactoring the internals of the
@ParameterizedTest infrastructure so that the ClassLoader of the class
in which the @ParameterizedTest method is declared is used to resolve
classes instead of the default ClassLoader

Closes #3291
  • Loading branch information
sbrannen committed May 9, 2023
1 parent 3571869 commit f6e73ac
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ repository on GitHub.

==== Bug Fixes

* ❓
* When converting an argument for a `@ParameterizedTest` method from a fully-qualified
class name (`String`) to a `Class`, the `ClassLoader` of the class in which the
`@ParameterizedTest` method is declared is now used to resolve the `Class` instead of
the _default_ `ClassLoader`.

==== Deprecations and Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions junit-jupiter-params/junit-jupiter-params.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies {
testImplementation(projects.junitJupiterEngine)
testImplementation(projects.junitPlatformLauncher)
testImplementation(projects.junitPlatformSuiteEngine)
testImplementation(testFixtures(projects.junitPlatformCommons))
testImplementation(testFixtures(projects.junitJupiterEngine))

compileOnly(kotlin("stdlib"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ static class Aggregator implements Resolver {

@Override
public Object resolve(ParameterContext parameterContext, Object[] arguments, int invocationIndex) {
ArgumentsAccessor accessor = new DefaultArgumentsAccessor(invocationIndex, arguments);
ArgumentsAccessor accessor = new DefaultArgumentsAccessor(parameterContext, invocationIndex, arguments);
try {
return this.argumentsAggregator.aggregateArguments(accessor, parameterContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.List;

import org.apiguardian.api.API;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.params.converter.DefaultArgumentConverter;
import org.junit.platform.commons.util.ClassUtils;
import org.junit.platform.commons.util.Preconditions;
Expand All @@ -35,12 +36,15 @@
@API(status = INTERNAL, since = "5.2")
public class DefaultArgumentsAccessor implements ArgumentsAccessor {

private final ParameterContext parameterContext;
private final int invocationIndex;
private final Object[] arguments;

public DefaultArgumentsAccessor(int invocationIndex, Object... arguments) {
public DefaultArgumentsAccessor(ParameterContext parameterContext, int invocationIndex, Object... arguments) {
Preconditions.notNull(parameterContext, "ParameterContext must not be null");
Preconditions.condition(invocationIndex >= 1, () -> "invocation index must be >= 1");
Preconditions.notNull(arguments, "Arguments array must not be null");
this.parameterContext = parameterContext;
this.invocationIndex = invocationIndex;
this.arguments = arguments;
}
Expand All @@ -57,7 +61,8 @@ public <T> T get(int index, Class<T> requiredType) {
Preconditions.notNull(requiredType, "requiredType must not be null");
Object value = get(index);
try {
Object convertedValue = DefaultArgumentConverter.INSTANCE.convert(value, requiredType);
Object convertedValue = DefaultArgumentConverter.INSTANCE.convert(value, requiredType,
this.parameterContext);
return requiredType.cast(convertedValue);
}
catch (Exception ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
import java.util.function.Function;

import org.apiguardian.api.API;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.platform.commons.util.ClassLoaderUtils;
import org.junit.platform.commons.util.Preconditions;
import org.junit.platform.commons.util.ReflectionUtils;

Expand All @@ -70,14 +72,15 @@
* @see org.junit.jupiter.params.converter.ArgumentConverter
*/
@API(status = INTERNAL, since = "5.0")
public class DefaultArgumentConverter extends SimpleArgumentConverter {
public class DefaultArgumentConverter implements ArgumentConverter {

public static final DefaultArgumentConverter INSTANCE = new DefaultArgumentConverter();

private static final List<StringToObjectConverter> stringToObjectConverters = unmodifiableList(asList( //
new StringToBooleanConverter(), //
new StringToCharacterConverter(), //
new StringToNumberConverter(), //
new StringToClassConverter(), //
new StringToEnumConverter(), //
new StringToJavaTimeConverter(), //
new StringToCommonJavaTypesConverter(), //
Expand All @@ -89,7 +92,12 @@ private DefaultArgumentConverter() {
}

@Override
public Object convert(Object source, Class<?> targetType) {
public final Object convert(Object source, ParameterContext context) {
Class<?> targetType = context.getParameter().getType();
return convert(source, targetType, context);
}

public final Object convert(Object source, Class<?> targetType, ParameterContext context) {
if (source == null) {
if (targetType.isPrimitive()) {
throw new ArgumentConversionException(
Expand All @@ -102,17 +110,17 @@ public Object convert(Object source, Class<?> targetType) {
return source;
}

return convertToTargetType(source, targetType);
}

private Object convertToTargetType(Object source, Class<?> targetType) {
if (source instanceof String) {
Class<?> targetTypeToUse = toWrapperType(targetType);
Optional<StringToObjectConverter> converter = stringToObjectConverters.stream().filter(
candidate -> candidate.canConvert(targetTypeToUse)).findFirst();
if (converter.isPresent()) {
ClassLoader classLoader = context.getDeclaringExecutable().getDeclaringClass().getClassLoader();
if (classLoader == null) {
classLoader = ClassLoaderUtils.getDefaultClassLoader();
}
try {
return converter.get().convert((String) source, targetTypeToUse);
return converter.get().convert((String) source, targetTypeToUse, classLoader);
}
catch (Exception ex) {
if (ex instanceof ArgumentConversionException) {
Expand Down Expand Up @@ -146,6 +154,10 @@ interface StringToObjectConverter {

Object convert(String source, Class<?> targetType) throws Exception;

default Object convert(String source, Class<?> targetType, ClassLoader classLoader) throws Exception {
return convert(source, targetType);
}

}

private static class StringToBooleanConverter implements StringToObjectConverter {
Expand Down Expand Up @@ -203,6 +215,29 @@ public Object convert(String source, Class<?> targetType) {
}
}

private static class StringToClassConverter implements StringToObjectConverter {

@Override
public boolean canConvert(Class<?> targetType) {
return targetType == Class.class;
}

@Override
public Object convert(String source, Class<?> targetType) throws Exception {
throw new UnsupportedOperationException();
}

@Override
public Object convert(String className, Class<?> targetType, ClassLoader classLoader) throws Exception {
// @formatter:off
return ReflectionUtils.tryToLoadClass(className, classLoader)
.getOrThrow(cause -> new ArgumentConversionException(
"Failed to convert String \"" + className + "\" to type java.lang.Class", cause));
// @formatter:on
}

}

private static class StringToEnumConverter implements StringToObjectConverter {

@Override
Expand Down Expand Up @@ -261,8 +296,6 @@ private static class StringToCommonJavaTypesConverter implements StringToObjectC
static {
Map<Class<?>, Function<String, ?>> converters = new HashMap<>();

// java.lang
converters.put(Class.class, StringToCommonJavaTypesConverter::toClass);
// java.io and java.nio
converters.put(File.class, File::new);
converters.put(Charset.class, Charset::forName);
Expand Down Expand Up @@ -291,14 +324,6 @@ public Object convert(String source, Class<?> targetType) throws Exception {
return CONVERTERS.get(targetType).apply(source);
}

private static Class<?> toClass(String type) {
// @formatter:off
return ReflectionUtils.tryToLoadClass(type)
.getOrThrow(cause -> new ArgumentConversionException(
"Failed to convert String \"" + type + "\" to type java.lang.Class", cause));
// @formatter:on
}

private static URL toURL(String url) {
try {
return URI.create(url).toURL();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.lang.reflect.Method;
import java.util.Arrays;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.platform.commons.PreconditionViolationException;
import org.junit.platform.commons.util.ReflectionUtils;

/**
* Unit tests for {@link DefaultArgumentsAccessor}.
Expand All @@ -31,113 +36,113 @@ class DefaultArgumentsAccessorTests {

@Test
void argumentsMustNotBeNull() {
assertThrows(PreconditionViolationException.class, () -> new DefaultArgumentsAccessor(1, (Object[]) null));
assertThrows(PreconditionViolationException.class, () -> defaultArgumentsAccessor(1, (Object[]) null));
}

@Test
void indexMustNotBeNegative() {
ArgumentsAccessor arguments = new DefaultArgumentsAccessor(1, 1, 2);
ArgumentsAccessor arguments = defaultArgumentsAccessor(1, 1, 2);
Exception exception = assertThrows(PreconditionViolationException.class, () -> arguments.get(-1));
assertThat(exception.getMessage()).containsSubsequence("index must be", ">= 0");
}

@Test
void indexMustBeSmallerThanLength() {
ArgumentsAccessor arguments = new DefaultArgumentsAccessor(1, 1, 2);
ArgumentsAccessor arguments = defaultArgumentsAccessor(1, 1, 2);
Exception exception = assertThrows(PreconditionViolationException.class, () -> arguments.get(2));
assertThat(exception.getMessage()).containsSubsequence("index must be", "< 2");
}

@Test
void getNull() {
assertNull(new DefaultArgumentsAccessor(1, new Object[] { null }).get(0));
assertNull(defaultArgumentsAccessor(1, new Object[] { null }).get(0));
}

@Test
void getWithNullCastToWrapperType() {
assertNull(new DefaultArgumentsAccessor(1, (Object[]) new Integer[] { null }).get(0, Integer.class));
assertNull(defaultArgumentsAccessor(1, (Object[]) new Integer[] { null }).get(0, Integer.class));
}

@Test
void get() {
assertEquals(1, new DefaultArgumentsAccessor(1, 1).get(0));
assertEquals(1, defaultArgumentsAccessor(1, 1).get(0));
}

@Test
void getWithCast() {
assertEquals(Integer.valueOf(1), new DefaultArgumentsAccessor(1, 1).get(0, Integer.class));
assertEquals(Character.valueOf('A'), new DefaultArgumentsAccessor(1, 'A').get(0, Character.class));
assertEquals(Integer.valueOf(1), defaultArgumentsAccessor(1, 1).get(0, Integer.class));
assertEquals(Character.valueOf('A'), defaultArgumentsAccessor(1, 'A').get(0, Character.class));
}

@Test
void getWithCastToPrimitiveType() {
Exception exception = assertThrows(ArgumentAccessException.class,
() -> new DefaultArgumentsAccessor(1, 1).get(0, int.class));
() -> defaultArgumentsAccessor(1, 1).get(0, int.class));
assertThat(exception.getMessage()).isEqualTo(
"Argument at index [0] with value [1] and type [java.lang.Integer] could not be converted or cast to type [int].");

exception = assertThrows(ArgumentAccessException.class,
() -> new DefaultArgumentsAccessor(1, new Object[] { null }).get(0, int.class));
() -> defaultArgumentsAccessor(1, new Object[] { null }).get(0, int.class));
assertThat(exception.getMessage()).isEqualTo(
"Argument at index [0] with value [null] and type [null] could not be converted or cast to type [int].");
}

@Test
void getWithCastToIncompatibleType() {
Exception exception = assertThrows(ArgumentAccessException.class,
() -> new DefaultArgumentsAccessor(1, 1).get(0, Character.class));
() -> defaultArgumentsAccessor(1, 1).get(0, Character.class));
assertThat(exception.getMessage()).isEqualTo(
"Argument at index [0] with value [1] and type [java.lang.Integer] could not be converted or cast to type [java.lang.Character].");
}

@Test
void getCharacter() {
assertEquals(Character.valueOf('A'), new DefaultArgumentsAccessor(1, 'A', 'B').getCharacter(0));
assertEquals(Character.valueOf('A'), defaultArgumentsAccessor(1, 'A', 'B').getCharacter(0));
}

@Test
void getBoolean() {
assertEquals(Boolean.TRUE, new DefaultArgumentsAccessor(1, true, false).getBoolean(0));
assertEquals(Boolean.TRUE, defaultArgumentsAccessor(1, true, false).getBoolean(0));
}

@Test
void getByte() {
assertEquals(Byte.valueOf((byte) 42), new DefaultArgumentsAccessor(1, (byte) 42).getByte(0));
assertEquals(Byte.valueOf((byte) 42), defaultArgumentsAccessor(1, (byte) 42).getByte(0));
}

@Test
void getShort() {
assertEquals(Short.valueOf((short) 42), new DefaultArgumentsAccessor(1, (short) 42).getShort(0));
assertEquals(Short.valueOf((short) 42), defaultArgumentsAccessor(1, (short) 42).getShort(0));
}

@Test
void getInteger() {
assertEquals(Integer.valueOf(42), new DefaultArgumentsAccessor(1, 42).getInteger(0));
assertEquals(Integer.valueOf(42), defaultArgumentsAccessor(1, 42).getInteger(0));
}

@Test
void getLong() {
assertEquals(Long.valueOf(42L), new DefaultArgumentsAccessor(1, 42L).getLong(0));
assertEquals(Long.valueOf(42L), defaultArgumentsAccessor(1, 42L).getLong(0));
}

@Test
void getFloat() {
assertEquals(Float.valueOf(42.0f), new DefaultArgumentsAccessor(1, 42.0f).getFloat(0));
assertEquals(Float.valueOf(42.0f), defaultArgumentsAccessor(1, 42.0f).getFloat(0));
}

@Test
void getDouble() {
assertEquals(Double.valueOf(42.0), new DefaultArgumentsAccessor(1, 42.0).getDouble(0));
assertEquals(Double.valueOf(42.0), defaultArgumentsAccessor(1, 42.0).getDouble(0));
}

@Test
void getString() {
assertEquals("foo", new DefaultArgumentsAccessor(1, "foo", "bar").getString(0));
assertEquals("foo", defaultArgumentsAccessor(1, "foo", "bar").getString(0));
}

@Test
void toArray() {
var arguments = new DefaultArgumentsAccessor(1, "foo", "bar");
var arguments = defaultArgumentsAccessor(1, "foo", "bar");
var copy = arguments.toArray();
assertArrayEquals(new String[] { "foo", "bar" }, copy);

Expand All @@ -148,7 +153,7 @@ void toArray() {

@Test
void toList() {
var arguments = new DefaultArgumentsAccessor(1, "foo", "bar");
var arguments = defaultArgumentsAccessor(1, "foo", "bar");
var copy = arguments.toList();
assertIterableEquals(Arrays.asList("foo", "bar"), copy);

Expand All @@ -158,9 +163,24 @@ void toList() {

@Test
void size() {
assertEquals(0, new DefaultArgumentsAccessor(1).size());
assertEquals(1, new DefaultArgumentsAccessor(1, 42).size());
assertEquals(5, new DefaultArgumentsAccessor(1, 'a', 'b', 'c', 'd', 'e').size());
assertEquals(0, defaultArgumentsAccessor(1).size());
assertEquals(1, defaultArgumentsAccessor(1, 42).size());
assertEquals(5, defaultArgumentsAccessor(1, 'a', 'b', 'c', 'd', 'e').size());
}

private static DefaultArgumentsAccessor defaultArgumentsAccessor(int invocationIndex, Object... arguments) {
return new DefaultArgumentsAccessor(parameterContext(), invocationIndex, arguments);
}

private static ParameterContext parameterContext() {
Method declaringExecutable = ReflectionUtils.findMethod(DefaultArgumentsAccessorTests.class, "foo").get();
ParameterContext parameterContext = mock();
when(parameterContext.getDeclaringExecutable()).thenReturn(declaringExecutable);
return parameterContext;
}

@SuppressWarnings("unused")
private static void foo() {
}

}
Loading

0 comments on commit f6e73ac

Please sign in to comment.