diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f0689a0966..b6317c7ae4 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -5,17 +5,10 @@ package org.opensearch.sql.plugin; -import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; -import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.services.emrserverless.AWSEMRServerless; -import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import java.security.AccessController; -import java.security.PrivilegedAction; import java.time.Clock; import java.util.Arrays; import java.util.Collection; @@ -68,7 +61,6 @@ import org.opensearch.sql.datasources.transport.*; import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.executor.AsyncRestExecutor; -import org.opensearch.sql.legacy.metrics.GaugeMetric; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.legacy.plugin.RestSqlAction; import org.opensearch.sql.legacy.plugin.RestSqlStatsAction; @@ -86,22 +78,7 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; -import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService; -import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService; -import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.client.EmrServerlessClientImpl; import org.opensearch.sql.spark.cluster.ClusterManagerEventListener; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; -import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; -import org.opensearch.sql.spark.execution.session.SessionManager; -import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; -import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; -import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; @@ -127,9 +104,7 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin { private NodeClient client; private DataSourceServiceImpl dataSourceService; - private AsyncQueryExecutorService asyncQueryExecutorService; private Injector injector; - public String name() { return "sql"; } @@ -192,15 +167,7 @@ public List getRestHandlers( new ActionHandler<>( new ActionType<>( TransportCreateAsyncQueryRequestAction.NAME, CreateAsyncQueryActionResponse::new), - TransportCreateAsyncQueryRequestAction.class), - new ActionHandler<>( - new ActionType<>( - TransportGetAsyncQueryResultAction.NAME, GetAsyncQueryResultActionResponse::new), - TransportGetAsyncQueryResultAction.class), - new ActionHandler<>( - new ActionType<>( - TransportCancelAsyncQueryRequestAction.NAME, CancelAsyncQueryActionResponse::new), - TransportCancelAsyncQueryRequestAction.class)); + TransportCreateAsyncQueryRequestAction.class)); } @Override @@ -223,23 +190,6 @@ public Collection createComponents( dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); LocalClusterState.state().setClusterService(clusterService); LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); - SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier = - new SparkExecutionEngineConfigSupplierImpl(pluginSettings); - SparkExecutionEngineConfig sparkExecutionEngineConfig = - sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); - if (StringUtils.isEmpty(sparkExecutionEngineConfig.getRegion())) { - LOGGER.warn( - String.format( - "Async Query APIs are disabled as %s is not configured properly in cluster settings. " - + "Please configure and restart the domain to enable Async Query APIs", - SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue())); - this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl(); - } else { - this.asyncQueryExecutorService = - createAsyncQueryExecutorService( - sparkExecutionEngineConfigSupplier, sparkExecutionEngineConfig); - } - ModulesBuilder modules = new ModulesBuilder(); modules.add(new OpenSearchPluginModule()); modules.add( @@ -261,7 +211,7 @@ public Collection createComponents( OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING, environment.settings()); return ImmutableList.of( - dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings); + dataSourceService, clusterManagerEventListener, pluginSettings); } @Override @@ -319,56 +269,4 @@ private DataSourceServiceImpl createDataSourceService() { dataSourceUserAuthorizationHelper); } - private AsyncQueryExecutorService createAsyncQueryExecutorService( - SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier, - SparkExecutionEngineConfig sparkExecutionEngineConfig) { - StateStore stateStore = new StateStore(client, clusterService); - registerStateStoreMetrics(stateStore); - AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(stateStore); - EMRServerlessClient emrServerlessClient = - createEMRServerlessClient(sparkExecutionEngineConfig.getRegion()); - JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - this.dataSourceService, - new DataSourceUserAuthorizationHelperImpl(client), - jobExecutionResponseReader, - new FlintIndexMetadataReaderImpl(client), - client, - new SessionManager(stateStore, emrServerlessClient, pluginSettings), - new DefaultLeaseManager(pluginSettings, stateStore), - stateStore); - return new AsyncQueryExecutorServiceImpl( - asyncQueryJobMetadataStorageService, - sparkQueryDispatcher, - sparkExecutionEngineConfigSupplier); - } - - private void registerStateStoreMetrics(StateStore stateStore) { - GaugeMetric activeSessionMetric = - new GaugeMetric<>( - "active_async_query_sessions_count", - StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE)); - GaugeMetric activeStatementMetric = - new GaugeMetric<>( - "active_async_query_statements_count", - StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE)); - Metrics.getInstance().registerMetric(activeSessionMetric); - Metrics.getInstance().registerMetric(activeStatementMetric); - } - - private EMRServerlessClient createEMRServerlessClient(String region) { - return AccessController.doPrivileged( - (PrivilegedAction) - () -> { - AWSEMRServerless awsemrServerless = - AWSEMRServerlessClientBuilder.standard() - .withRegion(region) - .withCredentials(new DefaultAWSCredentialsProviderChain()) - .build(); - return new EmrServerlessClientImpl(awsemrServerless); - }); - } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java index 232a280db5..5cf78cd130 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java @@ -10,9 +10,18 @@ import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.ModulesBuilder; import org.opensearch.core.action.ActionListener; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.opensearch.security.SecurityAccess; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -22,7 +31,7 @@ public class TransportCancelAsyncQueryRequestAction extends HandledTransportAction { public static final String NAME = "cluster:admin/opensearch/ql/async_query/delete"; - private final AsyncQueryExecutorServiceImpl asyncQueryExecutorService; + private final Injector injector; public static final ActionType ACTION_TYPE = new ActionType<>(NAME, CancelAsyncQueryActionResponse::new); @@ -30,9 +39,20 @@ public class TransportCancelAsyncQueryRequestAction public TransportCancelAsyncQueryRequestAction( TransportService transportService, ActionFilters actionFilters, - AsyncQueryExecutorServiceImpl asyncQueryExecutorService) { + NodeClient client, + ClusterService clusterService, + DataSourceServiceImpl dataSourceService) { super(NAME, transportService, actionFilters, CancelAsyncQueryActionRequest::new); - this.asyncQueryExecutorService = asyncQueryExecutorService; + ModulesBuilder modules = new ModulesBuilder(); + modules.add( + b -> { + b.bind(NodeClient.class).toInstance(client); + b.bind(org.opensearch.sql.common.setting.Settings.class) + .toInstance(new OpenSearchSettings(clusterService.getClusterSettings())); + b.bind(DataSourceService.class).toInstance(dataSourceService); + }); + modules.add(new AsyncExecutorServiceModule()); + this.injector = modules.createInjector(); } @Override @@ -41,6 +61,8 @@ protected void doExecute( CancelAsyncQueryActionRequest request, ActionListener listener) { try { + AsyncQueryExecutorService asyncQueryExecutorService = + SecurityAccess.doPrivileged(() -> injector.getInstance(AsyncQueryExecutorService.class)); String jobId = asyncQueryExecutorService.cancelQuery(request.getQueryId()); listener.onResponse( new CancelAsyncQueryActionResponse( diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java index 991eafdad9..0fc2e6895f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java @@ -7,16 +7,27 @@ package org.opensearch.sql.spark.transport; +import java.util.ArrayList; +import java.util.List; import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Guice; import org.opensearch.common.inject.Inject; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.Module; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.opensearch.security.SecurityAccess; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -24,8 +35,7 @@ public class TransportCreateAsyncQueryRequestAction extends HandledTransportAction { - - private final AsyncQueryExecutorService asyncQueryExecutorService; + private final Injector injector; public static final String NAME = "cluster:admin/opensearch/ql/async_query/create"; public static final ActionType ACTION_TYPE = @@ -35,9 +45,21 @@ public class TransportCreateAsyncQueryRequestAction public TransportCreateAsyncQueryRequestAction( TransportService transportService, ActionFilters actionFilters, - AsyncQueryExecutorServiceImpl jobManagementService) { + NodeClient client, + ClusterService clusterService, + DataSourceServiceImpl dataSourceService) { super(NAME, transportService, actionFilters, CreateAsyncQueryActionRequest::new); - this.asyncQueryExecutorService = jobManagementService; + List moduleList = new ArrayList<>(); + moduleList.add(new AsyncExecutorServiceModule()); + moduleList.add( + b -> { + b.bind(NodeClient.class).toInstance(client); + b.bind(org.opensearch.sql.common.setting.Settings.class) + .toInstance(new OpenSearchSettings(clusterService.getClusterSettings())); + b.bind(DataSourceService.class).toInstance(dataSourceService); + b.bind(ClusterService.class).toInstance(clusterService); + }); + this.injector = Guice.createInjector(moduleList); } @Override @@ -47,6 +69,8 @@ protected void doExecute( ActionListener listener) { try { CreateAsyncQueryRequest createAsyncQueryRequest = request.getCreateAsyncQueryRequest(); + AsyncQueryExecutorService asyncQueryExecutorService = + SecurityAccess.doPrivileged(() -> injector.getInstance(AsyncQueryExecutorService.class)); CreateAsyncQueryResponse createAsyncQueryResponse = asyncQueryExecutorService.createAsyncQuery(createAsyncQueryRequest); String responseContent = diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java index 5c784cf04c..4db710728d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java @@ -10,15 +10,23 @@ import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.ModulesBuilder; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.executor.pagination.Cursor; +import org.opensearch.sql.opensearch.security.SecurityAccess; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; import org.opensearch.sql.protocol.response.format.ResponseFormatter; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; +import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule; import org.opensearch.sql.spark.transport.format.AsyncQueryResultResponseFormatter; import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest; import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse; @@ -29,7 +37,7 @@ public class TransportGetAsyncQueryResultAction extends HandledTransportAction< GetAsyncQueryResultActionRequest, GetAsyncQueryResultActionResponse> { - private final AsyncQueryExecutorService asyncQueryExecutorService; + private final Injector injector; public static final String NAME = "cluster:admin/opensearch/ql/async_query/result"; public static final ActionType ACTION_TYPE = @@ -39,9 +47,20 @@ public class TransportGetAsyncQueryResultAction public TransportGetAsyncQueryResultAction( TransportService transportService, ActionFilters actionFilters, - AsyncQueryExecutorServiceImpl jobManagementService) { + NodeClient client, + ClusterService clusterService, + DataSourceServiceImpl dataSourceService) { super(NAME, transportService, actionFilters, GetAsyncQueryResultActionRequest::new); - this.asyncQueryExecutorService = jobManagementService; + ModulesBuilder modules = new ModulesBuilder(); + modules.add( + b -> { + b.bind(NodeClient.class).toInstance(client); + b.bind(org.opensearch.sql.common.setting.Settings.class) + .toInstance(new OpenSearchSettings(clusterService.getClusterSettings())); + b.bind(DataSourceService.class).toInstance(dataSourceService); + }); + modules.add(new AsyncExecutorServiceModule()); + this.injector = modules.createInjector(); } @Override @@ -51,6 +70,8 @@ protected void doExecute( ActionListener listener) { try { String jobId = request.getQueryId(); + AsyncQueryExecutorService asyncQueryExecutorService = + SecurityAccess.doPrivileged(() -> injector.getInstance(AsyncQueryExecutorService.class)); AsyncQueryExecutionResponse asyncQueryExecutionResponse = asyncQueryExecutorService.getAsyncQueryResults(jobId); ResponseFormatter formatter = diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java new file mode 100644 index 0000000000..f57255f8b3 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -0,0 +1,175 @@ +package org.opensearch.sql.spark.transport.config; + +import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE; + +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder; +import java.security.AccessController; +import java.security.PrivilegedAction; +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.AbstractModule; +import org.opensearch.common.inject.Provides; +import org.opensearch.common.inject.Singleton; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.legacy.metrics.GaugeMetric; +import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService; +import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EmrServerlessClientImpl; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; +import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; +import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; + +@RequiredArgsConstructor +public class AsyncExecutorServiceModule extends AbstractModule { + + private static final Logger LOG = LogManager.getLogger(AsyncExecutorServiceModule.class); + + @Override + protected void configure() { + } + + + @Provides + public AsyncQueryExecutorService asyncQueryExecutorService(AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService, + SparkQueryDispatcher sparkQueryDispatcher, + SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) { + LOG.error("asyncQueryExecutorService"); + return new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, + sparkQueryDispatcher, + sparkExecutionEngineConfigSupplier); + } + + + @Provides + public AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService(StateStore stateStore) { + LOG.error("asyncQueryJobMetadataStorageService"); + return new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + } + + @Provides + public StateStore stateStore(NodeClient client, ClusterService clusterService) { + StateStore stateStore = new StateStore(client, clusterService); + registerStateStoreMetrics(stateStore); + LOG.error("stateStore"); + return stateStore; + } + + @Provides + public SparkQueryDispatcher sparkQueryDispatcher(EMRServerlessClient emrServerlessClient, + DataSourceService dataSourceService, + DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper, + JobExecutionResponseReader jobExecutionResponseReader, + FlintIndexMetadataReaderImpl flintIndexMetadataReader, + NodeClient client, + SessionManager sessionManager, + DefaultLeaseManager defaultLeaseManager, + StateStore stateStore + ) { + LOG.error("SparkQueryDispatcher"); + return new SparkQueryDispatcher( + emrServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader, + flintIndexMetadataReader, + client, + sessionManager, + defaultLeaseManager, + stateStore); + } + + @Provides + public SessionManager sessionManager(StateStore stateStore, + EMRServerlessClient emrServerlessClient, + Settings settings) { + LOG.error("sessionManager"); + return new SessionManager(stateStore, emrServerlessClient, settings); + } + + @Provides + public DefaultLeaseManager defaultLeaseManager(Settings settings, StateStore stateStore){ + LOG.error("defaultLeaseManager"); + return new DefaultLeaseManager(settings, stateStore); + } + + @Provides + public EMRServerlessClient createEMRServerlessClient(SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) { + LOG.error("createEMRServerlessClient"); + SparkExecutionEngineConfig sparkExecutionEngineConfig = + sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); + if(sparkExecutionEngineConfig.getRegion() != null) { + return AccessController.doPrivileged( + (PrivilegedAction) + () -> { + AWSEMRServerless awsemrServerless = + AWSEMRServerlessClientBuilder.standard() + .withRegion(sparkExecutionEngineConfig.getRegion()) + .withCredentials(new DefaultAWSCredentialsProviderChain()) + .build(); + return new EmrServerlessClientImpl(awsemrServerless); + }); + } else { + return null; + } + } + + @Provides + public SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier(Settings settings) { + LOG.error("sparkExecutionEngineConfigSupplier"); + return new SparkExecutionEngineConfigSupplierImpl(settings); + } + + @Provides + @Singleton + public FlintIndexMetadataReaderImpl flintIndexMetadataReader(NodeClient client) { + LOG.error("flintIndexMetadataReader"); + return new FlintIndexMetadataReaderImpl(client); + } + + @Provides + public JobExecutionResponseReader jobExecutionResponseReader(NodeClient client) { + LOG.error("jobExecutionResponseReader"); + return new JobExecutionResponseReader(client); + } + + @Provides + public DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper(NodeClient client) { + LOG.error("dataSourceUserAuthorizationHelper"); + return new DataSourceUserAuthorizationHelperImpl(client); + } + + + private void registerStateStoreMetrics(StateStore stateStore) { + GaugeMetric activeSessionMetric = + new GaugeMetric<>( + "active_async_query_sessions_count", + StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE)); + GaugeMetric activeStatementMetric = + new GaugeMetric<>( + "active_async_query_statements_count", + StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE)); + Metrics.getInstance().registerMetric(activeSessionMetric); + Metrics.getInstance().registerMetric(activeStatementMetric); + } + + +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java index 2ff76b9b57..6c79a27cd2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java @@ -22,7 +22,10 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; @@ -43,12 +46,16 @@ public class TransportCancelAsyncQueryRequestActionTest { private ArgumentCaptor deleteJobActionResponseArgumentCaptor; @Captor private ArgumentCaptor exceptionArgumentCaptor; + @Mock private NodeClient client; + @Mock private ClusterService clusterService; + @Mock private DataSourceServiceImpl dataSourceService; + @BeforeEach public void setUp() { action = new TransportCancelAsyncQueryRequestAction( - transportService, new ActionFilters(new HashSet<>()), asyncQueryExecutorService); + transportService, new ActionFilters(new HashSet<>()), client, clusterService, dataSourceService); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 36060d3850..eb26beef91 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -24,7 +24,11 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Injector; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; @@ -42,7 +46,9 @@ public class TransportCreateAsyncQueryRequestActionTest { @Mock private AsyncQueryExecutorServiceImpl jobExecutorService; @Mock private Task task; @Mock private ActionListener actionListener; - + @Mock private NodeClient client; + @Mock private ClusterService clusterService; + @Mock private DataSourceServiceImpl dataSourceService; @Captor private ArgumentCaptor createJobActionResponseArgumentCaptor; @@ -52,7 +58,7 @@ public class TransportCreateAsyncQueryRequestActionTest { public void setUp() { action = new TransportCreateAsyncQueryRequestAction( - transportService, new ActionFilters(new HashSet<>()), jobExecutorService); + transportService, new ActionFilters(new HashSet<>()), client, clusterService, dataSourceService); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java index 34f10b0083..82ee292945 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java @@ -28,7 +28,10 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; @@ -51,12 +54,16 @@ public class TransportGetAsyncQueryResultActionTest { private ArgumentCaptor createJobActionResponseArgumentCaptor; @Captor private ArgumentCaptor exceptionArgumentCaptor; + @Mock private NodeClient client; + @Mock private ClusterService clusterService; + @Mock private DataSourceServiceImpl dataSourceService; + @BeforeEach public void setUp() { action = new TransportGetAsyncQueryResultAction( - transportService, new ActionFilters(new HashSet<>()), jobExecutorService); + transportService, new ActionFilters(new HashSet<>()), client, clusterService, dataSourceService); } @Test