Skip to content

Commit

Permalink
[NOID] More fixes for issues related to URLAccessChecker
Browse files Browse the repository at this point in the history
* Geocode does these checks as well, and that was missing.
* UtilIT rewritten again to mock URLAccessChecker so as to not rely on coredb internals
  • Loading branch information
mnd999 committed Nov 24, 2023
1 parent d67ba23 commit 902f198
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 72 deletions.
6 changes: 0 additions & 6 deletions common/src/main/java/apoc/ApocConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ public class ApocConfig extends LifecycleAdapter {
private static ApocConfig theInstance;
private GraphDatabaseService systemDb;

private List<IPAddressString> blockedIpRanges = List.of();

private boolean expandCommands;

private Duration commandEvaluationTimeout;
Expand All @@ -125,7 +123,6 @@ public ApocConfig(
GlobalProcedures globalProceduresRegistry,
DatabaseManagementService databaseManagementService) {
this.neo4jConfig = neo4jConfig;
this.blockedIpRanges = neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist);
this.commandEvaluationTimeout =
neo4jConfig.get(GraphDatabaseInternalSettings.config_command_evaluation_timeout);
if (this.commandEvaluationTimeout == null) {
Expand All @@ -145,9 +142,6 @@ public ApocConfig(
// use only for unit tests
public ApocConfig(Config neo4jConfig) {
this.neo4jConfig = neo4jConfig;
if (neo4jConfig != null) {
this.blockedIpRanges = neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist);
}
this.log = NullLog.getInstance();
this.databaseManagementService = null;
theInstance = this;
Expand Down
4 changes: 2 additions & 2 deletions common/src/main/java/apoc/util/JsonUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ public static Stream<Object> loadJson(
}
}

public static Stream<Object> loadJson(String url) {
return loadJson(url, null, null, "", true, null, null);
public static Stream<Object> loadJson(String url, URLAccessChecker urlAccessChecker) {
return loadJson(url, null, null, "", true, null, urlAccessChecker);
}

public static <T> T parse(String json, String path, Class<T> type) {
Expand Down
37 changes: 21 additions & 16 deletions core/src/main/java/apoc/spatial/Geocode.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import java.util.Map;
import java.util.stream.Stream;
import org.apache.commons.configuration2.Configuration;

import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.*;

public class Geocode {
Expand All @@ -42,10 +44,13 @@ public class Geocode {
@Context
public TerminationGuard terminationGuard;

@Context
public URLAccessChecker urlAccessChecker;

interface GeocodeSupplier {
Stream<GeoCodeResult> geocode(String params, long maxResults);
Stream<GeoCodeResult> geocode(String params, long maxResults, URLAccessChecker urlAccessChecker);

Stream<GeoCodeResult> reverseGeocode(Double latitude, Double longitude);
Stream<GeoCodeResult> reverseGeocode(Double latitude, Double longitude, URLAccessChecker urlAccessChecker);
}

private static class Throttler {
Expand Down Expand Up @@ -123,13 +128,13 @@ public SupplierWithKey(Configuration config, TerminationGuard terminationGuard,
}

@SuppressWarnings("unchecked")
public Stream<GeoCodeResult> geocode(String address, long maxResults) {
public Stream<GeoCodeResult> geocode(String address, long maxResults, URLAccessChecker urlAccessChecker) {
if (address.isEmpty()) {
return Stream.empty();
}
throttler.waitForThrottle();
String url = urlTemplate.replace("PLACE", Util.encodeUrlComponent(address));
Object value = JsonUtil.loadJson(url).findFirst().orElse(null);
Object value = JsonUtil.loadJson(url, urlAccessChecker).findFirst().orElse(null);
if (value instanceof List) {
return findResults((List<Map<String, Object>>) value, maxResults);
} else if (value instanceof Map) {
Expand All @@ -142,13 +147,13 @@ public Stream<GeoCodeResult> geocode(String address, long maxResults) {
}

@Override
public Stream<GeoCodeResult> reverseGeocode(Double latitude, Double longitude) {
public Stream<GeoCodeResult> reverseGeocode(Double latitude, Double longitude, URLAccessChecker urlAccessChecker) {
if (latitude == null || longitude == null) {
return Stream.empty();
}
throttler.waitForThrottle();
String url = urlTemplateReverse.replace("LAT", latitude.toString()).replace("LNG", longitude.toString());
Object value = JsonUtil.loadJson(url).findFirst().orElse(null);
Object value = JsonUtil.loadJson(url, urlAccessChecker).findFirst().orElse(null);
if (value instanceof List) {
return findResults((List<Map<String, Object>>) value, 1);
} else if (value instanceof Map) {
Expand Down Expand Up @@ -203,12 +208,12 @@ public OSMSupplier(Configuration config, TerminationGuard terminationGuard) {
}

@SuppressWarnings("unchecked")
public Stream<GeoCodeResult> geocode(String address, long maxResults) {
public Stream<GeoCodeResult> geocode(String address, long maxResults, URLAccessChecker urlAccessChecker) {
if (address.isEmpty()) {
return Stream.empty();
}
throttler.waitForThrottle();
Object value = JsonUtil.loadJson(OSM_URL_GEOCODE + Util.encodeUrlComponent(address))
Object value = JsonUtil.loadJson(OSM_URL_GEOCODE + Util.encodeUrlComponent(address), urlAccessChecker)
.findFirst()
.orElse(null);
if (value instanceof List) {
Expand All @@ -225,14 +230,14 @@ public Stream<GeoCodeResult> geocode(String address, long maxResults) {
}

@Override
public Stream<GeoCodeResult> reverseGeocode(Double latitude, Double longitude) {
public Stream<GeoCodeResult> reverseGeocode(Double latitude, Double longitude, URLAccessChecker urlAccessChecker) {
if (latitude == null || longitude == null) {
return Stream.empty();
}
throttler.waitForThrottle();

Object value = JsonUtil.loadJson(
OSM_URL_REVERSE_GEOCODE + String.format("lat=%s&lon=%s", latitude, longitude))
OSM_URL_REVERSE_GEOCODE + String.format("lat=%s&lon=%s", latitude, longitude), urlAccessChecker)
.findFirst()
.orElse(null);
if (value instanceof Map) {
Expand Down Expand Up @@ -275,13 +280,13 @@ private String credentials(Configuration config) {
}

@SuppressWarnings("unchecked")
public Stream<GeoCodeResult> geocode(String address, long maxResults) {
public Stream<GeoCodeResult> geocode(String address, long maxResults, URLAccessChecker urlAccessChecker) {
if (address.isEmpty()) {
return Stream.empty();
}
throttler.waitForThrottle();
Object value = JsonUtil.loadJson(
String.format(GEOCODE_URL, credentials(this.config)) + Util.encodeUrlComponent(address))
String.format(GEOCODE_URL, credentials(this.config)) + Util.encodeUrlComponent(address), urlAccessChecker)
.findFirst()
.orElse(null);
if (value instanceof Map) {
Expand All @@ -306,13 +311,13 @@ public Stream<GeoCodeResult> geocode(String address, long maxResults) {
}

@Override
public Stream<GeoCodeResult> reverseGeocode(Double latitude, Double longitude) {
public Stream<GeoCodeResult> reverseGeocode(Double latitude, Double longitude, URLAccessChecker urlAccessChecker) {
if (latitude == null || longitude == null) {
return Stream.empty();
}
throttler.waitForThrottle();
Object value = JsonUtil.loadJson(String.format(REVERSE_GEOCODE_URL, credentials(this.config))
+ Util.encodeUrlComponent(latitude + "," + longitude))
+ Util.encodeUrlComponent(latitude + "," + longitude), urlAccessChecker)
.findFirst()
.orElse(null);
if (value instanceof Map) {
Expand Down Expand Up @@ -397,7 +402,7 @@ public Stream<GeoCodeResult> geocode(
return getSupplier(config)
.geocode(
address,
maxResults == 0 ? MAX_RESULTS : Math.min(Math.max(maxResults, 1), MAX_RESULTS));
maxResults == 0 ? MAX_RESULTS : Math.min(Math.max(maxResults, 1), MAX_RESULTS), urlAccessChecker);
} catch (IllegalStateException re) {
if (!quotaException && re.getMessage().startsWith("QUOTA_EXCEEDED")) return Stream.empty();
throw re;
Expand All @@ -415,7 +420,7 @@ public Stream<GeoCodeResult> reverseGeocode(
@Name(value = "quotaException", defaultValue = "false") boolean quotaException,
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {
try {
return getSupplier(config).reverseGeocode(latitude, longitude);
return getSupplier(config).reverseGeocode(latitude, longitude, urlAccessChecker);
} catch (IllegalStateException re) {
if (!quotaException && re.getMessage().startsWith("QUOTA_EXCEEDED")) return Stream.empty();
throw re;
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/java/apoc/spatial/SpatialTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public void setUp() {
TestUtil.registerProcedure(db, MockGeocode.class);
apocConfig().setProperty(APOC_IMPORT_FILE_ENABLED, true);
URL url = ClassLoader.getSystemResource("spatial.json");
Map tests = (Map) JsonUtil.loadJson(url.toString()).findFirst().orElse(null);
Map tests = (Map) JsonUtil.loadJson(url.toString(), null).findFirst().orElse(null);
for (Object event : (List) tests.get("events")) {
addEventData((Map) event);
}
Expand Down
100 changes: 53 additions & 47 deletions it/src/test/java/apoc/it/common/UtilIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,42 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import apoc.ApocConfig;
import apoc.util.Util;
import inet.ipaddr.IPAddressString;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import junit.framework.TestCase;
import org.apache.commons.io.IOUtils;
import org.junit.Assert;
import org.junit.Test;
import org.junit.jupiter.api.AfterEach;
import org.neo4j.configuration.Config;
import org.neo4j.configuration.GraphDatabaseInternalSettings;
import org.neo4j.kernel.impl.security.WebURLAccessRule;
import org.neo4j.graphdb.security.URLAccessValidationError;
import org.neo4j.kernel.impl.security.WebUrlAccessChecker;

import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.testcontainers.containers.GenericContainer;

public class UtilIT {
private GenericContainer httpServer;

private GenericContainer setUpServer(Config neo4jConfig, String redirectURL) {
new ApocConfig(neo4jConfig);
public UtilIT() throws Exception {
googleUrl = new URL( "https://www.google.com" );
}

private GenericContainer setUpServer(String redirectURL) {
new ApocConfig(null);
GenericContainer httpServer = new GenericContainer("alpine")
.withCommand(
"/bin/sh",
Expand All @@ -66,51 +73,48 @@ public void tearDown() {
httpServer.stop();
}

private final URL googleUrl;

@Test
public void redirectShouldWorkWhenProtocolNotChangesWithUrlLocation() throws IOException {
Config neo4jConfig = mock(Config.class);
when(neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist)).thenReturn(Collections.emptyList());
httpServer = setUpServer(neo4jConfig, "https://www.google.com");
public void redirectShouldWorkWhenProtocolNotChangesWithUrlLocation() throws Exception {
WebUrlAccessChecker mockChecker = mock(WebUrlAccessChecker.class);
httpServer = setUpServer("https://www.google.com");

// given
String url = getServerUrl(httpServer);
URL url = getServerUrl(httpServer);
when( mockChecker.checkURL( url ) ).thenReturn( url );
when( mockChecker.checkURL( googleUrl ) ).thenReturn( googleUrl );

// when
String page = IOUtils.toString( Util.openInputStream(url, null, null, null, new WebURLAccessRule(neo4jConfig)), StandardCharsets.UTF_8);
String page = IOUtils.toString( Util.openInputStream(url.toString(), null, null, null, mockChecker ), StandardCharsets.UTF_8);

// then
assertTrue(page.contains("<title>Google</title>"));
}

@Test
public void redirectWithBlockedIPsWithUrlLocation() {
List<IPAddressString> blockedIPs = List.of(new IPAddressString("127.168.0.1/8"));

Config neo4jConfig = mock(Config.class);
when(neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist)).thenReturn(blockedIPs);
public void redirectWithBlockedIPsWithUrlLocation() throws Exception{
WebUrlAccessChecker mockChecker = mock(WebUrlAccessChecker.class);

httpServer = setUpServer(neo4jConfig, "http://127.168.0.1");
String url = getServerUrl(httpServer);
httpServer = setUpServer("http://127.168.0.1");
URL url = getServerUrl(httpServer);
when( mockChecker.checkURL( url ) ).thenReturn( url );
when( mockChecker.checkURL( new URL("http://127.168.0.1") ) ).thenThrow( new URLAccessValidationError( "no" ) );

IOException e = Assert.assertThrows(IOException.class, () -> Util.openInputStream(url, null, null, null, new WebURLAccessRule(neo4jConfig)));
TestCase.assertTrue(
e.getMessage()
.contains(
"access to /127.168.0.1 is blocked via the configuration property internal.dbms.cypher_ip_blocklist"));
IOException e = Assert.assertThrows(IOException.class, () -> Util.openInputStream(url.toString(), null, null, null, mockChecker));
TestCase.assertTrue(e.getMessage().contains("no"));
}

@Test
public void redirectWithProtocolUpgradeIsAllowed() throws IOException {
List<IPAddressString> blockedIPs = List.of(new IPAddressString("127.168.0.1/8"));

Config neo4jConfig = mock(Config.class);
when(neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist)).thenReturn(blockedIPs);

httpServer = setUpServer(neo4jConfig, "https://www.google.com");
String url = getServerUrl(httpServer);
public void redirectWithProtocolUpgradeIsAllowed() throws Exception {
WebUrlAccessChecker mockChecker = mock(WebUrlAccessChecker.class);
httpServer = setUpServer("https://www.google.com");
URL url = getServerUrl(httpServer);
when( mockChecker.checkURL( url ) ).thenReturn( url );
when( mockChecker.checkURL( googleUrl ) ).thenReturn( googleUrl );

// when
String page = IOUtils.toString( Util.openInputStream(url, null, null, null, new WebURLAccessRule(neo4jConfig)), StandardCharsets.UTF_8 );
String page = IOUtils.toString( Util.openInputStream(url.toString(), null, null, null, mockChecker), StandardCharsets.UTF_8 );

// then
assertTrue(page.contains("<title>Google</title>"));
Expand All @@ -129,22 +133,21 @@ public void redirectWithProtocolDowngradeIsNotAllowed() throws IOException {
}

@Test
public void shouldFailForExceedingRedirectLimit() {
Config neo4jConfig = mock(Config.class);
when(neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist)).thenReturn(Collections.emptyList());

httpServer = setUpServer(neo4jConfig, "https://127.0.0.0");
String url = getServerUrl(httpServer);
public void shouldFailForExceedingRedirectLimit() throws Exception {
WebUrlAccessChecker mockChecker = mock(WebUrlAccessChecker.class);
httpServer = setUpServer("https://127.0.0.0");
URL url = getServerUrl(httpServer);
when( mockChecker.checkURL( any() ) ).thenAnswer( (Answer<URL>) invocation -> (URL) invocation.getArguments()[0] );

ArrayList<GenericContainer> servers = new ArrayList<>();
for (int i = 1; i <= 10; i++) {
GenericContainer server = setUpServer(neo4jConfig, url);
GenericContainer server = setUpServer(url.toString());
servers.add(server);
url = getServerUrl(server);
}

String finalUrl = url;
IOException e = Assert.assertThrows(IOException.class, () -> Util.openInputStream(finalUrl, null, null, null, new WebURLAccessRule(neo4jConfig)));
URL finalUrl = url;
IOException e = Assert.assertThrows(IOException.class, () -> Util.openInputStream(finalUrl.toString(), null, null, null, mockChecker));

TestCase.assertTrue(e.getMessage().contains("Redirect limit exceeded"));

Expand All @@ -154,21 +157,24 @@ public void shouldFailForExceedingRedirectLimit() {
}

@Test
public void redirectShouldThrowExceptionWhenProtocolChangesWithFileLocation() {
httpServer = setUpServer(null, "file:/etc/passwd");
public void redirectShouldThrowExceptionWhenProtocolChangesWithFileLocation() throws Exception {
WebUrlAccessChecker mockChecker = mock(WebUrlAccessChecker.class);
httpServer = setUpServer("file:/etc/passwd");
// given
String url = getServerUrl(httpServer);
URL url = getServerUrl(httpServer);
when( mockChecker.checkURL( url ) ).thenReturn( url );
Config neo4jConfig = mock(Config.class);
when(neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist)).thenReturn(Collections.emptyList());

// when
RuntimeException e =
Assert.assertThrows(RuntimeException.class, () -> Util.openInputStream(url, null, null, null, new WebURLAccessRule(neo4jConfig)));
Assert.assertThrows(RuntimeException.class, () -> Util.openInputStream(url.toString(), null, null, null, mockChecker));

assertEquals("The redirect URI has a different protocol: file:/etc/passwd", e.getMessage());
}

private String getServerUrl(GenericContainer httpServer) {
return String.format("http://%s:%s", httpServer.getContainerIpAddress(), httpServer.getMappedPort(8000));
private URL getServerUrl(GenericContainer httpServer) throws MalformedURLException
{
return new URL(String.format("http://%s:%s", httpServer.getContainerIpAddress(), httpServer.getMappedPort(8000)));
}
}

0 comments on commit 902f198

Please sign in to comment.