diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java index 6bf7e379de5e..a275f8165bf5 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -591,11 +591,14 @@ public void setServerName(String serverName) { @Override public String getServerName() { - String host = getHeader(HOST_HEADER); + String rawHostHeader = getHeader(HOST_HEADER); + String host = rawHostHeader; if (host != null) { host = host.trim(); if (host.startsWith("[")) { - host = host.substring(1, host.indexOf(']')); + int indexOfClosingBracket = host.indexOf(']'); + Assert.state(indexOfClosingBracket > -1, "Invalid Host header: " + rawHostHeader); + host = host.substring(0, indexOfClosingBracket + 1); } else if (host.contains(":")) { host = host.substring(0, host.indexOf(':')); @@ -613,12 +616,15 @@ public void setServerPort(int serverPort) { @Override public int getServerPort() { - String host = getHeader(HOST_HEADER); + String rawHostHeader = getHeader(HOST_HEADER); + String host = rawHostHeader; if (host != null) { host = host.trim(); int idx; if (host.startsWith("[")) { - idx = host.indexOf(':', host.indexOf(']')); + int indexOfClosingBracket = host.indexOf(']'); + Assert.state(indexOfClosingBracket > -1, "Invalid Host header: " + rawHostHeader); + idx = host.indexOf(':', indexOfClosingBracket); } else { idx = host.indexOf(':'); diff --git a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java index 6b2bd6037830..44d004780463 100644 --- a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java +++ b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.mock.web; import java.io.IOException; +import java.net.URL; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; @@ -29,10 +30,13 @@ import java.util.Map; import javax.servlet.http.Cookie; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.springframework.util.StreamUtils; +import static org.hamcrest.CoreMatchers.startsWith; import static org.junit.Assert.*; /** @@ -54,6 +58,9 @@ public class MockHttpServletRequestTests { private static final String IF_MODIFIED_SINCE = "If-Modified-Since"; private final MockHttpServletRequest request = new MockHttpServletRequest(); + + @Rule + public ExpectedException exception = ExpectedException.none(); @Test @@ -289,16 +296,23 @@ public void getServerNameViaHostHeaderWithPort() { @Test public void getServerNameViaHostHeaderAsIpv6AddressWithoutPort() { - String ipv6Address = "[2001:db8:0:1]"; - request.addHeader(HOST, ipv6Address); - assertEquals("2001:db8:0:1", request.getServerName()); + String host = "[2001:db8:0:1]"; + request.addHeader(HOST, host); + assertEquals(host, request.getServerName()); } @Test public void getServerNameViaHostHeaderAsIpv6AddressWithPort() { - String ipv6Address = "[2001:db8:0:1]:8081"; - request.addHeader(HOST, ipv6Address); - assertEquals("2001:db8:0:1", request.getServerName()); + request.addHeader(HOST, "[2001:db8:0:1]:8081"); + assertEquals("[2001:db8:0:1]", request.getServerName()); + } + + @Test + public void getServerNameWithInvalidIpv6AddressViaHostHeader() { + request.addHeader(HOST, "[::ffff:abcd:abcd"); // missing closing bracket + exception.expect(IllegalStateException.class); + exception.expectMessage(startsWith("Invalid Host header: ")); + request.getServerName(); } @Test @@ -312,6 +326,22 @@ public void getServerPortWithCustomPort() { assertEquals(8080, request.getServerPort()); } + @Test + public void getServerPortWithInvalidIpv6AddressViaHostHeader() { + request.addHeader(HOST, "[::ffff:abcd:abcd:8080"); // missing closing bracket + exception.expect(IllegalStateException.class); + exception.expectMessage(startsWith("Invalid Host header: ")); + request.getServerPort(); + } + + @Test + public void getServerPortWithIpv6AddressAndInvalidPortViaHostHeader() { + request.addHeader(HOST, "[::ffff:abcd:abcd]:bogus"); // "bogus" is not a port number + exception.expect(NumberFormatException.class); + exception.expectMessage("bogus"); + request.getServerPort(); + } + @Test public void getServerPortViaHostHeaderAsIpv6AddressWithoutPort() { String testServer = "[2001:db8:0:1]"; @@ -376,6 +406,43 @@ public void getRequestURLWithHostHeaderAndPort() { assertEquals("http://" + testServer, requestURL.toString()); } + @Test + public void getRequestURLWithIpv6AddressViaServerNameWithoutPort() throws Exception { + request.setServerName("[::ffff:abcd:abcd]"); + URL url = new java.net.URL(request.getRequestURL().toString()); + assertEquals("http://[::ffff:abcd:abcd]", url.toString()); + } + + @Test + public void getRequestURLWithIpv6AddressViaServerNameWithPort() throws Exception { + request.setServerName("[::ffff:abcd:abcd]"); + request.setServerPort(9999); + URL url = new java.net.URL(request.getRequestURL().toString()); + assertEquals("http://[::ffff:abcd:abcd]:9999", url.toString()); + } + + @Test + public void getRequestURLWithInvalidIpv6AddressViaHostHeader() { + request.addHeader(HOST, "[::ffff:abcd:abcd"); // missing closing bracket + exception.expect(IllegalStateException.class); + exception.expectMessage(startsWith("Invalid Host header: ")); + request.getRequestURL(); + } + + @Test + public void getRequestURLWithIpv6AddressViaHostHeaderWithoutPort() throws Exception { + request.addHeader(HOST, "[::ffff:abcd:abcd]"); + URL url = new java.net.URL(request.getRequestURL().toString()); + assertEquals("http://[::ffff:abcd:abcd]", url.toString()); + } + + @Test + public void getRequestURLWithIpv6AddressViaHostHeaderWithPort() throws Exception { + request.addHeader(HOST, "[::ffff:abcd:abcd]:9999"); + URL url = new java.net.URL(request.getRequestURL().toString()); + assertEquals("http://[::ffff:abcd:abcd]:9999", url.toString()); + } + @Test public void getRequestURLWithNullRequestUri() { request.setRequestURI(null); diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java index e62ef123af2c..1b4c92869ddd 100644 --- a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -591,11 +591,14 @@ public void setServerName(String serverName) { @Override public String getServerName() { - String host = getHeader(HOST_HEADER); + String rawHostHeader = getHeader(HOST_HEADER); + String host = rawHostHeader; if (host != null) { host = host.trim(); if (host.startsWith("[")) { - host = host.substring(1, host.indexOf(']')); + int indexOfClosingBracket = host.indexOf(']'); + Assert.state(indexOfClosingBracket > -1, "Invalid Host header: " + rawHostHeader); + host = host.substring(0, indexOfClosingBracket + 1); } else if (host.contains(":")) { host = host.substring(0, host.indexOf(':')); @@ -613,12 +616,15 @@ public void setServerPort(int serverPort) { @Override public int getServerPort() { - String host = getHeader(HOST_HEADER); + String rawHostHeader = getHeader(HOST_HEADER); + String host = rawHostHeader; if (host != null) { host = host.trim(); int idx; if (host.startsWith("[")) { - idx = host.indexOf(':', host.indexOf(']')); + int indexOfClosingBracket = host.indexOf(']'); + Assert.state(indexOfClosingBracket > -1, "Invalid Host header: " + rawHostHeader); + idx = host.indexOf(':', indexOfClosingBracket); } else { idx = host.indexOf(':');