Skip to content

Commit

Permalink
Fix resource leak in AE
Browse files Browse the repository at this point in the history
  • Loading branch information
rene-ye committed Jan 21, 2020
1 parent 5c27621 commit 427911c
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 172 deletions.
2 changes: 1 addition & 1 deletion src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -6202,7 +6202,7 @@ void writeRPCReaderUnicode(String sName, Reader re, long reLength, boolean bOut,

void sendEnclavePackage(String sql, ArrayList<byte[]> enclaveCEKs) throws SQLServerException {
if (null != con && con.isAEv2()) {
if (null != sql && "".equals(sql) && null != enclaveCEKs && 0 < enclaveCEKs.size() && con.enclaveEstablished()) {
if (null != sql && !"".equals(sql) && null != enclaveCEKs && 0 < enclaveCEKs.size() && con.enclaveEstablished()) {
byte[] b = con.generateEnclavePackage(sql, enclaveCEKs);
if (null != b && 0 != b.length) {
this.writeShort((short) b.length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,47 +102,42 @@ public EnclaveSession getEnclaveSession() {
return enclaveSession;
}

private AASAttestationResponse validateAttestationResponse(AASAttestationResponse ar) throws SQLServerException {
try {
ar.validateToken(attestationURL, aasParams.getNonce());
ar.validateDHPublicKey(aasParams.getNonce());
} catch (GeneralSecurityException e) {
SQLServerException.makeFromDriverError(null, this, e.getLocalizedMessage(), "0", false);
private void validateAttestationResponse() throws SQLServerException {
if (null != hgsResponse) {
try {
hgsResponse.validateToken(attestationURL, aasParams.getNonce());
hgsResponse.validateDHPublicKey(aasParams.getNonce());
} catch (GeneralSecurityException e) {
SQLServerException.makeFromDriverError(null, this, e.getLocalizedMessage(), "0", false);
}
}
return ar;
}

private ArrayList<byte[]> describeParameterEncryption(SQLServerConnection connection, String userSql,
String preparedTypeDefinitions, Parameter[] params,
ArrayList<String> parameterNames) throws SQLServerException {
ArrayList<byte[]> enclaveRequestedCEKs = new ArrayList<>();
ResultSet rs = null;
try (PreparedStatement stmt = connection.prepareStatement(connection.enclaveEstablished() ? SDPE1 : SDPE2)) {
if (connection.enclaveEstablished()) {
rs = executeSDPEv1(stmt, userSql, preparedTypeDefinitions);
} else {
rs = executeSDPEv2(stmt, userSql, preparedTypeDefinitions, aasParams);
}
if (null == rs) {
// No results. Meaning no parameter.
// Should never happen.
return enclaveRequestedCEKs;
}
processSDPEv1(userSql, preparedTypeDefinitions, params, parameterNames, connection, stmt, rs,
enclaveRequestedCEKs);
// Process the third resultset.
if (connection.isAEv2() && stmt.getMoreResults()) {
rs = (SQLServerResultSet) stmt.getResultSet();
while (rs.next()) {
hgsResponse = new AASAttestationResponse(rs.getBytes(1));
// This validates and establishes the enclave session if valid
if (!connection.enclaveEstablished()) {
hgsResponse = validateAttestationResponse(hgsResponse);
try (ResultSet rs = connection.enclaveEstablished() ? executeSDPEv1(stmt, userSql,
preparedTypeDefinitions) : executeSDPEv2(stmt, userSql, preparedTypeDefinitions, aasParams)) {
if (null == rs) {
// No results. Meaning no parameter.
// Should never happen.
return enclaveRequestedCEKs;
}
processSDPEv1(userSql, preparedTypeDefinitions, params, parameterNames, connection, stmt, rs,
enclaveRequestedCEKs);
// Process the third resultset.
if (connection.isAEv2() && stmt.getMoreResults()) {
try (ResultSet hgsRs = (SQLServerResultSet) stmt.getResultSet()) {
if (hgsRs.next()) {
hgsResponse = new AASAttestationResponse(hgsRs.getBytes(1));
// This validates and establishes the enclave session if valid
validateAttestationResponse();
}
}
}
}
// Null check for rs is done already.
rs.close();
} catch (SQLException | IOException e) {
if (e instanceof SQLServerException) {
throw (SQLServerException) e;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -855,113 +855,106 @@ private void getParameterEncryptionMetadata(Parameter[] params) throws SQLServer
* not need to be the same as in the table definition. Also, when string is sent to an int field, the parameter
* is defined as nvarchar(<size of string>). Same for varchar datatypes, exact length is used.
*/
SQLServerResultSet rs = null;
SQLServerCallableStatement stmt = null;

assert connection != null : "Connection should not be null";

try {
try (Statement stmt = connection.prepareCall("exec sp_describe_parameter_encryption ?,?")) {
if (getStatementLogger().isLoggable(java.util.logging.Level.FINE)) {
getStatementLogger().fine(
"Calling stored procedure sp_describe_parameter_encryption to get parameter encryption information.");
}
((SQLServerCallableStatement) stmt).isInternalEncryptionQuery = true;
((SQLServerCallableStatement) stmt).setNString(1, preparedSQL);
((SQLServerCallableStatement) stmt).setNString(2, preparedTypeDefinitions);
try (ResultSet rs = ((SQLServerCallableStatement) stmt).executeQueryInternal()) {
if (null == rs) {
// No results. Meaning no parameter.
// Should never happen.
return;
}

stmt = (SQLServerCallableStatement) connection.prepareCall("exec sp_describe_parameter_encryption ?,?");
stmt.isInternalEncryptionQuery = true;
stmt.setNString(1, preparedSQL);
stmt.setNString(2, preparedTypeDefinitions);
rs = (SQLServerResultSet) stmt.executeQueryInternal();
} catch (SQLException e) {
if (e instanceof SQLServerException) {
throw (SQLServerException) e;
} else {
throw new SQLServerException(SQLServerException.getErrString("R_UnableRetrieveParameterMetadata"), null,
0, e);
}
}

if (null == rs) {
// No results. Meaning no parameter.
// Should never happen.
return;
}
Map<Integer, CekTableEntry> cekList = new HashMap<>();
CekTableEntry cekEntry = null;
while (rs.next()) {
int currentOrdinal = rs.getInt(DescribeParameterEncryptionResultSet1.KeyOrdinal.value());
if (!cekList.containsKey(currentOrdinal)) {
cekEntry = new CekTableEntry(currentOrdinal);
cekList.put(cekEntry.ordinal, cekEntry);
} else {
cekEntry = cekList.get(currentOrdinal);
}
cekEntry.add(rs.getBytes(DescribeParameterEncryptionResultSet1.EncryptedKey.value()),
rs.getInt(DescribeParameterEncryptionResultSet1.DbId.value()),
rs.getInt(DescribeParameterEncryptionResultSet1.KeyId.value()),
rs.getInt(DescribeParameterEncryptionResultSet1.KeyVersion.value()),
rs.getBytes(DescribeParameterEncryptionResultSet1.KeyMdVersion.value()),
rs.getString(DescribeParameterEncryptionResultSet1.KeyPath.value()),
rs.getString(DescribeParameterEncryptionResultSet1.ProviderName.value()),
rs.getString(DescribeParameterEncryptionResultSet1.KeyEncryptionAlgorithm.value()));
}
if (getStatementLogger().isLoggable(java.util.logging.Level.FINE)) {
getStatementLogger().fine("Matadata of CEKs is retrieved.");
}

Map<Integer, CekTableEntry> cekList = new HashMap<>();
CekTableEntry cekEntry = null;
try {
while (rs.next()) {
int currentOrdinal = rs.getInt(DescribeParameterEncryptionResultSet1.KeyOrdinal.value());
if (!cekList.containsKey(currentOrdinal)) {
cekEntry = new CekTableEntry(currentOrdinal);
cekList.put(cekEntry.ordinal, cekEntry);
} else {
cekEntry = cekList.get(currentOrdinal);
// Process the second resultset.
if (!stmt.getMoreResults()) {
throw new SQLServerException(this,
SQLServerException.getErrString("R_UnexpectedDescribeParamFormat"), null, 0, false);
}
cekEntry.add(rs.getBytes(DescribeParameterEncryptionResultSet1.EncryptedKey.value()),
rs.getInt(DescribeParameterEncryptionResultSet1.DbId.value()),
rs.getInt(DescribeParameterEncryptionResultSet1.KeyId.value()),
rs.getInt(DescribeParameterEncryptionResultSet1.KeyVersion.value()),
rs.getBytes(DescribeParameterEncryptionResultSet1.KeyMdVersion.value()),
rs.getString(DescribeParameterEncryptionResultSet1.KeyPath.value()),
rs.getString(DescribeParameterEncryptionResultSet1.ProviderName.value()),
rs.getString(DescribeParameterEncryptionResultSet1.KeyEncryptionAlgorithm.value()));
}
if (getStatementLogger().isLoggable(java.util.logging.Level.FINE)) {
getStatementLogger().fine("Matadata of CEKs is retrieved.");
}
} catch (SQLException e) {
if (e instanceof SQLServerException) {
throw (SQLServerException) e;
} else {
throw new SQLServerException(SQLServerException.getErrString("R_UnableRetrieveParameterMetadata"), null,
0, e);
}
}

// Process the second resultset.
if (!stmt.getMoreResults()) {
throw new SQLServerException(this, SQLServerException.getErrString("R_UnexpectedDescribeParamFormat"), null,
0, false);
}
// Parameter count in the result set.
int paramCount = 0;
try (ResultSet secondRs = stmt.getResultSet()) {
while (secondRs.next()) {
paramCount++;
String paramName = secondRs
.getString(DescribeParameterEncryptionResultSet2.ParameterName.value());
int paramIndex = parameterNames.indexOf(paramName);
int cekOrdinal = secondRs
.getInt(DescribeParameterEncryptionResultSet2.ColumnEncryptionKeyOrdinal.value());
cekEntry = cekList.get(cekOrdinal);

// cekEntry will be null if none of the parameters are encrypted.
if ((null != cekEntry) && (cekList.size() < cekOrdinal)) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_InvalidEncryptionKeyOrdinal"));
Object[] msgArgs = {cekOrdinal, cekEntry.getSize()};
throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
}
SQLServerEncryptionType encType = SQLServerEncryptionType.of((byte) secondRs
.getInt(DescribeParameterEncryptionResultSet2.ColumnEncrytionType.value()));
if (SQLServerEncryptionType.PlainText != encType) {
params[paramIndex].cryptoMeta = new CryptoMetadata(cekEntry, (short) cekOrdinal,
(byte) secondRs.getInt(
DescribeParameterEncryptionResultSet2.ColumnEncryptionAlgorithm.value()),
null, encType.value, (byte) secondRs.getInt(
DescribeParameterEncryptionResultSet2.NormalizationRuleVersion.value()));
// Decrypt the symmetric key.(This will also validate and throw if needed).
SQLServerSecurityUtility.decryptSymmetricKey(params[paramIndex].cryptoMeta, connection);
} else {
if (params[paramIndex].getForceEncryption()) {
MessageFormat form = new MessageFormat(SQLServerException
.getErrString("R_ForceEncryptionTrue_HonorAETrue_UnencryptedColumn"));
Object[] msgArgs = {userSQL, paramIndex + 1};
SQLServerException.makeFromDriverError(connection, this, form.format(msgArgs), null,
true);
}
}
}
if (getStatementLogger().isLoggable(java.util.logging.Level.FINE)) {
getStatementLogger().fine("Parameter encryption metadata is set.");
}
}

// Parameter count in the result set.
int paramCount = 0;
try {
rs = (SQLServerResultSet) stmt.getResultSet();
while (rs.next()) {
paramCount++;
String paramName = rs.getString(DescribeParameterEncryptionResultSet2.ParameterName.value());
int paramIndex = parameterNames.indexOf(paramName);
int cekOrdinal = rs.getInt(DescribeParameterEncryptionResultSet2.ColumnEncryptionKeyOrdinal.value());
cekEntry = cekList.get(cekOrdinal);

// cekEntry will be null if none of the parameters are encrypted.
if ((null != cekEntry) && (cekList.size() < cekOrdinal)) {
if (paramCount != params.length) {
// Encryption metadata wasn't sent by the server.
// We expect the metadata to be sent for all the parameters in the original
// sp_describe_parameter_encryption.
// For parameters that don't need encryption, the encryption type is set to plaintext.
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_InvalidEncryptionKeyOrdinal"));
Object[] msgArgs = {cekOrdinal, cekEntry.getSize()};
SQLServerException.getErrString("R_MissingParamEncryptionMetadata"));
Object[] msgArgs = {userSQL};
throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
}
SQLServerEncryptionType encType = SQLServerEncryptionType
.of((byte) rs.getInt(DescribeParameterEncryptionResultSet2.ColumnEncrytionType.value()));
if (SQLServerEncryptionType.PlainText != encType) {
params[paramIndex].cryptoMeta = new CryptoMetadata(cekEntry, (short) cekOrdinal,
(byte) rs.getInt(DescribeParameterEncryptionResultSet2.ColumnEncryptionAlgorithm.value()),
null, encType.value,
(byte) rs.getInt(DescribeParameterEncryptionResultSet2.NormalizationRuleVersion.value()));
// Decrypt the symmetric key.(This will also validate and throw if needed).
SQLServerSecurityUtility.decryptSymmetricKey(params[paramIndex].cryptoMeta, connection);
} else {
if (params[paramIndex].getForceEncryption()) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_ForceEncryptionTrue_HonorAETrue_UnencryptedColumn"));
Object[] msgArgs = {userSQL, paramIndex + 1};
SQLServerException.makeFromDriverError(connection, this, form.format(msgArgs), null, true);
}
}
}
if (getStatementLogger().isLoggable(java.util.logging.Level.FINE)) {
getStatementLogger().fine("Parameter encryption metadata is set.");
}
} catch (SQLException e) {
if (e instanceof SQLServerException) {
Expand All @@ -972,22 +965,6 @@ private void getParameterEncryptionMetadata(Parameter[] params) throws SQLServer
}
}

if (paramCount != params.length) {
// Encryption metadata wasn't sent by the server.
// We expect the metadata to be sent for all the parameters in the original
// sp_describe_parameter_encryption.
// For parameters that don't need encryption, the encryption type is set to plaintext.
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_MissingParamEncryptionMetadata"));
Object[] msgArgs = {userSQL};
throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
}

// Null check for rs is done already.
rs.close();

if (null != stmt) {
stmt.close();
}
connection.resetCurrentCommand();
}

Expand Down
Loading

0 comments on commit 427911c

Please sign in to comment.