Skip to content

Commit

Permalink
Merge pull request from GHSA-35fr-h7jr-hh86
Browse files Browse the repository at this point in the history
Motivation:

An `HttpService` can produce a malformed HTTP response when a user
specified a malformed HTTP header values, such as:

    ResponseHeaders.of(HttpStatus.OK
                       "my-header", "foo\r\nbad-header: bar");

Modification:

- Add strict header value validation to `HttpHeadersBase`
- Add strict header name validation to `HttpHeaderNames.of()`, which is
  used by `HttpHeadersBase`.

Result:

- It is not possible anymore to send a bad header value which can be
  misused for sending additional headers or injecting arbitrary content.
  • Loading branch information
trustin committed Dec 5, 2019
1 parent 80310b3 commit b597f7a
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.BitSet;
import java.util.Map;

import com.google.common.base.Ascii;
import com.google.common.collect.ImmutableMap;
import com.google.common.math.IntMath;

import io.netty.util.AsciiString;

Expand Down Expand Up @@ -65,6 +67,35 @@ public final class HttpHeaderNames {
// - Sec-Fetch-User
// - Sec-Metadata

private static final int PROHIBITED_NAME_CHAR_MASK = ~63;
private static final BitSet PROHIBITED_NAME_CHARS = new BitSet(~PROHIBITED_NAME_CHAR_MASK + 1);
private static final String[] PROHIBITED_NAME_CHAR_NAMES = new String[~PROHIBITED_NAME_CHAR_MASK + 1];

static {
PROHIBITED_NAME_CHARS.set(0);
PROHIBITED_NAME_CHARS.set('\t');
PROHIBITED_NAME_CHARS.set('\n');
PROHIBITED_NAME_CHARS.set(0xB);
PROHIBITED_NAME_CHARS.set('\f');
PROHIBITED_NAME_CHARS.set('\r');
PROHIBITED_NAME_CHARS.set(' ');
PROHIBITED_NAME_CHARS.set(',');
PROHIBITED_NAME_CHARS.set(':');
PROHIBITED_NAME_CHARS.set(';');
PROHIBITED_NAME_CHARS.set('=');
PROHIBITED_NAME_CHAR_NAMES[0] = "<NUL>";
PROHIBITED_NAME_CHAR_NAMES['\t'] = "<TAB>";
PROHIBITED_NAME_CHAR_NAMES['\n'] = "<LF>";
PROHIBITED_NAME_CHAR_NAMES[0xB] = "<VT>";
PROHIBITED_NAME_CHAR_NAMES['\f'] = "<FF>";
PROHIBITED_NAME_CHAR_NAMES['\r'] = "<CR>";
PROHIBITED_NAME_CHAR_NAMES[' '] = "<SP>";
PROHIBITED_NAME_CHAR_NAMES[','] = ",";
PROHIBITED_NAME_CHAR_NAMES[':'] = ":";
PROHIBITED_NAME_CHAR_NAMES[';'] = ";";
PROHIBITED_NAME_CHAR_NAMES['='] = "=";
}

// Pseudo-headers

/**
Expand Down Expand Up @@ -564,10 +595,16 @@ public final class HttpHeaderNames {
map = builder.build();
}

private static AsciiString create(String name) {
return AsciiString.cached(Ascii.toLowerCase(name));
}

/**
* Lower-cases and converts the specified header name into an {@link AsciiString}. If {@code "name"} is
* a known header name, this method will return a pre-instantiated {@link AsciiString} to reduce
* the allocation rate of {@link AsciiString}.
*
* @throws IllegalArgumentException if the specified {@code name} is not a valid header name.
*/
public static AsciiString of(CharSequence name) {
if (name instanceof AsciiString) {
Expand All @@ -576,22 +613,71 @@ public static AsciiString of(CharSequence name) {

final String lowerCased = Ascii.toLowerCase(requireNonNull(name, "name"));
final AsciiString cached = map.get(lowerCased);
return cached != null ? cached : AsciiString.cached(lowerCased);
if (cached != null) {
return cached;
}

return validate(AsciiString.cached(lowerCased));
}

/**
* Lower-cases and converts the specified header name into an {@link AsciiString}. If {@code "name"} is
* a known header name, this method will return a pre-instantiated {@link AsciiString} to reduce
* the allocation rate of {@link AsciiString}.
*
* @throws IllegalArgumentException if the specified {@code name} is not a valid header name.
*/
public static AsciiString of(AsciiString name) {
final AsciiString lowerCased = name.toLowerCase();
final AsciiString cached = map.get(lowerCased);
return cached != null ? cached : lowerCased;
if (cached != null) {
return cached;
}

return validate(lowerCased);
}

private static AsciiString create(String name) {
return AsciiString.cached(Ascii.toLowerCase(name));
private static AsciiString validate(AsciiString name) {
if (name.isEmpty()) {
throw new IllegalArgumentException("malformed header name: <EMPTY>");
}

final int lastIndex;
try {
lastIndex = name.forEachByte(value -> {
if ((value & PROHIBITED_NAME_CHAR_MASK) != 0) { // value >= 64
return true;
}

// value < 64
return !PROHIBITED_NAME_CHARS.get(value);
});
} catch (Exception e) {
throw new Error(e);
}

if (lastIndex >= 0) {
throw new IllegalArgumentException(malformedHeaderNameMessage(name));
}

return name;
}

private static String malformedHeaderNameMessage(AsciiString name) {
final StringBuilder buf = new StringBuilder(IntMath.saturatedAdd(name.length(), 64));
buf.append("malformed header name: ");

final int nameLength = name.length();
for (int i = 0; i < nameLength; i++) {
final char ch = name.charAt(i);
if (PROHIBITED_NAME_CHARS.get(ch)) {
buf.append(PROHIBITED_NAME_CHAR_NAMES[ch]);
} else {
buf.append(ch);
}
}

return buf.toString();
}

private HttpHeaderNames() {}
Expand Down
78 changes: 62 additions & 16 deletions core/src/main/java/com/linecorp/armeria/common/HttpHeadersBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
*/
package com.linecorp.armeria.common;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.linecorp.armeria.internal.ArmeriaHttpUtil.isAbsoluteUri;
import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.hasPseudoHeaderFormat;
Expand All @@ -41,6 +40,7 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
Expand All @@ -58,6 +58,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterators;
import com.google.common.math.IntMath;

import io.netty.handler.codec.DateFormatter;
import io.netty.util.AsciiString;
Expand All @@ -67,6 +68,23 @@
*/
class HttpHeadersBase implements HttpHeaderGetters {

private static final int PROHIBITED_VALUE_CHAR_MASK = ~15;
private static final BitSet PROHIBITED_VALUE_CHARS = new BitSet(~PROHIBITED_VALUE_CHAR_MASK + 1);
private static final String[] PROHIBITED_VALUE_CHAR_NAMES = new String[~PROHIBITED_VALUE_CHAR_MASK + 1];

static {
PROHIBITED_VALUE_CHARS.set(0);
PROHIBITED_VALUE_CHARS.set('\n');
PROHIBITED_VALUE_CHARS.set(0xB);
PROHIBITED_VALUE_CHARS.set('\f');
PROHIBITED_VALUE_CHARS.set('\r');
PROHIBITED_VALUE_CHAR_NAMES[0] = "<NUL>";
PROHIBITED_VALUE_CHAR_NAMES['\n'] = "<LF>";
PROHIBITED_VALUE_CHAR_NAMES[0xB] = "<VT>";
PROHIBITED_VALUE_CHAR_NAMES['\f'] = "<FF>";
PROHIBITED_VALUE_CHAR_NAMES['\r'] = "<CR>";
}

static final int DEFAULT_SIZE_HINT = 16;

/**
Expand Down Expand Up @@ -545,15 +563,15 @@ final long getTimeMillisAndRemove(CharSequence name, long defaultValue) {
}

final void add(CharSequence name, String value) {
final AsciiString normalizedName = normalizeName(name);
final AsciiString normalizedName = HttpHeaderNames.of(name);
requireNonNull(value, "value");
final int h = normalizedName.hashCode();
final int i = index(h);
add0(h, i, normalizedName, value);
}

final void add(CharSequence name, Iterable<String> values) {
final AsciiString normalizedName = normalizeName(name);
final AsciiString normalizedName = HttpHeaderNames.of(name);
requireNonNull(values, "values");
final int h = normalizedName.hashCode();
final int i = index(h);
Expand All @@ -564,7 +582,7 @@ final void add(CharSequence name, Iterable<String> values) {
}

final void add(CharSequence name, String... values) {
final AsciiString normalizedName = normalizeName(name);
final AsciiString normalizedName = HttpHeaderNames.of(name);
requireNonNull(values, "values");
final int h = normalizedName.hashCode();
final int i = index(h);
Expand All @@ -590,7 +608,7 @@ final void addObject(CharSequence name, Object value) {
}

final void addObject(CharSequence name, Iterable<?> values) {
final AsciiString normalizedName = normalizeName(name);
final AsciiString normalizedName = HttpHeaderNames.of(name);
requireNonNull(values, "values");
for (Object v : values) {
requireNonNullElement(values, v);
Expand All @@ -599,7 +617,7 @@ final void addObject(CharSequence name, Iterable<?> values) {
}

final void addObject(CharSequence name, Object... values) {
final AsciiString normalizedName = normalizeName(name);
final AsciiString normalizedName = HttpHeaderNames.of(name);
requireNonNull(values, "values");
for (Object v : values) {
requireNonNullElement(values, v);
Expand Down Expand Up @@ -638,7 +656,7 @@ final void addTimeMillis(CharSequence name, long value) {
}

final void set(CharSequence name, String value) {
final AsciiString normalizedName = normalizeName(name);
final AsciiString normalizedName = HttpHeaderNames.of(name);
requireNonNull(value, "value");
final int h = normalizedName.hashCode();
final int i = index(h);
Expand All @@ -647,7 +665,7 @@ final void set(CharSequence name, String value) {
}

final void set(CharSequence name, Iterable<String> values) {
final AsciiString normalizedName = normalizeName(name);
final AsciiString normalizedName = HttpHeaderNames.of(name);
requireNonNull(values, "values");

final int h = normalizedName.hashCode();
Expand All @@ -661,7 +679,7 @@ final void set(CharSequence name, Iterable<String> values) {
}

final void set(CharSequence name, String... values) {
final AsciiString normalizedName = normalizeName(name);
final AsciiString normalizedName = HttpHeaderNames.of(name);
requireNonNull(values, "values");

final int h = normalizedName.hashCode();
Expand Down Expand Up @@ -739,7 +757,7 @@ final void setObject(CharSequence name, Object value) {
}

final void setObject(CharSequence name, Iterable<?> values) {
final AsciiString normalizedName = normalizeName(name);
final AsciiString normalizedName = HttpHeaderNames.of(name);
requireNonNull(values, "values");

final int h = normalizedName.hashCode();
Expand All @@ -753,7 +771,7 @@ final void setObject(CharSequence name, Iterable<?> values) {
}

final void setObject(CharSequence name, Object... values) {
final AsciiString normalizedName = normalizeName(name);
final AsciiString normalizedName = HttpHeaderNames.of(name);
requireNonNull(values, "values");

final int h = normalizedName.hashCode();
Expand Down Expand Up @@ -813,11 +831,6 @@ final void clear() {
size = 0;
}

private static AsciiString normalizeName(CharSequence name) {
checkArgument(requireNonNull(name, "name").length() > 0, "name is empty.");
return HttpHeaderNames.of(name);
}

private static void requireNonNullElement(Object values, @Nullable Object e) {
if (e == null) {
throw new NullPointerException("values contains null: " + values);
Expand All @@ -829,11 +842,44 @@ private int index(int hash) {
}

private void add0(int h, int i, AsciiString name, String value) {
validateValue(value);
// Update the hash table.
entries[i] = new HeaderEntry(h, name, value, entries[i]);
++size;
}

private static void validateValue(String value) {
final int valueLength = value.length();
for (int i = 0; i < valueLength; i++) {
final char ch = value.charAt(i);
if ((ch & PROHIBITED_VALUE_CHAR_MASK) != 0) { // ch >= 16
continue;
}

// ch < 16
if (PROHIBITED_VALUE_CHARS.get(ch)) {
throw new IllegalArgumentException(malformedHeaderValueMessage(value));
}
}
}

private static String malformedHeaderValueMessage(String value) {
final StringBuilder buf = new StringBuilder(IntMath.saturatedAdd(value.length(), 64));
buf.append("malformed header value: ");

final int valueLength = value.length();
for (int i = 0; i < valueLength; i++) {
final char ch = value.charAt(i);
if (PROHIBITED_VALUE_CHARS.get(ch)) {
buf.append(PROHIBITED_VALUE_CHAR_NAMES[ch]);
} else {
buf.append(ch);
}
}

return buf.toString();
}

private boolean addFast(Iterable<? extends Entry<? extends CharSequence, ?>> headers) {
if (!(headers instanceof HttpHeadersBase)) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ private static CharSequenceMap toLowercaseMap(Iterator<? extends CharSequence> v
final CharSequenceMap result = new CharSequenceMap(arraySizeHint);

while (valuesIter.hasNext()) {
final AsciiString lowerCased = HttpHeaderNames.of(valuesIter.next()).toLowerCase();
final AsciiString lowerCased = AsciiString.of(valuesIter.next()).toLowerCase();
try {
int index = lowerCased.forEachByte(FIND_COMMA);
if (index != -1) {
Expand Down
Loading

0 comments on commit b597f7a

Please sign in to comment.