Skip to content

Commit

Permalink
move TlsExtensionType into Platform
Browse files Browse the repository at this point in the history
  • Loading branch information
ericgribkoff committed Jan 19, 2018
1 parent b3350be commit 8a71766
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 209 deletions.
2 changes: 1 addition & 1 deletion SECURITY.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.conscrypt.Conscrypt;
import java.security.Security;
...

Security.addProvider(Conscrypt.newProvider());
Security.insertProviderAt(Conscrypt.newProvider(), 1);
```

## TLS with OpenSSL
Expand Down
65 changes: 8 additions & 57 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpProtocolNegotiator.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import com.google.common.annotations.VisibleForTesting;
import io.grpc.okhttp.internal.OptionalMethod;
import io.grpc.okhttp.internal.Platform;
import io.grpc.okhttp.internal.Platform.TlsExtensionType;
import io.grpc.okhttp.internal.Protocol;
import io.grpc.okhttp.internal.Util;
import java.io.IOException;
import java.net.Socket;
import java.security.Security;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand All @@ -41,7 +41,7 @@ class OkHttpProtocolNegotiator {
private static OkHttpProtocolNegotiator NEGOTIATOR =
createNegotiator(OkHttpProtocolNegotiator.class.getClassLoader());

private final Platform platform;
protected final Platform platform;

@VisibleForTesting
OkHttpProtocolNegotiator(Platform platform) {
Expand Down Expand Up @@ -72,7 +72,7 @@ static OkHttpProtocolNegotiator createNegotiator(ClassLoader loader) {
}
}
return android
? new AndroidNegotiator(DEFAULT_PLATFORM, AndroidNegotiator.DEFAULT_TLS_EXTENSION_TYPE)
? new AndroidNegotiator(DEFAULT_PLATFORM)
: new OkHttpProtocolNegotiator(DEFAULT_PLATFORM);
}

Expand Down Expand Up @@ -133,19 +133,8 @@ static final class AndroidNegotiator extends OkHttpProtocolNegotiator {
private static final OptionalMethod<Socket> SET_NPN_PROTOCOLS =
new OptionalMethod<Socket>(null, "setNpnProtocols", byte[].class);

private static final TlsExtensionType DEFAULT_TLS_EXTENSION_TYPE =
pickTlsExtensionType(AndroidNegotiator.class.getClassLoader());

enum TlsExtensionType {
ALPN_AND_NPN,
NPN,
}

private final TlsExtensionType tlsExtensionType;

AndroidNegotiator(Platform platform, TlsExtensionType tlsExtensionType) {
AndroidNegotiator(Platform platform) {
super(platform);
this.tlsExtensionType = checkNotNull(tlsExtensionType, "Unable to pick a TLS extension");
}

@Override
Expand Down Expand Up @@ -174,11 +163,11 @@ protected void configureTlsExtensions(
}

Object[] parameters = {Platform.concatLengthPrefixed(protocols)};
if (tlsExtensionType == TlsExtensionType.ALPN_AND_NPN) {
if (platform.getTlsExtensionType() == TlsExtensionType.ALPN_AND_NPN) {
SET_ALPN_PROTOCOLS.invokeWithoutCheckedException(sslSocket, parameters);
}

if (tlsExtensionType != null) {
if (platform.getTlsExtensionType() != TlsExtensionType.NONE) {
SET_NPN_PROTOCOLS.invokeWithoutCheckedException(sslSocket, parameters);
} else {
throw new RuntimeException("We can not do TLS handshake on this Android version, please"
Expand All @@ -188,7 +177,7 @@ protected void configureTlsExtensions(

@Override
public String getSelectedProtocol(SSLSocket socket) {
if (tlsExtensionType == TlsExtensionType.ALPN_AND_NPN) {
if (platform.getTlsExtensionType() == TlsExtensionType.ALPN_AND_NPN) {
try {
byte[] alpnResult =
(byte[]) GET_ALPN_SELECTED_PROTOCOL.invokeWithoutCheckedException(socket);
Expand All @@ -201,7 +190,7 @@ public String getSelectedProtocol(SSLSocket socket) {
}
}

if (tlsExtensionType != null) {
if (platform.getTlsExtensionType() != TlsExtensionType.NONE) {
try {
byte[] npnResult =
(byte[]) GET_NPN_SELECTED_PROTOCOL.invokeWithoutCheckedException(socket);
Expand All @@ -215,43 +204,5 @@ public String getSelectedProtocol(SSLSocket socket) {
}
return null;
}

@VisibleForTesting
static TlsExtensionType pickTlsExtensionType(ClassLoader loader) {
// Decide which TLS Extension (APLN and NPN) we will use, follow the rules:
// 1. If Google Play Services Security Provider is installed, use both
// 2. If Conscrypt is installed, use both
// 3. If on Android 5.0 or later, use both, else
// 4. If on Android 4.1 or later, use NPN, else
// 5. Fail.

// Check if Google Play Services Security Provider is installed.
if (Security.getProvider("GmsCore_OpenSSL") != null) {
return TlsExtensionType.ALPN_AND_NPN;
}

if (Security.getProvider("Conscrypt") != null) {
return TlsExtensionType.ALPN_AND_NPN;
}

// Check if on Android 5.0 or later.
try {
loader.loadClass("android.net.Network"); // Arbitrary class added in Android 5.0.
return TlsExtensionType.ALPN_AND_NPN;
} catch (ClassNotFoundException e) {
logger.log(Level.FINE, "Can't find class", e);
}

// Check if on Android 4.1 or later.
try {
loader.loadClass("android.app.ActivityOptions"); // Arbitrary class added in Android 4.1.
return TlsExtensionType.NPN;
} catch (ClassNotFoundException e) {
logger.log(Level.FINE, "Can't find class", e);
}

// This will be caught by the constructor.
return null;
}
}
}
122 changes: 7 additions & 115 deletions okhttp/src/test/java/io/grpc/okhttp/OkHttpProtocolNegotiatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import static com.google.common.base.Charsets.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
Expand All @@ -28,17 +27,13 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.grpc.okhttp.OkHttpProtocolNegotiator.AndroidNegotiator;
import io.grpc.okhttp.OkHttpProtocolNegotiator.AndroidNegotiator.TlsExtensionType;
import io.grpc.okhttp.internal.Platform;
import io.grpc.okhttp.internal.Platform.TlsExtensionType;
import io.grpc.okhttp.internal.Protocol;
import java.io.IOException;
import java.security.Provider;
import java.security.Security;
import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
Expand All @@ -53,22 +48,9 @@
public class OkHttpProtocolNegotiatorTest {
@Rule public final ExpectedException thrown = ExpectedException.none();

private final Provider fakeSecurityProvider = new Provider("GmsCore_OpenSSL", 1.0, "info") {};
private final Provider fakeConscrypt = new Provider("Conscrypt", 1.0, "info") {};
private final SSLSocket sock = mock(SSLSocket.class);
private final Platform platform = mock(Platform.class);

@Before
public void setUp() {
// Tests that depend on android need this to know which protocol negotiation to use.
Security.addProvider(fakeSecurityProvider);
}

@After
public void tearDown() {
Security.removeProvider(fakeSecurityProvider.getName());
}

@Test
public void createNegotiator_isAndroid() {
ClassLoader cl = new ClassLoader(this.getClass().getClassLoader()) {
Expand Down Expand Up @@ -180,103 +162,11 @@ public void negotiate_preferGrpcExp() throws Exception {
verify(platform).afterHandshake(sock);
}

@Test
public void pickTlsExtensionType_securityProvider() throws Exception {
assertNotNull(Security.getProvider(fakeSecurityProvider.getName()));

AndroidNegotiator.TlsExtensionType tlsExtensionType =
AndroidNegotiator.pickTlsExtensionType(getClass().getClassLoader());

assertEquals(TlsExtensionType.ALPN_AND_NPN, tlsExtensionType);
}

@Test
public void pickTlsExtensionType_android50() throws Exception {
Security.removeProvider(fakeSecurityProvider.getName());
ClassLoader cl = new ClassLoader(this.getClass().getClassLoader()) {
@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
// Just don't throw.
if ("android.net.Network".equals(name)) {
return null;
}
return super.findClass(name);
}
};

AndroidNegotiator.TlsExtensionType tlsExtensionType =
AndroidNegotiator.pickTlsExtensionType(cl);

assertEquals(TlsExtensionType.ALPN_AND_NPN, tlsExtensionType);
}

@Test
public void pickTlsExtensionType_android41() throws Exception {
Security.removeProvider(fakeSecurityProvider.getName());
ClassLoader cl = new ClassLoader(this.getClass().getClassLoader()) {
@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
// Just don't throw.
if ("android.app.ActivityOptions".equals(name)) {
return null;
}
return super.findClass(name);
}
};

AndroidNegotiator.TlsExtensionType tlsExtensionType =
AndroidNegotiator.pickTlsExtensionType(cl);

assertEquals(TlsExtensionType.NPN, tlsExtensionType);
}

@Test
public void pickTlsExtensionType_android41WithConscrypt() throws Exception {
Security.removeProvider(fakeSecurityProvider.getName());
Security.addProvider(fakeConscrypt);
ClassLoader cl =
new ClassLoader(this.getClass().getClassLoader()) {
@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
// Just don't throw.
if ("android.app.ActivityOptions".equals(name)) {
return null;
}
return super.findClass(name);
}
};

AndroidNegotiator.TlsExtensionType tlsExtensionType =
AndroidNegotiator.pickTlsExtensionType(cl);

assertEquals(TlsExtensionType.ALPN_AND_NPN, tlsExtensionType);

// Clean up
Security.removeProvider(fakeConscrypt.getName());
}

@Test
public void pickTlsExtensionType_none() throws Exception {
Security.removeProvider(fakeSecurityProvider.getName());

AndroidNegotiator.TlsExtensionType tlsExtensionType =
AndroidNegotiator.pickTlsExtensionType(getClass().getClassLoader());

assertNull(tlsExtensionType);
}

@Test
public void androidNegotiator_failsOnNull() {
thrown.expect(NullPointerException.class);
thrown.expectMessage("Unable to pick a TLS extension");

new AndroidNegotiator(platform, null);
}

// Checks that the super class is properly invoked.
@Test
public void negotiate_android_handshakeFails() throws Exception {
AndroidNegotiator negotiator = new AndroidNegotiator(platform, TlsExtensionType.ALPN_AND_NPN);
when(platform.getTlsExtensionType()).thenReturn(TlsExtensionType.ALPN_AND_NPN);
AndroidNegotiator negotiator = new AndroidNegotiator(platform);

FakeAndroidSslSocket androidSock = new FakeAndroidSslSocket() {
@Override
Expand All @@ -301,7 +191,8 @@ public byte[] getAlpnSelectedProtocol() {

@Test
public void getSelectedProtocol_alpn() throws Exception {
AndroidNegotiator negotiator = new AndroidNegotiator(platform, TlsExtensionType.ALPN_AND_NPN);
when(platform.getTlsExtensionType()).thenReturn(TlsExtensionType.ALPN_AND_NPN);
AndroidNegotiator negotiator = new AndroidNegotiator(platform);
FakeAndroidSslSocket androidSock = new FakeAndroidSslSocketAlpn();

String actual = negotiator.getSelectedProtocol(androidSock);
Expand All @@ -319,7 +210,8 @@ public byte[] getNpnSelectedProtocol() {

@Test
public void getSelectedProtocol_npn() throws Exception {
AndroidNegotiator negotiator = new AndroidNegotiator(platform, TlsExtensionType.NPN);
when(platform.getTlsExtensionType()).thenReturn(TlsExtensionType.NPN);
AndroidNegotiator negotiator = new AndroidNegotiator(platform);
FakeAndroidSslSocket androidSock = new FakeAndroidSslSocketNpn();

String actual = negotiator.getSelectedProtocol(androidSock);
Expand Down
Loading

0 comments on commit 8a71766

Please sign in to comment.