Skip to content

Commit

Permalink
Async Executor Service Depedencies Refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <reddyvam@amazon.com>
  • Loading branch information
vamsi-amazon committed Jan 30, 2024
1 parent 6fcf31b commit d26b157
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 127 deletions.
116 changes: 5 additions & 111 deletions plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,11 @@

package org.opensearch.sql.plugin;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;
import static java.util.Collections.singletonList;
import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata;
import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.Clock;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -68,7 +62,6 @@
import org.opensearch.sql.datasources.transport.*;
import org.opensearch.sql.legacy.esdomain.LocalClusterState;
import org.opensearch.sql.legacy.executor.AsyncRestExecutor;
import org.opensearch.sql.legacy.metrics.GaugeMetric;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.legacy.plugin.RestSqlAction;
import org.opensearch.sql.legacy.plugin.RestSqlStatsAction;
Expand All @@ -86,30 +79,11 @@
import org.opensearch.sql.plugin.transport.TransportPPLQueryAction;
import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse;
import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EmrServerlessClientImpl;
import org.opensearch.sql.spark.cluster.ClusterManagerEventListener;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl;
import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction;
import org.opensearch.sql.spark.storage.SparkStorageFactory;
import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction;
import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction;
import org.opensearch.sql.spark.transport.TransportGetAsyncQueryResultAction;
import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse;
import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse;
import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse;
import org.opensearch.sql.storage.DataSourceFactory;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
Expand All @@ -127,9 +101,8 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin {

private NodeClient client;
private DataSourceServiceImpl dataSourceService;
private AsyncQueryExecutorService asyncQueryExecutorService;
private Injector injector;

public String name() {
return "sql";
}
Expand Down Expand Up @@ -192,15 +165,7 @@ public List<RestHandler> getRestHandlers(
new ActionHandler<>(
new ActionType<>(
TransportCreateAsyncQueryRequestAction.NAME, CreateAsyncQueryActionResponse::new),
TransportCreateAsyncQueryRequestAction.class),
new ActionHandler<>(
new ActionType<>(
TransportGetAsyncQueryResultAction.NAME, GetAsyncQueryResultActionResponse::new),
TransportGetAsyncQueryResultAction.class),
new ActionHandler<>(
new ActionType<>(
TransportCancelAsyncQueryRequestAction.NAME, CancelAsyncQueryActionResponse::new),
TransportCancelAsyncQueryRequestAction.class));
TransportCreateAsyncQueryRequestAction.class));
}

@Override
Expand All @@ -223,23 +188,6 @@ public Collection<Object> createComponents(
dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata());
LocalClusterState.state().setClusterService(clusterService);
LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings);
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier =
new SparkExecutionEngineConfigSupplierImpl(pluginSettings);
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
if (StringUtils.isEmpty(sparkExecutionEngineConfig.getRegion())) {
LOGGER.warn(
String.format(
"Async Query APIs are disabled as %s is not configured properly in cluster settings. "
+ "Please configure and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl();
} else {
this.asyncQueryExecutorService =
createAsyncQueryExecutorService(
sparkExecutionEngineConfigSupplier, sparkExecutionEngineConfig);
}

ModulesBuilder modules = new ModulesBuilder();
modules.add(new OpenSearchPluginModule());
modules.add(
Expand All @@ -260,13 +208,12 @@ public Collection<Object> createComponents(
OpenSearchSettings.RESULT_INDEX_TTL_SETTING,
OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING,
environment.settings());
return ImmutableList.of(
dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings);
return ImmutableList.of(dataSourceService, clusterManagerEventListener, pluginSettings);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return Collections.singletonList(
return singletonList(
new FixedExecutorBuilder(
settings,
AsyncRestExecutor.SQL_WORKER_THREAD_POOL_NAME,
Expand Down Expand Up @@ -318,57 +265,4 @@ private DataSourceServiceImpl createDataSourceService() {
dataSourceMetadataStorage,
dataSourceUserAuthorizationHelper);
}

private AsyncQueryExecutorService createAsyncQueryExecutorService(
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier,
SparkExecutionEngineConfig sparkExecutionEngineConfig) {
StateStore stateStore = new StateStore(client, clusterService);
registerStateStoreMetrics(stateStore);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
EMRServerlessClient emrServerlessClient =
createEMRServerlessClient(sparkExecutionEngineConfig.getRegion());
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClient,
this.dataSourceService,
new DataSourceUserAuthorizationHelperImpl(client),
jobExecutionResponseReader,
new FlintIndexMetadataReaderImpl(client),
client,
new SessionManager(stateStore, emrServerlessClient, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore);
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
sparkQueryDispatcher,
sparkExecutionEngineConfigSupplier);
}

private void registerStateStoreMetrics(StateStore stateStore) {
GaugeMetric<Long> activeSessionMetric =
new GaugeMetric<>(
"active_async_query_sessions_count",
StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE));
GaugeMetric<Long> activeStatementMetric =
new GaugeMetric<>(
"active_async_query_statements_count",
StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE));
Metrics.getInstance().registerMetric(activeSessionMetric);
Metrics.getInstance().registerMetric(activeStatementMetric);
}

private EMRServerlessClient createEMRServerlessClient(String region) {
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(region)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand Down Expand Up @@ -59,6 +60,19 @@ public class SparkQueryDispatcher {

private StateStore stateStore;

private SparkQueryDispatcher(SparkQueryDispatcherBuilder builder) {
this.emrServerlessClient = builder.emrServerlessClient;
this.dataSourceService = builder.dataSourceService;
this.dataSourceUserAuthorizationHelper = builder.dataSourceUserAuthorizationHelper;
this.jobExecutionResponseReader = builder.jobExecutionResponseReader;
this.flintIndexMetadataReader = builder.flintIndexMetadataReader;
this.client = builder.client;
this.sessionManager = builder.sessionManager;
this.leaseManager = builder.leaseManager;
this.stateStore = builder.stateStore;
}


public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) {
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource());
Expand Down Expand Up @@ -98,6 +112,10 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest)
return asyncQueryHandler.submit(dispatchQueryRequest, contextBuilder.build());
}

private void createEmrServerlessClient(SparkExecutionEngineConfig sparkExecutionEngineConfig) {

}

public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) {
if (asyncQueryJobMetadata.getSessionId() != null) {
return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager)
Expand Down Expand Up @@ -156,4 +174,75 @@ private static Map<String, String> getDefaultTagsForJobSubmission(
tags.put(DATASOURCE_TAG_KEY, dispatchQueryRequest.getDatasource());
return tags;
}


class SparkQueryDispatcherBuilder {

private EMRServerlessClient emrServerlessClient; // Optional
private DataSourceService dataSourceService;
private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper;
private JobExecutionResponseReader jobExecutionResponseReader;
private FlintIndexMetadataReader flintIndexMetadataReader;
private Client client;
private SessionManager sessionManager;
private LeaseManager leaseManager;
private StateStore stateStore;

public SparkQueryDispatcherBuilder() {
}

public SparkQueryDispatcherBuilder setEMRServerlessClient(
EMRServerlessClient emrServerlessClient) {
this.emrServerlessClient = emrServerlessClient;
return this;
}

public SparkQueryDispatcherBuilder setDataSourceService(DataSourceService dataSourceService) {
this.dataSourceService = dataSourceService;
return this;
}

public SparkQueryDispatcherBuilder setDataSourceUserAuthorizationHelper(
DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper) {
this.dataSourceUserAuthorizationHelper = dataSourceUserAuthorizationHelper;
return this;
}

public SparkQueryDispatcherBuilder setJobExecutionResponseReader(
JobExecutionResponseReader jobExecutionResponseReader) {
this.jobExecutionResponseReader = jobExecutionResponseReader;
return this;
}

public SparkQueryDispatcherBuilder setFlintIndexMetadataReader(
FlintIndexMetadataReader flintIndexMetadataReader) {
this.flintIndexMetadataReader = flintIndexMetadataReader;
return this;
}

public SparkQueryDispatcherBuilder setClient(Client client) {
this.client = client;
return this;
}

public SparkQueryDispatcherBuilder setSessionManager(SessionManager sessionManager) {
this.sessionManager = sessionManager;
return this;
}

public SparkQueryDispatcherBuilder setLeaseManager(LeaseManager leaseManager) {
this.leaseManager = leaseManager;
return this;
}

public SparkQueryDispatcherBuilder setStateStore(StateStore stateStore) {
this.stateStore = stateStore;
return this;
}

public SparkQueryDispatcher build() {
return new SparkQueryDispatcher(this);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,23 @@

package org.opensearch.sql.spark.transport;

import java.util.ArrayList;
import java.util.List;
import org.opensearch.action.ActionType;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Guice;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.inject.Injector;
import org.opensearch.common.inject.Module;
import org.opensearch.core.action.ActionListener;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasources.service.DataSourceServiceImpl;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService;
import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule;
import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest;
import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse;
import org.opensearch.tasks.Task;
Expand All @@ -22,17 +33,28 @@ public class TransportCancelAsyncQueryRequestAction
extends HandledTransportAction<CancelAsyncQueryActionRequest, CancelAsyncQueryActionResponse> {

public static final String NAME = "cluster:admin/opensearch/ql/async_query/delete";
private final AsyncQueryExecutorServiceImpl asyncQueryExecutorService;
private final Injector injector;
public static final ActionType<CancelAsyncQueryActionResponse> ACTION_TYPE =
new ActionType<>(NAME, CancelAsyncQueryActionResponse::new);

@Inject
public TransportCancelAsyncQueryRequestAction(
TransportService transportService,
ActionFilters actionFilters,
AsyncQueryExecutorServiceImpl asyncQueryExecutorService) {
NodeClient client,
ClusterService clusterService,
DataSourceServiceImpl dataSourceService) {
super(NAME, transportService, actionFilters, CancelAsyncQueryActionRequest::new);
this.asyncQueryExecutorService = asyncQueryExecutorService;
List<Module> modules = new ArrayList<>();
modules.add(
b -> {
b.bind(NodeClient.class).toInstance(client);
b.bind(org.opensearch.sql.common.setting.Settings.class)
.toInstance(new OpenSearchSettings(clusterService.getClusterSettings()));
b.bind(DataSourceService.class).toInstance(dataSourceService);
});
modules.add(new AsyncExecutorServiceModule());
this.injector = Guice.createInjector();
}

@Override
Expand All @@ -41,6 +63,8 @@ protected void doExecute(
CancelAsyncQueryActionRequest request,
ActionListener<CancelAsyncQueryActionResponse> listener) {
try {
AsyncQueryExecutorService asyncQueryExecutorService =
injector.getInstance(AsyncQueryExecutorService.class);
String jobId = asyncQueryExecutorService.cancelQuery(request.getQueryId());
listener.onResponse(
new CancelAsyncQueryActionResponse(
Expand Down
Loading

0 comments on commit d26b157

Please sign in to comment.