Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds sigv4 support to Elasticsearch client #3305

Merged
merged 4 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion data-prepper-plugins/aws-plugin-api/build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

dependencies {
implementation 'software.amazon.awssdk:auth'
implementation 'software.amazon.awssdk:apache-client'
}

test {
Expand All @@ -12,7 +13,7 @@ jacocoTestCoverageVerification {
violationRules {
rule {
limit {
minimum = 1.0
minimum = 0.99
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
/*
* Copyright OpenSearch Contributors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with
* the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.dataprepper.plugins.sink.opensearch;
package org.opensearch.dataprepper.aws.api;

import org.apache.http.Header;
import org.apache.http.HttpEntityEnclosingRequest;
Expand Down Expand Up @@ -48,7 +40,7 @@
* An {@link HttpRequestInterceptor} that signs requests using any AWS {@link Signer}
* and {@link AwsCredentialsProvider}.
*/
final class AwsRequestSigningApacheInterceptor implements HttpRequestInterceptor {
public final class AwsRequestSigningApache4Interceptor implements HttpRequestInterceptor {

/**
* Constant to check content-length
Expand Down Expand Up @@ -90,10 +82,10 @@ final class AwsRequestSigningApacheInterceptor implements HttpRequestInterceptor
* @param awsCredentialsProvider source of AWS credentials for signing
* @param region signing region
*/
public AwsRequestSigningApacheInterceptor(final String service,
final Signer signer,
final AwsCredentialsProvider awsCredentialsProvider,
final Region region) {
public AwsRequestSigningApache4Interceptor(final String service,
final Signer signer,
final AwsCredentialsProvider awsCredentialsProvider,
final Region region) {
this.service = Objects.requireNonNull(service);
this.signer = Objects.requireNonNull(signer);
this.awsCredentialsProvider = Objects.requireNonNull(awsCredentialsProvider);
Expand All @@ -107,10 +99,10 @@ public AwsRequestSigningApacheInterceptor(final String service,
* @param awsCredentialsProvider source of AWS credentials for signing
* @param region signing region
*/
public AwsRequestSigningApacheInterceptor(final String service,
final Signer signer,
final AwsCredentialsProvider awsCredentialsProvider,
final String region) {
public AwsRequestSigningApache4Interceptor(final String service,
final Signer signer,
final AwsCredentialsProvider awsCredentialsProvider,
final String region) {
this(service, signer, awsCredentialsProvider, Region.of(region));
}

Expand Down Expand Up @@ -177,7 +169,7 @@ private URI buildUri(final HttpContext context, URIBuilder uriBuilder) throws IO
}

return uriBuilder.build();
} catch (URISyntaxException e) {
} catch (final Exception e) {
throw new IOException("Invalid URI", e);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.aws.api;

import org.apache.http.HttpEntity;
import org.apache.http.HttpEntityEnclosingRequest;
import org.apache.http.HttpHost;
import org.apache.http.RequestLine;
import org.apache.http.message.BasicHeader;
import org.apache.http.protocol.HttpContext;
import org.apache.http.protocol.HttpCoreContext;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.signer.Signer;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.regions.Region;

import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
public class AwsRequestSigningApache4InterceptorTest {

@Mock
private Signer signer;

@Mock
private AwsCredentialsProvider awsCredentialsProvider;

@Mock
private HttpEntityEnclosingRequest httpRequest;

@Mock
private HttpContext httpContext;

private AwsRequestSigningApache4Interceptor createObjectUnderTest() {
return new AwsRequestSigningApache4Interceptor("es", signer, awsCredentialsProvider, Region.US_EAST_1);
}

@Test
void invalidURI_throws_IOException() {

final RequestLine requestLine = mock(RequestLine.class);
when(requestLine.getUri()).thenReturn("http://invalid-uri.com/file[/].html\n");

when(httpRequest.getRequestLine()).thenReturn(requestLine);

final AwsRequestSigningApache4Interceptor objectUnderTest = new AwsRequestSigningApache4Interceptor("es", signer, awsCredentialsProvider, "us-east-1");

assertThrows(IOException.class, () -> objectUnderTest.process(httpRequest, httpContext));
}

@Test
void IOException_is_thrown_when_buildURI_throws_exception() {
final RequestLine requestLine = mock(RequestLine.class);
when(requestLine.getMethod()).thenReturn("GET");
when(requestLine.getUri()).thenReturn("http://localhost?param=test");
when(httpRequest.getRequestLine()).thenReturn(requestLine);

when(httpContext.getAttribute(HttpCoreContext.HTTP_TARGET_HOST)).thenThrow(RuntimeException.class);

final AwsRequestSigningApache4Interceptor objectUnderTest = createObjectUnderTest();

assertThrows(IOException.class, () -> objectUnderTest.process(httpRequest, httpContext));
}

@Test
void empty_contentStreamProvider_throws_IllegalStateException() throws IOException {
final RequestLine requestLine = mock(RequestLine.class);
when(requestLine.getMethod()).thenReturn("GET");
when(requestLine.getUri()).thenReturn("http://localhost?param=test");
when(httpRequest.getRequestLine()).thenReturn(requestLine);
when(httpRequest.getAllHeaders()).thenReturn(new BasicHeader[]{
new BasicHeader("test-name", "test-value"),
new BasicHeader("content-length", "0")
});

final HttpEntity httpEntity = mock(HttpEntity.class);
final InputStream inputStream = mock(InputStream.class);
when(httpEntity.getContent()).thenReturn(inputStream);

when((httpRequest).getEntity()).thenReturn(httpEntity);

final HttpHost httpHost = HttpHost.create("http://localhost?param=test");
when(httpContext.getAttribute(HttpCoreContext.HTTP_TARGET_HOST)).thenReturn(httpHost);

final SdkHttpFullRequest signedRequest = mock(SdkHttpFullRequest.class);
when(signedRequest.headers()).thenReturn(Map.of("test-name", List.of("test-value")));
when(signedRequest.contentStreamProvider()).thenReturn(Optional.empty());
when(signer.sign(any(SdkHttpFullRequest.class), any(ExecutionAttributes.class)))
.thenReturn(signedRequest);

final AwsRequestSigningApache4Interceptor objectUnderTest = createObjectUnderTest();

assertThrows(IllegalStateException.class, () -> objectUnderTest.process(httpRequest, httpContext));
}

@Test
void testHappyPath() throws IOException {
final RequestLine requestLine = mock(RequestLine.class);
when(requestLine.getMethod()).thenReturn("GET");
when(requestLine.getUri()).thenReturn("http://localhost?param=test");
when(httpRequest.getRequestLine()).thenReturn(requestLine);
when(httpRequest.getAllHeaders()).thenReturn(new BasicHeader[]{
new BasicHeader("test-name", "test-value"),
new BasicHeader("content-length", "0")
});

final HttpEntity httpEntity = mock(HttpEntity.class);
final InputStream inputStream = mock(InputStream.class);
when(httpEntity.getContent()).thenReturn(inputStream);

when((httpRequest).getEntity()).thenReturn(httpEntity);

final HttpHost httpHost = HttpHost.create("http://localhost?param=test");
when(httpContext.getAttribute(HttpCoreContext.HTTP_TARGET_HOST)).thenReturn(httpHost);

final SdkHttpFullRequest signedRequest = mock(SdkHttpFullRequest.class);
when(signedRequest.headers()).thenReturn(Map.of("test-name", List.of("test-value")));
final ContentStreamProvider contentStreamProvider = mock(ContentStreamProvider.class);
final InputStream contentInputStream = mock(InputStream.class);
when(contentStreamProvider.newStream()).thenReturn(contentInputStream);
when(signedRequest.contentStreamProvider()).thenReturn(Optional.of(contentStreamProvider));
when(signer.sign(any(SdkHttpFullRequest.class), any(ExecutionAttributes.class)))
.thenReturn(signedRequest);
createObjectUnderTest().process(httpRequest, httpContext);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public void run() {
indexPartition.get(), sourceCoordinator);

openSearchSourcePluginMetrics.getIndicesProcessedCounter().increment();
LOG.info("Completed processing for index: '{}'", indexPartition.get().getPartitionKey());
} catch (final PartitionUpdateException | PartitionNotFoundException | PartitionNotOwnedException e) {
LOG.warn("The search_after worker received an exception from the source coordinator. There is a potential for duplicate data for index {}, giving up partition and getting next partition: {}", indexPartition.get().getPartitionKey(), e.getMessage());
sourceCoordinator.giveUpPartitions();
Expand Down Expand Up @@ -125,6 +126,8 @@ public void run() {
private void processIndex(final SourcePartition<OpenSearchIndexProgressState> openSearchIndexPartition,
final AcknowledgementSet acknowledgementSet) {
final String indexName = openSearchIndexPartition.getPartitionKey();
LOG.info("Started processing for index: '{}'", indexName);

Optional<OpenSearchIndexProgressState> openSearchIndexProgressStateOptional = openSearchIndexPartition.getPartitionState();

if (openSearchIndexProgressStateOptional.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ public void run() {
indexPartition.get(), sourceCoordinator);

openSearchSourcePluginMetrics.getIndicesProcessedCounter().increment();
LOG.info("Completed processing for index: '{}'", indexPartition.get().getPartitionKey());
} catch (final PartitionUpdateException | PartitionNotFoundException | PartitionNotOwnedException e) {
LOG.warn("PitWorker received an exception from the source coordinator. There is a potential for duplicate data for index {}, giving up partition and getting next partition: {}", indexPartition.get().getPartitionKey(), e.getMessage());
sourceCoordinator.giveUpPartitions();
Expand Down Expand Up @@ -149,6 +150,8 @@ public void run() {
private void processIndex(final SourcePartition<OpenSearchIndexProgressState> openSearchIndexPartition,
final AcknowledgementSet acknowledgementSet) {
final String indexName = openSearchIndexPartition.getPartitionKey();

LOG.info("Starting processing for index: '{}'", indexName);
Optional<OpenSearchIndexProgressState> openSearchIndexProgressStateOptional = openSearchIndexPartition.getPartitionState();

if (openSearchIndexProgressStateOptional.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ public void run() {
indexPartition.get(), sourceCoordinator);

openSearchSourcePluginMetrics.getIndicesProcessedCounter().increment();
LOG.info("Completed processing for index: '{}'", indexPartition.get().getPartitionKey());
} catch (final PartitionUpdateException | PartitionNotFoundException | PartitionNotOwnedException e) {
LOG.warn("ScrollWorker received an exception from the source coordinator. There is a potential for duplicate data for index {}, giving up partition and getting next partition: {}", indexPartition.get().getPartitionKey(), e.getMessage());
sourceCoordinator.giveUpPartitions();
Expand Down Expand Up @@ -142,6 +143,7 @@ public void run() {
private void processIndex(final SourcePartition<OpenSearchIndexProgressState> openSearchIndexPartition,
final AcknowledgementSet acknowledgementSet) {
final String indexName = openSearchIndexPartition.getPartitionKey();
LOG.info("Started processing for index: '{}'", indexName);

final Integer batchSize = openSearchSourceConfiguration.getSearchConfiguration().getBatchSize();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import co.elastic.clients.transport.ElasticsearchTransport;
import org.apache.http.Header;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequestInterceptor;
import org.apache.http.HttpResponseInterceptor;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
Expand All @@ -31,11 +32,13 @@
import org.opensearch.client.transport.rest_client.RestClientTransport;
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.aws.api.AwsRequestSigningApache4Interceptor;
import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration;
import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ConnectionConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;

Expand Down Expand Up @@ -165,12 +168,38 @@ private org.elasticsearch.client.RestClient createElasticSearchRestClient(final
new BasicHeader("Content-type", "application/json")
});

attachBasicAuth(restClientBuilder, openSearchSourceConfiguration);
if (Objects.nonNull(openSearchSourceConfiguration.getAwsAuthenticationOptions())) {
attachSigV4ForElasticsearchClient(restClientBuilder, openSearchSourceConfiguration);
} else {
attachBasicAuth(restClientBuilder, openSearchSourceConfiguration);
}
setConnectAndSocketTimeout(restClientBuilder, openSearchSourceConfiguration);

return restClientBuilder.build();
}

private void attachSigV4ForElasticsearchClient(final org.elasticsearch.client.RestClientBuilder restClientBuilder,
final OpenSearchSourceConfiguration openSearchSourceConfiguration) {
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder()
.withRegion(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion())
.withStsRoleArn(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsRoleArn())
.withStsExternalId(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsExternalId())
.withStsHeaderOverrides(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsHeaderOverrides())
.build());
final Aws4Signer aws4Signer = Aws4Signer.create();
final HttpRequestInterceptor httpRequestInterceptor = new AwsRequestSigningApache4Interceptor(AOS_SERVICE_NAME, aws4Signer,
awsCredentialsProvider, openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion());
restClientBuilder.setHttpClientConfigCallback(httpClientBuilder -> {
httpClientBuilder.addInterceptorLast(httpRequestInterceptor);
attachSSLContext(httpClientBuilder, openSearchSourceConfiguration);
httpClientBuilder.addInterceptorLast(
(HttpResponseInterceptor)
(response, context) ->
response.addHeader("X-Elastic-Product", "Elasticsearch"));
return httpClientBuilder;
});
}

private void attachBasicAuth(final RestClientBuilder restClientBuilder, final OpenSearchSourceConfiguration openSearchSourceConfiguration) {

restClientBuilder.setHttpClientConfigCallback(httpClientBuilder -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,33 @@ void provideElasticSearchClient_with_username_and_password() {
verifyNoInteractions(awsCredentialsSupplier);
}

@Test
void provideElasticSearchClient_with_aws_auth() {
when(connectionConfiguration.getCertPath()).thenReturn(null);
when(connectionConfiguration.getSocketTimeout()).thenReturn(null);
when(connectionConfiguration.getConnectTimeout()).thenReturn(null);

final AwsAuthenticationConfiguration awsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class);
when(awsAuthenticationConfiguration.getAwsRegion()).thenReturn(Region.US_EAST_1);
final String stsRoleArn = "arn:aws:iam::123456789012:role/my-role";
when(awsAuthenticationConfiguration.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationConfiguration.getAwsStsHeaderOverrides()).thenReturn(Collections.emptyMap());
when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration);

final ArgumentCaptor<AwsCredentialsOptions> awsCredentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class);
final AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class);
when(awsCredentialsSupplier.getProvider(awsCredentialsOptionsArgumentCaptor.capture())).thenReturn(awsCredentialsProvider);

final ElasticsearchClient elasticsearchClient = createObjectUnderTest().provideElasticSearchClient(openSearchSourceConfiguration);
assertThat(elasticsearchClient, notNullValue());

final AwsCredentialsOptions awsCredentialsOptions = awsCredentialsOptionsArgumentCaptor.getValue();
assertThat(awsCredentialsOptions, notNullValue());
assertThat(awsCredentialsOptions.getRegion(), equalTo(Region.US_EAST_1));
assertThat(awsCredentialsOptions.getStsHeaderOverrides(), equalTo(Collections.emptyMap()));
assertThat(awsCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn));
}

@Test
void provideOpenSearchClient_with_aws_auth() {
when(connectionConfiguration.getCertPath()).thenReturn(null);
Expand Down
Loading
Loading