Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Async Executor Service Depedencies Refactor #2497

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.ThreadContext;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
Expand All @@ -24,6 +27,7 @@
import org.opensearch.sql.common.utils.QueryContext;
import org.opensearch.sql.legacy.executor.format.ErrorMessageFactory;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.threadpool.ThreadPool;

/**
* Currently this interface is for node level. Cluster level is coming up soon.
Expand Down Expand Up @@ -69,8 +73,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

try {
return channel ->
channel.sendResponse(
new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON()));
schedule(
client,
() ->
channel.sendResponse(
new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON())));
} catch (Exception e) {
LOG.error("Failed during Query SQL STATS Action.", e);

Expand All @@ -91,4 +98,17 @@ protected Set<String> responseParams() {
"sql", "flat", "separator", "_score", "_type", "_id", "newLine", "format", "sanitize"));
return responseParams;
}

private void schedule(NodeClient client, Runnable task) {
ThreadPool threadPool = client.threadPool();
threadPool.schedule(withCurrentContext(task), new TimeValue(0), "sql-worker");
}

private Runnable withCurrentContext(final Runnable task) {
final Map<String, String> currentContext = ThreadContext.getImmutableContext();
return () -> {
ThreadContext.putAll(currentContext);
task.run();
};
}
}
106 changes: 9 additions & 97 deletions plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,14 @@

package org.opensearch.sql.plugin;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;
import static java.util.Collections.singletonList;
import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata;
import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.Clock;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
Expand Down Expand Up @@ -68,7 +61,6 @@
import org.opensearch.sql.datasources.transport.*;
import org.opensearch.sql.legacy.esdomain.LocalClusterState;
import org.opensearch.sql.legacy.executor.AsyncRestExecutor;
import org.opensearch.sql.legacy.metrics.GaugeMetric;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.legacy.plugin.RestSqlAction;
import org.opensearch.sql.legacy.plugin.RestSqlStatsAction;
Expand All @@ -87,26 +79,13 @@
import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse;
import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EmrServerlessClientImpl;
import org.opensearch.sql.spark.cluster.ClusterManagerEventListener;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl;
import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction;
import org.opensearch.sql.spark.storage.SparkStorageFactory;
import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction;
import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction;
import org.opensearch.sql.spark.transport.TransportGetAsyncQueryResultAction;
import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule;
import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse;
import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse;
import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse;
Expand All @@ -127,7 +106,6 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin {

private NodeClient client;
private DataSourceServiceImpl dataSourceService;
private AsyncQueryExecutorService asyncQueryExecutorService;
private Injector injector;

public String name() {
Expand Down Expand Up @@ -223,32 +201,16 @@ public Collection<Object> createComponents(
dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata());
LocalClusterState.state().setClusterService(clusterService);
LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings);
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier =
new SparkExecutionEngineConfigSupplierImpl(pluginSettings);
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
if (StringUtils.isEmpty(sparkExecutionEngineConfig.getRegion())) {
LOGGER.warn(
String.format(
"Async Query APIs are disabled as %s is not configured properly in cluster settings. "
+ "Please configure and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl();
} else {
this.asyncQueryExecutorService =
createAsyncQueryExecutorService(
sparkExecutionEngineConfigSupplier, sparkExecutionEngineConfig);
}

ModulesBuilder modules = new ModulesBuilder();
modules.add(new OpenSearchPluginModule());
modules.add(
b -> {
b.bind(NodeClient.class).toInstance((NodeClient) client);
b.bind(org.opensearch.sql.common.setting.Settings.class).toInstance(pluginSettings);
b.bind(DataSourceService.class).toInstance(dataSourceService);
b.bind(ClusterService.class).toInstance(clusterService);
});

modules.add(new AsyncExecutorServiceModule());
injector = modules.createInjector();
ClusterManagerEventListener clusterManagerEventListener =
new ClusterManagerEventListener(
Expand All @@ -261,12 +223,15 @@ public Collection<Object> createComponents(
OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING,
environment.settings());
return ImmutableList.of(
dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings);
dataSourceService,
injector.getInstance(AsyncQueryExecutorService.class),
clusterManagerEventListener,
pluginSettings);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return Collections.singletonList(
return singletonList(
new FixedExecutorBuilder(
settings,
AsyncRestExecutor.SQL_WORKER_THREAD_POOL_NAME,
Expand Down Expand Up @@ -318,57 +283,4 @@ private DataSourceServiceImpl createDataSourceService() {
dataSourceMetadataStorage,
dataSourceUserAuthorizationHelper);
}

private AsyncQueryExecutorService createAsyncQueryExecutorService(
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier,
SparkExecutionEngineConfig sparkExecutionEngineConfig) {
StateStore stateStore = new StateStore(client, clusterService);
registerStateStoreMetrics(stateStore);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
EMRServerlessClient emrServerlessClient =
createEMRServerlessClient(sparkExecutionEngineConfig.getRegion());
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClient,
this.dataSourceService,
new DataSourceUserAuthorizationHelperImpl(client),
jobExecutionResponseReader,
new FlintIndexMetadataReaderImpl(client),
client,
new SessionManager(stateStore, emrServerlessClient, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore);
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
sparkQueryDispatcher,
sparkExecutionEngineConfigSupplier);
}

private void registerStateStoreMetrics(StateStore stateStore) {
GaugeMetric<Long> activeSessionMetric =
new GaugeMetric<>(
"active_async_query_sessions_count",
StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE));
GaugeMetric<Long> activeStatementMetric =
new GaugeMetric<>(
"active_async_query_statements_count",
StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE));
Metrics.getInstance().registerMetric(activeSessionMetric);
Metrics.getInstance().registerMetric(activeStatementMetric);
}

private EMRServerlessClient createEMRServerlessClient(String region) {
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(region)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.rest.RestController;
import org.opensearch.rest.RestRequest;
import org.opensearch.sql.common.utils.QueryContext;
import org.opensearch.sql.datasources.utils.Scheduler;
import org.opensearch.sql.legacy.executor.format.ErrorMessageFactory;
import org.opensearch.sql.legacy.metrics.Metrics;

Expand Down Expand Up @@ -67,8 +68,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

try {
return channel ->
channel.sendResponse(
new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON()));
Scheduler.schedule(
client,
() ->
channel.sendResponse(
new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON())));
} catch (Exception e) {
LOG.error("Failed during Query PPL STATS Action.", e);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.sql.spark.asyncquery;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;
import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD;

Expand Down Expand Up @@ -34,26 +33,10 @@ public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService
private AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService;
private SparkQueryDispatcher sparkQueryDispatcher;
private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier;
private Boolean isSparkJobExecutionEnabled;

public AsyncQueryExecutorServiceImpl() {
this.isSparkJobExecutionEnabled = Boolean.FALSE;
}

public AsyncQueryExecutorServiceImpl(
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService,
SparkQueryDispatcher sparkQueryDispatcher,
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) {
this.isSparkJobExecutionEnabled = Boolean.TRUE;
this.asyncQueryJobMetadataStorageService = asyncQueryJobMetadataStorageService;
this.sparkQueryDispatcher = sparkQueryDispatcher;
this.sparkExecutionEngineConfigSupplier = sparkExecutionEngineConfigSupplier;
}

@Override
public CreateAsyncQueryResponse createAsyncQuery(
CreateAsyncQueryRequest createAsyncQueryRequest) {
validateSparkExecutionEngineSettings();
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -80,7 +63,6 @@ public CreateAsyncQueryResponse createAsyncQuery(

@Override
public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) {
validateSparkExecutionEngineSettings();
Optional<AsyncQueryJobMetadata> jobMetadata =
asyncQueryJobMetadataStorageService.getJobMetadata(queryId);
if (jobMetadata.isPresent()) {
Expand Down Expand Up @@ -120,14 +102,4 @@ public String cancelQuery(String queryId) {
}
throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId));
}

private void validateSparkExecutionEngineSettings() {
if (!isSparkJobExecutionEnabled) {
throw new IllegalArgumentException(
String.format(
"Async Query APIs are disabled as %s is not configured in cluster settings. Please"
+ " configure the setting and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.client;

/** Factory interface for creating instances of {@link EMRServerlessClient}. */
public interface EMRServerlessClientFactory {

/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @return An {@link EMRServerlessClient} instance.
*/
EMRServerlessClient getClient();
}
Loading
Loading