Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates the S3 sink to use the AWS Plugin for loading AWS credentials #2787

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data-prepper-plugins/s3-sink/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
dependencies {
implementation project(':data-prepper-api')
implementation project(path: ':data-prepper-plugins:common')
implementation project(':data-prepper-plugins:aws-plugin-api')
implementation 'io.micrometer:micrometer-core'
implementation 'com.fasterxml.jackson.core:jackson-core'
implementation 'com.fasterxml.jackson.core:jackson-databind'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ public void setUp() {
when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("2mb"));
when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.parse("PT3M"));
when(s3SinkConfig.getThresholdOptions()).thenReturn(thresholdOptions);
when(s3SinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(s3region));

lenient().when(pluginMetrics.counter(S3SinkService.OBJECTS_SUCCEEDED)).thenReturn(snapshotSuccessCounter);
lenient().when(pluginMetrics.counter(S3SinkService.OBJECTS_FAILED)).thenReturn(snapshotFailedCounter);
Expand Down Expand Up @@ -136,7 +134,7 @@ void verify_flushed_records_into_s3_bucket() {
}

private S3SinkService createObjectUnderTest() {
return new S3SinkService(s3SinkConfig, bufferFactory, codec, pluginMetrics);
return new S3SinkService(s3SinkConfig, bufferFactory, codec, s3Client, pluginMetrics);
}

private int gets3ObjectCount() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink;

import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.sink.configuration.AwsAuthenticationOptions;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.services.s3.S3Client;

public final class ClientFactory {
private ClientFactory() { }

static S3Client createS3Client(final S3SinkConfig s3SinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) {
final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(s3SinkConfig.getAwsAuthenticationOptions());
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions);

return S3Client.builder()
.region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion())
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(createOverrideConfiguration(s3SinkConfig)).build();
}

private static ClientOverrideConfiguration createOverrideConfiguration(final S3SinkConfig s3SinkConfig) {
final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(s3SinkConfig.getMaxConnectionRetries()).build();
return ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.build();
}

private static AwsCredentialsOptions convertToCredentialsOptions(final AwsAuthenticationOptions awsAuthenticationOptions) {
return AwsCredentialsOptions.builder()
.withRegion(awsAuthenticationOptions.getAwsRegion())
.withStsRoleArn(awsAuthenticationOptions.getAwsStsRoleArn())
.withStsHeaderOverrides(awsAuthenticationOptions.getAwsStsHeaderOverrides())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.dataprepper.plugins.sink;

import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.configuration.PluginModel;
Expand All @@ -22,6 +23,8 @@
import org.opensearch.dataprepper.plugins.sink.codec.Codec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.s3.S3Client;

import java.util.Collection;

/**
Expand All @@ -35,7 +38,7 @@ public class S3Sink extends AbstractSink<Record<Event>> {
private final S3SinkConfig s3SinkConfig;
private final Codec codec;
private volatile boolean sinkInitialized;
private S3SinkService s3SinkService;
private final S3SinkService s3SinkService;
private final BufferFactory bufferFactory;

/**
Expand All @@ -44,8 +47,10 @@ public class S3Sink extends AbstractSink<Record<Event>> {
* @param pluginFactory dp plugin factory.
*/
@DataPrepperPluginConstructor
public S3Sink(final PluginSetting pluginSetting, final S3SinkConfig s3SinkConfig,
final PluginFactory pluginFactory) {
public S3Sink(final PluginSetting pluginSetting,
final S3SinkConfig s3SinkConfig,
final PluginFactory pluginFactory,
final AwsCredentialsSupplier awsCredentialsSupplier) {
super(pluginSetting);
this.s3SinkConfig = s3SinkConfig;
final PluginModel codecConfiguration = s3SinkConfig.getCodec();
Expand All @@ -59,6 +64,8 @@ public S3Sink(final PluginSetting pluginSetting, final S3SinkConfig s3SinkConfig
} else {
bufferFactory = new InMemoryBufferFactory();
}
final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier);
s3SinkService = new S3SinkService(s3SinkConfig, bufferFactory, codec, s3Client, pluginMetrics);
}

@Override
Expand All @@ -85,7 +92,6 @@ public void doInitialize() {
* Initialize {@link S3SinkService}
*/
private void doInitializeInternal() {
s3SinkService = new S3SinkService(s3SinkConfig, bufferFactory, codec, pluginMetrics);
sinkInitialized = Boolean.TRUE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.services.s3.S3Client;

import java.io.IOException;
Expand All @@ -47,6 +45,7 @@ public class S3SinkService {
private final BufferFactory bufferFactory;
private final Collection<EventHandle> bufferedEventHandles;
private final Codec codec;
private final S3Client s3Client;
private Buffer currentBuffer;
private final int maxEvents;
private final ByteCount maxBytes;
Expand All @@ -61,15 +60,17 @@ public class S3SinkService {

/**
* @param s3SinkConfig s3 sink related configuration.
* @param bufferFactory factory of buffer.
* @param bufferFactory factory of buffer.
* @param codec parser.
* @param s3Client
* @param pluginMetrics metrics.
*/
public S3SinkService(final S3SinkConfig s3SinkConfig, final BufferFactory bufferFactory,
final Codec codec, final PluginMetrics pluginMetrics) {
final Codec codec, final S3Client s3Client, final PluginMetrics pluginMetrics) {
this.s3SinkConfig = s3SinkConfig;
this.bufferFactory = bufferFactory;
this.codec = codec;
this.s3Client = s3Client;
reentrantLock = new ReentrantLock();

bufferedEventHandles = new LinkedList<>();
Expand Down Expand Up @@ -154,7 +155,7 @@ protected boolean retryFlushToS3(final Buffer currentBuffer, final String s3Key)
int retryCount = maxRetries;
do {
try {
currentBuffer.flushToS3(createS3Client(), bucket, s3Key);
currentBuffer.flushToS3(s3Client, bucket, s3Key);
isUploadedToS3 = Boolean.TRUE;
} catch (AwsServiceException | SdkClientException e) {
LOG.error("Exception occurred while uploading records to s3 bucket. Retry countdown : {} | exception:",
Expand All @@ -179,15 +180,4 @@ protected String generateKey() {
final String namePattern = ObjectKey.objectFileName(s3SinkConfig);
return (!pathPrefix.isEmpty()) ? pathPrefix + namePattern : namePattern;
}

/**
* create s3 client instance.
* @return {@link S3Client}
*/
public S3Client createS3Client() {
return S3Client.builder().region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion())
.credentialsProvider(s3SinkConfig.getAwsAuthenticationOptions().authenticateAwsConfiguration())
.overrideConfiguration(ClientOverrideConfiguration.builder().retryPolicy(RetryPolicy.builder()
.numRetries(s3SinkConfig.getMaxConnectionRetries()).build()).build()).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,11 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.constraints.Size;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;

import java.util.Map;
import java.util.Optional;
import java.util.UUID;

public class AwsAuthenticationOptions {
private static final String AWS_IAM_ROLE = "role";
private static final String AWS_IAM = "iam";

@JsonProperty("region")
@Size(min = 1, message = "Region cannot be empty string")
private String awsRegion;
Expand All @@ -35,58 +24,15 @@ public class AwsAuthenticationOptions {
@Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override")
private Map<String, String> awsStsHeaderOverrides;

private void validateStsRoleArn() {
final Arn arn = getArn();
if (!AWS_IAM.equals(arn.service())) {
throw new IllegalArgumentException("sts_role_arn must be an IAM Role");
}
final Optional<String> resourceType = arn.resource().resourceType();
if (resourceType.isEmpty() || !resourceType.get().equals(AWS_IAM_ROLE)) {
throw new IllegalArgumentException("sts_role_arn must be an IAM Role");
}
}

private Arn getArn() {
try {
return Arn.fromString(awsStsRoleArn);
} catch (final Exception e) {
throw new IllegalArgumentException(String.format("Invalid ARN format for awsStsRoleArn. Check the format of %s", awsStsRoleArn));
}
}

public Region getAwsRegion() {
return awsRegion != null ? Region.of(awsRegion) : null;
}

public AwsCredentialsProvider authenticateAwsConfiguration() {

final AwsCredentialsProvider awsCredentialsProvider;
if (awsStsRoleArn != null && !awsStsRoleArn.isEmpty()) {

validateStsRoleArn();

final StsClient stsClient = StsClient.builder()
.region(getAwsRegion())
.build();

AssumeRoleRequest.Builder assumeRoleRequestBuilder = AssumeRoleRequest.builder()
.roleSessionName("S3-Sink-" + UUID.randomUUID())
.roleArn(awsStsRoleArn);
if(awsStsHeaderOverrides != null && !awsStsHeaderOverrides.isEmpty()) {
assumeRoleRequestBuilder = assumeRoleRequestBuilder
.overrideConfiguration(configuration -> awsStsHeaderOverrides.forEach(configuration::putHeader));
}

awsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient)
.refreshRequest(assumeRoleRequestBuilder.build())
.build();

} else {
// use default credential provider
awsCredentialsProvider = DefaultCredentialsProvider.create();
}
public String getAwsStsRoleArn() {
return awsStsRoleArn;
}

return awsCredentialsProvider;
public Map<String, String> getAwsStsHeaderOverrides() {
return awsStsHeaderOverrides;
Comment on lines +31 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no unit tests in AwsAuthenticationOptionsTest.java file for these new methods

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cmanning09 , I just pushed a change which adds testing for these getters.

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.sink.configuration.AwsAuthenticationOptions;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;

import java.util.Map;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class ClientFactoryTest {
@Mock
private S3SinkConfig s3SinkConfig;
@Mock
private AwsCredentialsSupplier awsCredentialsSupplier;

@Mock
private AwsAuthenticationOptions awsAuthenticationOptions;

@BeforeEach
void setUp() {
when(s3SinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);
}

@Test
void createS3Client_with_real_S3Client() {
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1);
final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier);

assertThat(s3Client, notNullValue());
}

@ParameterizedTest
@ValueSource(strings = {"us-east-1", "us-west-2", "eu-central-1"})
void createS3Client_provides_correct_inputs(final String regionString) {
final Region region = Region.of(regionString);
final String stsRoleArn = UUID.randomUUID().toString();
final Map<String, String> stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString());
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region);
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides);

final AwsCredentialsProvider expectedCredentialsProvider = mock(AwsCredentialsProvider.class);
when(awsCredentialsSupplier.getProvider(any())).thenReturn(expectedCredentialsProvider);

final S3ClientBuilder s3ClientBuilder = mock(S3ClientBuilder.class);
when(s3ClientBuilder.region(region)).thenReturn(s3ClientBuilder);
when(s3ClientBuilder.credentialsProvider(any())).thenReturn(s3ClientBuilder);
when(s3ClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(s3ClientBuilder);
try(final MockedStatic<S3Client> s3ClientMockedStatic = mockStatic(S3Client.class)) {
s3ClientMockedStatic.when(S3Client::builder)
.thenReturn(s3ClientBuilder);
ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier);
}

final ArgumentCaptor<AwsCredentialsProvider> credentialsProviderArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsProvider.class);
verify(s3ClientBuilder).credentialsProvider(credentialsProviderArgumentCaptor.capture());

final AwsCredentialsProvider actualCredentialsProvider = credentialsProviderArgumentCaptor.getValue();

assertThat(actualCredentialsProvider, equalTo(expectedCredentialsProvider));

final ArgumentCaptor<AwsCredentialsOptions> optionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class);
verify(awsCredentialsSupplier).getProvider(optionsArgumentCaptor.capture());

final AwsCredentialsOptions actualCredentialsOptions = optionsArgumentCaptor.getValue();
assertThat(actualCredentialsOptions.getRegion(), equalTo(region));
assertThat(actualCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn));
assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides));
}
}
Loading
Loading