Skip to content

Commit

Permalink
Add mechanism for early http header validation (#92220)
Browse files Browse the repository at this point in the history
This introduces a way to validate HTTP headers prior to reading
the request body.

Co-authored-by: Albert Zaharovits <albert.zaharovits@elastic.co>
  • Loading branch information
Tim-Brooks and albertzaharovits committed Mar 31, 2023
1 parent d5e93a8 commit 2ac20c0
Show file tree
Hide file tree
Showing 9 changed files with 890 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.http.netty4;

import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpObject;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.util.ReferenceCountUtil;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.TriConsumer;

import java.util.ArrayDeque;

import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.DROPPING_DATA_PERMANENTLY;
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.DROPPING_DATA_UNTIL_NEXT_REQUEST;
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.FORWARDING_DATA_UNTIL_NEXT_REQUEST;
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.QUEUEING_DATA;
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.WAITING_TO_START;

public class Netty4HttpHeaderValidator extends ChannelInboundHandlerAdapter {

public static final TriConsumer<HttpRequest, Channel, ActionListener<Void>> NOOP_VALIDATOR = ((
httpRequest,
channel,
listener) -> listener.onResponse(null));

private final TriConsumer<HttpRequest, Channel, ActionListener<Void>> validator;
private ArrayDeque<HttpObject> pending = new ArrayDeque<>(4);
private State state = WAITING_TO_START;

public Netty4HttpHeaderValidator(TriConsumer<HttpRequest, Channel, ActionListener<Void>> validator) {
this.validator = validator;
}

State getState() {
return state;
}

@SuppressWarnings("fallthrough")
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
assert msg instanceof HttpObject;
final HttpObject httpObject = (HttpObject) msg;

switch (state) {
case WAITING_TO_START:
assert pending.isEmpty();
pending.add(ReferenceCountUtil.retain(httpObject));
requestStart(ctx);
assert state == QUEUEING_DATA;
break;
case QUEUEING_DATA:
pending.add(ReferenceCountUtil.retain(httpObject));
break;
case FORWARDING_DATA_UNTIL_NEXT_REQUEST:
assert pending.isEmpty();
if (httpObject instanceof LastHttpContent) {
state = WAITING_TO_START;
}
ctx.fireChannelRead(httpObject);
break;
case DROPPING_DATA_UNTIL_NEXT_REQUEST:
assert pending.isEmpty();
if (httpObject instanceof LastHttpContent) {
state = WAITING_TO_START;
}
// fall-through
case DROPPING_DATA_PERMANENTLY:
assert pending.isEmpty();
ReferenceCountUtil.release(httpObject); // consume without enqueuing
break;
}

setAutoReadForState(ctx, state);
}

private void requestStart(ChannelHandlerContext ctx) {
assert state == WAITING_TO_START;

if (pending.isEmpty()) {
return;
}

final HttpObject httpObject = pending.getFirst();
final HttpRequest httpRequest;
if (httpObject instanceof HttpRequest && httpObject.decoderResult().isSuccess()) {
// a properly decoded HTTP start message is expected to begin validation
// anything else is probably an error that the downstream HTTP message aggregator will have to handle
httpRequest = (HttpRequest) httpObject;
} else {
httpRequest = null;
}

state = QUEUEING_DATA;

if (httpRequest == null) {
// this looks like a malformed request and will forward without validation
ctx.channel().eventLoop().submit(() -> forwardFullRequest(ctx));
} else {
validator.apply(httpRequest, ctx.channel(), new ActionListener<>() {
@Override
public void onResponse(Void unused) {
// Always use "Submit" to prevent reentrancy concerns if we are still on event loop
ctx.channel().eventLoop().submit(() -> forwardFullRequest(ctx));
}

@Override
public void onFailure(Exception e) {
// Always use "Submit" to prevent reentrancy concerns if we are still on event loop
ctx.channel().eventLoop().submit(() -> forwardRequestWithDecoderExceptionAndNoContent(ctx, e));
}
});
}
}

private void forwardFullRequest(ChannelHandlerContext ctx) {
assert ctx.channel().eventLoop().inEventLoop();
assert ctx.channel().config().isAutoRead() == false;
assert state == QUEUEING_DATA;

boolean fullRequestForwarded = forwardData(ctx, pending);

assert fullRequestForwarded || pending.isEmpty();
if (fullRequestForwarded) {
state = WAITING_TO_START;
requestStart(ctx);
} else {
state = FORWARDING_DATA_UNTIL_NEXT_REQUEST;
}

assert state == WAITING_TO_START || state == QUEUEING_DATA || state == FORWARDING_DATA_UNTIL_NEXT_REQUEST;
setAutoReadForState(ctx, state);
}

private void forwardRequestWithDecoderExceptionAndNoContent(ChannelHandlerContext ctx, Exception e) {
assert ctx.channel().eventLoop().inEventLoop();
assert ctx.channel().config().isAutoRead() == false;
assert state == QUEUEING_DATA;

HttpObject messageToForward = pending.getFirst();
boolean fullRequestDropped = dropData(pending);
if (messageToForward instanceof HttpContent toReplace) {
// if the request to forward contained data (which got dropped), replace with empty data
messageToForward = toReplace.replace(Unpooled.EMPTY_BUFFER);
}
messageToForward.setDecoderResult(DecoderResult.failure(e));
ctx.fireChannelRead(messageToForward);

assert fullRequestDropped || pending.isEmpty();
if (fullRequestDropped) {
state = WAITING_TO_START;
requestStart(ctx);
} else {
state = DROPPING_DATA_UNTIL_NEXT_REQUEST;
}

assert state == WAITING_TO_START || state == QUEUEING_DATA || state == DROPPING_DATA_UNTIL_NEXT_REQUEST;
setAutoReadForState(ctx, state);
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
state = DROPPING_DATA_PERMANENTLY;
while (true) {
if (dropData(pending) == false) {
break;
}
}
super.channelInactive(ctx);
}

private static boolean forwardData(ChannelHandlerContext ctx, ArrayDeque<HttpObject> pending) {
final int pendingMessages = pending.size();
try {
HttpObject toForward;
while ((toForward = pending.poll()) != null) {
ctx.fireChannelRead(toForward);
ReferenceCountUtil.release(toForward); // reference cnt incremented when enqueued
if (toForward instanceof LastHttpContent) {
return true;
}
}
return false;
} finally {
maybeResizePendingDown(pendingMessages, pending);
}
}

private static boolean dropData(ArrayDeque<HttpObject> pending) {
final int pendingMessages = pending.size();
try {
HttpObject toDrop;
while ((toDrop = pending.poll()) != null) {
ReferenceCountUtil.release(toDrop, 2); // 1 for enqueuing, 1 for consuming
if (toDrop instanceof LastHttpContent) {
return true;
}
}
return false;
} finally {
maybeResizePendingDown(pendingMessages, pending);
}
}

private static void maybeResizePendingDown(int largeSize, ArrayDeque<HttpObject> pending) {
if (pending.size() <= 4 && largeSize > 32) {
// Prevent the ArrayDeque from becoming forever large due to a single large message.
ArrayDeque<HttpObject> old = pending;
pending = new ArrayDeque<>(4);
pending.addAll(old);
}
}

private static void setAutoReadForState(ChannelHandlerContext ctx, State state) {
ctx.channel().config().setAutoRead((state == QUEUEING_DATA || state == DROPPING_DATA_PERMANENTLY) == false);
}

enum State {
WAITING_TO_START,
QUEUEING_DATA,
FORWARDING_DATA_UNTIL_NEXT_REQUEST,
DROPPING_DATA_UNTIL_NEXT_REQUEST,
DROPPING_DATA_PERMANENTLY
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.netty.handler.codec.http.HttpContentCompressor;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseEncoder;
Expand All @@ -35,6 +36,8 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.TriConsumer;
import org.elasticsearch.common.network.CloseableChannel;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.ClusterSettings;
Expand Down Expand Up @@ -143,6 +146,7 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
private final RecvByteBufAllocator recvByteBufAllocator;
private final TLSConfig tlsConfig;
private final AcceptChannelHandler.AcceptPredicate acceptChannelPredicate;
private final TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator;
private final int readTimeoutMillis;

private final int maxCompositeBufferComponents;
Expand All @@ -160,8 +164,8 @@ public Netty4HttpServerTransport(
SharedGroupFactory sharedGroupFactory,
Tracer tracer,
TLSConfig tlsConfig,
@Nullable AcceptChannelHandler.AcceptPredicate acceptChannelPredicate

@Nullable AcceptChannelHandler.AcceptPredicate acceptChannelPredicate,
@Nullable TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator
) {
super(
settings,
Expand All @@ -178,6 +182,7 @@ public Netty4HttpServerTransport(
this.sharedGroupFactory = sharedGroupFactory;
this.tlsConfig = tlsConfig;
this.acceptChannelPredicate = acceptChannelPredicate;
this.headerValidator = headerValidator;

this.pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings);

Expand Down Expand Up @@ -323,7 +328,7 @@ public void onException(HttpChannel channel, Exception cause) {
}

public ChannelHandler configureServerChannelHandler() {
return new HttpChannelHandler(this, handlingSettings, tlsConfig, acceptChannelPredicate);
return new HttpChannelHandler(this, handlingSettings, tlsConfig, acceptChannelPredicate, headerValidator);
}

static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel");
Expand All @@ -335,17 +340,20 @@ protected static class HttpChannelHandler extends ChannelInitializer<Channel> {
private final HttpHandlingSettings handlingSettings;
private final TLSConfig tlsConfig;
private final BiPredicate<String, InetSocketAddress> acceptChannelPredicate;
private final TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator;

protected HttpChannelHandler(
final Netty4HttpServerTransport transport,
final HttpHandlingSettings handlingSettings,
final TLSConfig tlsConfig,
@Nullable final BiPredicate<String, InetSocketAddress> acceptChannelPredicate
@Nullable final BiPredicate<String, InetSocketAddress> acceptChannelPredicate,
@Nullable final TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator
) {
this.transport = transport;
this.handlingSettings = handlingSettings;
this.tlsConfig = tlsConfig;
this.acceptChannelPredicate = acceptChannelPredicate;
this.headerValidator = headerValidator;
}

@Override
Expand Down Expand Up @@ -374,11 +382,17 @@ protected void initChannel(Channel ch) throws Exception {
handlingSettings.maxChunkSize()
);
decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR);
ch.pipeline().addLast("decoder", decoder); // parses the HTTP bytes request into HTTP message pieces
if (headerValidator != null) {
// runs a validation function on the first HTTP message piece which contains all the headers
// if validation passes, the pieces of that particular request are forwarded, otherwise they are discarded
ch.pipeline().addLast("header_validator", new Netty4HttpHeaderValidator(headerValidator));
}
// combines the HTTP message pieces into a single full HTTP request (with headers and body)
final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.maxContentLength());
aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents);
ch.pipeline()
.addLast("decoder", decoder)
.addLast("decoder_compress", new HttpContentDecompressor())
.addLast("decoder_compress", new HttpContentDecompressor()) // this handles request body decompression
.addLast("encoder", new HttpResponseEncoder() {
@Override
protected boolean isContentAlwaysEmpty(HttpResponse msg) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ public Map<String, Supplier<HttpServerTransport>> getHttpTransports(
getSharedGroupFactory(settings),
tracer,
TLSConfig.noTLS(),
null,
null
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ public void dispatchBadRequest(RestChannel channel, ThreadContext threadContext,
new SharedGroupFactory(Settings.EMPTY),
Tracer.NOOP,
TLSConfig.noTLS(),
null
null,
randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
)
) {
httpServerTransport.start();
Expand Down

0 comments on commit 2ac20c0

Please sign in to comment.