Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport] Added credential caching for Managed Identity Credential and Default Azure Credential (2415) #2426

Merged
merged 1 commit into from
May 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 104 additions & 16 deletions src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@
package com.microsoft.sqlserver.jdbc;

import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Optional;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRequestContext;
import com.azure.identity.ManagedIdentityCredential;
import com.azure.identity.ManagedIdentityCredentialBuilder;
Expand Down Expand Up @@ -46,6 +51,11 @@ class SQLServerSecurityUtility {
// Environment variable for additionally allowed tenants. The tenantIds are comma delimited
private static final String ADDITIONALLY_ALLOWED_TENANTS = "ADDITIONALLY_ALLOWED_TENANTS";

// Credential Cache for ManagedIdentityCredential and DefaultAzureCredential
private static final HashMap<String, Credential> CREDENTIAL_CACHE = new HashMap<>();

private static final Lock CREDENTIAL_LOCK = new ReentrantLock();

private SQLServerSecurityUtility() {
throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported"));
}
Expand Down Expand Up @@ -331,16 +341,35 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer
*/
static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource,
String managedIdentityClientId) throws SQLServerException {
ManagedIdentityCredential mic = null;

if (logger.isLoggable(java.util.logging.Level.FINEST)) {
logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId);
}

if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
mic = new ManagedIdentityCredentialBuilder().clientId(managedIdentityClientId).build();
} else {
mic = new ManagedIdentityCredentialBuilder().build();
String key = getHashedSecret(
new String[] {managedIdentityClientId, ManagedIdentityCredential.class.getSimpleName()});
ManagedIdentityCredential mic = (ManagedIdentityCredential) getCredentialFromCache(key);

if (null == mic) {
CREDENTIAL_LOCK.lock();

try {
mic = (ManagedIdentityCredential) getCredentialFromCache(key);
if (null == mic) {
ManagedIdentityCredentialBuilder micBuilder = new ManagedIdentityCredentialBuilder();

if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
mic = micBuilder.clientId(managedIdentityClientId).build();
} else {
mic = micBuilder.build();
}

Credential credential = new Credential(mic);
CREDENTIAL_CACHE.put(key, credential);
}
} finally {
CREDENTIAL_LOCK.unlock();
}
}

TokenRequestContext tokenRequestContext = new TokenRequestContext();
Expand Down Expand Up @@ -383,22 +412,49 @@ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource,
String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS);
String[] additionallyAllowedTenants = getAdditonallyAllowedTenants();

DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder();
DefaultAzureCredential dac = null;
int secretsLength = null == additionallyAllowedTenants ? 3 : additionallyAllowedTenants.length + 3;
String[] secrets = new String[secretsLength];

if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
dacBuilder.managedIdentityClientId(managedIdentityClientId);
if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
System.arraycopy(additionallyAllowedTenants, 0, secrets, 3, additionallyAllowedTenants.length);
}

if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) {
dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath);
}
secrets[0] = DefaultAzureCredential.class.getSimpleName();
secrets[1] = managedIdentityClientId;
secrets[2] = intellijKeepassPath;

if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants);
}
String key = getHashedSecret(secrets);
DefaultAzureCredential dac = (DefaultAzureCredential) getCredentialFromCache(key);

if (null == dac) {
CREDENTIAL_LOCK.lock();

try {
dac = (DefaultAzureCredential) getCredentialFromCache(key);
if (null == dac) {
DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder();

if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
dacBuilder.managedIdentityClientId(managedIdentityClientId);
}

if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) {
dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath);
}

if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants);
}

dac = dacBuilder.build();

dac = dacBuilder.build();
Credential credential = new Credential(dac);
CREDENTIAL_CACHE.put(key, credential);
}
} finally {
CREDENTIAL_LOCK.unlock();
}
}

TokenRequestContext tokenRequestContext = new TokenRequestContext();
String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource
Expand Down Expand Up @@ -430,4 +486,36 @@ private static String[] getAdditonallyAllowedTenants() {

return null;
}

private static TokenCredential getCredentialFromCache(String key) {
Credential credential = CREDENTIAL_CACHE.get(key);

if (null != credential) {
return credential.tokenCredential;
}

return null;
}

private static class Credential {
TokenCredential tokenCredential;

public Credential(TokenCredential tokenCredential) {
this.tokenCredential = tokenCredential;
}
}

private static String getHashedSecret(String[] secrets) throws SQLServerException {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
for (String secret : secrets) {
if (null != secret) {
md.update(secret.getBytes(java.nio.charset.StandardCharsets.UTF_16LE));
}
}
return new String(md.digest());
} catch (NoSuchAlgorithmException e) {
throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e);
}
}
}