diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f0689a0966..ef39e2271f 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -5,17 +5,11 @@ package org.opensearch.sql.plugin; -import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; +import static java.util.Collections.singletonList; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; -import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.services.emrserverless.AWSEMRServerless; -import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import java.security.AccessController; -import java.security.PrivilegedAction; import java.time.Clock; import java.util.Arrays; import java.util.Collection; @@ -68,7 +62,6 @@ import org.opensearch.sql.datasources.transport.*; import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.executor.AsyncRestExecutor; -import org.opensearch.sql.legacy.metrics.GaugeMetric; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.legacy.plugin.RestSqlAction; import org.opensearch.sql.legacy.plugin.RestSqlStatsAction; @@ -86,30 +79,11 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; -import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService; -import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService; -import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.client.EmrServerlessClientImpl; import org.opensearch.sql.spark.cluster.ClusterManagerEventListener; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; -import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; -import org.opensearch.sql.spark.execution.session.SessionManager; -import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; -import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; -import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; import org.opensearch.sql.spark.storage.SparkStorageFactory; -import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction; -import org.opensearch.sql.spark.transport.TransportGetAsyncQueryResultAction; -import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse; -import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; @@ -127,9 +101,8 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin { private NodeClient client; private DataSourceServiceImpl dataSourceService; - private AsyncQueryExecutorService asyncQueryExecutorService; private Injector injector; - + public String name() { return "sql"; } @@ -192,15 +165,7 @@ public List getRestHandlers( new ActionHandler<>( new ActionType<>( TransportCreateAsyncQueryRequestAction.NAME, CreateAsyncQueryActionResponse::new), - TransportCreateAsyncQueryRequestAction.class), - new ActionHandler<>( - new ActionType<>( - TransportGetAsyncQueryResultAction.NAME, GetAsyncQueryResultActionResponse::new), - TransportGetAsyncQueryResultAction.class), - new ActionHandler<>( - new ActionType<>( - TransportCancelAsyncQueryRequestAction.NAME, CancelAsyncQueryActionResponse::new), - TransportCancelAsyncQueryRequestAction.class)); + TransportCreateAsyncQueryRequestAction.class)); } @Override @@ -223,23 +188,6 @@ public Collection createComponents( dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); LocalClusterState.state().setClusterService(clusterService); LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); - SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier = - new SparkExecutionEngineConfigSupplierImpl(pluginSettings); - SparkExecutionEngineConfig sparkExecutionEngineConfig = - sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); - if (StringUtils.isEmpty(sparkExecutionEngineConfig.getRegion())) { - LOGGER.warn( - String.format( - "Async Query APIs are disabled as %s is not configured properly in cluster settings. " - + "Please configure and restart the domain to enable Async Query APIs", - SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue())); - this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl(); - } else { - this.asyncQueryExecutorService = - createAsyncQueryExecutorService( - sparkExecutionEngineConfigSupplier, sparkExecutionEngineConfig); - } - ModulesBuilder modules = new ModulesBuilder(); modules.add(new OpenSearchPluginModule()); modules.add( @@ -260,13 +208,12 @@ public Collection createComponents( OpenSearchSettings.RESULT_INDEX_TTL_SETTING, OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING, environment.settings()); - return ImmutableList.of( - dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings); + return ImmutableList.of(dataSourceService, clusterManagerEventListener, pluginSettings); } @Override public List> getExecutorBuilders(Settings settings) { - return Collections.singletonList( + return singletonList( new FixedExecutorBuilder( settings, AsyncRestExecutor.SQL_WORKER_THREAD_POOL_NAME, @@ -318,57 +265,4 @@ private DataSourceServiceImpl createDataSourceService() { dataSourceMetadataStorage, dataSourceUserAuthorizationHelper); } - - private AsyncQueryExecutorService createAsyncQueryExecutorService( - SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier, - SparkExecutionEngineConfig sparkExecutionEngineConfig) { - StateStore stateStore = new StateStore(client, clusterService); - registerStateStoreMetrics(stateStore); - AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(stateStore); - EMRServerlessClient emrServerlessClient = - createEMRServerlessClient(sparkExecutionEngineConfig.getRegion()); - JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - this.dataSourceService, - new DataSourceUserAuthorizationHelperImpl(client), - jobExecutionResponseReader, - new FlintIndexMetadataReaderImpl(client), - client, - new SessionManager(stateStore, emrServerlessClient, pluginSettings), - new DefaultLeaseManager(pluginSettings, stateStore), - stateStore); - return new AsyncQueryExecutorServiceImpl( - asyncQueryJobMetadataStorageService, - sparkQueryDispatcher, - sparkExecutionEngineConfigSupplier); - } - - private void registerStateStoreMetrics(StateStore stateStore) { - GaugeMetric activeSessionMetric = - new GaugeMetric<>( - "active_async_query_sessions_count", - StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE)); - GaugeMetric activeStatementMetric = - new GaugeMetric<>( - "active_async_query_statements_count", - StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE)); - Metrics.getInstance().registerMetric(activeSessionMetric); - Metrics.getInstance().registerMetric(activeStatementMetric); - } - - private EMRServerlessClient createEMRServerlessClient(String region) { - return AccessController.doPrivileged( - (PrivilegedAction) - () -> { - AWSEMRServerless awsemrServerless = - AWSEMRServerlessClientBuilder.standard() - .withRegion(region) - .withCredentials(new DefaultAWSCredentialsProviderChain()) - .build(); - return new EmrServerlessClientImpl(awsemrServerless); - }); - } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 0aa183335e..e3062f6f12 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -18,6 +18,7 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -59,6 +60,19 @@ public class SparkQueryDispatcher { private StateStore stateStore; + private SparkQueryDispatcher(SparkQueryDispatcherBuilder builder) { + this.emrServerlessClient = builder.emrServerlessClient; + this.dataSourceService = builder.dataSourceService; + this.dataSourceUserAuthorizationHelper = builder.dataSourceUserAuthorizationHelper; + this.jobExecutionResponseReader = builder.jobExecutionResponseReader; + this.flintIndexMetadataReader = builder.flintIndexMetadataReader; + this.client = builder.client; + this.sessionManager = builder.sessionManager; + this.leaseManager = builder.leaseManager; + this.stateStore = builder.stateStore; + } + + public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); @@ -98,6 +112,10 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) return asyncQueryHandler.submit(dispatchQueryRequest, contextBuilder.build()); } + private void createEmrServerlessClient(SparkExecutionEngineConfig sparkExecutionEngineConfig) { + + } + public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { if (asyncQueryJobMetadata.getSessionId() != null) { return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager) @@ -156,4 +174,75 @@ private static Map getDefaultTagsForJobSubmission( tags.put(DATASOURCE_TAG_KEY, dispatchQueryRequest.getDatasource()); return tags; } + + + class SparkQueryDispatcherBuilder { + + private EMRServerlessClient emrServerlessClient; // Optional + private DataSourceService dataSourceService; + private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; + private JobExecutionResponseReader jobExecutionResponseReader; + private FlintIndexMetadataReader flintIndexMetadataReader; + private Client client; + private SessionManager sessionManager; + private LeaseManager leaseManager; + private StateStore stateStore; + + public SparkQueryDispatcherBuilder() { + } + + public SparkQueryDispatcherBuilder setEMRServerlessClient( + EMRServerlessClient emrServerlessClient) { + this.emrServerlessClient = emrServerlessClient; + return this; + } + + public SparkQueryDispatcherBuilder setDataSourceService(DataSourceService dataSourceService) { + this.dataSourceService = dataSourceService; + return this; + } + + public SparkQueryDispatcherBuilder setDataSourceUserAuthorizationHelper( + DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper) { + this.dataSourceUserAuthorizationHelper = dataSourceUserAuthorizationHelper; + return this; + } + + public SparkQueryDispatcherBuilder setJobExecutionResponseReader( + JobExecutionResponseReader jobExecutionResponseReader) { + this.jobExecutionResponseReader = jobExecutionResponseReader; + return this; + } + + public SparkQueryDispatcherBuilder setFlintIndexMetadataReader( + FlintIndexMetadataReader flintIndexMetadataReader) { + this.flintIndexMetadataReader = flintIndexMetadataReader; + return this; + } + + public SparkQueryDispatcherBuilder setClient(Client client) { + this.client = client; + return this; + } + + public SparkQueryDispatcherBuilder setSessionManager(SessionManager sessionManager) { + this.sessionManager = sessionManager; + return this; + } + + public SparkQueryDispatcherBuilder setLeaseManager(LeaseManager leaseManager) { + this.leaseManager = leaseManager; + return this; + } + + public SparkQueryDispatcherBuilder setStateStore(StateStore stateStore) { + this.stateStore = stateStore; + return this; + } + + public SparkQueryDispatcher build() { + return new SparkQueryDispatcher(this); + } + } + } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java index 232a280db5..de67689b60 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java @@ -7,12 +7,23 @@ package org.opensearch.sql.spark.transport; +import java.util.ArrayList; +import java.util.List; import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Guice; import org.opensearch.common.inject.Inject; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.Module; import org.opensearch.core.action.ActionListener; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -22,7 +33,7 @@ public class TransportCancelAsyncQueryRequestAction extends HandledTransportAction { public static final String NAME = "cluster:admin/opensearch/ql/async_query/delete"; - private final AsyncQueryExecutorServiceImpl asyncQueryExecutorService; + private final Injector injector; public static final ActionType ACTION_TYPE = new ActionType<>(NAME, CancelAsyncQueryActionResponse::new); @@ -30,9 +41,20 @@ public class TransportCancelAsyncQueryRequestAction public TransportCancelAsyncQueryRequestAction( TransportService transportService, ActionFilters actionFilters, - AsyncQueryExecutorServiceImpl asyncQueryExecutorService) { + NodeClient client, + ClusterService clusterService, + DataSourceServiceImpl dataSourceService) { super(NAME, transportService, actionFilters, CancelAsyncQueryActionRequest::new); - this.asyncQueryExecutorService = asyncQueryExecutorService; + List modules = new ArrayList<>(); + modules.add( + b -> { + b.bind(NodeClient.class).toInstance(client); + b.bind(org.opensearch.sql.common.setting.Settings.class) + .toInstance(new OpenSearchSettings(clusterService.getClusterSettings())); + b.bind(DataSourceService.class).toInstance(dataSourceService); + }); + modules.add(new AsyncExecutorServiceModule()); + this.injector = Guice.createInjector(); } @Override @@ -41,6 +63,8 @@ protected void doExecute( CancelAsyncQueryActionRequest request, ActionListener listener) { try { + AsyncQueryExecutorService asyncQueryExecutorService = + injector.getInstance(AsyncQueryExecutorService.class); String jobId = asyncQueryExecutorService.cancelQuery(request.getQueryId()); listener.onResponse( new CancelAsyncQueryActionResponse( diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java index 991eafdad9..a2872f297d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java @@ -7,16 +7,26 @@ package org.opensearch.sql.spark.transport; +import java.util.ArrayList; +import java.util.List; import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Guice; import org.opensearch.common.inject.Inject; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.Module; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -24,8 +34,7 @@ public class TransportCreateAsyncQueryRequestAction extends HandledTransportAction { - - private final AsyncQueryExecutorService asyncQueryExecutorService; + private final Injector injector; public static final String NAME = "cluster:admin/opensearch/ql/async_query/create"; public static final ActionType ACTION_TYPE = @@ -35,9 +44,21 @@ public class TransportCreateAsyncQueryRequestAction public TransportCreateAsyncQueryRequestAction( TransportService transportService, ActionFilters actionFilters, - AsyncQueryExecutorServiceImpl jobManagementService) { + NodeClient client, + ClusterService clusterService, + DataSourceServiceImpl dataSourceService) { super(NAME, transportService, actionFilters, CreateAsyncQueryActionRequest::new); - this.asyncQueryExecutorService = jobManagementService; + List moduleList = new ArrayList<>(); + moduleList.add(new AsyncExecutorServiceModule()); + moduleList.add( + b -> { + b.bind(NodeClient.class).toInstance(client); + b.bind(org.opensearch.sql.common.setting.Settings.class) + .toInstance(new OpenSearchSettings(clusterService.getClusterSettings())); + b.bind(DataSourceService.class).toInstance(dataSourceService); + b.bind(ClusterService.class).toInstance(clusterService); + }); + this.injector = Guice.createInjector(moduleList); } @Override @@ -47,6 +68,8 @@ protected void doExecute( ActionListener listener) { try { CreateAsyncQueryRequest createAsyncQueryRequest = request.getCreateAsyncQueryRequest(); + AsyncQueryExecutorService asyncQueryExecutorService = + injector.getInstance(AsyncQueryExecutorService.class); CreateAsyncQueryResponse createAsyncQueryResponse = asyncQueryExecutorService.createAsyncQuery(createAsyncQueryRequest); String responseContent = diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java index 5c784cf04c..ccb136c924 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java @@ -7,18 +7,28 @@ package org.opensearch.sql.spark.transport; +import java.util.ArrayList; +import java.util.List; import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Guice; import org.opensearch.common.inject.Inject; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.Module; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.executor.pagination.Cursor; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; import org.opensearch.sql.protocol.response.format.ResponseFormatter; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; +import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule; import org.opensearch.sql.spark.transport.format.AsyncQueryResultResponseFormatter; import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest; import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse; @@ -29,7 +39,7 @@ public class TransportGetAsyncQueryResultAction extends HandledTransportAction< GetAsyncQueryResultActionRequest, GetAsyncQueryResultActionResponse> { - private final AsyncQueryExecutorService asyncQueryExecutorService; + private final Injector injector; public static final String NAME = "cluster:admin/opensearch/ql/async_query/result"; public static final ActionType ACTION_TYPE = @@ -39,9 +49,20 @@ public class TransportGetAsyncQueryResultAction public TransportGetAsyncQueryResultAction( TransportService transportService, ActionFilters actionFilters, - AsyncQueryExecutorServiceImpl jobManagementService) { + NodeClient client, + ClusterService clusterService, + DataSourceServiceImpl dataSourceService) { super(NAME, transportService, actionFilters, GetAsyncQueryResultActionRequest::new); - this.asyncQueryExecutorService = jobManagementService; + List modules = new ArrayList<>(); + modules.add( + b -> { + b.bind(NodeClient.class).toInstance(client); + b.bind(org.opensearch.sql.common.setting.Settings.class) + .toInstance(new OpenSearchSettings(clusterService.getClusterSettings())); + b.bind(DataSourceService.class).toInstance(dataSourceService); + }); + modules.add(new AsyncExecutorServiceModule()); + this.injector = Guice.createInjector(); } @Override @@ -51,6 +72,8 @@ protected void doExecute( ActionListener listener) { try { String jobId = request.getQueryId(); + AsyncQueryExecutorService asyncQueryExecutorService = + injector.getInstance(AsyncQueryExecutorService.class); AsyncQueryExecutionResponse asyncQueryExecutionResponse = asyncQueryExecutorService.getAsyncQueryResults(jobId); ResponseFormatter formatter = diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java new file mode 100644 index 0000000000..9a8bec43f9 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -0,0 +1,163 @@ +package org.opensearch.sql.spark.transport.config; + +import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE; + +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder; +import java.security.AccessController; +import java.security.PrivilegedAction; +import lombok.RequiredArgsConstructor; +import org.apache.commons.math3.ml.clustering.Cluster; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.AbstractModule; +import org.opensearch.common.inject.Provides; +import org.opensearch.common.inject.Singleton; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.legacy.metrics.GaugeMetric; +import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService; +import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EmrServerlessClientImpl; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; +import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@RequiredArgsConstructor +public class AsyncExecutorServiceModule extends AbstractModule { + + private static final Logger LOG = LogManager.getLogger(AsyncExecutorServiceModule.class); + + @Override + protected void configure() {} + + @Provides + public AsyncQueryExecutorService asyncQueryExecutorService( + AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService, + SparkQueryDispatcher sparkQueryDispatcher, + SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) { + return new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, + sparkQueryDispatcher, + sparkExecutionEngineConfigSupplier); + } + + @Provides + public AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService( + StateStore stateStore) { + return new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + } + + @Provides + @Singleton + public StateStore stateStore(NodeClient client, ClusterService clusterService) { + StateStore stateStore = new StateStore(client, clusterService); + registerStateStoreMetrics(stateStore); + return stateStore; + } + + @Provides + public SparkQueryDispatcher sparkQueryDispatcher( + EMRServerlessClient emrServerlessClient, + DataSourceService dataSourceService, + DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper, + JobExecutionResponseReader jobExecutionResponseReader, + FlintIndexMetadataReaderImpl flintIndexMetadataReader, + NodeClient client, + SessionManager sessionManager, + DefaultLeaseManager defaultLeaseManager, + StateStore stateStore, + ClusterService clusterService) { + return new SparkQueryDispatcher( + emrServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader, + flintIndexMetadataReader, + client, + sessionManager, + defaultLeaseManager, + stateStore); + } + + @Provides + public SessionManager sessionManager( + StateStore stateStore, EMRServerlessClient emrServerlessClient, Settings settings) { + return new SessionManager(stateStore, emrServerlessClient, settings); + } + + @Provides + public DefaultLeaseManager defaultLeaseManager(Settings settings, StateStore stateStore) { + return new DefaultLeaseManager(settings, stateStore); + } + + @Provides + public EMRServerlessClient createEMRServerlessClient( + SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) { + SparkExecutionEngineConfig sparkExecutionEngineConfig = + sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); + if (sparkExecutionEngineConfig.getRegion() != null) { + return AccessController.doPrivileged( + (PrivilegedAction) + () -> { + AWSEMRServerless awsemrServerless = + AWSEMRServerlessClientBuilder.standard() + .withRegion(sparkExecutionEngineConfig.getRegion()) + .withCredentials(new DefaultAWSCredentialsProviderChain()) + .build(); + return new EmrServerlessClientImpl(awsemrServerless); + }); + } else { + return null; + } + } + + @Provides + public SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier(Settings settings) { + return new SparkExecutionEngineConfigSupplierImpl(settings); + } + + @Provides + @Singleton + public FlintIndexMetadataReaderImpl flintIndexMetadataReader(NodeClient client) { + return new FlintIndexMetadataReaderImpl(client); + } + + @Provides + public JobExecutionResponseReader jobExecutionResponseReader(NodeClient client) { + return new JobExecutionResponseReader(client); + } + + @Provides + public DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper( + NodeClient client) { + return new DataSourceUserAuthorizationHelperImpl(client); + } + + private void registerStateStoreMetrics(StateStore stateStore) { + GaugeMetric activeSessionMetric = + new GaugeMetric<>( + "active_async_query_sessions_count", + StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE)); + GaugeMetric activeStatementMetric = + new GaugeMetric<>( + "active_async_query_statements_count", + StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE)); + Metrics.getInstance().registerMetric(activeSessionMetric); + Metrics.getInstance().registerMetric(activeStatementMetric); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java index 2ff76b9b57..50c3206b1f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java @@ -22,7 +22,10 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; @@ -43,12 +46,19 @@ public class TransportCancelAsyncQueryRequestActionTest { private ArgumentCaptor deleteJobActionResponseArgumentCaptor; @Captor private ArgumentCaptor exceptionArgumentCaptor; + @Mock private NodeClient client; + @Mock private ClusterService clusterService; + @Mock private DataSourceServiceImpl dataSourceService; @BeforeEach public void setUp() { action = new TransportCancelAsyncQueryRequestAction( - transportService, new ActionFilters(new HashSet<>()), asyncQueryExecutorService); + transportService, + new ActionFilters(new HashSet<>()), + client, + clusterService, + dataSourceService); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 36060d3850..0fad5d8bde 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -24,7 +24,10 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; @@ -42,6 +45,9 @@ public class TransportCreateAsyncQueryRequestActionTest { @Mock private AsyncQueryExecutorServiceImpl jobExecutorService; @Mock private Task task; @Mock private ActionListener actionListener; + @Mock private NodeClient client; + @Mock private ClusterService clusterService; + @Mock private DataSourceServiceImpl dataSourceService; @Captor private ArgumentCaptor createJobActionResponseArgumentCaptor; @@ -52,7 +58,11 @@ public class TransportCreateAsyncQueryRequestActionTest { public void setUp() { action = new TransportCreateAsyncQueryRequestAction( - transportService, new ActionFilters(new HashSet<>()), jobExecutorService); + transportService, + new ActionFilters(new HashSet<>()), + client, + clusterService, + dataSourceService); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java index 34f10b0083..ce9859d326 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java @@ -28,7 +28,10 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; @@ -51,12 +54,19 @@ public class TransportGetAsyncQueryResultActionTest { private ArgumentCaptor createJobActionResponseArgumentCaptor; @Captor private ArgumentCaptor exceptionArgumentCaptor; + @Mock private NodeClient client; + @Mock private ClusterService clusterService; + @Mock private DataSourceServiceImpl dataSourceService; @BeforeEach public void setUp() { action = new TransportGetAsyncQueryResultAction( - transportService, new ActionFilters(new HashSet<>()), jobExecutorService); + transportService, + new ActionFilters(new HashSet<>()), + client, + clusterService, + dataSourceService); } @Test