diff --git a/handler/src/main/java/io/netty/handler/ssl/SniHandler.java b/handler/src/main/java/io/netty/handler/ssl/SniHandler.java index 25b472ea260..77043fc4270 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SniHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SniHandler.java @@ -37,14 +37,19 @@ */ public class SniHandler extends ByteToMessageDecoder { + // Maximal number of ssl records to inspect before fallback to the default SslContext. + private static final int MAX_SSL_RECORDS = 4; + private static final InternalLogger logger = InternalLoggerFactory.getInstance(SniHandler.class); + private static final Selection EMPTY_SELECTION = new Selection(null, null); + private final DomainNameMapping mapping; - private boolean handshaken; - private volatile String hostname; - private volatile SslContext selectedContext; + private boolean handshakeFailed; + + private volatile Selection selection = EMPTY_SELECTION; /** * Create a SNI detection handler with configured {@link SslContext} @@ -59,127 +64,159 @@ public SniHandler(DomainNameMapping mapping) { } this.mapping = (DomainNameMapping) mapping; - handshaken = false; } /** * @return the selected hostname */ public String hostname() { - return hostname; + return selection.hostname; } /** * @return the selected sslcontext */ public SslContext sslContext() { - return selectedContext; + return selection.context; } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - if (!handshaken && in.readableBytes() >= 5) { - String hostname = sniHostNameFromHandshakeInfo(in); - if (hostname != null) { - hostname = IDN.toASCII(hostname, IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US); + if (!handshakeFailed && in.readableBytes() >= SslUtils.SSL_RECORD_HEADER_LENGTH) { + int writerIndex = in.writerIndex(); + int readerIndex = in.readerIndex(); + try { + loop: for (int i = 0; i < MAX_SSL_RECORDS; i++) { + int command = in.getUnsignedByte(readerIndex); + + // tls, but not handshake command + switch (command) { + case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: + case SslUtils.SSL_CONTENT_TYPE_ALERT: + int len = SslUtils.getEncryptedPacketLength(in, readerIndex); + + // Not an SSL/TLS packet + if (len == -1) { + handshakeFailed = true; + NotSslRecordException e = new NotSslRecordException( + "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); + in.skipBytes(in.readableBytes()); + ctx.fireExceptionCaught(e); + + SslUtils.notifyHandshakeFailure(ctx, e); + return; + } + if (writerIndex - readerIndex - SslUtils.SSL_RECORD_HEADER_LENGTH < len) { + // Not enough data + return; + } + // increase readerIndex and try again. + readerIndex += len; + continue; + case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE: + int majorVersion = in.getUnsignedByte(readerIndex + 1); + + // SSLv3 or TLS + if (majorVersion == 3) { + int packetLength = in.getUnsignedShort(readerIndex + 3) + + SslUtils.SSL_RECORD_HEADER_LENGTH; + + if (in.readableBytes() < packetLength) { + // client hello incomplete try again to decode once more data is ready. + return; + } + // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2 + // + // Decode the ssl client hello packet. + // We have to skip bytes until SessionID (which sum to 43 bytes). + // + // struct { + // ProtocolVersion client_version; + // Random random; + // SessionID session_id; + // CipherSuite cipher_suites<2..2^16-2>; + // CompressionMethod compression_methods<1..2^8-1>; + // select (extensions_present) { + // case false: + // struct {}; + // case true: + // Extension extensions<0..2^16-1>; + // }; + // } ClientHello; + // + int offset = readerIndex + 43; + + int sessionIdLength = in.getUnsignedByte(offset); + offset += sessionIdLength + 1; + + int cipherSuitesLength = in.getUnsignedShort(offset); + offset += cipherSuitesLength + 2; + + int compressionMethodLength = in.getUnsignedByte(offset); + offset += compressionMethodLength + 1; + + int extensionsLength = in.getUnsignedShort(offset); + offset += 2; + int extensionsLimit = offset + extensionsLength; + + while (offset < extensionsLimit) { + int extensionType = in.getUnsignedShort(offset); + offset += 2; + + int extensionLength = in.getUnsignedShort(offset); + offset += 2; + + // SNI + // See https://tools.ietf.org/html/rfc6066#page-6 + if (extensionType == 0) { + int serverNameType = in.getUnsignedByte(offset + 2); + if (serverNameType == 0) { + int serverNameLength = in.getUnsignedShort(offset + 3); + String hostname = in.toString(offset + 5, serverNameLength, + CharsetUtil.UTF_8); + select(ctx, IDN.toASCII(hostname, + IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US)); + return; + } else { + // invalid enum value + break loop; + } + } + + offset += extensionLength; + } + } + // Fall-through + default: + //not tls, ssl or application data, do not try sni + break loop; + } + } + } catch (Throwable e) { + // unexpected encoding, ignore sni and use default + if (logger.isDebugEnabled()) { + logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e); + } } - this.hostname = hostname; - - // the mapping will return default context when this.hostname is null - selectedContext = mapping.map(hostname); - } - - if (handshaken) { - SslHandler sslHandler = selectedContext.newHandler(ctx.alloc()); - ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler); + // Just select the default SslContext + select(ctx, null); } } - private String sniHostNameFromHandshakeInfo(ByteBuf in) { - int readerIndex = in.readerIndex(); - try { - int command = in.getUnsignedByte(readerIndex); - - // tls, but not handshake command - switch (command) { - case SslConstants.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: - case SslConstants.SSL_CONTENT_TYPE_ALERT: - case SslConstants.SSL_CONTENT_TYPE_APPLICATION_DATA: - return null; - case SslConstants.SSL_CONTENT_TYPE_HANDSHAKE: - break; - default: - //not tls or sslv3, do not try sni - handshaken = true; - return null; - } - - int majorVersion = in.getUnsignedByte(readerIndex + 1); - - // SSLv3 or TLS - if (majorVersion == 3) { - - int packetLength = in.getUnsignedShort(readerIndex + 3) + 5; - - if (in.readableBytes() >= packetLength) { - // decode the ssl client hello packet - // we have to skip some var-length fields - int offset = readerIndex + 43; - - int sessionIdLength = in.getUnsignedByte(offset); - offset += sessionIdLength + 1; - - int cipherSuitesLength = in.getUnsignedShort(offset); - offset += cipherSuitesLength + 2; - - int compressionMethodLength = in.getUnsignedByte(offset); - offset += compressionMethodLength + 1; - - int extensionsLength = in.getUnsignedShort(offset); - offset += 2; - int extensionsLimit = offset + extensionsLength; - - while (offset < extensionsLimit) { - int extensionType = in.getUnsignedShort(offset); - offset += 2; - - int extensionLength = in.getUnsignedShort(offset); - offset += 2; - - // SNI - if (extensionType == 0) { - handshaken = true; - int serverNameType = in.getUnsignedByte(offset + 2); - if (serverNameType == 0) { - int serverNameLength = in.getUnsignedShort(offset + 3); - return in.toString(offset + 5, serverNameLength, - CharsetUtil.UTF_8); - } else { - // invalid enum value - return null; - } - } + private void select(ChannelHandlerContext ctx, String hostname) { + SslContext selectedContext = mapping.map(hostname); + selection = new Selection(selectedContext, hostname); + SslHandler sslHandler = selectedContext.newHandler(ctx.alloc()); + ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler); + } - offset += extensionLength; - } + private static final class Selection { + final SslContext context; + final String hostname; - handshaken = true; - return null; - } else { - // client hello incomplete - return null; - } - } else { - handshaken = true; - return null; - } - } catch (Throwable e) { - // unexpected encoding, ignore sni and use default - if (logger.isDebugEnabled()) { - logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e); - } - handshaken = true; - return null; + Selection(SslContext context, String hostname) { + this.context = context; + this.hostname = hostname; } } } diff --git a/handler/src/main/java/io/netty/handler/ssl/SslConstants.java b/handler/src/main/java/io/netty/handler/ssl/SslConstants.java deleted file mode 100644 index ec938598698..00000000000 --- a/handler/src/main/java/io/netty/handler/ssl/SslConstants.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2014 The Netty Project - * - * The Netty Project licenses this file to you under the Apache License, - * version 2.0 (the "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - */ -package io.netty.handler.ssl; - -/** - * Constants for SSL packets. - */ -final class SslConstants { - - /** - * change cipher spec - */ - public static final int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20; - - /** - * alert - */ - public static final int SSL_CONTENT_TYPE_ALERT = 21; - - /** - * handshake - */ - public static final int SSL_CONTENT_TYPE_HANDSHAKE = 22; - - /** - * application data - */ - public static final int SSL_CONTENT_TYPE_APPLICATION_DATA = 23; - - private SslConstants() { - } -} diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index b91cf238e19..7dde9d0445f 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -66,6 +66,8 @@ import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; +import static io.netty.handler.ssl.SslUtils.getEncryptedPacketLength; + /** * Adds SSL * · TLS and StartTLS support to a {@link Channel}. Please refer @@ -813,84 +815,13 @@ private boolean ignoreException(Throwable t) { * Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read. */ public static boolean isEncrypted(ByteBuf buffer) { - if (buffer.readableBytes() < 5) { - throw new IllegalArgumentException("buffer must have at least 5 readable bytes"); + if (buffer.readableBytes() < SslUtils.SSL_RECORD_HEADER_LENGTH) { + throw new IllegalArgumentException( + "buffer must have at least " + SslUtils.SSL_RECORD_HEADER_LENGTH + " readable bytes"); } return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1; } - /** - * Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase - * the readerIndex of the given {@link ByteBuf}. - * - * @param buffer - * The {@link ByteBuf} to read from. Be aware that it must have at least 5 bytes to read, - * otherwise it will throw an {@link IllegalArgumentException}. - * @return length - * The length of the encrypted packet that is included in the buffer. This will - * return {@code -1} if the given {@link ByteBuf} is not encrypted at all. - * @throws IllegalArgumentException - * Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read. - */ - private static int getEncryptedPacketLength(ByteBuf buffer, int offset) { - int packetLength = 0; - - // SSLv3 or TLS - Check ContentType - boolean tls; - switch (buffer.getUnsignedByte(offset)) { - case 20: // change_cipher_spec - case 21: // alert - case 22: // handshake - case 23: // application_data - tls = true; - break; - default: - // SSLv2 or bad data - tls = false; - } - - if (tls) { - // SSLv3 or TLS - Check ProtocolVersion - int majorVersion = buffer.getUnsignedByte(offset + 1); - if (majorVersion == 3) { - // SSLv3 or TLS - packetLength = buffer.getUnsignedShort(offset + 3) + 5; - if (packetLength <= 5) { - // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) - tls = false; - } - } else { - // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) - tls = false; - } - } - - if (!tls) { - // SSLv2 or bad data - Check the version - boolean sslv2 = true; - int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3; - int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1); - if (majorVersion == 2 || majorVersion == 3) { - // SSLv2 - if (headerLength == 2) { - packetLength = (buffer.getShort(offset) & 0x7FFF) + 2; - } else { - packetLength = (buffer.getShort(offset) & 0x3FFF) + 3; - } - if (packetLength <= headerLength) { - sslv2 = false; - } - } else { - sslv2 = false; - } - - if (!sslv2) { - return -1; - } - } - return packetLength; - } - @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws SSLException { final int startOffset = in.readerIndex(); @@ -913,7 +844,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t while (totalLength < OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) { final int readableBytes = endOffset - offset; - if (readableBytes < 5) { + if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) { break; } @@ -1299,8 +1230,7 @@ private void setHandshakeFailure(ChannelHandlerContext ctx, Throwable cause, boo private void notifyHandshakeFailure(Throwable cause) { if (handshakePromise.tryFailure(cause)) { - ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); - ctx.close(); + SslUtils.notifyHandshakeFailure(ctx, cause); } } diff --git a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java new file mode 100644 index 00000000000..6bf75d38d19 --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java @@ -0,0 +1,127 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; + +/** + * Constants for SSL packets. + */ +final class SslUtils { + + /** + * change cipher spec + */ + public static final int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20; + + /** + * alert + */ + public static final int SSL_CONTENT_TYPE_ALERT = 21; + + /** + * handshake + */ + public static final int SSL_CONTENT_TYPE_HANDSHAKE = 22; + + /** + * application data + */ + public static final int SSL_CONTENT_TYPE_APPLICATION_DATA = 23; + + /** + * the length of the ssl record header (in bytes) + */ + public static final int SSL_RECORD_HEADER_LENGTH = 5; + + /** + * Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase + * the readerIndex of the given {@link ByteBuf}. + * + * @param buffer + * The {@link ByteBuf} to read from. Be aware that it must have at least + * {@link #SSL_RECORD_HEADER_LENGTH} bytes to read, + * otherwise it will throw an {@link IllegalArgumentException}. + * @return length + * The length of the encrypted packet that is included in the buffer. This will + * return {@code -1} if the given {@link ByteBuf} is not encrypted at all. + * @throws IllegalArgumentException + * Is thrown if the given {@link ByteBuf} has not at least {@link #SSL_RECORD_HEADER_LENGTH} + * bytes to read. + */ + static int getEncryptedPacketLength(ByteBuf buffer, int offset) { + int packetLength = 0; + + // SSLv3 or TLS - Check ContentType + boolean tls; + switch (buffer.getUnsignedByte(offset)) { + case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: + case SSL_CONTENT_TYPE_ALERT: + case SSL_CONTENT_TYPE_HANDSHAKE: + case SSL_CONTENT_TYPE_APPLICATION_DATA: + tls = true; + break; + default: + // SSLv2 or bad data + tls = false; + } + + if (tls) { + // SSLv3 or TLS - Check ProtocolVersion + int majorVersion = buffer.getUnsignedByte(offset + 1); + if (majorVersion == 3) { + // SSLv3 or TLS + packetLength = buffer.getUnsignedShort(offset + 3) + SSL_RECORD_HEADER_LENGTH; + if (packetLength <= SSL_RECORD_HEADER_LENGTH) { + // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) + tls = false; + } + } else { + // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) + tls = false; + } + } + + if (!tls) { + // SSLv2 or bad data - Check the version + int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3; + int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1); + if (majorVersion == 2 || majorVersion == 3) { + // SSLv2 + if (headerLength == 2) { + packetLength = (buffer.getShort(offset) & 0x7FFF) + 2; + } else { + packetLength = (buffer.getShort(offset) & 0x3FFF) + 3; + } + if (packetLength <= headerLength) { + return -1; + } + } else { + return -1; + } + } + return packetLength; + } + + static void notifyHandshakeFailure(ChannelHandlerContext ctx, Throwable cause) { + ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); + ctx.close(); + } + + private SslUtils() { + } +}