Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

okhttp: add socketFactory method to channel builder #5378

Merged
merged 5 commits into from
Feb 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 37 additions & 9 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
Expand Down Expand Up @@ -120,6 +121,7 @@ public static OkHttpChannelBuilder forTarget(String target) {
private Executor transportExecutor;
private ScheduledExecutorService scheduledExecutorService;

private SocketFactory socketFactory;
private SSLSocketFactory sslSocketFactory;
private HostnameVerifier hostnameVerifier;
private ConnectionSpec connectionSpec = INTERNAL_DEFAULT_CONNECTION_SPEC;
Expand Down Expand Up @@ -156,6 +158,17 @@ public final OkHttpChannelBuilder transportExecutor(@Nullable Executor transport
return this;
}

/**
* Override the default {@link SocketFactory} used to create sockets. If the socket factory is not
* set or set to null, a default one will be used.
*
* @since 1.20.0
*/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add @since 1.20.0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

public final OkHttpChannelBuilder socketFactory(@Nullable SocketFactory socketFactory) {
this.socketFactory = socketFactory;
return this;
}

/**
* Sets the negotiation type for the HTTP/2 connection.
*
Expand Down Expand Up @@ -397,10 +410,21 @@ public OkHttpChannelBuilder maxInboundMetadataSize(int bytes) {
@Internal
protected final ClientTransportFactory buildTransportFactory() {
boolean enableKeepAlive = keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED;
return new OkHttpTransportFactory(transportExecutor, scheduledExecutorService,
createSocketFactory(), hostnameVerifier, connectionSpec, maxInboundMessageSize(),
enableKeepAlive, keepAliveTimeNanos, keepAliveTimeoutNanos, flowControlWindow,
keepAliveWithoutCalls, maxInboundMetadataSize, transportTracerFactory);
return new OkHttpTransportFactory(
transportExecutor,
scheduledExecutorService,
socketFactory,
createSslSocketFactory(),
hostnameVerifier,
connectionSpec,
maxInboundMessageSize(),
enableKeepAlive,
keepAliveTimeNanos,
keepAliveTimeoutNanos,
flowControlWindow,
keepAliveWithoutCalls,
maxInboundMetadataSize,
transportTracerFactory);
}

@Override
Expand All @@ -417,7 +441,7 @@ protected int getDefaultPort() {

@VisibleForTesting
@Nullable
SSLSocketFactory createSocketFactory() {
SSLSocketFactory createSslSocketFactory() {
switch (negotiationType) {
case TLS:
try {
Expand Down Expand Up @@ -463,8 +487,8 @@ static final class OkHttpTransportFactory implements ClientTransportFactory {
private final boolean usingSharedExecutor;
private final boolean usingSharedScheduler;
private final TransportTracer.Factory transportTracerFactory;
@Nullable
private final SSLSocketFactory socketFactory;
private final SocketFactory socketFactory;
@Nullable private final SSLSocketFactory sslSocketFactory;
@Nullable
private final HostnameVerifier hostnameVerifier;
private final ConnectionSpec connectionSpec;
Expand All @@ -478,9 +502,11 @@ static final class OkHttpTransportFactory implements ClientTransportFactory {
private final ScheduledExecutorService timeoutService;
private boolean closed;

private OkHttpTransportFactory(Executor executor,
private OkHttpTransportFactory(
Executor executor,
@Nullable ScheduledExecutorService timeoutService,
@Nullable SSLSocketFactory socketFactory,
@Nullable SocketFactory socketFactory,
@Nullable SSLSocketFactory sslSocketFactory,
@Nullable HostnameVerifier hostnameVerifier,
ConnectionSpec connectionSpec,
int maxMessageSize,
Expand All @@ -495,6 +521,7 @@ private OkHttpTransportFactory(Executor executor,
this.timeoutService = usingSharedScheduler
? SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE) : timeoutService;
this.socketFactory = socketFactory;
this.sslSocketFactory = sslSocketFactory;
this.hostnameVerifier = hostnameVerifier;
this.connectionSpec = connectionSpec;
this.maxMessageSize = maxMessageSize;
Expand Down Expand Up @@ -536,6 +563,7 @@ public void run() {
options.getUserAgent(),
executor,
socketFactory,
sslSocketFactory,
hostnameVerifier,
connectionSpec,
maxMessageSize,
Expand Down
35 changes: 27 additions & 8 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
Expand Down Expand Up @@ -175,6 +176,7 @@ private static Map<ErrorCode, Status> buildErrorCodeToStatusMap() {
private boolean stopped;
@GuardedBy("lock")
private boolean hasStream;
private final SocketFactory socketFactory;
private SSLSocketFactory sslSocketFactory;
private HostnameVerifier hostnameVerifier;
private Socket socket;
Expand Down Expand Up @@ -219,12 +221,21 @@ protected void handleNotInUse() {
Runnable connectingCallback;
SettableFuture<Void> connectedFuture;

OkHttpClientTransport(InetSocketAddress address, String authority, @Nullable String userAgent,
Executor executor, @Nullable SSLSocketFactory sslSocketFactory,
@Nullable HostnameVerifier hostnameVerifier, ConnectionSpec connectionSpec,
int maxMessageSize, int initialWindowSize,
OkHttpClientTransport(
InetSocketAddress address,
String authority,
@Nullable String userAgent,
Executor executor,
@Nullable SocketFactory socketFactory,
@Nullable SSLSocketFactory sslSocketFactory,
@Nullable HostnameVerifier hostnameVerifier,
ConnectionSpec connectionSpec,
int maxMessageSize,
int initialWindowSize,
@Nullable HttpConnectProxiedSocketAddress proxiedAddr,
Runnable tooManyPingsRunnable, int maxInboundMetadataSize, TransportTracer transportTracer) {
Runnable tooManyPingsRunnable,
int maxInboundMetadataSize,
TransportTracer transportTracer) {
this.address = Preconditions.checkNotNull(address, "address");
this.defaultAuthority = authority;
this.maxMessageSize = maxMessageSize;
Expand All @@ -234,6 +245,7 @@ protected void handleNotInUse() {
// Client initiated streams are odd, server initiated ones are even. Server should not need to
// use it. We start clients at 3 to avoid conflicting with HTTP negotiation.
nextStreamId = 3;
this.socketFactory = socketFactory == null ? SocketFactory.getDefault() : socketFactory;
this.sslSocketFactory = sslSocketFactory;
this.hostnameVerifier = hostnameVerifier;
this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec");
Expand Down Expand Up @@ -273,6 +285,7 @@ protected void handleNotInUse() {
this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent);
this.executor = Preconditions.checkNotNull(executor, "executor");
serializingExecutor = new SerializingExecutor(executor);
this.socketFactory = SocketFactory.getDefault();
this.testFrameReader = Preconditions.checkNotNull(frameReader, "frameReader");
this.testFrameWriter = Preconditions.checkNotNull(testFrameWriter, "testFrameWriter");
this.socket = Preconditions.checkNotNull(socket, "socket");
Expand Down Expand Up @@ -506,7 +519,7 @@ public void close() {
SSLSession sslSession = null;
try {
if (proxiedAddr == null) {
sock = new Socket(address.getAddress(), address.getPort());
sock = socketFactory.createSocket(address.getAddress(), address.getPort());
} else {
if (proxiedAddr.getProxyAddress() instanceof InetSocketAddress) {
sock = createHttpProxySocket(
Expand Down Expand Up @@ -584,9 +597,10 @@ private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddres
Socket sock;
// The proxy address may not be resolved
if (proxyAddress.getAddress() != null) {
sock = new Socket(proxyAddress.getAddress(), proxyAddress.getPort());
sock = socketFactory.createSocket(proxyAddress.getAddress(), proxyAddress.getPort());
} else {
sock = new Socket(proxyAddress.getHostName(), proxyAddress.getPort());
sock =
socketFactory.createSocket(proxyAddress.getHostName(), proxyAddress.getPort());
}
sock.setTcpNoDelay(true);

Expand Down Expand Up @@ -771,6 +785,11 @@ ClientFrameHandler getHandler() {
return clientFrameHandler;
}

@VisibleForTesting
SocketFactory getSocketFactory() {
return socketFactory;
}

@VisibleForTesting
int getPendingStreamSize() {
synchronized (lock) {
Expand Down
60 changes: 57 additions & 3 deletions okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@
import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourceHolder;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.concurrent.ScheduledExecutorService;
import javax.net.SocketFactory;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
Expand Down Expand Up @@ -125,10 +128,10 @@ public void usePlaintextDefaultPort() {
@Test
public void usePlaintextCreatesNullSocketFactory() {
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("host", 1234);
assertNotNull(builder.createSocketFactory());
assertNotNull(builder.createSslSocketFactory());

builder.usePlaintext();
assertNull(builder.createSocketFactory());
assertNull(builder.createSslSocketFactory());
}

@Test
Expand Down Expand Up @@ -159,5 +162,56 @@ public void scheduledExecutorService_custom() {

clientTransportFactory.close();
}
}

@Test
public void socketFactory_default() {
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forTarget("foo");
ClientTransportFactory transportFactory = builder.buildTransportFactory();
OkHttpClientTransport transport =
(OkHttpClientTransport)
transportFactory.newClientTransport(
new InetSocketAddress(5678), new ClientTransportFactory.ClientTransportOptions());

assertSame(SocketFactory.getDefault(), transport.getSocketFactory());

transportFactory.close();
}

@Test
public void socketFactory_custom() {
SocketFactory socketFactory =
new SocketFactory() {
@Override
public Socket createSocket(String s, int i) {
return null;
}

@Override
public Socket createSocket(String s, int i, InetAddress inetAddress, int i1) {
return null;
}

@Override
public Socket createSocket(InetAddress inetAddress, int i) {
return null;
}

@Override
public Socket createSocket(
InetAddress inetAddress, int i, InetAddress inetAddress1, int i1) {
return null;
}
};
OkHttpChannelBuilder builder =
OkHttpChannelBuilder.forTarget("foo").socketFactory(socketFactory);
ClientTransportFactory transportFactory = builder.buildTransportFactory();
OkHttpClientTransport transport =
(OkHttpClientTransport)
transportFactory.newClientTransport(
new InetSocketAddress(5678), new ClientTransportFactory.ClientTransportOptions());

assertSame(socketFactory, transport.getSocketFactory());

transportFactory.close();
}
}