Skip to content

Commit

Permalink
[aws#36] Adding a proxy to delegate to appropriate IAMSaslClientProvi…
Browse files Browse the repository at this point in the history
…der based on ClassLoader
  • Loading branch information
dannycranmer committed Oct 10, 2022
1 parent 5e2cb95 commit 6efbfaa
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ bin
internals
build
lombok.config
out/
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package software.amazon.msk.auth.iam;

import software.amazon.msk.auth.iam.internals.ClassLoaderAwareIAMSaslClientProvider;
import software.amazon.msk.auth.iam.internals.IAMSaslClientProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -35,6 +36,7 @@ public class IAMLoginModule implements LoginModule {
private static final Logger log = LoggerFactory.getLogger(IAMLoginModule.class);

static {
ClassLoaderAwareIAMSaslClientProvider.initialize();
IAMSaslClientProvider.initialize();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;
Expand Down Expand Up @@ -203,28 +204,58 @@ private static boolean isChallengeEmpty(byte[] challenge) {
return true;
}

public static class ClassLoaderAwareIAMSaslClientFactory implements SaslClientFactory {

@Override
public SaslClient createSaslClient(String[] mechanisms,
String authorizationId,
String protocol,
String serverName,
Map<String, ?> props,
CallbackHandler cbh) throws SaslException {
String mechanismName = getMechanismNameForClassLoader(cbh.getClass().getClassLoader());

// Creat a client by delegating to the SaslClientFactory for the classloader of the CallbackHandler
return Sasl.createSaslClient(
new String[] { mechanismName },
authorizationId, protocol, serverName, props, cbh);
}

@Override
public String[] getMechanismNames(Map<String, ?> props) {
return new String[] { IAMLoginModule.MECHANISM };
}
}

public static class IAMSaslClientFactory implements SaslClientFactory {

@Override
public SaslClient createSaslClient(String[] mechanisms,
String authorizationId,
String protocol,
String serverName,
Map<String, ?> props,
CallbackHandler cbh) throws SaslException {
String mechanismName = getMechanismNameForClassLoader(getClass().getClassLoader());

for (String mechanism : mechanisms) {
if (IAMLoginModule.MECHANISM.equals(mechanism)) {
if (mechanismName.equals(mechanism)) {
return new IAMSaslClient(mechanism, cbh, serverName, new AWS4SignedPayloadGenerator());
}
}

throw new SaslException(
"Requested mechanisms " + Arrays.asList(mechanisms) + " not supported. The supported" +
"mechanism is " + IAMLoginModule.MECHANISM);
"Requested mechanisms " + Arrays.asList(mechanisms) + " not supported. " +
"The supported mechanism is " + mechanismName);
}

@Override
public String[] getMechanismNames(Map<String, ?> props) {
return new String[] { IAMLoginModule.MECHANISM };
return new String[] { getMechanismNameForClassLoader(getClass().getClassLoader()) };
}
}

public static String getMechanismNameForClassLoader(ClassLoader classLoader) {
return IAMLoginModule.MECHANISM + "." + classLoader.hashCode();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@
*/
package software.amazon.msk.auth.iam.internals;

import software.amazon.msk.auth.iam.IAMLoginModule;
import software.amazon.msk.auth.iam.internals.IAMSaslClient.IAMSaslClientFactory;

import java.security.Provider;
import java.security.Security;

import static software.amazon.msk.auth.iam.internals.IAMSaslClient.getMechanismNameForClassLoader;

public class IAMSaslClientProvider extends Provider {
/**
* Constructs a IAM Sasl Client provider with a fixed name, version number,
* and information.
*/
protected IAMSaslClientProvider() {
super("SASL/IAM Client Provider", 1.0, "SASL/IAM Client Provider for Kafka");
put("SaslClientFactory." + IAMLoginModule.MECHANISM, IAMSaslClient.IAMSaslClientFactory.class.getName());
super("SASL/IAM Client Provider (" +
IAMSaslClientProvider.class.getClassLoader().hashCode(), 1.0,
") SASL/IAM Client Provider for Kafka");
put("SaslClientFactory." + getMechanismNameForClassLoader(getClass().getClassLoader()), IAMSaslClientFactory.class.getName());
}

public static void initialize() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
package software.amazon.msk.auth.iam.internals;

import com.amazonaws.auth.BasicAWSCredentials;
import org.junit.jupiter.api.BeforeEach;
import software.amazon.msk.auth.iam.IAMClientCallbackHandler;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.kafka.common.errors.IllegalSaslStateException;
import org.junit.jupiter.api.Test;
import software.amazon.msk.auth.iam.internals.IAMSaslClient.ClassLoaderAwareIAMSaslClientFactory;

import static java.util.Collections.emptyMap;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
Expand All @@ -47,14 +50,19 @@ public class IAMSaslClientTest {

private static final BasicAWSCredentials BASIC_AWS_CREDENTIALS = new BasicAWSCredentials(ACCESS_KEY_VALUE, SECRET_KEY_VALUE);

@BeforeEach
public void setUp() {
IAMSaslClientProvider.initialize();
}

@Test
public void testCompleteValidExchange() throws IOException, ParseException {
public void testCompleteValidExchange() throws IOException {
IAMSaslClient saslClient = getSuccessfulIAMClient(getIamClientCallbackHandler());
runValidExchangeForSaslClient(saslClient, ACCESS_KEY_VALUE, SECRET_KEY_VALUE);
}

private void runValidExchangeForSaslClient(IAMSaslClient saslClient, String accessKey, String secretKey) {
assertEquals(AWS_MSK_IAM, saslClient.getMechanismName());
assertEquals(getMechanismName(), saslClient.getMechanismName());
assertTrue(saslClient.hasInitialResponse());
SystemPropertyCredentialsUtils.runTestWithSystemPropertyCredentials(() -> {
try {
Expand Down Expand Up @@ -97,7 +105,7 @@ public void testMultipleSaslClients() throws IOException, ParseException {

private IAMClientCallbackHandler getIamClientCallbackHandler() {
IAMClientCallbackHandler cbh = new IAMClientCallbackHandler();
cbh.configure(Collections.EMPTY_MAP, AWS_MSK_IAM, Collections.emptyList());
cbh.configure(emptyMap(), AWS_MSK_IAM, Collections.emptyList());
return cbh;
}

Expand Down Expand Up @@ -127,11 +135,11 @@ public void testThrowingCallback() throws SaslException {
@Test
public void testInvalidServerResponse() throws SaslException {
SaslClient saslClient = getSuccessfulIAMClient(getIamClientCallbackHandler());
assertEquals(AWS_MSK_IAM, saslClient.getMechanismName());
assertEquals(getMechanismName(), saslClient.getMechanismName());
assertTrue(saslClient.hasInitialResponse());
SystemPropertyCredentialsUtils.runTestWithSystemPropertyCredentials(() -> {
try {
byte[] response = saslClient.evaluateChallenge(new byte[]{});
saslClient.evaluateChallenge(new byte[]{});
} catch (SaslException e) {
throw new RuntimeException("Test failed", e);
}
Expand All @@ -149,7 +157,7 @@ public void testInvalidResponseVersion() throws SaslException {
SaslClient saslClient = getSuccessfulIAMClient(getIamClientCallbackHandler());
SystemPropertyCredentialsUtils.runTestWithSystemPropertyCredentials(() -> {
try {
byte[] response = saslClient.evaluateChallenge(new byte[]{});
saslClient.evaluateChallenge(new byte[]{});
} catch (SaslException e) {
throw new RuntimeException("Test failed", e);
}
Expand All @@ -174,11 +182,11 @@ private byte[] getResponseWithInvalidVersion() {
@Test
public void testEmptyServerResponse() throws SaslException {
SaslClient saslClient = getSuccessfulIAMClient(getIamClientCallbackHandler());
assertEquals(AWS_MSK_IAM, saslClient.getMechanismName());
assertEquals(getMechanismName(), saslClient.getMechanismName());
assertTrue(saslClient.hasInitialResponse());
SystemPropertyCredentialsUtils.runTestWithSystemPropertyCredentials(() -> {
try {
byte[] response = saslClient.evaluateChallenge(new byte[]{});
saslClient.evaluateChallenge(new byte[]{});
} catch (SaslException e) {
throw new RuntimeException("Test failed", e);
}
Expand All @@ -191,19 +199,25 @@ public void testEmptyServerResponse() throws SaslException {

@Test
public void testFactoryMechanisms() {
assertArrayEquals(new String[]{AWS_MSK_IAM},
new IAMSaslClient.IAMSaslClientFactory().getMechanismNames(Collections.emptyMap()));
assertArrayEquals(new String[] { getMechanismName() },
new IAMSaslClient.IAMSaslClientFactory().getMechanismNames(emptyMap()));
}

@Test
public void testInvalidMechanism() {

assertThrows(SaslException.class, () -> new IAMSaslClient.IAMSaslClientFactory()
.createSaslClient(new String[]{AWS_MSK_IAM + "BAD"}, "AUTH_ID", "PROTOCOL", VALID_HOSTNAME,
Collections.emptyMap(),
emptyMap(),
new SuccessfulIAMCallbackHandler(BASIC_AWS_CREDENTIALS)));
}

@Test
public void testClassLoaderAwareIAMSaslClientFactoryMechanisms() {
assertArrayEquals(new String[] { AWS_MSK_IAM },
new ClassLoaderAwareIAMSaslClientFactory().getMechanismNames(emptyMap()));
}

private static class SuccessfulIAMCallbackHandler extends IAMClientCallbackHandler {
private final BasicAWSCredentials basicAWSCredentials;

Expand Down Expand Up @@ -240,10 +254,14 @@ protected void handleCallback(AWSCredentialsCallback callback) throws IOExceptio
}

private IAMSaslClient getIAMClient(Supplier<IAMClientCallbackHandler> handlerSupplier) throws SaslException {
return (IAMSaslClient )new IAMSaslClient.IAMSaslClientFactory()
.createSaslClient(new String[]{AWS_MSK_IAM}, "AUTH_ID", "PROTOCOL", VALID_HOSTNAME,
Collections.emptyMap(),
return (IAMSaslClient) new IAMSaslClient.ClassLoaderAwareIAMSaslClientFactory()
.createSaslClient(new String[] { AWS_MSK_IAM }, "AUTH_ID", "PROTOCOL", VALID_HOSTNAME,
emptyMap(),
handlerSupplier.get());
}

private String getMechanismName() {
return AWS_MSK_IAM + "." + getClass().getClassLoader().hashCode();
}

}

0 comments on commit 6efbfaa

Please sign in to comment.