Skip to content

Commit

Permalink
Merge pull request from GHSA-6mjq-h674-j845
Browse files Browse the repository at this point in the history
Motivation:

In theory the ClientHello can span multiple records and so reach the length of 16MB. This can result in high memory usage, we should allow the user to define a maximum length.

Modifications:

Add new constructor which allows to limit the maximum length of a ClientHello message.

Result:

Be able to guard against high memory usage when parsing ClientHello messages
  • Loading branch information
normanmaurer committed Jun 20, 2023
1 parent 1bb825b commit 535da17
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 11 deletions.
13 changes: 11 additions & 2 deletions handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,23 @@ private static String extractSniHostname(ByteBuf in) {
private String hostname;

/**
* @param handshakeTimeoutMillis the handshake timeout in milliseconds
* @param handshakeTimeoutMillis the handshake timeout in milliseconds
*/
protected AbstractSniHandler(long handshakeTimeoutMillis) {
this(0, handshakeTimeoutMillis);
}

/**
* @paramm maxClientHelloLength the maximum length of the client hello message.
* @param handshakeTimeoutMillis the handshake timeout in milliseconds
*/
protected AbstractSniHandler(int maxClientHelloLength, long handshakeTimeoutMillis) {
super(maxClientHelloLength);
this.handshakeTimeoutMillis = checkPositiveOrZero(handshakeTimeoutMillis, "handshakeTimeoutMillis");
}

public AbstractSniHandler() {
this(0L);
this(0, 0L);
}

@Override
Expand Down
36 changes: 31 additions & 5 deletions handler/src/main/java/io/netty/handler/ssl/SniHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ public SniHandler(Mapping<? super String, ? extends SslContext> mapping) {
* maintained by {@link Mapping}
*
* @param mapping the mapping of domain name to {@link SslContext}
* @param maxClientHelloLength the maximum length of the client hello message
* @param handshakeTimeoutMillis the handshake timeout in milliseconds
*/
public SniHandler(Mapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis) {
this(new AsyncMappingAdapter(mapping), handshakeTimeoutMillis);
public SniHandler(Mapping<? super String, ? extends SslContext> mapping,
int maxClientHelloLength, long handshakeTimeoutMillis) {
this(new AsyncMappingAdapter(mapping), maxClientHelloLength, handshakeTimeoutMillis);
}

/**
Expand All @@ -80,22 +82,46 @@ public SniHandler(DomainNameMapping<? extends SslContext> mapping) {
*/
@SuppressWarnings("unchecked")
public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping) {
this(mapping, 0L);
this(mapping, 0, 0L);
}

/**
* Creates a SNI detection handler with configured {@link SslContext}
* maintained by {@link AsyncMapping}
*
* @param mapping the mapping of domain name to {@link SslContext}
* @param maxClientHelloLength the maximum length of the client hello message
* @param handshakeTimeoutMillis the handshake timeout in milliseconds
*/
@SuppressWarnings("unchecked")
public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis) {
super(handshakeTimeoutMillis);
public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping,
int maxClientHelloLength, long handshakeTimeoutMillis) {
super(maxClientHelloLength, handshakeTimeoutMillis);
this.mapping = (AsyncMapping<String, SslContext>) ObjectUtil.checkNotNull(mapping, "mapping");
}

/**
* Creates a SNI detection handler with configured {@link SslContext}
* maintained by {@link Mapping}
*
* @param mapping the mapping of domain name to {@link SslContext}
* @param handshakeTimeoutMillis the handshake timeout in milliseconds
*/
public SniHandler(Mapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis) {
this(new AsyncMappingAdapter(mapping), handshakeTimeoutMillis);
}

/**
* Creates a SNI detection handler with configured {@link SslContext}
* maintained by {@link AsyncMapping}
*
* @param mapping the mapping of domain name to {@link SslContext}
* @param handshakeTimeoutMillis the handshake timeout in milliseconds
*/
public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis) {
this(mapping, 0, handshakeTimeoutMillis);
}

/**
* @return the selected hostname
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.internal.ObjectUtil;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
Expand All @@ -36,14 +38,32 @@
*/
public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler {

/**
* The maximum length of client hello message as defined by
* <a href="https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1">RFC5246</a>.
*/
public static final int MAX_CLIENT_HELLO_LENGTH = 0xFFFFFF;

private static final InternalLogger logger =
InternalLoggerFactory.getInstance(SslClientHelloHandler.class);

private final int maxClientHelloLength;
private boolean handshakeFailed;
private boolean suppressRead;
private boolean readPending;
private ByteBuf handshakeBuffer;

public SslClientHelloHandler() {
this(MAX_CLIENT_HELLO_LENGTH);
}

protected SslClientHelloHandler(int maxClientHelloLength) {
// 16MB is the maximum as per RFC:
// See https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1
this.maxClientHelloLength =
ObjectUtil.checkInRange(maxClientHelloLength, 0, MAX_CLIENT_HELLO_LENGTH, "maxClientHelloLength");
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (!suppressRead && !handshakeFailed) {
Expand Down Expand Up @@ -117,6 +137,15 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t
handshakeLength = in.getUnsignedMedium(readerIndex +
SslUtils.SSL_RECORD_HEADER_LENGTH + 1);

if (handshakeLength > maxClientHelloLength && maxClientHelloLength != 0) {
TooLongFrameException e = new TooLongFrameException(
"ClientHello length exceeds " + maxClientHelloLength +
": " + handshakeLength);
in.skipBytes(in.readableBytes());
ctx.fireUserEventTriggered(new SniCompletionEvent(e));
SslUtils.handleHandshakeFailure(ctx, e, true);
throw e;
}
// Consume handshakeType and handshakeLength (this sums up as 4 bytes)
readerIndex += 4;
packetLength -= 4;
Expand Down Expand Up @@ -161,6 +190,9 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t
} catch (NotSslRecordException e) {
// Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.
throw e;
} catch (TooLongFrameException e) {
// Just rethrow as in this case we also closed the channel
throw e;
} catch (Exception e) {
// unexpected encoding, ignore sni and use default
if (logger.isDebugEnabled()) {
Expand Down
57 changes: 53 additions & 4 deletions handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import javax.net.ssl.HandshakeCompletedEvent;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;

import io.netty.handler.codec.TooLongFrameException;
import io.netty.util.concurrent.Future;

import io.netty.bootstrap.Bootstrap;
Expand Down Expand Up @@ -715,14 +714,64 @@ private static List<ByteBuf> split(ByteBuf clientHello, int maxSize) {
return result;
}

@Test
public void testSniHandlerFailsOnTooBigClientHello() throws Exception {
SniHandler handler = new SniHandler(new Mapping<String, SslContext>() {
@Override
public SslContext map(String input) {
throw new UnsupportedOperationException("Should not be called");
}
}, 10, 0);

final AtomicReference<SniCompletionEvent> completionEventRef =
new AtomicReference<SniCompletionEvent>();
final EmbeddedChannel ch = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SniCompletionEvent) {
completionEventRef.set((SniCompletionEvent) evt);
}
}
});
final ByteBuf buffer = ch.alloc().buffer();
buffer.writeByte(0x16); // Content Type: Handshake
buffer.writeShort((short) 0x0303); // TLS 1.2
buffer.writeShort((short) 0x0006); // Packet length

// 16_777_215
buffer.writeByte((byte) 0x01); // Client Hello
buffer.writeMedium(0xFFFFFF); // Length
buffer.writeShort((short) 0x0303); // TLS 1.2

assertThrows(TooLongFrameException.class, new Executable() {
@Override
public void execute() throws Throwable {
ch.writeInbound(buffer);
}
});
try {
while (completionEventRef.get() == null) {
Thread.sleep(100);
// We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop.
ch.runPendingTasks();
}
SniCompletionEvent completionEvent = completionEventRef.get();
assertNotNull(completionEvent);
assertNotNull(completionEvent.cause());
assertEquals(TooLongFrameException.class, completionEvent.cause().getClass());
} finally {
ch.finishAndReleaseAll();
}
}

@Test
public void testSniHandlerFiresHandshakeTimeout() throws Exception {
SniHandler handler = new SniHandler(new Mapping<String, SslContext>() {
@Override
public SslContext map(String input) {
throw new UnsupportedOperationException("Should not be called");
}
}, 10);
}, 0, 10);

final AtomicReference<SniCompletionEvent> completionEventRef =
new AtomicReference<SniCompletionEvent>();
Expand Down Expand Up @@ -758,7 +807,7 @@ public void testSslHandlerFiresHandshakeTimeout(SslProvider provider) throws Exc
public SslContext map(String input) {
return context;
}
}, 100);
}, 0, 100);

final AtomicReference<SniCompletionEvent> sniCompletionEventRef =
new AtomicReference<SniCompletionEvent>();
Expand Down

0 comments on commit 535da17

Please sign in to comment.