diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/xml/ResourceEntityResolver.java b/spring-beans/src/main/java/org/springframework/beans/factory/xml/ResourceEntityResolver.java index 81abce016538..1b348693c9b7 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/xml/ResourceEntityResolver.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/xml/ResourceEntityResolver.java @@ -120,18 +120,20 @@ else if (systemId.endsWith(DTD_SUFFIX) || systemId.endsWith(XSD_SUFFIX)) { /** * A fallback method for {@link #resolveEntity(String, String)} that is used when a * "schema" entity (DTD or XSD) cannot be resolved as a local resource. The default - * behavior is to perform a remote resolution over HTTPS. + * behavior is to perform remote resolution over HTTPS. *

Subclasses can override this method to change the default behavior. *

* @param publicId the public identifier of the external entity being referenced, * or null if none was supplied - * @param systemId the system identifier of the external entity being referenced + * @param systemId the system identifier of the external entity being referenced, + * representing the URL of the DTD or XSD * @return an InputSource object describing the new input source, or null to request - * that the parser open a regular URI connection to the system identifier. + * that the parser open a regular URI connection to the system identifier + * @since 6.0.4 */ @Nullable protected InputSource resolveSchemaEntity(@Nullable String publicId, String systemId) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/xml/ResourceEntityResolverTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/xml/ResourceEntityResolverTests.java index 2b1b0b326057..b6904fce9f4e 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/xml/ResourceEntityResolverTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/xml/ResourceEntityResolverTests.java @@ -16,88 +16,68 @@ package org.springframework.beans.factory.xml; -import java.io.IOException; - -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mockito; import org.xml.sax.InputSource; -import org.xml.sax.SAXException; import org.springframework.core.io.Resource; import org.springframework.core.io.ResourceLoader; import org.springframework.lang.Nullable; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** + * Unit tests for ResourceEntityResolver. + * * @author Simon Baslé + * @author Sam Brannen + * @since 6.0.4 */ class ResourceEntityResolverTests { - @Test - void resolveEntityCallsFallbackWithNullOnDtd() throws IOException, SAXException { - ResourceEntityResolver resolver = new FallingBackEntityResolver(false, null); - - assertThat(resolver.resolveEntity("testPublicId", "https://example.org/exampleschema.dtd")) - .isNull(); - } - - @Test - void resolveEntityCallsFallbackWithNullOnXsd() throws IOException, SAXException { - ResourceEntityResolver resolver = new FallingBackEntityResolver(false, null); + @ParameterizedTest + @ValueSource(strings = { "https://example.org/schema/", "https://example.org/schema.xml" }) + void resolveEntityDoesNotCallFallbackIfNotSchema(String systemId) throws Exception { + ConfigurableFallbackEntityResolver resolver = new ConfigurableFallbackEntityResolver(true); - assertThat(resolver.resolveEntity("testPublicId", "https://example.org/exampleschema.xsd")) - .isNull(); + assertThat(resolver.resolveEntity("testPublicId", systemId)).isNull(); + assertThat(resolver.fallbackInvoked).isFalse(); } - @Test - void resolveEntityCallsFallbackWithThrowOnDtd() { - ResourceEntityResolver resolver = new FallingBackEntityResolver(true, null); + @ParameterizedTest + @ValueSource(strings = { "https://example.org/schema.dtd", "https://example.org/schema.xsd" }) + void resolveEntityCallsFallbackThatReturnsNull(String systemId) throws Exception { + ConfigurableFallbackEntityResolver resolver = new ConfigurableFallbackEntityResolver(null); - assertThatIllegalStateException().isThrownBy( - () -> resolver.resolveEntity("testPublicId", "https://example.org/exampleschema.dtd")) - .withMessage("FallingBackEntityResolver that throws"); + assertThat(resolver.resolveEntity("testPublicId", systemId)).isNull(); + assertThat(resolver.fallbackInvoked).isTrue(); } - @Test - void resolveEntityCallsFallbackWithThrowOnXsd() { - ResourceEntityResolver resolver = new FallingBackEntityResolver(true, null); + @ParameterizedTest + @ValueSource(strings = { "https://example.org/schema.dtd", "https://example.org/schema.xsd" }) + void resolveEntityCallsFallbackThatThrowsException(String systemId) { + ConfigurableFallbackEntityResolver resolver = new ConfigurableFallbackEntityResolver(true); - assertThatIllegalStateException().isThrownBy( - () -> resolver.resolveEntity("testPublicId", "https://example.org/exampleschema.xsd")) - .withMessage("FallingBackEntityResolver that throws"); + assertThatExceptionOfType(ResolutionRejectedException.class) + .isThrownBy(() -> resolver.resolveEntity("testPublicId", systemId)); + assertThat(resolver.fallbackInvoked).isTrue(); } - @Test - void resolveEntityCallsFallbackWithInputSourceOnDtd() throws IOException, SAXException { + @ParameterizedTest + @ValueSource(strings = { "https://example.org/schema.dtd", "https://example.org/schema.xsd" }) + void resolveEntityCallsFallbackThatReturnsInputSource(String systemId) throws Exception { InputSource expected = Mockito.mock(InputSource.class); - ResourceEntityResolver resolver = new FallingBackEntityResolver(false, expected); + ConfigurableFallbackEntityResolver resolver = new ConfigurableFallbackEntityResolver(expected); - assertThat(resolver.resolveEntity("testPublicId", "https://example.org/exampleschema.dtd")) - .isNotNull() - .isSameAs(expected); + assertThat(resolver.resolveEntity("testPublicId", systemId)).isSameAs(expected); + assertThat(resolver.fallbackInvoked).isTrue(); } - @Test - void resolveEntityCallsFallbackWithInputSourceOnXsd() throws IOException, SAXException { - InputSource expected = Mockito.mock(InputSource.class); - ResourceEntityResolver resolver = new FallingBackEntityResolver(false, expected); - - assertThat(resolver.resolveEntity("testPublicId", "https://example.org/exampleschema.xsd")) - .isNotNull() - .isSameAs(expected); - } - - @Test - void resolveEntityDoesntCallFallbackIfNotSchema() throws IOException, SAXException { - ResourceEntityResolver resolver = new FallingBackEntityResolver(true, null); - - assertThat(resolver.resolveEntity("testPublicId", "https://example.org/example.xml")) - .isNull(); - } private static final class NoOpResourceLoader implements ResourceLoader { + @Override public Resource getResource(String location) { return null; @@ -109,23 +89,40 @@ public ClassLoader getClassLoader() { } } - private static class FallingBackEntityResolver extends ResourceEntityResolver { + private static class ConfigurableFallbackEntityResolver extends ResourceEntityResolver { private final boolean shouldThrow; + @Nullable private final InputSource returnValue; - private FallingBackEntityResolver(boolean shouldThrow, @Nullable InputSource returnValue) { + boolean fallbackInvoked = false; + + + private ConfigurableFallbackEntityResolver(boolean shouldThrow) { super(new NoOpResourceLoader()); this.shouldThrow = shouldThrow; + this.returnValue = null; + } + + private ConfigurableFallbackEntityResolver(@Nullable InputSource returnValue) { + super(new NoOpResourceLoader()); + this.shouldThrow = false; this.returnValue = returnValue; } + @Nullable @Override protected InputSource resolveSchemaEntity(String publicId, String systemId) { - if (shouldThrow) throw new IllegalStateException("FallingBackEntityResolver that throws"); + this.fallbackInvoked = true; + if (this.shouldThrow) { + throw new ResolutionRejectedException(); + } return this.returnValue; } } + + static class ResolutionRejectedException extends RuntimeException {} + }