Skip to content

Commit

Permalink
Refactor async executor service dependencies using guice framework
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <reddyvam@amazon.com>
  • Loading branch information
vamsi-amazon committed Feb 1, 2024
1 parent e59bf75 commit 3d17b63
Show file tree
Hide file tree
Showing 21 changed files with 599 additions and 214 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.ThreadContext;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
Expand All @@ -24,6 +27,7 @@
import org.opensearch.sql.common.utils.QueryContext;
import org.opensearch.sql.legacy.executor.format.ErrorMessageFactory;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.threadpool.ThreadPool;

/**
* Currently this interface is for node level. Cluster level is coming up soon.
Expand Down Expand Up @@ -69,8 +73,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

try {
return channel ->
channel.sendResponse(
new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON()));
schedule(
client,
() ->
channel.sendResponse(
new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON())));
} catch (Exception e) {
LOG.error("Failed during Query SQL STATS Action.", e);

Expand All @@ -91,4 +98,17 @@ protected Set<String> responseParams() {
"sql", "flat", "separator", "_score", "_type", "_id", "newLine", "format", "sanitize"));
return responseParams;
}

private void schedule(NodeClient client, Runnable task) {
ThreadPool threadPool = client.threadPool();
threadPool.schedule(withCurrentContext(task), new TimeValue(0), "sql-worker");
}

private Runnable withCurrentContext(final Runnable task) {
final Map<String, String> currentContext = ThreadContext.getImmutableContext();
return () -> {
ThreadContext.putAll(currentContext);
task.run();
};
}
}
106 changes: 9 additions & 97 deletions plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,14 @@

package org.opensearch.sql.plugin;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;
import static java.util.Collections.singletonList;
import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata;
import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.Clock;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
Expand Down Expand Up @@ -68,7 +61,6 @@
import org.opensearch.sql.datasources.transport.*;
import org.opensearch.sql.legacy.esdomain.LocalClusterState;
import org.opensearch.sql.legacy.executor.AsyncRestExecutor;
import org.opensearch.sql.legacy.metrics.GaugeMetric;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.legacy.plugin.RestSqlAction;
import org.opensearch.sql.legacy.plugin.RestSqlStatsAction;
Expand All @@ -87,26 +79,13 @@
import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse;
import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EmrServerlessClientImpl;
import org.opensearch.sql.spark.cluster.ClusterManagerEventListener;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl;
import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction;
import org.opensearch.sql.spark.storage.SparkStorageFactory;
import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction;
import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction;
import org.opensearch.sql.spark.transport.TransportGetAsyncQueryResultAction;
import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule;
import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse;
import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse;
import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse;
Expand All @@ -127,7 +106,6 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin {

private NodeClient client;
private DataSourceServiceImpl dataSourceService;
private AsyncQueryExecutorService asyncQueryExecutorService;
private Injector injector;

public String name() {
Expand Down Expand Up @@ -223,32 +201,16 @@ public Collection<Object> createComponents(
dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata());
LocalClusterState.state().setClusterService(clusterService);
LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings);
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier =
new SparkExecutionEngineConfigSupplierImpl(pluginSettings);
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
if (StringUtils.isEmpty(sparkExecutionEngineConfig.getRegion())) {
LOGGER.warn(
String.format(
"Async Query APIs are disabled as %s is not configured properly in cluster settings. "
+ "Please configure and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl();
} else {
this.asyncQueryExecutorService =
createAsyncQueryExecutorService(
sparkExecutionEngineConfigSupplier, sparkExecutionEngineConfig);
}

ModulesBuilder modules = new ModulesBuilder();
modules.add(new OpenSearchPluginModule());
modules.add(
b -> {
b.bind(NodeClient.class).toInstance((NodeClient) client);
b.bind(org.opensearch.sql.common.setting.Settings.class).toInstance(pluginSettings);
b.bind(DataSourceService.class).toInstance(dataSourceService);
b.bind(ClusterService.class).toInstance(clusterService);
});

modules.add(new AsyncExecutorServiceModule());
injector = modules.createInjector();
ClusterManagerEventListener clusterManagerEventListener =
new ClusterManagerEventListener(
Expand All @@ -261,12 +223,15 @@ public Collection<Object> createComponents(
OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING,
environment.settings());
return ImmutableList.of(
dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings);
dataSourceService,
injector.getInstance(AsyncQueryExecutorService.class),
clusterManagerEventListener,
pluginSettings);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return Collections.singletonList(
return singletonList(
new FixedExecutorBuilder(
settings,
AsyncRestExecutor.SQL_WORKER_THREAD_POOL_NAME,
Expand Down Expand Up @@ -318,57 +283,4 @@ private DataSourceServiceImpl createDataSourceService() {
dataSourceMetadataStorage,
dataSourceUserAuthorizationHelper);
}

private AsyncQueryExecutorService createAsyncQueryExecutorService(
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier,
SparkExecutionEngineConfig sparkExecutionEngineConfig) {
StateStore stateStore = new StateStore(client, clusterService);
registerStateStoreMetrics(stateStore);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
EMRServerlessClient emrServerlessClient =
createEMRServerlessClient(sparkExecutionEngineConfig.getRegion());
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClient,
this.dataSourceService,
new DataSourceUserAuthorizationHelperImpl(client),
jobExecutionResponseReader,
new FlintIndexMetadataReaderImpl(client),
client,
new SessionManager(stateStore, emrServerlessClient, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore);
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
sparkQueryDispatcher,
sparkExecutionEngineConfigSupplier);
}

private void registerStateStoreMetrics(StateStore stateStore) {
GaugeMetric<Long> activeSessionMetric =
new GaugeMetric<>(
"active_async_query_sessions_count",
StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE));
GaugeMetric<Long> activeStatementMetric =
new GaugeMetric<>(
"active_async_query_statements_count",
StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE));
Metrics.getInstance().registerMetric(activeSessionMetric);
Metrics.getInstance().registerMetric(activeStatementMetric);
}

private EMRServerlessClient createEMRServerlessClient(String region) {
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(region)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.rest.RestController;
import org.opensearch.rest.RestRequest;
import org.opensearch.sql.common.utils.QueryContext;
import org.opensearch.sql.datasources.utils.Scheduler;
import org.opensearch.sql.legacy.executor.format.ErrorMessageFactory;
import org.opensearch.sql.legacy.metrics.Metrics;

Expand Down Expand Up @@ -67,8 +68,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

try {
return channel ->
channel.sendResponse(
new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON()));
Scheduler.schedule(
client,
() ->
channel.sendResponse(
new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON())));
} catch (Exception e) {
LOG.error("Failed during Query PPL STATS Action.", e);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.sql.spark.asyncquery;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;
import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD;

Expand Down Expand Up @@ -34,26 +33,10 @@ public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService
private AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService;
private SparkQueryDispatcher sparkQueryDispatcher;
private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier;
private Boolean isSparkJobExecutionEnabled;

public AsyncQueryExecutorServiceImpl() {
this.isSparkJobExecutionEnabled = Boolean.FALSE;
}

public AsyncQueryExecutorServiceImpl(
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService,
SparkQueryDispatcher sparkQueryDispatcher,
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) {
this.isSparkJobExecutionEnabled = Boolean.TRUE;
this.asyncQueryJobMetadataStorageService = asyncQueryJobMetadataStorageService;
this.sparkQueryDispatcher = sparkQueryDispatcher;
this.sparkExecutionEngineConfigSupplier = sparkExecutionEngineConfigSupplier;
}

@Override
public CreateAsyncQueryResponse createAsyncQuery(
CreateAsyncQueryRequest createAsyncQueryRequest) {
validateSparkExecutionEngineSettings();
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -80,7 +63,6 @@ public CreateAsyncQueryResponse createAsyncQuery(

@Override
public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) {
validateSparkExecutionEngineSettings();
Optional<AsyncQueryJobMetadata> jobMetadata =
asyncQueryJobMetadataStorageService.getJobMetadata(queryId);
if (jobMetadata.isPresent()) {
Expand Down Expand Up @@ -120,14 +102,4 @@ public String cancelQuery(String queryId) {
}
throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId));
}

private void validateSparkExecutionEngineSettings() {
if (!isSparkJobExecutionEnabled) {
throw new IllegalArgumentException(
String.format(
"Async Query APIs are disabled as %s is not configured in cluster settings. Please"
+ " configure the setting and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.opensearch.sql.spark.client;

/** Factory interface for creating instances of {@link EMRServerlessClient}. */
public interface EMRServerlessClientFactory {

/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @return An {@link EMRServerlessClient} instance.
*/
EMRServerlessClient getClient();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package org.opensearch.sql.spark.client;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder;
import java.security.AccessController;
import java.security.PrivilegedAction;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;

/** Implementation of {@link EMRServerlessClientFactory}. */
@RequiredArgsConstructor
public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactory {

private final SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier;
private EMRServerlessClient emrServerlessClient;
private String region;

/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @return An {@link EMRServerlessClient} instance.
*/
@Override
public EMRServerlessClient getClient() {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
validateSparkExecutionEngineConfig(sparkExecutionEngineConfig);
if (isNewClientCreationRequired(sparkExecutionEngineConfig.getRegion())) {
region = sparkExecutionEngineConfig.getRegion();
this.emrServerlessClient = createEMRServerlessClient(this.region);
}
return this.emrServerlessClient;
}

private boolean isNewClientCreationRequired(String region) {
return !region.equals(this.region);
}

private void validateSparkExecutionEngineConfig(
SparkExecutionEngineConfig sparkExecutionEngineConfig) {
if (sparkExecutionEngineConfig == null || sparkExecutionEngineConfig.getRegion() == null) {
throw new IllegalArgumentException(
String.format(
"Async Query APIs are disabled. Please configure %s in cluster settings to enable"
+ " them.",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
}
}

private EMRServerlessClient createEMRServerlessClient(String awsRegion) {
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(awsRegion)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless);
});
}
}
Loading

0 comments on commit 3d17b63

Please sign in to comment.