diff --git a/common/buffers/src/main/java/io/helidon/common/buffers/LazyString.java b/common/buffers/src/main/java/io/helidon/common/buffers/LazyString.java index 911e22b72e7..4843e6bac84 100644 --- a/common/buffers/src/main/java/io/helidon/common/buffers/LazyString.java +++ b/common/buffers/src/main/java/io/helidon/common/buffers/LazyString.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. + * Copyright (c) 2022, 2023 Oracle and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/common/buffers/src/test/java/io/helidon/common/buffers/LazyStringTest.java b/common/buffers/src/test/java/io/helidon/common/buffers/LazyStringTest.java index e54eef6f174..bb66ad7ee69 100644 --- a/common/buffers/src/test/java/io/helidon/common/buffers/LazyStringTest.java +++ b/common/buffers/src/test/java/io/helidon/common/buffers/LazyStringTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. + * Copyright (c) 2022, 2023 Oracle and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/common/http/pom.xml b/common/http/pom.xml index a500b152c98..5ae13e3d451 100644 --- a/common/http/pom.xml +++ b/common/http/pom.xml @@ -74,6 +74,11 @@ junit-jupiter-api test + + junit-jupiter-params + org.junit.jupiter + test + org.hamcrest hamcrest-all diff --git a/common/http/src/main/java/io/helidon/common/http/Http.java b/common/http/src/main/java/io/helidon/common/http/Http.java index ed67d836b97..55024e1c95a 100644 --- a/common/http/src/main/java/io/helidon/common/http/Http.java +++ b/common/http/src/main/java/io/helidon/common/http/Http.java @@ -238,7 +238,6 @@ public static Method create(String name) { return method; } - /** * Create a predicate for the provided methods. * @@ -960,6 +959,41 @@ default void writeHttp1Header(BufferData buffer) { } } + /** + * Check validity of header name and values. + * + * @throws IllegalArgumentException in case the HeaderValue is not valid + */ + default void validate() throws IllegalArgumentException { + String name = name(); + // validate that header name only contains valid characters + HttpToken.validate(name); + // Validate header value + validateValue(name, values()); + } + + + // validate header value based on https://www.rfc-editor.org/rfc/rfc7230#section-3.2 and throws IllegalArgumentException + // if invalid. + private static void validateValue(String name, String value) throws IllegalArgumentException { + char[] vChars = value.toCharArray(); + int vLength = vChars.length; + for (int i = 0; i < vLength; i++) { + char vChar = vChars[i]; + if (i == 0) { + if (vChar < '!' || vChar == '\u007f') { + throw new IllegalArgumentException("First character of the header value is invalid" + + " for header '" + name + "'"); + } + } else { + if (vChar < ' ' && vChar != '\t' || vChar == '\u007f') { + throw new IllegalArgumentException("Character at position " + (i + 1) + " of the header value is invalid" + + " for header '" + name + "'"); + } + } + } + } + private void writeHeader(BufferData buffer, byte[] nameBytes, byte[] valueBytes) { // header name buffer.write(nameBytes); diff --git a/common/http/src/main/java/io/helidon/common/http/Http1HeadersParser.java b/common/http/src/main/java/io/helidon/common/http/Http1HeadersParser.java index 970147e9d60..e7c74620da7 100644 --- a/common/http/src/main/java/io/helidon/common/http/Http1HeadersParser.java +++ b/common/http/src/main/java/io/helidon/common/http/Http1HeadersParser.java @@ -67,7 +67,11 @@ public static WritableHeaders readHeaders(DataReader reader, int maxHeadersSi reader.skip(2); maxLength -= eol + 1; - headers.add(Http.Header.create(header, value)); + Http.HeaderValue headerValue = Http.Header.create(header, value); + headers.add(headerValue); + if (validate) { + headerValue.validate(); + } if (maxLength < 0) { throw new IllegalStateException("Header size exceeded"); } @@ -113,9 +117,6 @@ private static Http.HeaderName readHeaderName(DataReader reader, } String headerName = reader.readAsciiString(col); - if (validate) { - HttpToken.validate(headerName); - } Http.HeaderName header = Http.Header.create(headerName); reader.skip(1); // skip the colon character diff --git a/common/http/src/test/java/io/helidon/common/http/Http1HeadersParserTest.java b/common/http/src/test/java/io/helidon/common/http/Http1HeadersParserTest.java index d7b6b0f71a6..ebf86b2f506 100644 --- a/common/http/src/test/java/io/helidon/common/http/Http1HeadersParserTest.java +++ b/common/http/src/test/java/io/helidon/common/http/Http1HeadersParserTest.java @@ -17,16 +17,27 @@ package io.helidon.common.http; import java.nio.charset.StandardCharsets; +import java.util.stream.Stream; import io.helidon.common.buffers.DataReader; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.params.provider.Arguments.arguments; class Http1HeadersParserTest { + private static String CUSTOM_HEADER_NAME = "Custom-Header-Name"; + private static String CUSTOM_HEADER_VALUE = "Custom-Header-Value"; + public static final String VALID_HEADER_VALUE = "Valid-Header-Value"; + public static final String VALID_HEADER_NAME = "Valid-Header-Name"; + @Test void testHeadersAreCaseInsensitive() { DataReader reader = new DataReader(() -> ( @@ -44,6 +55,66 @@ void testHeadersAreCaseInsensitive() { testHeader(headers, "HeADer", "hv1", "hv2", "hv3"); } + @ParameterizedTest + @MethodSource("headers") + void testHeadersWithValidationEnabled(String headerName, String headerValue, boolean expectsValid) { + // retrieve headers with validation enabled + WritableHeaders headers; + if (expectsValid) { + headers = getHeaders(headerName, headerValue, true); + String responseHeaderValue = headers.get(Http.Header.create(headerName)).values(); + // returned header values WhiteSpaces are trimmed so need to be tested with trimmed values + assertThat(responseHeaderValue, is(headerValue.trim())); + } else { + Assertions.assertThrows(IllegalArgumentException.class, + () -> getHeaders(headerName, headerValue, true)); + } + } + + @ParameterizedTest + @MethodSource("headers") + void testHeadersWithValidationDisabled(String headerValue) { + // retrieve headers without validating + WritableHeaders headers = getHeaders(CUSTOM_HEADER_NAME, headerValue, false); + String responseHeaderValue = headers.get(Http.Header.create(CUSTOM_HEADER_NAME)).values(); + // returned header values WhiteSpaces are trimmed so need to be tested with trimmed values + assertThat(responseHeaderValue, is(headerValue.trim())); + } + + private static WritableHeaders getHeaders(String headerName, String headerValue, boolean validate) { + DataReader reader = + new DataReader(() -> (headerName + ":" + headerValue + "\r\n" + "\r\n").getBytes(StandardCharsets.US_ASCII)); + return Http1HeadersParser.readHeaders(reader, 1024, validate); + } + + private static Stream headers() { + return Stream.of( + // Invalid header names + arguments("Header\u001aName", VALID_HEADER_VALUE, false), + arguments("Header\u000EName", VALID_HEADER_VALUE, false), + arguments("HeaderName\r\n", VALID_HEADER_VALUE, false), + arguments("(Header:Name)", VALID_HEADER_VALUE, false), + arguments("", VALID_HEADER_VALUE, false), + arguments("{Header=Name}", VALID_HEADER_VALUE, false), + arguments("\"HeaderName\"", VALID_HEADER_VALUE, false), + arguments("[\\HeaderName]", VALID_HEADER_VALUE, false), + arguments("@Header,Name;", VALID_HEADER_VALUE, false), + // Valid header names + arguments("!#$Custom~%&\'*Header+^`|", VALID_HEADER_VALUE, true), + arguments("Custom_0-9_a-z_A-Z_Header", VALID_HEADER_VALUE, true), + // Valid header values + arguments(VALID_HEADER_NAME, "Header Value", true), + arguments(VALID_HEADER_NAME, "HeaderValue1\u0009, Header=Value2", true), + arguments(VALID_HEADER_NAME, "Header\tValue", true), + arguments(VALID_HEADER_NAME, " Header Value ", true), + // Invalid header values + arguments(VALID_HEADER_NAME, "H\u001ceaderValue1", false), + arguments(VALID_HEADER_NAME, "HeaderValue1, Header\u007fValue", false), + arguments(VALID_HEADER_NAME, "HeaderValue1\u001f, HeaderValue2", false) + ); + } + + private void testHeader(Headers headers, String header, String... values) { Http.HeaderName headerName = Http.Header.create(header); assertThat("Headers should contain header: " + headerName.lowerCase(), @@ -53,4 +124,4 @@ private void testHeader(Headers headers, String header, String... values) { headers.get(headerName).allValues(), hasItems(values)); } -} \ No newline at end of file +} diff --git a/common/http/src/test/java/io/helidon/common/http/HttpTest.java b/common/http/src/test/java/io/helidon/common/http/HttpTest.java index 6a6b9d21e51..b6e9f225541 100644 --- a/common/http/src/test/java/io/helidon/common/http/HttpTest.java +++ b/common/http/src/test/java/io/helidon/common/http/HttpTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2022 Oracle and/or its affiliates. + * Copyright (c) 2018, 2023 Oracle and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,20 @@ package io.helidon.common.http; +import java.util.stream.Stream; + +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import static io.helidon.common.http.Http.Status.TEMPORARY_REDIRECT_307; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.sameInstance; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.params.provider.Arguments.arguments; /** * Unit test for {@link Http}. @@ -50,4 +57,36 @@ void testResposneStatusCustomReason() { assertThat(rs.code(), is(TEMPORARY_REDIRECT_307.code())); assertThat(rs.family(), is(TEMPORARY_REDIRECT_307.family())); } + + @ParameterizedTest + @MethodSource("headers") + void testHeaderValidation(String headerName, String headerValues, boolean expectsValid) { + Http.HeaderValue header = Http.Header.create(Http.Header.create(headerName), headerValues); + if (expectsValid) { + header.validate(); + } else { + Assertions.assertThrows(IllegalArgumentException.class, () -> header.validate()); + } + } + + private static Stream headers() { + return Stream.of( + // Valid headers + arguments("!#$Custom~%&\'*Header+^`|", "!Header\tValue~", true), + arguments("Custom_0-9_a-z_A-Z_Header", "\u0080Header Value\u00ff", true), + // Invalid headers + arguments("Valid-Header-Name", "H\u001ceaderValue1", false), + arguments("Valid-Header-Name", "HeaderValue1, Header\u007fValue", false), + arguments("Valid-Header-Name", "HeaderValue1\u001f, HeaderValue2", false), + arguments("Header\u001aName", "Valid-Header-Value", false), + arguments("Header\u000EName", "Valid-Header-Value", false), + arguments("HeaderName\r\n", "Valid-Header-Value", false), + arguments("(Header:Name)", "Valid-Header-Value", false), + arguments("", "Valid-Header-Value", false), + arguments("{Header=Name}", "Valid-Header-Value", false), + arguments("\"HeaderName\"", "Valid-Header-Value", false), + arguments("[\\HeaderName]", "Valid-Header-Value", false), + arguments("@Header,Name;", "Valid-Header-Value", false) + ); + } } diff --git a/nima/webclient/webclient/src/main/java/io/helidon/nima/webclient/http1/HttpCallChainBase.java b/nima/webclient/webclient/src/main/java/io/helidon/nima/webclient/http1/HttpCallChainBase.java index 98f814eb560..e67af09c502 100644 --- a/nima/webclient/webclient/src/main/java/io/helidon/nima/webclient/http1/HttpCallChainBase.java +++ b/nima/webclient/webclient/src/main/java/io/helidon/nima/webclient/http1/HttpCallChainBase.java @@ -47,8 +47,11 @@ abstract class HttpCallChainBase implements WebClientService.Chain { this.tls = tls; } - static void writeHeaders(Headers headers, BufferData bufferData) { + static void writeHeaders(Headers headers, BufferData bufferData, boolean validate) { for (Http.HeaderValue header : headers) { + if (validate) { + header.validate(); + } header.writeHttp1Header(bufferData); } bufferData.write(Bytes.CR_BYTE); diff --git a/nima/webclient/webclient/src/main/java/io/helidon/nima/webclient/http1/HttpCallEntityChain.java b/nima/webclient/webclient/src/main/java/io/helidon/nima/webclient/http1/HttpCallEntityChain.java index 9c9cf4e2de6..a929b3a17d5 100644 --- a/nima/webclient/webclient/src/main/java/io/helidon/nima/webclient/http1/HttpCallEntityChain.java +++ b/nima/webclient/webclient/src/main/java/io/helidon/nima/webclient/http1/HttpCallEntityChain.java @@ -28,15 +28,13 @@ import io.helidon.common.http.Http; import io.helidon.nima.common.tls.Tls; import io.helidon.nima.http.media.EntityWriter; -import io.helidon.nima.http.media.MediaContext; import io.helidon.nima.webclient.ClientConnection; import io.helidon.nima.webclient.WebClientServiceRequest; import io.helidon.nima.webclient.WebClientServiceResponse; class HttpCallEntityChain extends HttpCallChainBase { - private final MediaContext mediaContext; - private final int maxStatusLineLength; + private final Http1ClientConfig clientConfig; private final CompletableFuture whenSent; private final CompletableFuture whenComplete; private final Object entity; @@ -48,8 +46,7 @@ class HttpCallEntityChain extends HttpCallChainBase { CompletableFuture whenComplete, Object entity) { super(clientConfig, connection, tls); - this.mediaContext = clientConfig.mediaContext(); - this.maxStatusLineLength = clientConfig.maxStatusLineLength(); + this.clientConfig = clientConfig; this.whenSent = whenSent; this.whenComplete = whenComplete; this.entity = entity; @@ -71,8 +68,7 @@ public WebClientServiceResponse doProceed(ClientConnection connection, headers.set(Http.Header.create(Http.Header.CONTENT_LENGTH, entityBytes.length)); - // todo validate request headers - writeHeaders(headers, writeBuffer); + writeHeaders(headers, writeBuffer, clientConfig.validateHeaders()); // we have completed writing the headers whenSent.complete(serviceRequest); @@ -81,7 +77,7 @@ public WebClientServiceResponse doProceed(ClientConnection connection, } writer.write(writeBuffer); - Http.Status responseStatus = Http1StatusParser.readStatus(reader, maxStatusLineLength); + Http.Status responseStatus = Http1StatusParser.readStatus(reader, clientConfig.maxStatusLineLength()); ClientResponseHeaders responseHeaders = readHeaders(reader); return WebClientServiceResponse.builder() @@ -99,7 +95,7 @@ byte[] entityBytes(Object entity, ClientRequestHeaders headers) { return (byte[]) entity; } GenericType genericType = GenericType.create(entity); - EntityWriter writer = mediaContext.writer(genericType, headers); + EntityWriter writer = clientConfig.mediaContext().writer(genericType, headers); // todo this should use output stream of client, but that would require delaying header write // to first byte written diff --git a/nima/webclient/webclient/src/main/java/io/helidon/nima/webclient/http1/HttpCallOutputStreamChain.java b/nima/webclient/webclient/src/main/java/io/helidon/nima/webclient/http1/HttpCallOutputStreamChain.java index 2da84b4086f..5aa2f0830de 100644 --- a/nima/webclient/webclient/src/main/java/io/helidon/nima/webclient/http1/HttpCallOutputStreamChain.java +++ b/nima/webclient/webclient/src/main/java/io/helidon/nima/webclient/http1/HttpCallOutputStreamChain.java @@ -37,8 +37,7 @@ import io.helidon.nima.webclient.WebClientServiceResponse; class HttpCallOutputStreamChain extends HttpCallChainBase { - private final int maxStatusLineLength; - private final boolean sendExpect100Continue; + private final Http1ClientConfig clientConfig; private final CompletableFuture whenSent; private final CompletableFuture whenComplete; private final ClientRequest.OutputStreamHandler osHandler; @@ -50,8 +49,7 @@ class HttpCallOutputStreamChain extends HttpCallChainBase { CompletableFuture whenComplete, ClientRequest.OutputStreamHandler osHandler) { super(clientConfig, connection, tls); - this.maxStatusLineLength = clientConfig.maxStatusLineLength(); - this.sendExpect100Continue = clientConfig.sendExpectContinue(); + this.clientConfig = clientConfig; this.whenSent = whenSent; this.whenComplete = whenComplete; this.osHandler = osHandler; @@ -69,8 +67,7 @@ WebClientServiceResponse doProceed(ClientConnection connection, reader, writeBuffer, headers, - maxStatusLineLength, - sendExpect100Continue, + clientConfig, serviceRequest, whenSent); @@ -84,7 +81,7 @@ WebClientServiceResponse doProceed(ClientConnection connection, throw new IllegalStateException("Output stream was not closed in handler"); } - Http.Status responseStatus = Http1StatusParser.readStatus(reader, maxStatusLineLength); + Http.Status responseStatus = Http1StatusParser.readStatus(reader, clientConfig.maxStatusLineLength()); ClientResponseHeaders responseHeaders = readHeaders(reader); return WebClientServiceResponse.builder() @@ -105,8 +102,7 @@ private static class ClientConnectionOutputStream extends OutputStream { private final DataReader reader; private final WebClientServiceRequest request; private final CompletableFuture whenSent; - private final int maxStatusLineLength; - private final boolean sendExpect100Continue; + private final Http1ClientConfig clientConfig; private final WritableHeaders headers; private final BufferData prologue; @@ -121,16 +117,14 @@ private ClientConnectionOutputStream(DataWriter writer, DataReader reader, BufferData prologue, WritableHeaders headers, - int maxStatusLineLength, - boolean sendExpect100Continue, + Http1ClientConfig clientConfig, WebClientServiceRequest request, CompletableFuture whenSent) { this.writer = writer; this.reader = reader; this.headers = headers; this.prologue = prologue; - this.maxStatusLineLength = maxStatusLineLength; - this.sendExpect100Continue = sendExpect100Continue; + this.clientConfig = clientConfig; this.contentLength = headers.contentLength().orElse(-1); this.chunked = contentLength == -1 || headers.contains(Http.HeaderValues.TRANSFER_ENCODING_CHUNKED); this.request = request; @@ -229,7 +223,7 @@ private void writeContent(BufferData buffer) throws IOException { } private void sendPrologueAndHeader() { - boolean expects100Continue = sendExpect100Continue && chunked && !noData; + boolean expects100Continue = clientConfig.sendExpectContinue() && chunked && !noData; if (expects100Continue) { headers.add(Http.HeaderValues.EXPECT_100); } @@ -251,13 +245,13 @@ private void sendPrologueAndHeader() { // todo validate request headers BufferData headerBuffer = BufferData.growing(128); - writeHeaders(headers, headerBuffer); + writeHeaders(headers, headerBuffer, clientConfig.validateHeaders()); writer.writeNow(headerBuffer); whenSent.complete(request); if (expects100Continue) { - Http.Status responseStatus = Http1StatusParser.readStatus(reader, maxStatusLineLength); + Http.Status responseStatus = Http1StatusParser.readStatus(reader, clientConfig.maxStatusLineLength()); if (responseStatus != Http.Status.CONTINUE_100) { throw new IllegalStateException("Expected a status of '100 Continue' but received a '" + responseStatus + "' instead"); diff --git a/nima/webclient/webclient/src/test/java/io/helidon/nima/webclient/http1/ClientRequestImplTest.java b/nima/webclient/webclient/src/test/java/io/helidon/nima/webclient/http1/ClientRequestImplTest.java index f75c1cadc99..0fefb5fc9ca 100644 --- a/nima/webclient/webclient/src/test/java/io/helidon/nima/webclient/http1/ClientRequestImplTest.java +++ b/nima/webclient/webclient/src/test/java/io/helidon/nima/webclient/http1/ClientRequestImplTest.java @@ -19,6 +19,9 @@ import java.io.OutputStream; import java.net.URI; import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; import java.util.StringTokenizer; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.ExecutorService; @@ -41,10 +44,12 @@ import io.helidon.nima.webclient.ClientConnection; import io.helidon.nima.webclient.WebClient; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import static io.helidon.common.testing.http.junit5.HttpHeaderMatcher.hasHeader; import static io.helidon.common.testing.http.junit5.HttpHeaderMatcher.noHeader; @@ -56,14 +61,19 @@ import static org.junit.jupiter.params.provider.Arguments.arguments; class ClientRequestImplTest { - private static final Http.HeaderValue REQ_CHUNKED_HEADER = Http.Header.createCached( + public static final String VALID_HEADER_VALUE = "Valid-Header-Value"; + public static final String VALID_HEADER_NAME = "Valid-Header-Name"; + public static final String BAD_HEADER_PATH = "/badHeader"; + private static final Http.HeaderValue REQ_CHUNKED_HEADER = Http.Header.create( Http.Header.create("X-Req-Chunked"), "true"); - private static final Http.HeaderValue REQ_EXPECT_100_HEADER_NAME = Http.Header.createCached( + private static final Http.HeaderValue REQ_EXPECT_100_HEADER_NAME = Http.Header.create( Http.Header.create("X-Req-Expect100"), "true"); private static final Http.HeaderName REQ_CONTENT_LENGTH_HEADER_NAME = Http.Header.create("X-Req-ContentLength"); + private static final Http.HeaderName BAD_HEADER_NAME = Http.Header.create("Bad-Header"); private static final long NO_CONTENT_LENGTH = -1L; private static final Http1Client client = WebClient.builder().build(); private static final int dummyPort = 1234; + public static final String HEADER_NAME_VALUE_DELIMETER = "->"; @Test void testMaxHeaderSizeFail() { @@ -223,6 +233,81 @@ void testRelativeUris(boolean relativeUris, boolean outputStream, String request assertThat(st.nextToken(), startsWith(expectedUriStart)); } + @ParameterizedTest + @MethodSource("headerValues") + void testHeaderValues(List headerValues, boolean expectsValid) { + Http1ClientRequest request = client.get("http://localhost:" + dummyPort + "/test"); + request.header(Http.Header.create(Http.Header.create("HeaderName"), headerValues)); + request.connection(new FakeHttp1ClientConnection()); + if (expectsValid) { + Http1ClientResponse response = request.request(); + assertThat(response.status(), is(Http.Status.OK_200)); + } else { + Assertions.assertThrows(IllegalArgumentException.class, () -> request.request()); + } + } + + @ParameterizedTest + @MethodSource("headers") + void testHeaders(Http.HeaderValue header, boolean expectsValid) { + Http1ClientRequest request = client.get("http://localhost:" + dummyPort + "/test"); + request.connection(new FakeHttp1ClientConnection()); + request.header(header); + if (expectsValid) { + Http1ClientResponse response = request.request(); + assertThat(response.status(), is(Http.Status.OK_200)); + } else { + Assertions.assertThrows(IllegalArgumentException.class, () -> request.request()); + } + } + + @ParameterizedTest + @MethodSource("headers") + void testDisableHeaderValidation(Http.HeaderValue header, boolean expectsValid) { + Http1Client clientWithNoHeaderValidation = WebClient.builder() + .validateHeaders(false) + .build(); + Http1ClientRequest request = clientWithNoHeaderValidation.put("http://localhost:" + dummyPort + "/test"); + request.header(header); + request.connection(new FakeHttp1ClientConnection()); + Http1ClientResponse response = request.submit("Sending Something"); + if (expectsValid) { + assertThat(response.status(), is(Http.Status.OK_200)); + } else { + assertThat(response.status(), is(Http.Status.BAD_REQUEST_400)); + } + } + + @ParameterizedTest + @MethodSource("responseHeaders") + void testHeadersFromResponse(String headerName, String headerValue, boolean expectsValid) { + Http1ClientRequest request = client.get("http://localhost:" + dummyPort + BAD_HEADER_PATH); + request.connection(new FakeHttp1ClientConnection()); + String headerNameAndValue = headerName + HEADER_NAME_VALUE_DELIMETER + headerValue; + if (expectsValid) { + Http1ClientResponse response = request.submit(headerNameAndValue); + assertThat(response.status(), is(Http.Status.OK_200)); + String responseHeaderValue = response.headers().get(Http.Header.create(headerName)).values(); + assertThat(responseHeaderValue, is(headerValue.trim())); + } else { + Assertions.assertThrows(IllegalArgumentException.class, () -> request.submit(headerNameAndValue)); + } + } + + @ParameterizedTest + @MethodSource("responseHeadersForDisabledValidation") + void testDisableValidationForHeadersFromResponse(String headerName, String headerValue) { + Http1Client clientWithNoHeaderValidation = WebClient.builder() + .validateHeaders(false) + .build(); + Http1ClientRequest request = clientWithNoHeaderValidation.put("http://localhost:" + dummyPort + BAD_HEADER_PATH); + request.connection(new FakeHttp1ClientConnection()); + Http1ClientResponse response = request.submit(headerName + HEADER_NAME_VALUE_DELIMETER + headerValue); + assertThat(response.status(), is(Http.Status.OK_200)); + String responseHeaderValue = response.headers().get(Http.Header.create(headerName)).values(); + assertThat(responseHeaderValue, is(headerValue.trim())); + } + private static void validateSuccessfulResponse(Http1Client client, ClientConnection connection) { String requestEntity = "Sending Something"; Http1ClientRequest request = client.put("http://localhost:" + dummyPort + "/test"); @@ -300,6 +385,101 @@ private static Stream relativeUris() { arguments(true, false, "https://www.dummy.com:1111/test", "/test")); } + private static Stream headerValues() { + return Stream.of( + // Valid header values + arguments(Arrays.asList("Header Value"), true), + arguments(Arrays.asList("HeaderValue1", "Header\u0080Value\u00ff2"), true), + arguments(Arrays.asList("HeaderName1\u0009", "Header=Value2"), true), + // Invalid header values + arguments(Arrays.asList(" HeaderValue"), false), + arguments(Arrays.asList("HeaderValue1", "Header\u007fValue"), false), + arguments(Arrays.asList("HeaderValue1\r\n", "HeaderValue2"), false) + ); + } + + private static Stream headers() { + return Stream.of( + // Valid headers + arguments(Http.HeaderValues.ACCEPT_RANGES_BYTES, true), + arguments(Http.HeaderValues.CONNECTION_KEEP_ALIVE, true), + arguments(Http.HeaderValues.CONTENT_TYPE_TEXT_PLAIN, true), + arguments(Http.HeaderValues.ACCEPT_TEXT, true), + arguments(Http.HeaderValues.CACHE_NO_CACHE, true), + arguments(Http.HeaderValues.TE_TRAILERS, true), + arguments(Http.Header.create(Http.Header.create("!#$Custom~%&\'*Header+^`|"), "!Header\tValue~"), true), + arguments(Http.Header.create(Http.Header.create("Custom_0-9_a-z_A-Z_Header"), + "\u0080Header Value\u00ff"), true), + // Invalid headers + arguments(Http.Header.create(Http.Header.create(VALID_HEADER_NAME), "H\u001ceaderValue1"), false), + arguments(Http.Header.create(Http.Header.create(VALID_HEADER_NAME), + "HeaderValue1, Header\u007fValue"), false), + arguments(Http.Header.create(Http.Header.create(VALID_HEADER_NAME), + "HeaderValue1\u001f, HeaderValue2"), false), + arguments(Http.Header.create(Http.Header.create("Header\u001aName"), VALID_HEADER_VALUE), false), + arguments(Http.Header.create(Http.Header.create("Header\u000EName"), VALID_HEADER_VALUE), false), + arguments(Http.Header.create(Http.Header.create("HeaderName\r\n"), VALID_HEADER_VALUE), false), + arguments(Http.Header.create(Http.Header.create("HeaderName\u00FF\u0124"), VALID_HEADER_VALUE), false), + arguments(Http.Header.create(Http.Header.create("(Header:Name)"), VALID_HEADER_VALUE), false), + arguments(Http.Header.create(Http.Header.create(""), VALID_HEADER_VALUE), false), + arguments(Http.Header.create(Http.Header.create("{Header=Name}"), VALID_HEADER_VALUE), false), + arguments(Http.Header.create(Http.Header.create("\"HeaderName\""), VALID_HEADER_VALUE), false), + arguments(Http.Header.create(Http.Header.create("[\\HeaderName]"), VALID_HEADER_VALUE), false), + arguments(Http.Header.create(Http.Header.create("@Header,Name;"), VALID_HEADER_VALUE), false) + ); + } + + private static Stream responseHeaders() { + return Stream.of( + // Invalid header names + arguments("Header\u001aName", VALID_HEADER_VALUE, false), + arguments("Header\u000EName", VALID_HEADER_VALUE, false), + arguments("HeaderName\r\n", VALID_HEADER_VALUE, false), + arguments("(Header:Name)", VALID_HEADER_VALUE, false), + arguments("", VALID_HEADER_VALUE, false), + arguments("{Header=Name}", VALID_HEADER_VALUE, false), + arguments("\"HeaderName\"", VALID_HEADER_VALUE, false), + arguments("[\\HeaderName]", VALID_HEADER_VALUE, false), + arguments("@Header,Name;", VALID_HEADER_VALUE, false), + // Valid header names + arguments("!#$Custom~%&\'*Header+^`|", VALID_HEADER_VALUE, true), + arguments("Custom_0-9_a-z_A-Z_Header", VALID_HEADER_VALUE, true), + // Valid header values + arguments(VALID_HEADER_NAME, "Header Value", true), + arguments(VALID_HEADER_NAME, "HeaderValue1\u0009, Header=Value2", true), + arguments(VALID_HEADER_NAME, "Header\tValue", true), + arguments(VALID_HEADER_NAME, " Header Value ", true), + // Invalid header values + arguments(VALID_HEADER_NAME, "H\u001ceaderValue1", false), + arguments(VALID_HEADER_NAME, "HeaderValue1, Header\u007fValue", false), + arguments(VALID_HEADER_NAME, "HeaderValue1\u001f, HeaderValue2", false) + ); + } + + private static Stream responseHeadersForDisabledValidation() { + return Stream.of( + // Invalid header names + arguments("Header\u001aName", VALID_HEADER_VALUE, false), + arguments("Header\u000EName", VALID_HEADER_VALUE, false), + arguments("{Header=Name}", VALID_HEADER_VALUE, false), + arguments("\"HeaderName\"", VALID_HEADER_VALUE, false), + arguments("[\\HeaderName]", VALID_HEADER_VALUE, false), + arguments("@Header,Name;", VALID_HEADER_VALUE, false), + // Valid header names + arguments("!#$Custom~%&\'*Header+^`|", VALID_HEADER_VALUE, true), + arguments("Custom_0-9_a-z_A-Z_Header", VALID_HEADER_VALUE, true), + // Valid header values + arguments(VALID_HEADER_NAME, "Header Value", true), + arguments(VALID_HEADER_NAME, "HeaderValue1\u0009, Header=Value2", true), + arguments(VALID_HEADER_NAME, "Header\tValue", true), + arguments(VALID_HEADER_NAME, " Header Value ", true), + // Invalid header values + arguments(VALID_HEADER_NAME, "H\u001ceaderValue1", false), + arguments(VALID_HEADER_NAME, "HeaderValue1, Header\u007fValue", false), + arguments(VALID_HEADER_NAME, "HeaderValue1\u001f, HeaderValue2", false) + ); + } + private static class FakeHttp1ClientConnection implements ClientConnection { private final DataReader clientReader; private final DataWriter clientWriter; @@ -423,53 +603,73 @@ private void webServerHandle() { serverReader.skip(2); // skip CRLF } + boolean requestFailed = false; // Read Headers - WritableHeaders reqHeaders = Http1HeadersParser.readHeaders(serverReader, 16384, true); + WritableHeaders reqHeaders = null; + try { + reqHeaders = Http1HeadersParser.readHeaders(serverReader, 16384, false); + for (Iterator it = reqHeaders.iterator(); it.hasNext(); ) { + Http.HeaderValue header = it.next(); + header.validate(); + } + } catch (IllegalArgumentException e) { + requestFailed = true; + } int entitySize = 0; - if (reqHeaders.contains(Http.HeaderValues.TRANSFER_ENCODING_CHUNKED)) { - // Send 100-Continue if requested - if (reqHeaders.contains(Http.HeaderValues.EXPECT_100)) { - serverWriter.write( - BufferData.create("HTTP/1.1 100 Continue\r\n".getBytes(StandardCharsets.UTF_8))); - } + if (!requestFailed) { + if (reqHeaders.contains(Http.HeaderValues.TRANSFER_ENCODING_CHUNKED)) { + // Send 100-Continue if requested + if (reqHeaders.contains(Http.HeaderValues.EXPECT_100)) { + serverWriter.write( + BufferData.create("HTTP/1.1 100 Continue\r\n".getBytes(StandardCharsets.UTF_8))); + } - // Assemble the entity from the chunks - while (true) { - String hex = serverReader.readLine(); - int chunkLength = Integer.parseUnsignedInt(hex, 16); - if (chunkLength == 0) { - serverReader.readLine(); - break; + // Assemble the entity from the chunks + while (true) { + String hex = serverReader.readLine(); + int chunkLength = Integer.parseUnsignedInt(hex, 16); + if (chunkLength == 0) { + serverReader.readLine(); + break; + } + BufferData chunkData = serverReader.readBuffer(chunkLength); + entity.write(chunkData); + serverReader.skip(2); + entitySize += chunkLength; + } + } else if (reqHeaders.contains(Http.Header.CONTENT_LENGTH)) { + entitySize = reqHeaders.get(Http.Header.CONTENT_LENGTH).value(int.class); + if (entitySize > 0) { + entity.write(serverReader.getBuffer(entitySize)); } - BufferData chunkData = serverReader.readBuffer(chunkLength); - entity.write(chunkData); - serverReader.skip(2); - entitySize += chunkLength; - } - } else if (reqHeaders.contains(Http.Header.CONTENT_LENGTH)) { - entitySize = reqHeaders.get(Http.Header.CONTENT_LENGTH).value(int.class); - if (entitySize > 0) { - entity.write(serverReader.getBuffer(entitySize)); } } WritableHeaders resHeaders = WritableHeaders.create(); resHeaders.add(Http.HeaderValues.CONNECTION_KEEP_ALIVE); - // Send headers that can be validated if Expect-100-Continue, Content_Length, and Chunked request headers exist - if (reqHeaders.contains(Http.HeaderValues.EXPECT_100)) { - resHeaders.set(REQ_EXPECT_100_HEADER_NAME); - } - if (reqHeaders.contains(Http.Header.CONTENT_LENGTH)) { - resHeaders.set(REQ_CONTENT_LENGTH_HEADER_NAME, reqHeaders.get(Http.Header.CONTENT_LENGTH).value()); + if (reqHeaders != null) { + // Send headers that can be validated if Expect-100-Continue, Content_Length, and Chunked request headers exist + if (reqHeaders.contains(Http.HeaderValues.EXPECT_100)) { + resHeaders.set(REQ_EXPECT_100_HEADER_NAME); + } + if (reqHeaders.contains(Http.Header.CONTENT_LENGTH)) { + resHeaders.set(REQ_CONTENT_LENGTH_HEADER_NAME, reqHeaders.get(Http.Header.CONTENT_LENGTH).value()); + } + if (reqHeaders.contains(Http.HeaderValues.TRANSFER_ENCODING_CHUNKED)) { + resHeaders.set(REQ_CHUNKED_HEADER); + } } - if (reqHeaders.contains(Http.HeaderValues.TRANSFER_ENCODING_CHUNKED)) { - resHeaders.set(REQ_CHUNKED_HEADER); + + // if prologue contains "/badHeader" path, send back the entity (name and value delimited by ->) as a header + if (getPrologue().contains(BAD_HEADER_PATH)) { + String[] header = entity.readString(entitySize, StandardCharsets.US_ASCII).split(HEADER_NAME_VALUE_DELIMETER); + resHeaders.add(Http.Header.create(Http.Header.create(header[0]), header[1])); } - // Send OK status response - serverWriter.write(BufferData.create("HTTP/1.1 200 OK\r\n".getBytes(StandardCharsets.UTF_8))); + String responseMessage = !requestFailed ? "HTTP/1.1 200 OK\r\n" : "HTTP/1.1 400 Bad Request\r\n"; + serverWriter.write(BufferData.create(responseMessage.getBytes(StandardCharsets.UTF_8))); // Send the headers resHeaders.add(Http.Header.CONTENT_LENGTH, Integer.toString(entitySize));