Skip to content

Commit

Permalink
Async Executor Service Depedencies Refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <reddyvam@amazon.com>
  • Loading branch information
vamsi-amazon committed Jan 29, 2024
1 parent 6fcf31b commit 1c6ef2b
Show file tree
Hide file tree
Showing 8 changed files with 281 additions and 121 deletions.
106 changes: 2 additions & 104 deletions plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,10 @@

package org.opensearch.sql.plugin;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;
import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata;
import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.Clock;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -68,7 +61,6 @@
import org.opensearch.sql.datasources.transport.*;
import org.opensearch.sql.legacy.esdomain.LocalClusterState;
import org.opensearch.sql.legacy.executor.AsyncRestExecutor;
import org.opensearch.sql.legacy.metrics.GaugeMetric;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.legacy.plugin.RestSqlAction;
import org.opensearch.sql.legacy.plugin.RestSqlStatsAction;
Expand All @@ -86,22 +78,7 @@
import org.opensearch.sql.plugin.transport.TransportPPLQueryAction;
import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse;
import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EmrServerlessClientImpl;
import org.opensearch.sql.spark.cluster.ClusterManagerEventListener;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl;
import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction;
import org.opensearch.sql.spark.storage.SparkStorageFactory;
import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction;
Expand All @@ -127,9 +104,7 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin {

private NodeClient client;
private DataSourceServiceImpl dataSourceService;
private AsyncQueryExecutorService asyncQueryExecutorService;
private Injector injector;

public String name() {
return "sql";
}
Expand Down Expand Up @@ -192,15 +167,7 @@ public List<RestHandler> getRestHandlers(
new ActionHandler<>(
new ActionType<>(
TransportCreateAsyncQueryRequestAction.NAME, CreateAsyncQueryActionResponse::new),
TransportCreateAsyncQueryRequestAction.class),
new ActionHandler<>(
new ActionType<>(
TransportGetAsyncQueryResultAction.NAME, GetAsyncQueryResultActionResponse::new),
TransportGetAsyncQueryResultAction.class),
new ActionHandler<>(
new ActionType<>(
TransportCancelAsyncQueryRequestAction.NAME, CancelAsyncQueryActionResponse::new),
TransportCancelAsyncQueryRequestAction.class));
TransportCreateAsyncQueryRequestAction.class));
}

@Override
Expand All @@ -223,23 +190,6 @@ public Collection<Object> createComponents(
dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata());
LocalClusterState.state().setClusterService(clusterService);
LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings);
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier =
new SparkExecutionEngineConfigSupplierImpl(pluginSettings);
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
if (StringUtils.isEmpty(sparkExecutionEngineConfig.getRegion())) {
LOGGER.warn(
String.format(
"Async Query APIs are disabled as %s is not configured properly in cluster settings. "
+ "Please configure and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl();
} else {
this.asyncQueryExecutorService =
createAsyncQueryExecutorService(
sparkExecutionEngineConfigSupplier, sparkExecutionEngineConfig);
}

ModulesBuilder modules = new ModulesBuilder();
modules.add(new OpenSearchPluginModule());
modules.add(
Expand All @@ -261,7 +211,7 @@ public Collection<Object> createComponents(
OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING,
environment.settings());
return ImmutableList.of(
dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings);
dataSourceService, clusterManagerEventListener, pluginSettings);
}

@Override
Expand Down Expand Up @@ -319,56 +269,4 @@ private DataSourceServiceImpl createDataSourceService() {
dataSourceUserAuthorizationHelper);
}

private AsyncQueryExecutorService createAsyncQueryExecutorService(
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier,
SparkExecutionEngineConfig sparkExecutionEngineConfig) {
StateStore stateStore = new StateStore(client, clusterService);
registerStateStoreMetrics(stateStore);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
EMRServerlessClient emrServerlessClient =
createEMRServerlessClient(sparkExecutionEngineConfig.getRegion());
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClient,
this.dataSourceService,
new DataSourceUserAuthorizationHelperImpl(client),
jobExecutionResponseReader,
new FlintIndexMetadataReaderImpl(client),
client,
new SessionManager(stateStore, emrServerlessClient, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore);
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
sparkQueryDispatcher,
sparkExecutionEngineConfigSupplier);
}

private void registerStateStoreMetrics(StateStore stateStore) {
GaugeMetric<Long> activeSessionMetric =
new GaugeMetric<>(
"active_async_query_sessions_count",
StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE));
GaugeMetric<Long> activeStatementMetric =
new GaugeMetric<>(
"active_async_query_statements_count",
StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE));
Metrics.getInstance().registerMetric(activeSessionMetric);
Metrics.getInstance().registerMetric(activeStatementMetric);
}

private EMRServerlessClient createEMRServerlessClient(String region) {
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(region)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,18 @@
import org.opensearch.action.ActionType;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.inject.Injector;
import org.opensearch.common.inject.ModulesBuilder;
import org.opensearch.core.action.ActionListener;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasources.service.DataSourceServiceImpl;
import org.opensearch.sql.opensearch.security.SecurityAccess;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService;
import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule;
import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest;
import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse;
import org.opensearch.tasks.Task;
Expand All @@ -22,17 +31,28 @@ public class TransportCancelAsyncQueryRequestAction
extends HandledTransportAction<CancelAsyncQueryActionRequest, CancelAsyncQueryActionResponse> {

public static final String NAME = "cluster:admin/opensearch/ql/async_query/delete";
private final AsyncQueryExecutorServiceImpl asyncQueryExecutorService;
private final Injector injector;
public static final ActionType<CancelAsyncQueryActionResponse> ACTION_TYPE =
new ActionType<>(NAME, CancelAsyncQueryActionResponse::new);

@Inject
public TransportCancelAsyncQueryRequestAction(
TransportService transportService,
ActionFilters actionFilters,
AsyncQueryExecutorServiceImpl asyncQueryExecutorService) {
NodeClient client,
ClusterService clusterService,
DataSourceServiceImpl dataSourceService) {
super(NAME, transportService, actionFilters, CancelAsyncQueryActionRequest::new);
this.asyncQueryExecutorService = asyncQueryExecutorService;
ModulesBuilder modules = new ModulesBuilder();
modules.add(
b -> {
b.bind(NodeClient.class).toInstance(client);
b.bind(org.opensearch.sql.common.setting.Settings.class)
.toInstance(new OpenSearchSettings(clusterService.getClusterSettings()));
b.bind(DataSourceService.class).toInstance(dataSourceService);
});
modules.add(new AsyncExecutorServiceModule());
this.injector = modules.createInjector();
}

@Override
Expand All @@ -41,6 +61,8 @@ protected void doExecute(
CancelAsyncQueryActionRequest request,
ActionListener<CancelAsyncQueryActionResponse> listener) {
try {
AsyncQueryExecutorService asyncQueryExecutorService =
SecurityAccess.doPrivileged(() -> injector.getInstance(AsyncQueryExecutorService.class));
String jobId = asyncQueryExecutorService.cancelQuery(request.getQueryId());
listener.onResponse(
new CancelAsyncQueryActionResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,35 @@

package org.opensearch.sql.spark.transport;

import java.util.ArrayList;
import java.util.List;
import org.opensearch.action.ActionType;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Guice;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.inject.Injector;
import org.opensearch.common.inject.Module;
import org.opensearch.core.action.ActionListener;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasources.service.DataSourceServiceImpl;
import org.opensearch.sql.opensearch.security.SecurityAccess;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.sql.protocol.response.format.JsonResponseFormatter;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse;
import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule;
import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest;
import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class TransportCreateAsyncQueryRequestAction
extends HandledTransportAction<CreateAsyncQueryActionRequest, CreateAsyncQueryActionResponse> {

private final AsyncQueryExecutorService asyncQueryExecutorService;
private final Injector injector;

public static final String NAME = "cluster:admin/opensearch/ql/async_query/create";
public static final ActionType<CreateAsyncQueryActionResponse> ACTION_TYPE =
Expand All @@ -35,9 +45,21 @@ public class TransportCreateAsyncQueryRequestAction
public TransportCreateAsyncQueryRequestAction(
TransportService transportService,
ActionFilters actionFilters,
AsyncQueryExecutorServiceImpl jobManagementService) {
NodeClient client,
ClusterService clusterService,
DataSourceServiceImpl dataSourceService) {
super(NAME, transportService, actionFilters, CreateAsyncQueryActionRequest::new);
this.asyncQueryExecutorService = jobManagementService;
List<Module> moduleList = new ArrayList<>();
moduleList.add(new AsyncExecutorServiceModule());
moduleList.add(
b -> {
b.bind(NodeClient.class).toInstance(client);
b.bind(org.opensearch.sql.common.setting.Settings.class)
.toInstance(new OpenSearchSettings(clusterService.getClusterSettings()));
b.bind(DataSourceService.class).toInstance(dataSourceService);
b.bind(ClusterService.class).toInstance(clusterService);
});
this.injector = Guice.createInjector(moduleList);
}

@Override
Expand All @@ -47,6 +69,8 @@ protected void doExecute(
ActionListener<CreateAsyncQueryActionResponse> listener) {
try {
CreateAsyncQueryRequest createAsyncQueryRequest = request.getCreateAsyncQueryRequest();
AsyncQueryExecutorService asyncQueryExecutorService =
SecurityAccess.doPrivileged(() -> injector.getInstance(AsyncQueryExecutorService.class));
CreateAsyncQueryResponse createAsyncQueryResponse =
asyncQueryExecutorService.createAsyncQuery(createAsyncQueryRequest);
String responseContent =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,23 @@
import org.opensearch.action.ActionType;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.inject.Injector;
import org.opensearch.common.inject.ModulesBuilder;
import org.opensearch.core.action.ActionListener;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasources.service.DataSourceServiceImpl;
import org.opensearch.sql.executor.pagination.Cursor;
import org.opensearch.sql.opensearch.security.SecurityAccess;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.sql.protocol.response.format.JsonResponseFormatter;
import org.opensearch.sql.protocol.response.format.ResponseFormatter;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult;
import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule;
import org.opensearch.sql.spark.transport.format.AsyncQueryResultResponseFormatter;
import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest;
import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse;
Expand All @@ -29,7 +37,7 @@ public class TransportGetAsyncQueryResultAction
extends HandledTransportAction<
GetAsyncQueryResultActionRequest, GetAsyncQueryResultActionResponse> {

private final AsyncQueryExecutorService asyncQueryExecutorService;
private final Injector injector;

public static final String NAME = "cluster:admin/opensearch/ql/async_query/result";
public static final ActionType<GetAsyncQueryResultActionResponse> ACTION_TYPE =
Expand All @@ -39,9 +47,20 @@ public class TransportGetAsyncQueryResultAction
public TransportGetAsyncQueryResultAction(
TransportService transportService,
ActionFilters actionFilters,
AsyncQueryExecutorServiceImpl jobManagementService) {
NodeClient client,
ClusterService clusterService,
DataSourceServiceImpl dataSourceService) {
super(NAME, transportService, actionFilters, GetAsyncQueryResultActionRequest::new);
this.asyncQueryExecutorService = jobManagementService;
ModulesBuilder modules = new ModulesBuilder();
modules.add(
b -> {
b.bind(NodeClient.class).toInstance(client);
b.bind(org.opensearch.sql.common.setting.Settings.class)
.toInstance(new OpenSearchSettings(clusterService.getClusterSettings()));
b.bind(DataSourceService.class).toInstance(dataSourceService);
});
modules.add(new AsyncExecutorServiceModule());
this.injector = modules.createInjector();
}

@Override
Expand All @@ -51,6 +70,8 @@ protected void doExecute(
ActionListener<GetAsyncQueryResultActionResponse> listener) {
try {
String jobId = request.getQueryId();
AsyncQueryExecutorService asyncQueryExecutorService =
SecurityAccess.doPrivileged(() -> injector.getInstance(AsyncQueryExecutorService.class));
AsyncQueryExecutionResponse asyncQueryExecutionResponse =
asyncQueryExecutorService.getAsyncQueryResults(jobId);
ResponseFormatter<AsyncQueryResult> formatter =
Expand Down
Loading

0 comments on commit 1c6ef2b

Please sign in to comment.