Skip to content

Commit

Permalink
Refactor the ChannelProvider to handle the SslHandler internally inst…
Browse files Browse the repository at this point in the history
…ead of having it handled in the HttpChannelConnector and the NetClientImpl
  • Loading branch information
vietj committed Oct 30, 2018
1 parent 390c69d commit 669da46
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 111 deletions.
82 changes: 26 additions & 56 deletions src/main/java/io/vertx/core/http/impl/HttpChannelConnector.java
Expand Up @@ -134,78 +134,56 @@ private void doConnect(

applyConnectionOptions(bootstrap);

ChannelProvider channelProvider;
boolean useAlpn = options.isUseAlpn();

// http proxy requests are handled in HttpClientImpl, everything else can use netty proxy handler
if (options.getProxyOptions() == null || !ssl && options.getProxyOptions().getType()== ProxyType.HTTP ) {
channelProvider = ChannelProvider.INSTANCE;
ChannelProvider channelProvider;
if (options.getProxyOptions() == null || !ssl && options.getProxyOptions().getType()== ProxyType.HTTP) {
channelProvider = new ChannelProvider(bootstrap, sslHelper, context, options.getProxyOptions());
} else {
channelProvider = ProxyChannelProvider.INSTANCE;
channelProvider = new ProxyChannelProvider(bootstrap, sslHelper, context, options.getProxyOptions());
}

boolean useAlpn = options.isUseAlpn();
Handler<Channel> channelInitializer = ch -> {

// Configure pipeline
ChannelPipeline pipeline = ch.pipeline();
if (ssl) {
SslHandler sslHandler = new SslHandler(sslHelper.createEngine(client.getVertx(), peerHost, port, options.isForceSni() ? peerHost : null));
ch.pipeline().addLast("ssl", sslHandler);
// TCP connected, so now we must do the SSL handshake
sslHandler.handshakeFuture().addListener(fut -> {
if (fut.isSuccess()) {
String protocol = sslHandler.applicationProtocol();
if (useAlpn) {
if ("h2".equals(protocol)) {
applyHttp2ConnectionOptions(ch.pipeline());
http2Connected(listener, context, ch, future);
} else {
applyHttp1xConnectionOptions(ch.pipeline());
HttpVersion fallbackProtocol = "http/1.0".equals(protocol) ?
HttpVersion.HTTP_1_0 : HttpVersion.HTTP_1_1;
http1xConnected(listener, fallbackProtocol, host, port, true, context, ch, http1Weight, future);
}
Handler<AsyncResult<Channel>> channelHandler = res -> {
if (res.succeeded()) {
Channel ch = res.result();
if (ssl) {
String protocol = channelProvider.applicationProtocol();
if (useAlpn) {
if ("h2".equals(protocol)) {
applyHttp2ConnectionOptions(ch.pipeline());
http2Connected(listener, context, ch, future);
} else {
applyHttp1xConnectionOptions(ch.pipeline());
http1xConnected(listener, version, host, port, true, context, ch, http1Weight, future);
HttpVersion fallbackProtocol = "http/1.0".equals(protocol) ?
HttpVersion.HTTP_1_0 : HttpVersion.HTTP_1_1;
http1xConnected(listener, fallbackProtocol, host, port, true, context, ch, http1Weight, future);
}
} else {
handshakeFailure(ch, fut.cause(), listener, future);
}
});
} else {
if (version == HttpVersion.HTTP_2) {
if (options.isHttp2ClearTextUpgrade()) {
applyHttp1xConnectionOptions(pipeline);
} else {
applyHttp2ConnectionOptions(pipeline);
applyHttp1xConnectionOptions(ch.pipeline());
http1xConnected(listener, version, host, port, true, context, ch, http1Weight, future);
}
} else {
applyHttp1xConnectionOptions(pipeline);
}
}
};

Handler<AsyncResult<Channel>> channelHandler = res -> {

if (res.succeeded()) {
Channel ch = res.result();
if (!ssl) {
ChannelPipeline pipeline = ch.pipeline();
if (version == HttpVersion.HTTP_2) {
if (options.isHttp2ClearTextUpgrade()) {
applyHttp1xConnectionOptions(pipeline);
http1xConnected(listener, version, host, port, false, context, ch, http2Weight, future);
} else {
applyHttp2ConnectionOptions(pipeline);
http2Connected(listener, context, ch, future);
}
} else {
applyHttp1xConnectionOptions(pipeline);
http1xConnected(listener, version, host, port, false, context, ch, http1Weight, future);
}
}
} else {
connectFailed(null, listener, res.cause(), future);
connectFailed(channelProvider.channel(), listener, res.cause(), future);
}
};

channelProvider.connect(context, bootstrap, options.getProxyOptions(), SocketAddress.inetSocketAddress(port, host), channelInitializer, channelHandler);
channelProvider.connect(ssl, SocketAddress.inetSocketAddress(port, host), peerHost, options.isForceSni(), channelHandler);
}

private void applyConnectionOptions(Bootstrap bootstrap) {
Expand All @@ -232,14 +210,6 @@ private void applyHttp1xConnectionOptions(ChannelPipeline pipeline) {
}
}

private void handshakeFailure(Channel ch, Throwable cause, ConnectionListener<HttpClientConnection> listener, Future<ConnectResult<HttpClientConnection>> future) {
SSLHandshakeException sslException = new SSLHandshakeException("Failed to create SSL connection");
if (cause != null) {
sslException.initCause(cause);
}
connectFailed(ch, listener, sslException, future);
}

private void http1xConnected(ConnectionListener<HttpClientConnection> listener,
HttpVersion version,
String host,
Expand Down
73 changes: 63 additions & 10 deletions src/main/java/io/vertx/core/net/impl/ChannelProvider.java
Expand Up @@ -15,6 +15,7 @@
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.vertx.core.AsyncResult;
Expand All @@ -25,6 +26,8 @@
import io.vertx.core.net.ProxyOptions;
import io.vertx.core.net.SocketAddress;

import javax.net.ssl.SSLHandshakeException;

/**
* The logic for connecting to an host, this implementations performs a connection
* to the host after resolving its internet address.
Expand All @@ -35,14 +38,37 @@
*/
public class ChannelProvider {

public static final ChannelProvider INSTANCE = new ChannelProvider();
final Bootstrap bootstrap;
final SSLHelper sslHelper;
final ContextInternal context;
final ProxyOptions options;
private String applicationProtocol;
private Channel channel;

public ChannelProvider(Bootstrap bootstrap,
SSLHelper sslHelper,
ContextInternal context,
ProxyOptions options) {
this.bootstrap = bootstrap;
this.context = context;
this.sslHelper = sslHelper;
this.options = options;
}

public String applicationProtocol() {
return applicationProtocol;
}

protected ChannelProvider() {
public Channel channel() {
return channel;
}

public final void connect(ContextInternal context, Bootstrap bootstrap, ProxyOptions options, SocketAddress remoteAddress,
Handler<Channel> channelInitializer, Handler<AsyncResult<Channel>> channelHandler) {
doConnect(context, bootstrap, options, remoteAddress, channelInitializer, res -> {
public final void connect(boolean ssl,
SocketAddress remoteAddress,
String peerHost,
boolean forceSNI,
Handler<AsyncResult<Channel>> channelHandler) {
doConnect(ssl, remoteAddress, peerHost, forceSNI, res -> {
if (Context.isOnEventLoopThread()) {
channelHandler.handle(res);
} else {
Expand All @@ -52,24 +78,51 @@ public final void connect(ContextInternal context, Bootstrap bootstrap, ProxyOpt
});
}

protected void initialize(boolean ssl, SocketAddress remoteAddress, String peerHost, boolean forceSNI, Channel ch) {
if (ssl) {
SslHandler sslHandler = new SslHandler(sslHelper.createEngine(context.owner(), peerHost, remoteAddress.port(), forceSNI ? peerHost : null));
ch.pipeline().addLast("ssl", sslHandler);
}
}


public void doConnect(ContextInternal context, Bootstrap bootstrap, ProxyOptions options, SocketAddress remoteAddress,
Handler<Channel> channelInitializer, Handler<AsyncResult<Channel>> channelHandler) {
public void doConnect(boolean ssl, SocketAddress remoteAddress, String peerHost, boolean forceSNI, Handler<AsyncResult<Channel>> channelHandler) {
VertxInternal vertx = context.owner();
bootstrap.resolver(vertx.nettyAddressResolverGroup());
bootstrap.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel channel) throws Exception {
channelInitializer.handle(channel);
protected void initChannel(Channel ch) {
initialize(ssl, remoteAddress, peerHost, forceSNI, ch);
}
});
ChannelFuture fut = bootstrap.connect(vertx.transport().convert(remoteAddress, false));
fut.addListener(res -> {
if (res.isSuccess()) {
channelHandler.handle(io.vertx.core.Future.succeededFuture(fut.channel()));
cont(fut.channel(), channelHandler);
} else {
channelHandler.handle(io.vertx.core.Future.failedFuture(res.cause()));
}
});
}

protected void cont(Channel ch, Handler<AsyncResult<Channel>> channelHandler) {
channel = ch;
// TCP connected, so now we must do the SSL handshake if any
SslHandler sslHandler = channel.pipeline().get(SslHandler.class);
if (sslHandler != null) {
sslHandler.handshakeFuture().addListener(future -> {
if (future.isSuccess()) {
applicationProtocol = sslHandler.applicationProtocol();
channelHandler.handle(io.vertx.core.Future.succeededFuture(channel));
} else {
SSLHandshakeException sslException = new SSLHandshakeException("Failed to create SSL connection");
sslException.initCause(future.cause());
channelHandler.handle(io.vertx.core.Future.failedFuture(sslException));
}
});
} else {
channelHandler.handle(io.vertx.core.Future.succeededFuture(channel));
}
}

}
38 changes: 8 additions & 30 deletions src/main/java/io/vertx/core/net/impl/NetClientImpl.java
Expand Up @@ -35,6 +35,7 @@
import io.vertx.core.spi.metrics.TCPMetrics;
import io.vertx.core.spi.metrics.VertxMetrics;

import java.net.ConnectException;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -181,41 +182,18 @@ protected void doConnect(SocketAddress remoteAddress, String serverName, Handler

ChannelProvider channelProvider;
if (options.getProxyOptions() == null) {
channelProvider = ChannelProvider.INSTANCE;
channelProvider = new ChannelProvider(bootstrap, sslHelper, context, options.getProxyOptions());
} else {
channelProvider = ProxyChannelProvider.INSTANCE;
channelProvider = new ProxyChannelProvider(bootstrap, sslHelper, context, options.getProxyOptions());
}

Handler<Channel> channelInitializer = ch -> {
if (sslHelper.isSSL()) {
SslHandler sslHandler = new SslHandler(sslHelper.createEngine(vertx, remoteAddress, serverName));
ch.pipeline().addLast("ssl", sslHandler);
}
};

Handler<AsyncResult<Channel>> channelHandler = res -> {
if (res.succeeded()) {

Channel ch = res.result();

if (sslHelper.isSSL()) {
// TCP connected, so now we must do the SSL handshake
SslHandler sslHandler = (SslHandler) ch.pipeline().get("ssl");

io.netty.util.concurrent.Future<Channel> fut = sslHandler.handshakeFuture();
fut.addListener(future2 -> {
if (future2.isSuccess()) {
connected(context, ch, connectHandler, remoteAddress);
} else {
failed(context, ch, future2.cause(), connectHandler);
}
});
} else {
connected(context, ch, connectHandler, remoteAddress);
}

connected(context, ch, connectHandler, remoteAddress);
} else {
if (remainingAttempts > 0 || remainingAttempts == -1) {
Throwable cause = res.cause();
if (cause instanceof ConnectException && (remainingAttempts > 0 || remainingAttempts == -1)) {
context.executeFromIO(v -> {
log.debug("Failed to create connection. Will retry in " + options.getReconnectInterval() + " milliseconds");
//Set a timer to retry connection
Expand All @@ -224,12 +202,12 @@ protected void doConnect(SocketAddress remoteAddress, String serverName, Handler
);
});
} else {
failed(context, null, res.cause(), connectHandler);
failed(context, null, cause, connectHandler);
}
}
};

channelProvider.connect(context, bootstrap, options.getProxyOptions(), remoteAddress, channelInitializer, channelHandler);
channelProvider.connect(sslHelper.isSSL(), remoteAddress, remoteAddress.host(), false, channelHandler);
}

private void connected(ContextInternal context, Channel ch, Handler<AsyncResult<NetSocket>> connectHandler, SocketAddress remoteAddress) {
Expand Down
21 changes: 7 additions & 14 deletions src/main/java/io/vertx/core/net/impl/ProxyChannelProvider.java
Expand Up @@ -43,18 +43,12 @@
*/
public class ProxyChannelProvider extends ChannelProvider {

public static final ChannelProvider INSTANCE = new ProxyChannelProvider();

private ProxyChannelProvider() {
public ProxyChannelProvider(Bootstrap bootstrap, SSLHelper sslHelper, ContextInternal context, ProxyOptions options) {
super(bootstrap, sslHelper, context, options);
}

@Override
public void doConnect(ContextInternal context,
Bootstrap bootstrap,
ProxyOptions options,
SocketAddress remoteAddress,
Handler<Channel> channelInitializer,
Handler<AsyncResult<Channel>> channelHandler) {
public void doConnect(boolean ssl, SocketAddress remoteAddress, String peerHost, boolean forceSNI, Handler<AsyncResult<Channel>> channelHandler) {

final VertxInternal vertx = context.owner();
final String proxyHost = options.getHost();
Expand Down Expand Up @@ -96,19 +90,18 @@ protected void initChannel(Channel ch) throws Exception {
pipeline.addFirst("proxy", proxy);
pipeline.addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof ProxyConnectionEvent) {
pipeline.remove(proxy);
pipeline.remove(this);
channelInitializer.handle(ch);
channelHandler.handle(Future.succeededFuture(ch));
initialize(ssl, remoteAddress, peerHost, forceSNI, ch);
cont(ch, channelHandler);
}
ctx.fireUserEventTriggered(evt);
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
channelHandler.handle(Future.failedFuture(cause));
}
});
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/io/vertx/core/http/HttpTLSTest.java
Expand Up @@ -1129,7 +1129,7 @@ TLSTest run(boolean shouldPass) {
AtomicInteger count = new AtomicInteger();
server.exceptionHandler(err -> {
if (shouldPass) {
fail();
HttpTLSTest.this.fail(err);
} else {
if (count.incrementAndGet() == 1) {
complete();
Expand Down
1 change: 1 addition & 0 deletions src/test/java/io/vertx/core/net/NetTest.java
Expand Up @@ -41,6 +41,7 @@
import io.vertx.test.proxy.SocksProxy;
import io.vertx.test.proxy.TestProxyBase;
import org.junit.Assume;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
Expand Down

0 comments on commit 669da46

Please sign in to comment.