Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate http request and response headers when using the webclient #6515

Merged
merged 10 commits into from
Jun 28, 2023
Merged
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2023 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2023 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
5 changes: 5 additions & 0 deletions common/http/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@
<artifactId>junit-jupiter-api</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<artifactId>junit-jupiter-params</artifactId>
<groupId>org.junit.jupiter</groupId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.hamcrest</groupId>
<artifactId>hamcrest-all</artifactId>
Expand Down
36 changes: 35 additions & 1 deletion common/http/src/main/java/io/helidon/common/http/Http.java
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ public static Method create(String name) {
return method;
}


/**
* Create a predicate for the provided methods.
*
Expand Down Expand Up @@ -960,6 +959,41 @@ default void writeHttp1Header(BufferData buffer) {
}
}

/**
* Check validity of header name and values.
*
* @throws IllegalArgumentException in case the HeaderValue is not valid
*/
default void validate() throws IllegalArgumentException {
String name = name();
// validate that header name only contains valid characters
HttpToken.validate(name);
// Validate header value
validateValue(name, values());
}


// validate header value based on https://www.rfc-editor.org/rfc/rfc7230#section-3.2 and throws IllegalArgumentException
// if invalid.
private static void validateValue(String name, String value) throws IllegalArgumentException {
char[] vChars = value.toCharArray();
int vLength = vChars.length;
for (int i = 0; i < vLength; i++) {
char vChar = vChars[i];
if (i == 0) {
if (vChar < '!' || vChar == '\u007f') {
throw new IllegalArgumentException("First character of the header value is invalid"
+ " for header '" + name + "'");
}
} else {
if (vChar < ' ' && vChar != '\t' || vChar == '\u007f') {
throw new IllegalArgumentException("Character at position " + (i + 1) + " of the header value is invalid"
+ " for header '" + name + "'");
}
}
}
}

private void writeHeader(BufferData buffer, byte[] nameBytes, byte[] valueBytes) {
// header name
buffer.write(nameBytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ public static WritableHeaders<?> readHeaders(DataReader reader, int maxHeadersSi
reader.skip(2);
maxLength -= eol + 1;

headers.add(Http.Header.create(header, value));
Http.HeaderValue headerValue = Http.Header.create(header, value);
headers.add(headerValue);
if (validate) {
headerValue.validate();
}
if (maxLength < 0) {
throw new IllegalStateException("Header size exceeded");
}
Expand Down Expand Up @@ -113,9 +117,6 @@ private static Http.HeaderName readHeaderName(DataReader reader,
}

String headerName = reader.readAsciiString(col);
if (validate) {
HttpToken.validate(headerName);
}
Http.HeaderName header = Http.Header.create(headerName);
reader.skip(1); // skip the colon character

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,27 @@
package io.helidon.common.http;

import java.nio.charset.StandardCharsets;
import java.util.stream.Stream;

import io.helidon.common.buffers.DataReader;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import static org.hamcrest.CoreMatchers.hasItems;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.params.provider.Arguments.arguments;

class Http1HeadersParserTest {
private static String CUSTOM_HEADER_NAME = "Custom-Header-Name";
private static String CUSTOM_HEADER_VALUE = "Custom-Header-Value";
public static final String VALID_HEADER_VALUE = "Valid-Header-Value";
public static final String VALID_HEADER_NAME = "Valid-Header-Name";

@Test
void testHeadersAreCaseInsensitive() {
DataReader reader = new DataReader(() -> (
Expand All @@ -44,6 +55,66 @@ void testHeadersAreCaseInsensitive() {
testHeader(headers, "HeADer", "hv1", "hv2", "hv3");
}

@ParameterizedTest
@MethodSource("headers")
void testHeadersWithValidationEnabled(String headerName, String headerValue, boolean expectsValid) {
// retrieve headers with validation enabled
WritableHeaders<?> headers;
if (expectsValid) {
headers = getHeaders(headerName, headerValue, true);
String responseHeaderValue = headers.get(Http.Header.create(headerName)).values();
// returned header values WhiteSpaces are trimmed so need to be tested with trimmed values
assertThat(responseHeaderValue, is(headerValue.trim()));
} else {
Assertions.assertThrows(IllegalArgumentException.class,
() -> getHeaders(headerName, headerValue, true));
}
}

@ParameterizedTest
@MethodSource("headers")
void testHeadersWithValidationDisabled(String headerValue) {
// retrieve headers without validating
WritableHeaders<?> headers = getHeaders(CUSTOM_HEADER_NAME, headerValue, false);
String responseHeaderValue = headers.get(Http.Header.create(CUSTOM_HEADER_NAME)).values();
// returned header values WhiteSpaces are trimmed so need to be tested with trimmed values
assertThat(responseHeaderValue, is(headerValue.trim()));
}

private static WritableHeaders<?> getHeaders(String headerName, String headerValue, boolean validate) {
DataReader reader =
new DataReader(() -> (headerName + ":" + headerValue + "\r\n" + "\r\n").getBytes(StandardCharsets.US_ASCII));
return Http1HeadersParser.readHeaders(reader, 1024, validate);
}

private static Stream<Arguments> headers() {
return Stream.of(
// Invalid header names
arguments("Header\u001aName", VALID_HEADER_VALUE, false),
arguments("Header\u000EName", VALID_HEADER_VALUE, false),
arguments("HeaderName\r\n", VALID_HEADER_VALUE, false),
arguments("(Header:Name)", VALID_HEADER_VALUE, false),
arguments("<Header?Name>", VALID_HEADER_VALUE, false),
arguments("{Header=Name}", VALID_HEADER_VALUE, false),
arguments("\"HeaderName\"", VALID_HEADER_VALUE, false),
arguments("[\\HeaderName]", VALID_HEADER_VALUE, false),
arguments("@Header,Name;", VALID_HEADER_VALUE, false),
// Valid header names
arguments("!#$Custom~%&\'*Header+^`|", VALID_HEADER_VALUE, true),
arguments("Custom_0-9_a-z_A-Z_Header", VALID_HEADER_VALUE, true),
// Valid header values
arguments(VALID_HEADER_NAME, "Header Value", true),
arguments(VALID_HEADER_NAME, "HeaderValue1\u0009, Header=Value2", true),
arguments(VALID_HEADER_NAME, "Header\tValue", true),
arguments(VALID_HEADER_NAME, " Header Value ", true),
// Invalid header values
arguments(VALID_HEADER_NAME, "H\u001ceaderValue1", false),
arguments(VALID_HEADER_NAME, "HeaderValue1, Header\u007fValue", false),
arguments(VALID_HEADER_NAME, "HeaderValue1\u001f, HeaderValue2", false)
);
}


private void testHeader(Headers headers, String header, String... values) {
Http.HeaderName headerName = Http.Header.create(header);
assertThat("Headers should contain header: " + headerName.lowerCase(),
Expand All @@ -53,4 +124,4 @@ private void testHeader(Headers headers, String header, String... values) {
headers.get(headerName).allValues(),
hasItems(values));
}
}
}
41 changes: 40 additions & 1 deletion common/http/src/test/java/io/helidon/common/http/HttpTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, 2022 Oracle and/or its affiliates.
* Copyright (c) 2018, 2023 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,13 +16,20 @@

package io.helidon.common.http;

import java.util.stream.Stream;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import static io.helidon.common.http.Http.Status.TEMPORARY_REDIRECT_307;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.sameInstance;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.params.provider.Arguments.arguments;

/**
* Unit test for {@link Http}.
Expand Down Expand Up @@ -50,4 +57,36 @@ void testResposneStatusCustomReason() {
assertThat(rs.code(), is(TEMPORARY_REDIRECT_307.code()));
assertThat(rs.family(), is(TEMPORARY_REDIRECT_307.family()));
}

@ParameterizedTest
@MethodSource("headers")
void testHeaderValidation(String headerName, String headerValues, boolean expectsValid) {
Http.HeaderValue header = Http.Header.create(Http.Header.create(headerName), headerValues);
if (expectsValid) {
header.validate();
} else {
Assertions.assertThrows(IllegalArgumentException.class, () -> header.validate());
}
}

private static Stream<Arguments> headers() {
return Stream.of(
// Valid headers
arguments("!#$Custom~%&\'*Header+^`|", "!Header\tValue~", true),
arguments("Custom_0-9_a-z_A-Z_Header", "\u0080Header Value\u00ff", true),
// Invalid headers
arguments("Valid-Header-Name", "H\u001ceaderValue1", false),
arguments("Valid-Header-Name", "HeaderValue1, Header\u007fValue", false),
arguments("Valid-Header-Name", "HeaderValue1\u001f, HeaderValue2", false),
arguments("Header\u001aName", "Valid-Header-Value", false),
arguments("Header\u000EName", "Valid-Header-Value", false),
arguments("HeaderName\r\n", "Valid-Header-Value", false),
arguments("(Header:Name)", "Valid-Header-Value", false),
arguments("<Header?Name>", "Valid-Header-Value", false),
arguments("{Header=Name}", "Valid-Header-Value", false),
arguments("\"HeaderName\"", "Valid-Header-Value", false),
arguments("[\\HeaderName]", "Valid-Header-Value", false),
arguments("@Header,Name;", "Valid-Header-Value", false)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ abstract class HttpCallChainBase implements WebClientService.Chain {
this.tls = tls;
}

static void writeHeaders(Headers headers, BufferData bufferData) {
static void writeHeaders(Headers headers, BufferData bufferData, boolean validate) {
for (Http.HeaderValue header : headers) {
if (validate) {
header.validate();
}
header.writeHttp1Header(bufferData);
}
bufferData.write(Bytes.CR_BYTE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@
import io.helidon.common.http.Http;
import io.helidon.nima.common.tls.Tls;
import io.helidon.nima.http.media.EntityWriter;
import io.helidon.nima.http.media.MediaContext;
import io.helidon.nima.webclient.ClientConnection;
import io.helidon.nima.webclient.WebClientServiceRequest;
import io.helidon.nima.webclient.WebClientServiceResponse;

class HttpCallEntityChain extends HttpCallChainBase {

private final MediaContext mediaContext;
private final int maxStatusLineLength;
private final Http1ClientConfig clientConfig;
private final CompletableFuture<WebClientServiceRequest> whenSent;
private final CompletableFuture<WebClientServiceResponse> whenComplete;
private final Object entity;
Expand All @@ -48,8 +46,7 @@ class HttpCallEntityChain extends HttpCallChainBase {
CompletableFuture<WebClientServiceResponse> whenComplete,
Object entity) {
super(clientConfig, connection, tls);
this.mediaContext = clientConfig.mediaContext();
this.maxStatusLineLength = clientConfig.maxStatusLineLength();
this.clientConfig = clientConfig;
this.whenSent = whenSent;
this.whenComplete = whenComplete;
this.entity = entity;
Expand All @@ -71,8 +68,7 @@ public WebClientServiceResponse doProceed(ClientConnection connection,

headers.set(Http.Header.create(Http.Header.CONTENT_LENGTH, entityBytes.length));

// todo validate request headers
writeHeaders(headers, writeBuffer);
writeHeaders(headers, writeBuffer, clientConfig.validateHeaders());
// we have completed writing the headers
whenSent.complete(serviceRequest);

Expand All @@ -81,7 +77,7 @@ public WebClientServiceResponse doProceed(ClientConnection connection,
}
writer.write(writeBuffer);

Http.Status responseStatus = Http1StatusParser.readStatus(reader, maxStatusLineLength);
Http.Status responseStatus = Http1StatusParser.readStatus(reader, clientConfig.maxStatusLineLength());
ClientResponseHeaders responseHeaders = readHeaders(reader);

return WebClientServiceResponse.builder()
Expand All @@ -99,7 +95,7 @@ byte[] entityBytes(Object entity, ClientRequestHeaders headers) {
return (byte[]) entity;
}
GenericType<Object> genericType = GenericType.create(entity);
EntityWriter<Object> writer = mediaContext.writer(genericType, headers);
EntityWriter<Object> writer = clientConfig.mediaContext().writer(genericType, headers);

// todo this should use output stream of client, but that would require delaying header write
// to first byte written
Expand Down