Skip to content

Commit

Permalink
Expose SniHandler's replaceHandler() so that users can implement cust…
Browse files Browse the repository at this point in the history
…om behavior.

Motivation

The SniHandler is currently hiding its replaceHandler() method and everything that comes with it. The user has no easy way of getting a hold onto the SslContext for the purpose of reference counting for example. The SniHandler does have getter methods for the SslContext and hostname but they're not very practical or useful. For one the SniHandler will remove itself from the pipeline and we'd have to track a reference of it externally and as we saw in #5745 it'll possibly leave its internal "selection" object with the "EMPTY_SELECTION" value (i.e. we've just lost track of the SslContext).

Modifications

Expose replaceHandler() and allow the user to override it and get a hold onto the hostname, SslContext and SslHandler that will replace the SniHandler.

Result

It's possible to get a hold onto the SslContext, the hostname and the SslHandler that is about to replace the SniHandler. Users can add additional behavior.
  • Loading branch information
Roger Kapsi authored and Scottmitch committed Aug 29, 2016
1 parent 1208b90 commit f97866d
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 29 deletions.
52 changes: 37 additions & 15 deletions handler/src/main/java/io/netty/handler/ssl/SniHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
*/
package io.netty.handler.ssl;

import java.net.IDN;
import java.net.SocketAddress;
import java.util.List;
import java.util.Locale;

import javax.net.ssl.SSLEngine;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelHandlerContext;
Expand All @@ -34,13 +41,6 @@
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

import java.net.IDN;
import java.net.SocketAddress;
import java.util.List;
import java.util.Locale;

import javax.net.ssl.SSLEngine;

/**
* <p>Enables <a href="https://tools.ietf.org/html/rfc3546#section-3.1">SNI
* (Server Name Indication)</a> extension for server side SSL. For clients
Expand Down Expand Up @@ -277,7 +277,7 @@ private void select(final ChannelHandlerContext ctx, final String hostname) {
Future<SslContext> future = mapping.map(hostname, ctx.executor().<SslContext>newPromise());
if (future.isDone()) {
if (future.isSuccess()) {
replaceHandler(ctx, new Selection(future.getNow(), hostname));
onSslContext(ctx, hostname, future.getNow());
} else {
throw new DecoderException("failed to get the SslContext for " + hostname, future.cause());
}
Expand All @@ -289,7 +289,7 @@ public void operationComplete(Future<SslContext> future) throws Exception {
try {
suppressRead = false;
if (future.isSuccess()) {
replaceHandler(ctx, new Selection(future.getNow(), hostname));
onSslContext(ctx, hostname, future.getNow());
} else {
ctx.fireExceptionCaught(new DecoderException("failed to get the SslContext for "
+ hostname, future.cause()));
Expand All @@ -305,19 +305,41 @@ public void operationComplete(Future<SslContext> future) throws Exception {
}
}

private void replaceHandler(ChannelHandlerContext ctx, Selection selection) {
SSLEngine sslEngine = null;
this.selection = selection;
/**
* Called upon successful completion of the {@link AsyncMapping}'s {@link Future}.
*
* @see #select(ChannelHandlerContext, String)
*/
private void onSslContext(ChannelHandlerContext ctx, String hostname, SslContext sslContext) {
this.selection = new Selection(sslContext, hostname);
try {
sslEngine = selection.context.newEngine(ctx.alloc());
ctx.pipeline().replace(this, SslHandler.class.getName(), selection.context.newHandler(sslEngine));
replaceHandler(ctx, hostname, sslContext);
} catch (Throwable cause) {
this.selection = EMPTY_SELECTION;
ctx.fireExceptionCaught(cause);
}
}

/**
* The default implementation of this method will simply replace {@code this} {@link SniHandler}
* instance with a {@link SslHandler}. Users may override this method to implement custom behavior.
*
* Please be aware that this method may get called after a client has already disconnected and
* custom implementations must take it into consideration when overriding this method.
*
* It's also possible for the hostname argument to be {@code null}.
*/
protected void replaceHandler(ChannelHandlerContext ctx, String hostname, SslContext sslContext) throws Exception {
SSLEngine sslEngine = null;
try {
sslEngine = sslContext.newEngine(ctx.alloc());
ctx.pipeline().replace(this, SslHandler.class.getName(), SslContext.newHandler(sslEngine));
sslEngine = null;
} finally {
// Since the SslHandler was not inserted into the pipeline the ownership of the SSLEngine was not
// transferred to the SslHandler.
// See https://github.com/netty/netty/issues/5678
ReferenceCountUtil.safeRelease(sslEngine);
ctx.fireExceptionCaught(cause);
}
}

Expand Down
183 changes: 169 additions & 14 deletions handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,52 @@

package io.netty.handler.ssl;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeTrue;

import java.io.File;
import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import javax.net.ssl.SSLEngine;
import javax.xml.bind.DatatypeConverter;

import org.junit.Test;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.DomainNameMapping;
import io.netty.util.DomainNameMappingBuilder;
import io.netty.util.Mapping;
import io.netty.util.ReferenceCountUtil;
import org.junit.Test;

import javax.xml.bind.DatatypeConverter;
import java.io.File;
import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import io.netty.util.ReferenceCounted;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.ObjectUtil;

public class SniHandlerTest {

Expand All @@ -60,10 +76,16 @@ private static ApplicationProtocolConfig newApnConfig() {
}

private static SslContext makeSslContext() throws Exception {
return makeSslContext(null);
}

private static SslContext makeSslContext(SslProvider provider) throws Exception {
File keyFile = new File(SniHandlerTest.class.getResource("test_encrypted.pem").getFile());
File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile());

return SslContextBuilder.forServer(crtFile, keyFile, "12345").applicationProtocolConfig(newApnConfig()).build();
return SslContextBuilder.forServer(crtFile, keyFile, "12345")
.sslProvider(provider)
.applicationProtocolConfig(newApnConfig()).build();
}

private static SslContext makeSslClientContext() throws Exception {
Expand Down Expand Up @@ -229,4 +251,137 @@ protected void configurePipeline(ChannelHandlerContext ctx, String protocol) thr
group.shutdownGracefully(0, 0, TimeUnit.MICROSECONDS);
}
}

@Test(timeout = 10L * 1000L)
public void testReplaceHandler() throws Exception {

assumeTrue(OpenSsl.isAvailable());

final String sniHost = "sni.netty.io";
LocalAddress address = new LocalAddress("testReplaceHandler-" + Math.random());
EventLoopGroup group = new DefaultEventLoopGroup(1);
Channel sc = null;
Channel cc = null;

SelfSignedCertificate cert = new SelfSignedCertificate();

try {
final SslContext sslServerContext = SslContextBuilder
.forServer(cert.key(), cert.cert())
.sslProvider(SslProvider.OPENSSL)
.build();

final Mapping<String, SslContext> mapping = new Mapping<String, SslContext>() {
@Override
public SslContext map(String input) {
return sslServerContext;
}
};

final Promise<Void> releasePromise = group.next().newPromise();

final SniHandler handler = new SniHandler(mapping) {
@Override
protected void replaceHandler(ChannelHandlerContext ctx,
String hostname, final SslContext sslContext)
throws Exception {

boolean success = false;
try {
// The SniHandler's replaceHandler() method allows us to implement custom behavior.
// As an example, we want to release() the SslContext upon channelInactive() or rather
// when the SslHandler closes it's SslEngine. If you take a close look at SslHandler
// you'll see that it's doing it in the #handlerRemoved0() method.

SSLEngine sslEngine = sslContext.newEngine(ctx.alloc());
try {
SslHandler customSslHandler = new CustomSslHandler(sslContext, sslEngine) {
@Override
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
try {
super.handlerRemoved0(ctx);
} finally {
releasePromise.trySuccess(null);
}
}
};
ctx.pipeline().replace(this, CustomSslHandler.class.getName(), customSslHandler);
success = true;
} finally {
if (!success) {
ReferenceCountUtil.safeRelease(sslEngine);
}
}
} finally {
if (!success) {
ReferenceCountUtil.safeRelease(sslContext);
releasePromise.cancel(true);
}
}
}
};

ServerBootstrap sb = new ServerBootstrap();
sc = sb.group(group).channel(LocalServerChannel.class).childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addFirst(handler);
}
}).bind(address).syncUninterruptibly().channel();

SslContext sslContext = SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE)
.build();

Bootstrap cb = new Bootstrap();
cc = cb.group(group).channel(LocalChannel.class).handler(new SslHandler(
sslContext.newEngine(ByteBufAllocator.DEFAULT, sniHost, -1)))
.connect(address).syncUninterruptibly().channel();

cc.writeAndFlush(Unpooled.wrappedBuffer("Hello, World!".getBytes()))
.syncUninterruptibly();

// Notice how the server's SslContext refCnt is 1
assertEquals(1, ((ReferenceCounted) sslServerContext).refCnt());

// The client disconnects
cc.close().syncUninterruptibly();
if (!releasePromise.awaitUninterruptibly(10L, TimeUnit.SECONDS)) {
throw new IllegalStateException("It doesn't seem #replaceHandler() got called.");
}

// We should have successfully release() the SslContext
assertEquals(0, ((ReferenceCounted) sslServerContext).refCnt());
} finally {
if (cc != null) {
cc.close().syncUninterruptibly();
}
if (sc != null) {
sc.close().syncUninterruptibly();
}
group.shutdownGracefully();

cert.delete();
}
}

/**
* This is a {@link SslHandler} that will call {@code release()} on the {@link SslContext} when
* the client disconnects.
*
* @see SniHandlerTest#testReplaceHandler()
*/
private static class CustomSslHandler extends SslHandler {
private final SslContext sslContext;

public CustomSslHandler(SslContext sslContext, SSLEngine sslEngine) {
super(sslEngine);
this.sslContext = ObjectUtil.checkNotNull(sslContext, "sslContext");
}

@Override
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
super.handlerRemoved0(ctx);
ReferenceCountUtil.release(sslContext);
}
}
}

0 comments on commit f97866d

Please sign in to comment.