From 902f198c636d03f4711d12b929905eb7afb09202 Mon Sep 17 00:00:00 2001 From: Mark Dixon <1756429+mnd999@users.noreply.github.com> Date: Fri, 24 Nov 2023 17:45:43 +0000 Subject: [PATCH] [NOID] More fixes for issues related to URLAccessChecker * Geocode does these checks as well, and that was missing. * UtilIT rewritten again to mock URLAccessChecker so as to not rely on coredb internals --- common/src/main/java/apoc/ApocConfig.java | 6 -- common/src/main/java/apoc/util/JsonUtil.java | 4 +- core/src/main/java/apoc/spatial/Geocode.java | 37 ++++--- .../test/java/apoc/spatial/SpatialTest.java | 2 +- it/src/test/java/apoc/it/common/UtilIT.java | 100 ++++++++++-------- 5 files changed, 77 insertions(+), 72 deletions(-) diff --git a/common/src/main/java/apoc/ApocConfig.java b/common/src/main/java/apoc/ApocConfig.java index 0630b01e8..7b1cec4cf 100644 --- a/common/src/main/java/apoc/ApocConfig.java +++ b/common/src/main/java/apoc/ApocConfig.java @@ -113,8 +113,6 @@ public class ApocConfig extends LifecycleAdapter { private static ApocConfig theInstance; private GraphDatabaseService systemDb; - private List blockedIpRanges = List.of(); - private boolean expandCommands; private Duration commandEvaluationTimeout; @@ -125,7 +123,6 @@ public ApocConfig( GlobalProcedures globalProceduresRegistry, DatabaseManagementService databaseManagementService) { this.neo4jConfig = neo4jConfig; - this.blockedIpRanges = neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist); this.commandEvaluationTimeout = neo4jConfig.get(GraphDatabaseInternalSettings.config_command_evaluation_timeout); if (this.commandEvaluationTimeout == null) { @@ -145,9 +142,6 @@ public ApocConfig( // use only for unit tests public ApocConfig(Config neo4jConfig) { this.neo4jConfig = neo4jConfig; - if (neo4jConfig != null) { - this.blockedIpRanges = neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist); - } this.log = NullLog.getInstance(); this.databaseManagementService = null; theInstance = this; diff --git a/common/src/main/java/apoc/util/JsonUtil.java b/common/src/main/java/apoc/util/JsonUtil.java index e45ffb93c..d3c7b21bc 100755 --- a/common/src/main/java/apoc/util/JsonUtil.java +++ b/common/src/main/java/apoc/util/JsonUtil.java @@ -132,8 +132,8 @@ public static Stream loadJson( } } - public static Stream loadJson(String url) { - return loadJson(url, null, null, "", true, null, null); + public static Stream loadJson(String url, URLAccessChecker urlAccessChecker) { + return loadJson(url, null, null, "", true, null, urlAccessChecker); } public static T parse(String json, String path, Class type) { diff --git a/core/src/main/java/apoc/spatial/Geocode.java b/core/src/main/java/apoc/spatial/Geocode.java index d1c0c4392..a91c0ea6d 100755 --- a/core/src/main/java/apoc/spatial/Geocode.java +++ b/core/src/main/java/apoc/spatial/Geocode.java @@ -32,6 +32,8 @@ import java.util.Map; import java.util.stream.Stream; import org.apache.commons.configuration2.Configuration; + +import org.neo4j.graphdb.security.URLAccessChecker; import org.neo4j.procedure.*; public class Geocode { @@ -42,10 +44,13 @@ public class Geocode { @Context public TerminationGuard terminationGuard; + @Context + public URLAccessChecker urlAccessChecker; + interface GeocodeSupplier { - Stream geocode(String params, long maxResults); + Stream geocode(String params, long maxResults, URLAccessChecker urlAccessChecker); - Stream reverseGeocode(Double latitude, Double longitude); + Stream reverseGeocode(Double latitude, Double longitude, URLAccessChecker urlAccessChecker); } private static class Throttler { @@ -123,13 +128,13 @@ public SupplierWithKey(Configuration config, TerminationGuard terminationGuard, } @SuppressWarnings("unchecked") - public Stream geocode(String address, long maxResults) { + public Stream geocode(String address, long maxResults, URLAccessChecker urlAccessChecker) { if (address.isEmpty()) { return Stream.empty(); } throttler.waitForThrottle(); String url = urlTemplate.replace("PLACE", Util.encodeUrlComponent(address)); - Object value = JsonUtil.loadJson(url).findFirst().orElse(null); + Object value = JsonUtil.loadJson(url, urlAccessChecker).findFirst().orElse(null); if (value instanceof List) { return findResults((List>) value, maxResults); } else if (value instanceof Map) { @@ -142,13 +147,13 @@ public Stream geocode(String address, long maxResults) { } @Override - public Stream reverseGeocode(Double latitude, Double longitude) { + public Stream reverseGeocode(Double latitude, Double longitude, URLAccessChecker urlAccessChecker) { if (latitude == null || longitude == null) { return Stream.empty(); } throttler.waitForThrottle(); String url = urlTemplateReverse.replace("LAT", latitude.toString()).replace("LNG", longitude.toString()); - Object value = JsonUtil.loadJson(url).findFirst().orElse(null); + Object value = JsonUtil.loadJson(url, urlAccessChecker).findFirst().orElse(null); if (value instanceof List) { return findResults((List>) value, 1); } else if (value instanceof Map) { @@ -203,12 +208,12 @@ public OSMSupplier(Configuration config, TerminationGuard terminationGuard) { } @SuppressWarnings("unchecked") - public Stream geocode(String address, long maxResults) { + public Stream geocode(String address, long maxResults, URLAccessChecker urlAccessChecker) { if (address.isEmpty()) { return Stream.empty(); } throttler.waitForThrottle(); - Object value = JsonUtil.loadJson(OSM_URL_GEOCODE + Util.encodeUrlComponent(address)) + Object value = JsonUtil.loadJson(OSM_URL_GEOCODE + Util.encodeUrlComponent(address), urlAccessChecker) .findFirst() .orElse(null); if (value instanceof List) { @@ -225,14 +230,14 @@ public Stream geocode(String address, long maxResults) { } @Override - public Stream reverseGeocode(Double latitude, Double longitude) { + public Stream reverseGeocode(Double latitude, Double longitude, URLAccessChecker urlAccessChecker) { if (latitude == null || longitude == null) { return Stream.empty(); } throttler.waitForThrottle(); Object value = JsonUtil.loadJson( - OSM_URL_REVERSE_GEOCODE + String.format("lat=%s&lon=%s", latitude, longitude)) + OSM_URL_REVERSE_GEOCODE + String.format("lat=%s&lon=%s", latitude, longitude), urlAccessChecker) .findFirst() .orElse(null); if (value instanceof Map) { @@ -275,13 +280,13 @@ private String credentials(Configuration config) { } @SuppressWarnings("unchecked") - public Stream geocode(String address, long maxResults) { + public Stream geocode(String address, long maxResults, URLAccessChecker urlAccessChecker) { if (address.isEmpty()) { return Stream.empty(); } throttler.waitForThrottle(); Object value = JsonUtil.loadJson( - String.format(GEOCODE_URL, credentials(this.config)) + Util.encodeUrlComponent(address)) + String.format(GEOCODE_URL, credentials(this.config)) + Util.encodeUrlComponent(address), urlAccessChecker) .findFirst() .orElse(null); if (value instanceof Map) { @@ -306,13 +311,13 @@ public Stream geocode(String address, long maxResults) { } @Override - public Stream reverseGeocode(Double latitude, Double longitude) { + public Stream reverseGeocode(Double latitude, Double longitude, URLAccessChecker urlAccessChecker) { if (latitude == null || longitude == null) { return Stream.empty(); } throttler.waitForThrottle(); Object value = JsonUtil.loadJson(String.format(REVERSE_GEOCODE_URL, credentials(this.config)) - + Util.encodeUrlComponent(latitude + "," + longitude)) + + Util.encodeUrlComponent(latitude + "," + longitude), urlAccessChecker) .findFirst() .orElse(null); if (value instanceof Map) { @@ -397,7 +402,7 @@ public Stream geocode( return getSupplier(config) .geocode( address, - maxResults == 0 ? MAX_RESULTS : Math.min(Math.max(maxResults, 1), MAX_RESULTS)); + maxResults == 0 ? MAX_RESULTS : Math.min(Math.max(maxResults, 1), MAX_RESULTS), urlAccessChecker); } catch (IllegalStateException re) { if (!quotaException && re.getMessage().startsWith("QUOTA_EXCEEDED")) return Stream.empty(); throw re; @@ -415,7 +420,7 @@ public Stream reverseGeocode( @Name(value = "quotaException", defaultValue = "false") boolean quotaException, @Name(value = "config", defaultValue = "{}") Map config) { try { - return getSupplier(config).reverseGeocode(latitude, longitude); + return getSupplier(config).reverseGeocode(latitude, longitude, urlAccessChecker); } catch (IllegalStateException re) { if (!quotaException && re.getMessage().startsWith("QUOTA_EXCEEDED")) return Stream.empty(); throw re; diff --git a/core/src/test/java/apoc/spatial/SpatialTest.java b/core/src/test/java/apoc/spatial/SpatialTest.java index e35ed2d8b..591458fae 100644 --- a/core/src/test/java/apoc/spatial/SpatialTest.java +++ b/core/src/test/java/apoc/spatial/SpatialTest.java @@ -145,7 +145,7 @@ public void setUp() { TestUtil.registerProcedure(db, MockGeocode.class); apocConfig().setProperty(APOC_IMPORT_FILE_ENABLED, true); URL url = ClassLoader.getSystemResource("spatial.json"); - Map tests = (Map) JsonUtil.loadJson(url.toString()).findFirst().orElse(null); + Map tests = (Map) JsonUtil.loadJson(url.toString(), null).findFirst().orElse(null); for (Object event : (List) tests.get("events")) { addEventData((Map) event); } diff --git a/it/src/test/java/apoc/it/common/UtilIT.java b/it/src/test/java/apoc/it/common/UtilIT.java index 2da29ed8f..479bed3bd 100644 --- a/it/src/test/java/apoc/it/common/UtilIT.java +++ b/it/src/test/java/apoc/it/common/UtilIT.java @@ -20,19 +20,19 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import apoc.ApocConfig; import apoc.util.Util; -import inet.ipaddr.IPAddressString; import java.io.IOException; import java.net.HttpURLConnection; +import java.net.MalformedURLException; import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; -import java.util.List; import junit.framework.TestCase; import org.apache.commons.io.IOUtils; import org.junit.Assert; @@ -40,15 +40,22 @@ import org.junit.jupiter.api.AfterEach; import org.neo4j.configuration.Config; import org.neo4j.configuration.GraphDatabaseInternalSettings; -import org.neo4j.kernel.impl.security.WebURLAccessRule; +import org.neo4j.graphdb.security.URLAccessValidationError; +import org.neo4j.kernel.impl.security.WebUrlAccessChecker; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import org.testcontainers.containers.GenericContainer; public class UtilIT { private GenericContainer httpServer; - private GenericContainer setUpServer(Config neo4jConfig, String redirectURL) { - new ApocConfig(neo4jConfig); + public UtilIT() throws Exception { + googleUrl = new URL( "https://www.google.com" ); + } + + private GenericContainer setUpServer(String redirectURL) { + new ApocConfig(null); GenericContainer httpServer = new GenericContainer("alpine") .withCommand( "/bin/sh", @@ -66,51 +73,48 @@ public void tearDown() { httpServer.stop(); } + private final URL googleUrl; + @Test - public void redirectShouldWorkWhenProtocolNotChangesWithUrlLocation() throws IOException { - Config neo4jConfig = mock(Config.class); - when(neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist)).thenReturn(Collections.emptyList()); - httpServer = setUpServer(neo4jConfig, "https://www.google.com"); + public void redirectShouldWorkWhenProtocolNotChangesWithUrlLocation() throws Exception { + WebUrlAccessChecker mockChecker = mock(WebUrlAccessChecker.class); + httpServer = setUpServer("https://www.google.com"); // given - String url = getServerUrl(httpServer); + URL url = getServerUrl(httpServer); + when( mockChecker.checkURL( url ) ).thenReturn( url ); + when( mockChecker.checkURL( googleUrl ) ).thenReturn( googleUrl ); // when - String page = IOUtils.toString( Util.openInputStream(url, null, null, null, new WebURLAccessRule(neo4jConfig)), StandardCharsets.UTF_8); + String page = IOUtils.toString( Util.openInputStream(url.toString(), null, null, null, mockChecker ), StandardCharsets.UTF_8); // then assertTrue(page.contains("Google")); } @Test - public void redirectWithBlockedIPsWithUrlLocation() { - List blockedIPs = List.of(new IPAddressString("127.168.0.1/8")); - - Config neo4jConfig = mock(Config.class); - when(neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist)).thenReturn(blockedIPs); + public void redirectWithBlockedIPsWithUrlLocation() throws Exception{ + WebUrlAccessChecker mockChecker = mock(WebUrlAccessChecker.class); - httpServer = setUpServer(neo4jConfig, "http://127.168.0.1"); - String url = getServerUrl(httpServer); + httpServer = setUpServer("http://127.168.0.1"); + URL url = getServerUrl(httpServer); + when( mockChecker.checkURL( url ) ).thenReturn( url ); + when( mockChecker.checkURL( new URL("http://127.168.0.1") ) ).thenThrow( new URLAccessValidationError( "no" ) ); - IOException e = Assert.assertThrows(IOException.class, () -> Util.openInputStream(url, null, null, null, new WebURLAccessRule(neo4jConfig))); - TestCase.assertTrue( - e.getMessage() - .contains( - "access to /127.168.0.1 is blocked via the configuration property internal.dbms.cypher_ip_blocklist")); + IOException e = Assert.assertThrows(IOException.class, () -> Util.openInputStream(url.toString(), null, null, null, mockChecker)); + TestCase.assertTrue(e.getMessage().contains("no")); } @Test - public void redirectWithProtocolUpgradeIsAllowed() throws IOException { - List blockedIPs = List.of(new IPAddressString("127.168.0.1/8")); - - Config neo4jConfig = mock(Config.class); - when(neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist)).thenReturn(blockedIPs); - - httpServer = setUpServer(neo4jConfig, "https://www.google.com"); - String url = getServerUrl(httpServer); + public void redirectWithProtocolUpgradeIsAllowed() throws Exception { + WebUrlAccessChecker mockChecker = mock(WebUrlAccessChecker.class); + httpServer = setUpServer("https://www.google.com"); + URL url = getServerUrl(httpServer); + when( mockChecker.checkURL( url ) ).thenReturn( url ); + when( mockChecker.checkURL( googleUrl ) ).thenReturn( googleUrl ); // when - String page = IOUtils.toString( Util.openInputStream(url, null, null, null, new WebURLAccessRule(neo4jConfig)), StandardCharsets.UTF_8 ); + String page = IOUtils.toString( Util.openInputStream(url.toString(), null, null, null, mockChecker), StandardCharsets.UTF_8 ); // then assertTrue(page.contains("Google")); @@ -129,22 +133,21 @@ public void redirectWithProtocolDowngradeIsNotAllowed() throws IOException { } @Test - public void shouldFailForExceedingRedirectLimit() { - Config neo4jConfig = mock(Config.class); - when(neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist)).thenReturn(Collections.emptyList()); - - httpServer = setUpServer(neo4jConfig, "https://127.0.0.0"); - String url = getServerUrl(httpServer); + public void shouldFailForExceedingRedirectLimit() throws Exception { + WebUrlAccessChecker mockChecker = mock(WebUrlAccessChecker.class); + httpServer = setUpServer("https://127.0.0.0"); + URL url = getServerUrl(httpServer); + when( mockChecker.checkURL( any() ) ).thenAnswer( (Answer) invocation -> (URL) invocation.getArguments()[0] ); ArrayList servers = new ArrayList<>(); for (int i = 1; i <= 10; i++) { - GenericContainer server = setUpServer(neo4jConfig, url); + GenericContainer server = setUpServer(url.toString()); servers.add(server); url = getServerUrl(server); } - String finalUrl = url; - IOException e = Assert.assertThrows(IOException.class, () -> Util.openInputStream(finalUrl, null, null, null, new WebURLAccessRule(neo4jConfig))); + URL finalUrl = url; + IOException e = Assert.assertThrows(IOException.class, () -> Util.openInputStream(finalUrl.toString(), null, null, null, mockChecker)); TestCase.assertTrue(e.getMessage().contains("Redirect limit exceeded")); @@ -154,21 +157,24 @@ public void shouldFailForExceedingRedirectLimit() { } @Test - public void redirectShouldThrowExceptionWhenProtocolChangesWithFileLocation() { - httpServer = setUpServer(null, "file:/etc/passwd"); + public void redirectShouldThrowExceptionWhenProtocolChangesWithFileLocation() throws Exception { + WebUrlAccessChecker mockChecker = mock(WebUrlAccessChecker.class); + httpServer = setUpServer("file:/etc/passwd"); // given - String url = getServerUrl(httpServer); + URL url = getServerUrl(httpServer); + when( mockChecker.checkURL( url ) ).thenReturn( url ); Config neo4jConfig = mock(Config.class); when(neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist)).thenReturn(Collections.emptyList()); // when RuntimeException e = - Assert.assertThrows(RuntimeException.class, () -> Util.openInputStream(url, null, null, null, new WebURLAccessRule(neo4jConfig))); + Assert.assertThrows(RuntimeException.class, () -> Util.openInputStream(url.toString(), null, null, null, mockChecker)); assertEquals("The redirect URI has a different protocol: file:/etc/passwd", e.getMessage()); } - private String getServerUrl(GenericContainer httpServer) { - return String.format("http://%s:%s", httpServer.getContainerIpAddress(), httpServer.getMappedPort(8000)); + private URL getServerUrl(GenericContainer httpServer) throws MalformedURLException + { + return new URL(String.format("http://%s:%s", httpServer.getContainerIpAddress(), httpServer.getMappedPort(8000))); } }