Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate http request and response headers when using the webclient #6515

Merged
merged 10 commits into from
Jun 28, 2023
Merged
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2023 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,6 +29,7 @@ public class LazyString {

private String stringValue;
private String owsLessValue;
private Validator validator;

/**
* New instance.
Expand Down Expand Up @@ -58,6 +59,15 @@ public LazyString(byte[] buffer, Charset charset) {
this.charset = charset;
}

/**
* Sets a custom validator for the LazyString Value when retrieved.
*
* @param validator custom validator implementation
*/
public void setValidator(Validator validator) {
this.validator = validator;
}

/**
* Strip optional whitespace(s) from beginning and end of the String.
* Defined by the HTTP specification, OWS is a sequence of zero to n space and/or horizontal tab characters.
Expand Down Expand Up @@ -87,6 +97,7 @@ public String stripOws() {
}
newLength = Math.max(newLength, 0);
owsLessValue = new String(buffer, newOffset, newLength, charset);
validate(owsLessValue);
}

return owsLessValue;
Expand All @@ -96,6 +107,7 @@ public String stripOws() {
public String toString() {
if (stringValue == null) {
stringValue = new String(buffer, offset, length, charset);
validate(stringValue);
}
return stringValue;
}
Expand All @@ -106,4 +118,23 @@ private boolean isOws(byte aByte) {
default -> false;
};
}

// Trigger validator only if it is set
private void validate(String value) {
if (validator != null) {
validator.validate(value);
}
}

/**
* Allows custom validator to be created.
*/
public interface Validator {
/**
* Validate the value and if it fails, then an implementation specific runtime exception may be thrown.
*
* @param value The value to validate.
*/
void validate(String value);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2023 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,20 +18,48 @@

import java.util.stream.Stream;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import static java.nio.charset.StandardCharsets.US_ASCII;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.params.provider.Arguments.arguments;

class LazyStringTest {
private static String INVALID_VALUE = "invalid-value";

@ParameterizedTest
@MethodSource("owsData")
void testOwsHandling(OwsTestData data) {
assertThat(data.string().stripOws(), is(data.expected()));
}

@ParameterizedTest
@MethodSource("valuesToValidate")
void testValidator(String value, boolean validate, boolean expectsValid) {
LazyString lazyString = new LazyString(value.getBytes(US_ASCII), US_ASCII);
// Supply custom validator when validate is true that will throw IllegalArgumentException if the value is invalid
if (validate) {
lazyString.setValidator(valueToValidate -> {
if (valueToValidate.contains(INVALID_VALUE)) {
throw new IllegalArgumentException("Found an invalid value");
}
});
}
// If there is no validator or validator does not encounter a problem, expect the value retrieval to succeed. Otherwise,
// expect that an IllegalArgumentException will be thrown.
if (expectsValid) {
assertThat(lazyString.stripOws(), is(value));
assertThat(lazyString.toString(), is(value));
} else {
Assertions.assertThrows(IllegalArgumentException.class, () -> lazyString.stripOws());
Assertions.assertThrows(IllegalArgumentException.class, () -> lazyString.toString());
}
}

private static Stream<OwsTestData> owsData() {
return Stream.of(
new OwsTestData(new LazyString("some-value".getBytes(US_ASCII), US_ASCII), "some-value"),
Expand All @@ -55,6 +83,20 @@ private static Stream<OwsTestData> owsData() {
);
}

private static Stream<Arguments> valuesToValidate() {
return Stream.of(
// Invalid value with validator set, expects that value retrieval will fail
arguments("first-" + INVALID_VALUE, true, false),
arguments(INVALID_VALUE + "-second", true, false),
// Valid value with validator set, expects that value retrieval will succeed
arguments("valid-third", true, true),
arguments("fourth-valid", true, true),
// Valid or Invalid value with no validator set, expects that value retrieval will succeed
arguments("valid-fifth", false, true),
arguments("sixth" + INVALID_VALUE, false, true)
);
}

record OwsTestData(LazyString string, String expected) {
@Override
public String toString() {
Expand Down
5 changes: 5 additions & 0 deletions common/http/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@
<artifactId>junit-jupiter-api</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<artifactId>junit-jupiter-params</artifactId>
<groupId>org.junit.jupiter</groupId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.hamcrest</groupId>
<artifactId>hamcrest-all</artifactId>
Expand Down
49 changes: 48 additions & 1 deletion common/http/src/main/java/io/helidon/common/http/Http.java
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ public static Method create(String name) {
return method;
}


/**
* Create a predicate for the provided methods.
*
Expand Down Expand Up @@ -960,6 +959,54 @@ default void writeHttp1Header(BufferData buffer) {
}
}

// validate header value based on https://www.rfc-editor.org/rfc/rfc7230#section-3.2 and returns the error message if
// invalid or null if otherwise.
private static String validateValueErrorMessage(String value) {
char[] vChars = value.toCharArray();
int vLength = vChars.length;
for (int i = 0; i < vLength; i++) {
char vChar = vChars[i];
if (i == 0) {
if (vChar < '!' || vChar == '\u007f') {
return "First character of the header value is invalid";
}
} else {
if (vChar < ' ' && vChar != '\t' || vChar == '\u007f') {
return "Character at position " + (i + 1) + " of the header value is invalid";
}
}
}
return null;
}

/**
* Check validity of a header value.
*
* @param value header value
* @throws IllegalArgumentException in case the HeaderValue is not valid
*/
static void validate(String value) {
String errorMessage = validateValueErrorMessage(value);
if (errorMessage != null) {
throw new IllegalArgumentException(errorMessage);
}
}

/**
* Check validity of header name and values.
*
* @throws IllegalArgumentException in case the HeaderValue is not valid
*/
default void validate() throws IllegalArgumentException {
String name = name();
// validate that header name only contains valid characters
HttpToken.validate(name);
String errorMessage = validateValueErrorMessage(values());
if (errorMessage != null) {
throw new IllegalArgumentException(errorMessage + " for header '" + name + "'");
}
}

private void writeHeader(BufferData buffer, byte[] nameBytes, byte[] valueBytes) {
// header name
buffer.write(nameBytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ public final class Http1HeadersParser {

private static final byte[] HD_USER_AGENT =
(HeaderEnum.USER_AGENT.defaultCase() + ":").getBytes(StandardCharsets.UTF_8);
// Instantiates LazyString.Validator and overrides validate() to call Http.HeaderValue.validate(value) which will then be
// triggered when the Header value is retrieved.
private static final LazyString.Validator HTTP_LAZY_STRING_VALIDATOR = Http.HeaderValue::validate;

private Http1HeadersParser() {
}
Expand Down Expand Up @@ -62,8 +65,13 @@ public static WritableHeaders<?> readHeaders(DataReader reader, int maxHeadersSi
if (eol == maxLength) {
throw new IllegalStateException("Header size exceeded");
}
// we do not need the string until somebody asks for this header (unless validation is on)
// we do not need the string until somebody asks for this header
LazyString value = reader.readLazyString(StandardCharsets.US_ASCII, eol);
if (validate) {
// if validation is on, use the default http value validator which will only be triggered when the string
// value is retrieved
value.setValidator(HTTP_LAZY_STRING_VALIDATOR);
}
reader.skip(2);
maxLength -= eol + 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,25 @@
package io.helidon.common.http;

import java.nio.charset.StandardCharsets;
import java.util.stream.Stream;

import io.helidon.common.buffers.DataReader;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import static org.hamcrest.CoreMatchers.hasItems;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.params.provider.Arguments.arguments;

class Http1HeadersParserTest {
private static String CUSTOM_HEADER_NAME = "Custom-Header-Name";
private static String CUSTOM_HEADER_VALUE = "Custom-Header-Value";

@Test
void testHeadersAreCaseInsensitive() {
DataReader reader = new DataReader(() -> (
Expand All @@ -44,6 +53,93 @@ void testHeadersAreCaseInsensitive() {
testHeader(headers, "HeADer", "hv1", "hv2", "hv3");
}

@ParameterizedTest
@MethodSource("headerValues")
void testHeaderValuesWithValidationEnabled(String headerValue, boolean expectsValid) {
// retrieve headers with validation enabled
WritableHeaders<?> headers = getHeaders(CUSTOM_HEADER_NAME, headerValue, true);
if (expectsValid) {
String responseHeaderValue = headers.get(Http.Header.create(CUSTOM_HEADER_NAME)).values();
// returned header values WhiteSpaces are trimmed so need to be tested with trimmed values
assertThat(responseHeaderValue, is(headerValue.trim()));
} else {
Assertions.assertThrows(IllegalArgumentException.class,
() -> headers.get(Http.Header.create(CUSTOM_HEADER_NAME)).values());
}
}

@ParameterizedTest
@MethodSource("headerValues")
void testHeaderValuesWithValidationDisabled(String headerValue) {
// retrieve headers without validating
WritableHeaders<?> headers = getHeaders(CUSTOM_HEADER_NAME, headerValue, false);
String responseHeaderValue = headers.get(Http.Header.create(CUSTOM_HEADER_NAME)).values();
// returned header values WhiteSpaces are trimmed so need to be tested with trimmed values
assertThat(responseHeaderValue, is(headerValue.trim()));
}

@ParameterizedTest
@MethodSource("headerNames")
void testHeaderNamesWithValidationEnabled(String headerName, boolean expectsValid) {
boolean validate = true;
if (expectsValid) {
WritableHeaders<?> headers = getHeaders(headerName, CUSTOM_HEADER_VALUE, validate);
String responseHeaderValue = headers.get(Http.Header.create(headerName)).values();
assertThat(responseHeaderValue, is(CUSTOM_HEADER_VALUE));
} else {
Assertions.assertThrows(IllegalArgumentException.class,
() -> getHeaders(headerName, CUSTOM_HEADER_VALUE, validate));
}
}

@ParameterizedTest
@MethodSource("headerValues")
void testHeaderNamesWithValidationDisabled(String headerName) {
// retrieve headers without validating
WritableHeaders<?> headers = getHeaders(headerName, CUSTOM_HEADER_VALUE, false);
String responseHeaderValue = headers.get(Http.Header.create(headerName)).values();
// returned header values WhiteSpaces are trimmed so need to be tested with trimmed values
assertThat(responseHeaderValue, is(CUSTOM_HEADER_VALUE));
}

private static WritableHeaders<?> getHeaders(String headerName, String headerValue, boolean validate) {
DataReader reader =
new DataReader(() -> (headerName + ":" + headerValue + "\r\n" + "\r\n").getBytes(StandardCharsets.US_ASCII));
return Http1HeadersParser.readHeaders(reader, 1024, validate);
}

private static Stream<Arguments> headerValues() {
return Stream.of(
// Valid header values
arguments("Header Value", true),
arguments("HeaderValue1\u0009, Header=Value2", true),
arguments("Header\tValue", true),
arguments(" Header Value ", true),
// Invalid header values
arguments("H\u001ceaderValue1", false),
arguments("HeaderValue1, Header\u007fValue", false),
arguments("HeaderValue1\u001f, HeaderValue2", false)
);
}

private static Stream<Arguments> headerNames() {
return Stream.of(
// Invalid header names
arguments("Header\u001aName", false),
arguments("Header\u000EName", false),
arguments("HeaderName\r\n", false),
arguments("(Header:Name)", false),
arguments("<Header?Name>", false),
arguments("{Header=Name}", false),
arguments("\"HeaderName\"", false),
arguments("[\\HeaderName]", false),
arguments("@Header,Name;", false),
// Valid header names
arguments("!#$Custom~%&\'*Header+^`|", true),
arguments("Custom_0-9_a-z_A-Z_Header", true)
);
}

private void testHeader(Headers headers, String header, String... values) {
Http.HeaderName headerName = Http.Header.create(header);
assertThat("Headers should contain header: " + headerName.lowerCase(),
Expand All @@ -53,4 +149,4 @@ private void testHeader(Headers headers, String header, String... values) {
headers.get(headerName).allValues(),
hasItems(values));
}
}
}
Loading