Skip to content

Commit

Permalink
Correctly handle non handshake commands when using SniHandler
Browse files Browse the repository at this point in the history
Motivation:

As we can only handle handshake commands to parse SNI we should try to skip alert and change cipher spec commands a few times before we fallback to use a default SslContext.

Modifications:

- Use default SslContext if no application data command was received
- Use default SslContext if after 4 commands we not received a handshake command
- Simplify code
- Eliminate multiple volatile fields
- Rename SslConstants to SslUtils
- Share code between SslHandler and SniHandler by moving stuff to SslUtils

Result:

Correct handling of non handshake commands and cleaner code.
  • Loading branch information
normanmaurer committed Jan 14, 2016
1 parent 951bacb commit 27118ec
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 224 deletions.
241 changes: 139 additions & 102 deletions handler/src/main/java/io/netty/handler/ssl/SniHandler.java
Expand Up @@ -37,14 +37,19 @@
*/
public class SniHandler extends ByteToMessageDecoder {

// Maximal number of ssl records to inspect before fallback to the default SslContext.
private static final int MAX_SSL_RECORDS = 4;

private static final InternalLogger logger =
InternalLoggerFactory.getInstance(SniHandler.class);

private static final Selection EMPTY_SELECTION = new Selection(null, null);

private final DomainNameMapping<SslContext> mapping;

private boolean handshaken;
private volatile String hostname;
private volatile SslContext selectedContext;
private boolean handshakeFailed;

private volatile Selection selection = EMPTY_SELECTION;

/**
* Create a SNI detection handler with configured {@link SslContext}
Expand All @@ -59,127 +64,159 @@ public SniHandler(DomainNameMapping<? extends SslContext> mapping) {
}

this.mapping = (DomainNameMapping<SslContext>) mapping;
handshaken = false;
}

/**
* @return the selected hostname
*/
public String hostname() {
return hostname;
return selection.hostname;
}

/**
* @return the selected sslcontext
*/
public SslContext sslContext() {
return selectedContext;
return selection.context;
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (!handshaken && in.readableBytes() >= 5) {
String hostname = sniHostNameFromHandshakeInfo(in);
if (hostname != null) {
hostname = IDN.toASCII(hostname, IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US);
if (!handshakeFailed && in.readableBytes() >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
int writerIndex = in.writerIndex();
int readerIndex = in.readerIndex();
try {
loop: for (int i = 0; i < MAX_SSL_RECORDS; i++) {
int command = in.getUnsignedByte(readerIndex);

// tls, but not handshake command
switch (command) {
case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SslUtils.SSL_CONTENT_TYPE_ALERT:
int len = SslUtils.getEncryptedPacketLength(in, readerIndex);

// Not an SSL/TLS packet
if (len == -1) {
handshakeFailed = true;
NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
in.skipBytes(in.readableBytes());
ctx.fireExceptionCaught(e);

SslUtils.notifyHandshakeFailure(ctx, e);
return;
}
if (writerIndex - readerIndex - SslUtils.SSL_RECORD_HEADER_LENGTH < len) {
// Not enough data
return;
}
// increase readerIndex and try again.
readerIndex += len;
continue;
case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
int majorVersion = in.getUnsignedByte(readerIndex + 1);

// SSLv3 or TLS
if (majorVersion == 3) {
int packetLength = in.getUnsignedShort(readerIndex + 3)
+ SslUtils.SSL_RECORD_HEADER_LENGTH;

if (in.readableBytes() < packetLength) {
// client hello incomplete try again to decode once more data is ready.
return;
}
// See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
//
// Decode the ssl client hello packet.
// We have to skip bytes until SessionID (which sum to 43 bytes).
//
// struct {
// ProtocolVersion client_version;
// Random random;
// SessionID session_id;
// CipherSuite cipher_suites<2..2^16-2>;
// CompressionMethod compression_methods<1..2^8-1>;
// select (extensions_present) {
// case false:
// struct {};
// case true:
// Extension extensions<0..2^16-1>;
// };
// } ClientHello;
//
int offset = readerIndex + 43;

int sessionIdLength = in.getUnsignedByte(offset);
offset += sessionIdLength + 1;

int cipherSuitesLength = in.getUnsignedShort(offset);
offset += cipherSuitesLength + 2;

int compressionMethodLength = in.getUnsignedByte(offset);
offset += compressionMethodLength + 1;

int extensionsLength = in.getUnsignedShort(offset);
offset += 2;
int extensionsLimit = offset + extensionsLength;

while (offset < extensionsLimit) {
int extensionType = in.getUnsignedShort(offset);
offset += 2;

int extensionLength = in.getUnsignedShort(offset);
offset += 2;

// SNI
// See https://tools.ietf.org/html/rfc6066#page-6
if (extensionType == 0) {
int serverNameType = in.getUnsignedByte(offset + 2);
if (serverNameType == 0) {
int serverNameLength = in.getUnsignedShort(offset + 3);
String hostname = in.toString(offset + 5, serverNameLength,
CharsetUtil.UTF_8);
select(ctx, IDN.toASCII(hostname,
IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US));
return;
} else {
// invalid enum value
break loop;
}
}

offset += extensionLength;
}
}
// Fall-through
default:
//not tls, ssl or application data, do not try sni
break loop;
}
}
} catch (Throwable e) {
// unexpected encoding, ignore sni and use default
if (logger.isDebugEnabled()) {
logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
}
}
this.hostname = hostname;

// the mapping will return default context when this.hostname is null
selectedContext = mapping.map(hostname);
}

if (handshaken) {
SslHandler sslHandler = selectedContext.newHandler(ctx.alloc());
ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler);
// Just select the default SslContext
select(ctx, null);
}
}

private String sniHostNameFromHandshakeInfo(ByteBuf in) {
int readerIndex = in.readerIndex();
try {
int command = in.getUnsignedByte(readerIndex);

// tls, but not handshake command
switch (command) {
case SslConstants.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SslConstants.SSL_CONTENT_TYPE_ALERT:
case SslConstants.SSL_CONTENT_TYPE_APPLICATION_DATA:
return null;
case SslConstants.SSL_CONTENT_TYPE_HANDSHAKE:
break;
default:
//not tls or sslv3, do not try sni
handshaken = true;
return null;
}

int majorVersion = in.getUnsignedByte(readerIndex + 1);

// SSLv3 or TLS
if (majorVersion == 3) {

int packetLength = in.getUnsignedShort(readerIndex + 3) + 5;

if (in.readableBytes() >= packetLength) {
// decode the ssl client hello packet
// we have to skip some var-length fields
int offset = readerIndex + 43;

int sessionIdLength = in.getUnsignedByte(offset);
offset += sessionIdLength + 1;

int cipherSuitesLength = in.getUnsignedShort(offset);
offset += cipherSuitesLength + 2;

int compressionMethodLength = in.getUnsignedByte(offset);
offset += compressionMethodLength + 1;

int extensionsLength = in.getUnsignedShort(offset);
offset += 2;
int extensionsLimit = offset + extensionsLength;

while (offset < extensionsLimit) {
int extensionType = in.getUnsignedShort(offset);
offset += 2;

int extensionLength = in.getUnsignedShort(offset);
offset += 2;

// SNI
if (extensionType == 0) {
handshaken = true;
int serverNameType = in.getUnsignedByte(offset + 2);
if (serverNameType == 0) {
int serverNameLength = in.getUnsignedShort(offset + 3);
return in.toString(offset + 5, serverNameLength,
CharsetUtil.UTF_8);
} else {
// invalid enum value
return null;
}
}
private void select(ChannelHandlerContext ctx, String hostname) {
SslContext selectedContext = mapping.map(hostname);
selection = new Selection(selectedContext, hostname);
SslHandler sslHandler = selectedContext.newHandler(ctx.alloc());
ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler);
}

offset += extensionLength;
}
private static final class Selection {
final SslContext context;
final String hostname;

handshaken = true;
return null;
} else {
// client hello incomplete
return null;
}
} else {
handshaken = true;
return null;
}
} catch (Throwable e) {
// unexpected encoding, ignore sni and use default
if (logger.isDebugEnabled()) {
logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
}
handshaken = true;
return null;
Selection(SslContext context, String hostname) {
this.context = context;
this.hostname = hostname;
}
}
}
45 changes: 0 additions & 45 deletions handler/src/main/java/io/netty/handler/ssl/SslConstants.java

This file was deleted.

0 comments on commit 27118ec

Please sign in to comment.