Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ALPN for TDSS connections #1795

Merged
merged 2 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import javax.net.SocketFactory;
import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
Expand Down Expand Up @@ -1763,6 +1764,11 @@ else if (con.getTrustManagerClass() != null) {

if (isTDSS) {
sslSocket = (SSLSocket) sslContext.getSocketFactory().createSocket(host, port);

// set ALPN values
SSLParameters sslParam = sslSocket.getSSLParameters();
sslParam.setApplicationProtocols(new String[] {"tds", "/", "8.0"});
sslSocket.setSSLParameters(sslParam);
} else {
// don't close proxy when SSL socket is closed
sslSocket = (SSLSocket) sslContext.getSocketFactory().createSocket(proxySocket, host, port, false);
Expand All @@ -1772,18 +1778,26 @@ else if (con.getTrustManagerClass() != null) {
if (logger.isLoggable(Level.FINER))
logger.finer(toString() + " Starting SSL handshake");

// TLS 1.2 intermittent exception happens here.
// TLS 1.2 intermittent exception may happen here.
handshakeState = SSLHandhsakeState.SSL_HANDHSAKE_STARTED;
sslSocket.startHandshake();
handshakeState = SSLHandhsakeState.SSL_HANDHSAKE_COMPLETE;

// After SSL handshake is complete, rewire proxy socket to use raw TCP/IP streams ...
if (isTDSS) {
if (logger.isLoggable(Level.FINEST)) {
String negotiatedProtocol = sslSocket.getApplicationProtocol();
logger.finest(toString() + " Application Protocol negotiated: "
+ ((negotiatedProtocol == null) ? "null" : negotiatedProtocol));
}
}

// After SSL handshake is complete, re-wire proxy socket to use raw TCP/IP streams ...
if (logger.isLoggable(Level.FINEST))
logger.finest(toString() + " Rewiring proxy streams after handshake");

proxySocket.setStreams(inputStream, outputStream);

// ... and rewire TDSChannel to use SSL streams.
// ... and re-wire TDSChannel to use SSL streams.
if (logger.isLoggable(Level.FINEST))
logger.finest(toString() + " Getting SSL InputStream");

Expand Down Expand Up @@ -2669,15 +2683,17 @@ private SocketFactory getSocketFactory() throws IOException {

/**
* Helper function which traverses through list of InetAddresses to find a resolved one
* @param hostName
*
* @param hostName
*
* @param portNumber
* Port Number
* @return First resolved address or unresolved address if none found
* @throws IOException
* @throws SQLServerException
*/
private InetSocketAddress getInetAddressByIPPreference(String hostName, int portNumber) throws IOException, SQLServerException {
private InetSocketAddress getInetAddressByIPPreference(String hostName,
int portNumber) throws IOException, SQLServerException {
InetSocketAddress addr = InetSocketAddress.createUnresolved(hostName, portNumber);
for (int i = 0; i < addressList.size(); i++) {
addr = new InetSocketAddress(addressList.get(i), portNumber);
Expand Down Expand Up @@ -2760,8 +2776,11 @@ private Socket getSocketByIPPreference(String hostName, int portNumber, int time

/**
* Fills static array of IP Addresses with addresses of the preferred protocol version.
* @param addresses Array of all addresses
* @param ipv6first Boolean switch for IPv6 first
*
* @param addresses
* Array of all addresses
* @param ipv6first
* Boolean switch for IPv6 first
*/
private void fillAddressList(InetAddress[] addresses, boolean ipv6first) {
addressList.clear();
Expand Down
38 changes: 21 additions & 17 deletions src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -1765,14 +1765,14 @@ private void registerKeyStoreProviderOnConnection(String keyStoreAuth, String ke
}
}

private void setKeyStoreSecretAndLocation(String keyStoreSecret, String keyStoreLocation) throws SQLServerException {
private void setKeyStoreSecretAndLocation(String keyStoreSecret,
String keyStoreLocation) throws SQLServerException {
// both secret and location must be set for JKS.
if ((null == keyStoreSecret) || (null == keyStoreLocation)) {
throw new SQLServerException(
SQLServerException.getErrString("R_keyStoreSecretOrLocationNotSet"), null);
throw new SQLServerException(SQLServerException.getErrString("R_keyStoreSecretOrLocationNotSet"), null);
} else {
SQLServerColumnEncryptionJavaKeyStoreProvider provider = new SQLServerColumnEncryptionJavaKeyStoreProvider(
keyStoreLocation, keyStoreSecret.toCharArray());
keyStoreLocation, keyStoreSecret.toCharArray());
systemColumnEncryptionKeyStoreProvider.put(provider.getName(), provider);
}
}
Expand Down Expand Up @@ -1959,14 +1959,15 @@ Connection connectInternal(Properties propsIn,
if (null != sPropValuePort) {
trustedServerNameAE += ":" + sPropValuePort;
}

sPropKey = SQLServerDriverStringProperty.IPADDRESS_PREFERENCE.toString();
sPropValue = activeConnectionProperties.getProperty(sPropKey);
if (null == sPropValue) {
sPropValue = SQLServerDriverStringProperty.IPADDRESS_PREFERENCE.getDefaultValue();
activeConnectionProperties.setProperty(sPropKey, sPropValue);
} else {
activeConnectionProperties.setProperty(sPropKey, IPAddressPreference.valueOfString(sPropValue).toString());
activeConnectionProperties.setProperty(sPropKey,
IPAddressPreference.valueOfString(sPropValue).toString());
}

sPropKey = SQLServerDriverStringProperty.APPLICATION_NAME.toString();
Expand Down Expand Up @@ -2016,8 +2017,8 @@ Connection connectInternal(Properties propsIn,

// enclave requires columnEncryption=enabled, enclaveAttestationUrl and enclaveAttestationProtocol
if (
// An attestation URL requires a protocol
(null != enclaveAttestationUrl && !enclaveAttestationUrl.isEmpty()
// An attestation URL requires a protocol
(null != enclaveAttestationUrl && !enclaveAttestationUrl.isEmpty()
&& (null == enclaveAttestationProtocol || enclaveAttestationProtocol.isEmpty()))

// An attestation protocol that is not NONE requires a URL
Expand Down Expand Up @@ -2073,7 +2074,8 @@ Connection connectInternal(Properties propsIn,
SQLServerException.getErrString("R_keyVaultProviderClientKeyNotSet"), null);
}
String keyVaultColumnEncryptionProviderClientKey = sPropValue;
setKeyVaultProvider(keyVaultColumnEncryptionProviderClientId, keyVaultColumnEncryptionProviderClientKey);
setKeyVaultProvider(keyVaultColumnEncryptionProviderClientId,
keyVaultColumnEncryptionProviderClientKey);
}

sPropKey = SQLServerDriverBooleanProperty.MULTI_SUBNET_FAILOVER.toString();
Expand Down Expand Up @@ -3265,7 +3267,8 @@ private InetSocketAddress connectHelper(ServerPortPlaceHolder serverInfo, int ti

// if the timeout is infinite slices are infinite too.
tdsChannel = new TDSChannel(this);
String iPAddressPreference = activeConnectionProperties.getProperty(SQLServerDriverStringProperty.IPADDRESS_PREFERENCE.toString());
String iPAddressPreference = activeConnectionProperties
.getProperty(SQLServerDriverStringProperty.IPADDRESS_PREFERENCE.toString());

InetSocketAddress inetSocketAddress = tdsChannel.open(serverInfo.getParsedServerName(),
serverInfo.getPortNumber(), (0 == timeOutFullInSeconds) ? 0 : timeOutSliceInMillis, useParallel,
Expand Down Expand Up @@ -3688,7 +3691,7 @@ void prelogin(String serverName, int portNumber) throws SQLServerException {
// If we say we don't support SSL and the server doesn't accept unencrypted connections,
// then terminate the connection.
if (TDS.ENCRYPT_NOT_SUP == requestedEncryptionLevel
&& TDS.ENCRYPT_NOT_SUP != negotiatedEncryptionLevel) {
&& TDS.ENCRYPT_NOT_SUP != negotiatedEncryptionLevel && !isTDSS) {
// If the server required an encrypted connection then terminate with an appropriate error.
if (TDS.ENCRYPT_REQ == negotiatedEncryptionLevel)
terminate(SQLServerException.DRIVER_ERROR_SSL_FAILED,
Expand Down Expand Up @@ -4558,8 +4561,8 @@ int writeAEFeatureRequest(boolean write, /* if false just calculates the length
if (write) {
tdsWriter.writeByte(TDS.TDS_FEATURE_EXT_AE); // FEATUREEXT_TC
tdsWriter.writeInt(1); // length of version
if (null == enclaveAttestationUrl || enclaveAttestationUrl.isEmpty() || (enclaveAttestationProtocol != null
&& !enclaveAttestationProtocol.equalsIgnoreCase(AttestationProtocol.NONE.toString()))) {
if (null == enclaveAttestationUrl || enclaveAttestationUrl.isEmpty() || (enclaveAttestationProtocol != null
&& !enclaveAttestationProtocol.equalsIgnoreCase(AttestationProtocol.NONE.toString()))) {
tdsWriter.writeByte(TDS.COLUMNENCRYPTION_VERSION1);
} else {
tdsWriter.writeByte(TDS.COLUMNENCRYPTION_VERSION2);
Expand Down Expand Up @@ -5678,8 +5681,8 @@ private void onFeatureExtAck(byte featureId, byte[] data) throws SQLServerExcept

serverColumnEncryptionVersion = ColumnEncryptionVersion.AE_V1;

if (null != enclaveAttestationUrl || (enclaveAttestationProtocol != null
&& enclaveAttestationProtocol.equalsIgnoreCase(AttestationProtocol.NONE.toString()))) {
if (null != enclaveAttestationUrl || (enclaveAttestationProtocol != null
&& enclaveAttestationProtocol.equalsIgnoreCase(AttestationProtocol.NONE.toString()))) {
if (aeVersion < TDS.COLUMNENCRYPTION_VERSION2) {
throw new SQLServerException(SQLServerException.getErrString("R_enclaveNotSupported"), null);
} else {
Expand Down Expand Up @@ -7666,8 +7669,9 @@ String getServerName() {

@Override
public void setIPAddressPreference(String iPAddressPreference) {
activeConnectionProperties.setProperty(SQLServerDriverStringProperty.IPADDRESS_PREFERENCE.toString(), iPAddressPreference);

activeConnectionProperties.setProperty(SQLServerDriverStringProperty.IPADDRESS_PREFERENCE.toString(),
iPAddressPreference);

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ protected Object[][] getContents() {
{"R_InvalidIPAddressPreference", "IP address preference {0} is not valid."},
{"R_UnableLoadAuthDll", "Unable to load authentication DLL {0}"},
{"R_illegalArgumentTrustManager", "Interal error. Peer certificate chain or key exchange algorithem can not be null or empty."},
{"R_serverCertError", "Error validating Server Certificate: {0}."},
{"R_serverCertError", "Error validating Server Certificate: {0}: {1}."},
{"R_SecureStringInitFailed", "Failed to initialize SecureStringUtil to store secure strings"},
};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.microsoft.sqlserver.jdbc;

import java.security.cert.CertificateException;
import java.security.cert.CertificateExpiredException;
import java.security.cert.X509Certificate;
import java.text.MessageFormat;
import java.util.Locale;
Expand Down Expand Up @@ -154,8 +153,8 @@ public void checkServerTrusted(X509Certificate[] chain, String authType) throws
}
} catch (Exception e) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_serverCertError"));
Object[] msgArgs = {serverCert, e.getMessage()};
throw new CertificateExpiredException(form.format(msgArgs));
Object[] msgArgs = {serverCert != null ? serverCert : hostName, e.getMessage()};
throw new CertificateException(form.format(msgArgs));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ public static void setupTests() throws Exception {

/**
* Test connection properties with SQLServerDataSource
*
* @throws SQLServerException
*/
@Test
public void testDataSource() {
public void testDataSource() throws SQLServerException {
SQLServerDataSource ds = new SQLServerDataSource();
String stringPropValue = "stringPropValue";
boolean booleanPropValue = true;
String booleanStringValue = "true";
int intPropValue = 1;

ds.setInstanceName(stringPropValue);
Expand All @@ -82,7 +83,7 @@ public void testDataSource() {

ds.setPortNumber(intPropValue);
assertEquals(intPropValue, ds.getPortNumber(), TestResource.getResource("R_valuesAreDifferent"));

ds.setIPAddressPreference(stringPropValue);
assertEquals(stringPropValue, ds.getIPAddressPreference(), TestResource.getResource("R_valuesAreDifferent"));

Expand Down Expand Up @@ -162,8 +163,29 @@ public void testDataSource() {
ds.setTrustStorePassword(stringPropValue);
assertEquals(stringPropValue, ds.getTrustStorePassword(), TestResource.getResource("R_valuesAreDifferent"));

ds.setEncrypt(booleanStringValue);
assertEquals(booleanStringValue, ds.getEncrypt(), TestResource.getResource("R_valuesAreDifferent"));
// verify encrypt=true options
ds.setEncrypt(EncryptOption.Mandatory.toString());
assertEquals("True", EncryptOption.valueOfString(ds.getEncrypt()).toString(),
TestResource.getResource("R_valuesAreDifferent"));
ds.setEncrypt(EncryptOption.True.toString());
assertEquals("True", EncryptOption.valueOfString(ds.getEncrypt()).toString(),
TestResource.getResource("R_valuesAreDifferent"));

// verify encrypt=false options
ds.setEncrypt(EncryptOption.Optional.toString());
assertEquals("False", EncryptOption.valueOfString(ds.getEncrypt()).toString(),
TestResource.getResource("R_valuesAreDifferent"));
ds.setEncrypt(EncryptOption.False.toString());
assertEquals("False", EncryptOption.valueOfString(ds.getEncrypt()).toString(),
TestResource.getResource("R_valuesAreDifferent"));
ds.setEncrypt(EncryptOption.No.toString());
assertEquals("False", EncryptOption.valueOfString(ds.getEncrypt()).toString(),
TestResource.getResource("R_valuesAreDifferent"));

// verify enrypt=strict options
ds.setEncrypt(EncryptOption.Strict.toString());
assertEquals("Strict", EncryptOption.valueOfString(ds.getEncrypt()).toString(),
TestResource.getResource("R_valuesAreDifferent"));

ds.setEncrypt(booleanPropValue);
assertEquals(Boolean.toString(booleanPropValue), ds.getEncrypt(),
Expand Down